naga/back/spv/
block.rs

1/*!
2Implementations for `BlockContext` methods.
3*/
4
5use alloc::vec::Vec;
6
7use arrayvec::ArrayVec;
8use spirv::Word;
9
10use super::{
11    helpers::map_storage_class, index::BoundsCheckResult, selection::Selection, Block,
12    BlockContext, Dimension, Error, IdGenerator, Instruction, LocalType, LookupType, NumericType,
13    ResultMember, WrappedFunction, Writer, WriterFlags,
14};
15use crate::{
16    arena::Handle, back::spv::helpers::is_uniform_matcx2_struct_member_access,
17    proc::index::GuardedIndex, Statement,
18};
19
20fn get_dimension(type_inner: &crate::TypeInner) -> Dimension {
21    match *type_inner {
22        crate::TypeInner::Scalar(_) => Dimension::Scalar,
23        crate::TypeInner::Vector { .. } => Dimension::Vector,
24        crate::TypeInner::Matrix { .. } => Dimension::Matrix,
25        crate::TypeInner::CooperativeMatrix { .. } => Dimension::CooperativeMatrix,
26        _ => unreachable!(),
27    }
28}
29
30/// How to derive the type of `OpAccessChain` instructions from Naga IR.
31///
32/// Most of the time, we compile Naga IR to SPIR-V instructions whose result
33/// types are simply the direct SPIR-V analog of the Naga IR's. But in some
34/// cases, the Naga IR and SPIR-V types need to diverge.
35///
36/// This enum specifies how [`BlockContext::write_access_chain`] should
37/// choose a SPIR-V result type for the `OpAccessChain` it generates, based on
38/// the type of the given Naga IR [`Expression`] it's generating code for.
39///
40/// [`Expression`]: crate::Expression
41#[derive(Copy, Clone)]
42enum AccessTypeAdjustment {
43    /// No adjustment needed: the SPIR-V type should be the direct
44    /// analog of the Naga IR expression type.
45    ///
46    /// For most access chains, this is the right thing: the Naga IR access
47    /// expression produces a [`Pointer`] to the element / component, and the
48    /// SPIR-V `OpAccessChain` instruction does the same.
49    ///
50    /// [`Pointer`]: crate::TypeInner::Pointer
51    None,
52
53    /// The SPIR-V type should be an `OpPointer` to the direct analog of the
54    /// Naga IR expression's type.
55    ///
56    /// This is necessary for indexing binding arrays in the [`Handle`] address
57    /// space:
58    ///
59    /// - In Naga IR, referencing a binding array [`GlobalVariable`] in the
60    ///   [`Handle`] address space produces a value of type [`BindingArray`],
61    ///   not a pointer to such. And [`Access`] and [`AccessIndex`] expressions
62    ///   operate on handle binding arrays by value, and produce handle values,
63    ///   not pointers.
64    ///
65    /// - In SPIR-V, a binding array `OpVariable` produces a pointer to an
66    ///   array, and `OpAccessChain` instructions operate on pointers,
67    ///   regardless of whether the elements are opaque types or not.
68    ///
69    /// See also the documentation for [`BindingArray`].
70    ///
71    /// [`Handle`]: crate::AddressSpace::Handle
72    /// [`GlobalVariable`]: crate::GlobalVariable
73    /// [`BindingArray`]: crate::TypeInner::BindingArray
74    /// [`Access`]: crate::Expression::Access
75    /// [`AccessIndex`]: crate::Expression::AccessIndex
76    IntroducePointer(spirv::StorageClass),
77
78    /// The SPIR-V type should be an `OpPointer` to the std140 layout
79    /// compatible variant of the Naga IR expression's base type.
80    ///
81    /// This is used when accessing a type through an [`AddressSpace::Uniform`]
82    /// pointer in cases where the original type is incompatible with std140
83    /// layout requirements and we have therefore declared the uniform to be of
84    /// an alternative std140 compliant type.
85    ///
86    /// [`AddressSpace::Uniform`]: crate::AddressSpace::Uniform
87    UseStd140CompatType,
88}
89
90/// The results of emitting code for a left-hand-side expression.
91///
92/// On success, `write_access_chain` returns one of these.
93enum ExpressionPointer {
94    /// The pointer to the expression's value is available, as the value of the
95    /// expression with the given id.
96    Ready { pointer_id: Word },
97
98    /// The access expression must be conditional on the value of `condition`, a boolean
99    /// expression that is true if all indices are in bounds. If `condition` is true, then
100    /// `access` is an `OpAccessChain` instruction that will compute a pointer to the
101    /// expression's value. If `condition` is false, then executing `access` would be
102    /// undefined behavior.
103    Conditional {
104        condition: Word,
105        access: Instruction,
106    },
107}
108
109/// The termination statement to be added to the end of the block
110enum BlockExit {
111    /// Generates an OpReturn (void return)
112    Return,
113    /// Generates an OpBranch to the specified block
114    Branch {
115        /// The branch target block
116        target: Word,
117    },
118    /// Translates a loop `break if` into an `OpBranchConditional` to the
119    /// merge block if true (the merge block is passed through [`LoopContext::break_id`]
120    /// or else to the loop header (passed through [`preamble_id`])
121    ///
122    /// [`preamble_id`]: Self::BreakIf::preamble_id
123    BreakIf {
124        /// The condition of the `break if`
125        condition: Handle<crate::Expression>,
126        /// The loop header block id
127        preamble_id: Word,
128    },
129}
130
131/// What code generation did with a provided [`BlockExit`] value.
132///
133/// A function that accepts a [`BlockExit`] argument should return a value of
134/// this type, to indicate whether the code it generated ended up using the
135/// provided exit, or ignored it and did a non-local exit of some other kind
136/// (say, [`Break`] or [`Continue`]). Some callers must use this information to
137/// decide whether to generate the target block at all.
138///
139/// [`Break`]: Statement::Break
140/// [`Continue`]: Statement::Continue
141#[must_use]
142enum BlockExitDisposition {
143    /// The generated code used the provided `BlockExit` value. If it included a
144    /// block label, the caller should be sure to actually emit the block it
145    /// refers to.
146    Used,
147
148    /// The generated code did not use the provided `BlockExit` value. If it
149    /// included a block label, the caller should not bother to actually emit
150    /// the block it refers to, unless it knows the block is needed for
151    /// something else.
152    Discarded,
153}
154
155#[derive(Clone, Copy, Default)]
156struct LoopContext {
157    continuing_id: Option<Word>,
158    break_id: Option<Word>,
159}
160
161#[derive(Debug)]
162pub(crate) struct DebugInfoInner<'a> {
163    pub source_code: &'a str,
164    pub source_file_id: Word,
165}
166
167impl Writer {
168    // Flip Y coordinate to adjust for coordinate space difference
169    // between SPIR-V and our IR.
170    // The `position_id` argument is a pointer to a `vecN<f32>`,
171    // whose `y` component we will negate.
172    fn write_epilogue_position_y_flip(
173        &mut self,
174        position_id: Word,
175        body: &mut Vec<Instruction>,
176    ) -> Result<(), Error> {
177        let float_ptr_type_id = self.get_f32_pointer_type_id(spirv::StorageClass::Output);
178        let index_y_id = self.get_index_constant(1);
179        let access_id = self.id_gen.next();
180        body.push(Instruction::access_chain(
181            float_ptr_type_id,
182            access_id,
183            position_id,
184            &[index_y_id],
185        ));
186
187        let float_type_id = self.get_f32_type_id();
188        let load_id = self.id_gen.next();
189        body.push(Instruction::load(float_type_id, load_id, access_id, None));
190
191        let neg_id = self.id_gen.next();
192        body.push(Instruction::unary(
193            spirv::Op::FNegate,
194            float_type_id,
195            neg_id,
196            load_id,
197        ));
198
199        body.push(Instruction::store(access_id, neg_id, None));
200        Ok(())
201    }
202
203    // Clamp fragment depth between 0 and 1.
204    fn write_epilogue_frag_depth_clamp(
205        &mut self,
206        frag_depth_id: Word,
207        body: &mut Vec<Instruction>,
208    ) -> Result<(), Error> {
209        let float_type_id = self.get_f32_type_id();
210        let zero_scalar_id = self.get_constant_scalar(crate::Literal::F32(0.0));
211        let one_scalar_id = self.get_constant_scalar(crate::Literal::F32(1.0));
212
213        let original_id = self.id_gen.next();
214        body.push(Instruction::load(
215            float_type_id,
216            original_id,
217            frag_depth_id,
218            None,
219        ));
220
221        let clamp_id = self.id_gen.next();
222        body.push(Instruction::ext_inst_gl_op(
223            self.gl450_ext_inst_id,
224            spirv::GlslStd450Op::FClamp,
225            float_type_id,
226            clamp_id,
227            &[original_id, zero_scalar_id, one_scalar_id],
228        ));
229
230        body.push(Instruction::store(frag_depth_id, clamp_id, None));
231        Ok(())
232    }
233
234    fn write_entry_point_return(
235        &mut self,
236        value_id: Word,
237        ir_result: &crate::FunctionResult,
238        result_members: &[ResultMember],
239        body: &mut Vec<Instruction>,
240    ) -> Result<Instruction, Error> {
241        for (index, res_member) in result_members.iter().enumerate() {
242            // This isn't a real builtin, and is handled elsewhere
243            if res_member.built_in == Some(crate::BuiltIn::MeshTaskSize) {
244                return Ok(Instruction::return_value(value_id));
245            }
246            let member_value_id = match ir_result.binding {
247                Some(_) => value_id,
248                None => {
249                    let member_value_id = self.id_gen.next();
250                    body.push(Instruction::composite_extract(
251                        res_member.type_id,
252                        member_value_id,
253                        value_id,
254                        &[index as u32],
255                    ));
256                    member_value_id
257                }
258            };
259
260            self.store_io_with_f16_polyfill(body, res_member.id, member_value_id);
261
262            match res_member.built_in {
263                Some(crate::BuiltIn::Position { .. })
264                    if self.flags.contains(WriterFlags::ADJUST_COORDINATE_SPACE) =>
265                {
266                    self.write_epilogue_position_y_flip(res_member.id, body)?;
267                }
268                Some(crate::BuiltIn::FragDepth)
269                    if self.flags.contains(WriterFlags::CLAMP_FRAG_DEPTH) =>
270                {
271                    self.write_epilogue_frag_depth_clamp(res_member.id, body)?;
272                }
273                _ => {}
274            }
275        }
276        Ok(Instruction::return_void())
277    }
278}
279
280impl BlockContext<'_> {
281    /// Generates code to ensure that a loop is bounded. Should be called immediately
282    /// after adding the OpLoopMerge instruction to `block`. This function will
283    /// [`consume()`](crate::back::spv::Function::consume) `block` and append its
284    /// instructions to a new [`Block`], which will be returned to the caller for it to
285    /// consumed prior to writing the loop body.
286    ///
287    /// Additionally this function will populate [`force_loop_bounding_vars`](crate::back::spv::Function::force_loop_bounding_vars),
288    /// ensuring that [`Function::to_words()`](crate::back::spv::Function::to_words) will
289    /// declare the required variables.
290    ///
291    /// See [`crate::back::msl::Writer::gen_force_bounded_loop_statements`] for details
292    /// of why this is required.
293    fn write_force_bounded_loop_instructions(&mut self, mut block: Block, merge_id: Word) -> Block {
294        let uint_type_id = self.writer.get_u32_type_id();
295        let uint2_type_id = self.writer.get_vec2u_type_id();
296        let uint2_ptr_type_id = self
297            .writer
298            .get_vec2u_pointer_type_id(spirv::StorageClass::Function);
299        let bool_type_id = self.writer.get_bool_type_id();
300        let bool2_type_id = self.writer.get_vec2_bool_type_id();
301        let zero_uint_const_id = self.writer.get_constant_scalar(crate::Literal::U32(0));
302        let zero_uint2_const_id = self.writer.get_constant_composite(
303            LookupType::Local(LocalType::Numeric(NumericType::Vector {
304                size: crate::VectorSize::Bi,
305                scalar: crate::Scalar::U32,
306            })),
307            &[zero_uint_const_id, zero_uint_const_id],
308        );
309        let one_uint_const_id = self.writer.get_constant_scalar(crate::Literal::U32(1));
310        let max_uint_const_id = self
311            .writer
312            .get_constant_scalar(crate::Literal::U32(u32::MAX));
313        let max_uint2_const_id = self.writer.get_constant_composite(
314            LookupType::Local(LocalType::Numeric(NumericType::Vector {
315                size: crate::VectorSize::Bi,
316                scalar: crate::Scalar::U32,
317            })),
318            &[max_uint_const_id, max_uint_const_id],
319        );
320
321        let loop_counter_var_id = self.gen_id();
322        if self.writer.flags.contains(WriterFlags::DEBUG) {
323            self.writer
324                .debugs
325                .push(Instruction::name(loop_counter_var_id, "loop_bound"));
326        }
327        let var = super::LocalVariable {
328            id: loop_counter_var_id,
329            instruction: Instruction::variable(
330                uint2_ptr_type_id,
331                loop_counter_var_id,
332                spirv::StorageClass::Function,
333                Some(max_uint2_const_id),
334            ),
335        };
336        self.function.force_loop_bounding_vars.push(var);
337
338        let break_if_block = self.gen_id();
339
340        self.function
341            .consume(block, Instruction::branch(break_if_block));
342        block = Block::new(break_if_block);
343
344        // Load the current loop counter value from its variable. We use a vec2<u32> to
345        // simulate a 64-bit counter.
346        let load_id = self.gen_id();
347        block.body.push(Instruction::load(
348            uint2_type_id,
349            load_id,
350            loop_counter_var_id,
351            None,
352        ));
353
354        // If both the high and low u32s have reached 0 then break. ie
355        // if (all(eq(loop_counter, vec2(0)))) { break; }
356        let eq_id = self.gen_id();
357        block.body.push(Instruction::binary(
358            spirv::Op::IEqual,
359            bool2_type_id,
360            eq_id,
361            zero_uint2_const_id,
362            load_id,
363        ));
364        let all_eq_id = self.gen_id();
365        block.body.push(Instruction::relational(
366            spirv::Op::All,
367            bool_type_id,
368            all_eq_id,
369            eq_id,
370        ));
371
372        let inc_counter_block_id = self.gen_id();
373        block.body.push(Instruction::selection_merge(
374            inc_counter_block_id,
375            spirv::SelectionControl::empty(),
376        ));
377        self.function.consume(
378            block,
379            Instruction::branch_conditional(all_eq_id, merge_id, inc_counter_block_id),
380        );
381        block = Block::new(inc_counter_block_id);
382
383        // To simulate a 64-bit counter we always decrement the low u32, and decrement
384        // the high u32 when the low u32 overflows. ie
385        // counter -= vec2(select(0u, 1u, counter.y == 0), 1u);
386        // Count down from u32::MAX rather than up from 0 to avoid hang on
387        // certain Intel drivers. See <https://github.com/gfx-rs/wgpu/issues/7319>.
388        let low_id = self.gen_id();
389        block.body.push(Instruction::composite_extract(
390            uint_type_id,
391            low_id,
392            load_id,
393            &[1],
394        ));
395        let low_overflow_id = self.gen_id();
396        block.body.push(Instruction::binary(
397            spirv::Op::IEqual,
398            bool_type_id,
399            low_overflow_id,
400            low_id,
401            zero_uint_const_id,
402        ));
403        let carry_bit_id = self.gen_id();
404        block.body.push(Instruction::select(
405            uint_type_id,
406            carry_bit_id,
407            low_overflow_id,
408            one_uint_const_id,
409            zero_uint_const_id,
410        ));
411        let decrement_id = self.gen_id();
412        block.body.push(Instruction::composite_construct(
413            uint2_type_id,
414            decrement_id,
415            &[carry_bit_id, one_uint_const_id],
416        ));
417        let result_id = self.gen_id();
418        block.body.push(Instruction::binary(
419            spirv::Op::ISub,
420            uint2_type_id,
421            result_id,
422            load_id,
423            decrement_id,
424        ));
425        block
426            .body
427            .push(Instruction::store(loop_counter_var_id, result_id, None));
428
429        block
430    }
431
432    /// If `pointer` refers to an access chain that contains a dynamic indexing
433    /// of a two-row matrix in the [`Uniform`] address space, write code to
434    /// access the value returning the ID of the result. Else return None.
435    ///
436    /// Two-row matrices in the uniform address space will have been declared
437    /// using a alternative std140 layout compatible type, where each column is
438    /// a member of a containing struct. As a result, SPIR-V is unable to access
439    /// its columns with a non-constant index. To work around this limitation
440    /// this function will call [`Self::write_checked_load()`] to load the
441    /// matrix itself, which handles conversion from the std140 compatible type
442    /// to the real matrix type. It then calls a [`wrapper function`] to obtain
443    /// the correct column from the matrix, and possibly extracts a component
444    /// from the vector too.
445    ///
446    /// [`Uniform`]: crate::AddressSpace::Uniform
447    /// [`wrapper function`]: super::Writer::write_wrapped_matcx2_get_column
448    fn maybe_write_uniform_matcx2_dynamic_access(
449        &mut self,
450        pointer: Handle<crate::Expression>,
451        block: &mut Block,
452    ) -> Result<Option<Word>, Error> {
453        // If this access chain contains a dynamic matrix access, `pointer` is
454        // either a pointer to a vector (the column) or a scalar (a component
455        // within the column). In either case grab the pointer to the column,
456        // and remember the component index if there is one. If `pointer`
457        // points to any other type we're not interested.
458        let (column_pointer, component_index) = match self.fun_info[pointer]
459            .ty
460            .inner_with(&self.ir_module.types)
461            .pointer_base_type()
462        {
463            Some(resolution) => match *resolution.inner_with(&self.ir_module.types) {
464                crate::TypeInner::Scalar(_) => match self.ir_function.expressions[pointer] {
465                    crate::Expression::Access { base, index } => {
466                        (base, Some(GuardedIndex::Expression(index)))
467                    }
468                    crate::Expression::AccessIndex { base, index } => {
469                        (base, Some(GuardedIndex::Known(index)))
470                    }
471                    _ => return Ok(None),
472                },
473                crate::TypeInner::Vector { .. } => (pointer, None),
474                _ => return Ok(None),
475            },
476            None => return Ok(None),
477        };
478
479        // Ensure the column is accessed with a dynamic index (i.e.
480        // `Expression::Access`), and grab the pointer to the matrix.
481        let crate::Expression::Access {
482            base: matrix_pointer,
483            index: column_index,
484        } = self.ir_function.expressions[column_pointer]
485        else {
486            return Ok(None);
487        };
488
489        // Ensure the matrix pointer is in the uniform address space.
490        let crate::TypeInner::Pointer {
491            base: matrix_pointer_base_type,
492            space: crate::AddressSpace::Uniform,
493        } = *self.fun_info[matrix_pointer]
494            .ty
495            .inner_with(&self.ir_module.types)
496        else {
497            return Ok(None);
498        };
499
500        // Ensure the matrix pointer actually points to a Cx2 matrix.
501        let crate::TypeInner::Matrix {
502            columns,
503            rows: rows @ crate::VectorSize::Bi,
504            scalar,
505        } = self.ir_module.types[matrix_pointer_base_type].inner
506        else {
507            return Ok(None);
508        };
509
510        let matrix_type_id = self.get_numeric_type_id(NumericType::Matrix {
511            columns,
512            rows,
513            scalar,
514        });
515        let column_type_id = self.get_numeric_type_id(NumericType::Vector { size: rows, scalar });
516        let component_type_id = self.get_numeric_type_id(NumericType::Scalar(scalar));
517        let get_column_function_id = self.writer.wrapped_functions
518            [&WrappedFunction::MatCx2GetColumn {
519                r#type: matrix_pointer_base_type,
520            }];
521
522        let matrix_load_id = self.write_checked_load(
523            matrix_pointer,
524            block,
525            AccessTypeAdjustment::None,
526            matrix_type_id,
527        )?;
528
529        // Naga IR allows the index to be either an I32 or U32 but our wrapper
530        // function expects a U32 argument, so convert it if required.
531        let column_index_id = match *self.fun_info[column_index]
532            .ty
533            .inner_with(&self.ir_module.types)
534        {
535            crate::TypeInner::Scalar(crate::Scalar {
536                kind: crate::ScalarKind::Uint,
537                ..
538            }) => self.cached[column_index],
539            crate::TypeInner::Scalar(crate::Scalar {
540                kind: crate::ScalarKind::Sint,
541                ..
542            }) => {
543                let cast_id = self.gen_id();
544                let u32_type_id = self.writer.get_u32_type_id();
545                block.body.push(Instruction::unary(
546                    spirv::Op::Bitcast,
547                    u32_type_id,
548                    cast_id,
549                    self.cached[column_index],
550                ));
551                cast_id
552            }
553            _ => return Err(Error::Validation("Matrix access index must be u32 or i32")),
554        };
555        let column_id = self.gen_id();
556        block.body.push(Instruction::function_call(
557            column_type_id,
558            column_id,
559            get_column_function_id,
560            &[matrix_load_id, column_index_id],
561        ));
562        let result_id = match component_index {
563            Some(index) => self.write_vector_access(
564                component_type_id,
565                column_pointer,
566                Some(column_id),
567                index,
568                block,
569            )?,
570            None => column_id,
571        };
572
573        Ok(Some(result_id))
574    }
575
576    /// If `pointer` refers to two-row matrix that is a member of a struct in
577    /// the [`Uniform`] address space, write code to load the matrix returning
578    /// the ID of the result. Else return None.
579    ///
580    /// Two-row matrices that are struct members in the uniform address space
581    /// will have been decomposed such that the struct contains a separate
582    /// vector member for each column of the matrix. This function will load
583    /// each column separately from the containing struct, then composite them
584    /// into the real matrix type.
585    ///
586    /// [`Uniform`]: crate::AddressSpace::Uniform
587    fn maybe_write_load_uniform_matcx2_struct_member(
588        &mut self,
589        pointer: Handle<crate::Expression>,
590        block: &mut Block,
591    ) -> Result<Option<Word>, Error> {
592        // Check this is a uniform address space pointer to a two-row matrix.
593        let crate::TypeInner::Pointer {
594            base: matrix_type,
595            space: space @ crate::AddressSpace::Uniform,
596        } = *self.fun_info[pointer].ty.inner_with(&self.ir_module.types)
597        else {
598            return Ok(None);
599        };
600
601        let crate::TypeInner::Matrix {
602            columns,
603            rows: rows @ crate::VectorSize::Bi,
604            scalar,
605        } = self.ir_module.types[matrix_type].inner
606        else {
607            return Ok(None);
608        };
609
610        // Check this is a struct member. Note struct members can only be
611        // accessed with `AccessIndex`.
612        let crate::Expression::AccessIndex {
613            base: struct_pointer,
614            index: member_index,
615        } = self.ir_function.expressions[pointer]
616        else {
617            return Ok(None);
618        };
619
620        let crate::TypeInner::Pointer {
621            base: struct_type, ..
622        } = *self.fun_info[struct_pointer]
623            .ty
624            .inner_with(&self.ir_module.types)
625        else {
626            return Ok(None);
627        };
628
629        let crate::TypeInner::Struct { .. } = self.ir_module.types[struct_type].inner else {
630            return Ok(None);
631        };
632
633        let matrix_type_id = self.get_numeric_type_id(NumericType::Matrix {
634            columns,
635            rows,
636            scalar,
637        });
638        let column_type_id = self.get_numeric_type_id(NumericType::Vector { size: rows, scalar });
639        let column_pointer_type_id =
640            self.get_pointer_type_id(column_type_id, map_storage_class(space));
641        let column0_index = self.writer.std140_compat_uniform_types[&struct_type].member_indices
642            [member_index as usize];
643        let column_indices = (0..columns as u32)
644            .map(|c| self.get_index_constant(column0_index + c))
645            .collect::<ArrayVec<_, 4>>();
646
647        // Load each column from the struct, then composite into the real
648        // matrix type.
649        let load_mat_from_struct =
650            |struct_pointer_id: Word, id_gen: &mut IdGenerator, block: &mut Block| -> Word {
651                let mut column_ids: ArrayVec<Word, 4> = ArrayVec::new();
652                for index in &column_indices {
653                    let column_pointer_id = id_gen.next();
654                    block.body.push(Instruction::access_chain(
655                        column_pointer_type_id,
656                        column_pointer_id,
657                        struct_pointer_id,
658                        &[*index],
659                    ));
660                    let column_id = id_gen.next();
661                    block.body.push(Instruction::load(
662                        column_type_id,
663                        column_id,
664                        column_pointer_id,
665                        None,
666                    ));
667                    column_ids.push(column_id);
668                }
669                let result_id = id_gen.next();
670                block.body.push(Instruction::composite_construct(
671                    matrix_type_id,
672                    result_id,
673                    &column_ids,
674                ));
675                result_id
676            };
677
678        let result_id = match self.write_access_chain(
679            struct_pointer,
680            block,
681            AccessTypeAdjustment::UseStd140CompatType,
682        )? {
683            ExpressionPointer::Ready { pointer_id } => {
684                load_mat_from_struct(pointer_id, &mut self.writer.id_gen, block)
685            }
686            ExpressionPointer::Conditional { condition, access } => self
687                .write_conditional_indexed_load(
688                    matrix_type_id,
689                    condition,
690                    block,
691                    |id_gen, block| {
692                        let pointer_id = access.result_id.unwrap();
693                        block.body.push(access);
694                        load_mat_from_struct(pointer_id, id_gen, block)
695                    },
696                ),
697        };
698
699        Ok(Some(result_id))
700    }
701
702    /// Cache an expression for a value.
703    pub(super) fn cache_expression_value(
704        &mut self,
705        expr_handle: Handle<crate::Expression>,
706        block: &mut Block,
707    ) -> Result<(), Error> {
708        let is_named_expression = self
709            .ir_function
710            .named_expressions
711            .contains_key(&expr_handle);
712
713        if self.fun_info[expr_handle].ref_count == 0 && !is_named_expression {
714            return Ok(());
715        }
716
717        let result_type_id = self.get_expression_type_id(&self.fun_info[expr_handle].ty);
718        let id = match self.ir_function.expressions[expr_handle] {
719            crate::Expression::Literal(literal) => self.writer.get_constant_scalar(literal),
720            crate::Expression::Constant(handle) => {
721                let init = self.ir_module.constants[handle].init;
722                self.writer.constant_ids[init]
723            }
724            crate::Expression::Override(_) => return Err(Error::Override),
725            crate::Expression::ZeroValue(_) => self.writer.get_constant_null(result_type_id),
726            crate::Expression::Compose { ty, ref components } => {
727                self.temp_list.clear();
728                if self.expression_constness.is_const(expr_handle) {
729                    self.temp_list.extend(
730                        crate::proc::flatten_compose(
731                            ty,
732                            components,
733                            &self.ir_function.expressions,
734                            &self.ir_module.types,
735                        )
736                        .map(|component| self.cached[component]),
737                    );
738                    self.writer
739                        .get_constant_composite(LookupType::Handle(ty), &self.temp_list)
740                } else {
741                    self.temp_list
742                        .extend(components.iter().map(|&component| self.cached[component]));
743
744                    let id = self.gen_id();
745                    block.body.push(Instruction::composite_construct(
746                        result_type_id,
747                        id,
748                        &self.temp_list,
749                    ));
750                    id
751                }
752            }
753            crate::Expression::Splat { size, value } => {
754                let value_id = self.cached[value];
755                let components = &[value_id; 4][..size as usize];
756
757                if self.expression_constness.is_const(expr_handle) {
758                    let ty = self
759                        .writer
760                        .get_expression_lookup_type(&self.fun_info[expr_handle].ty);
761                    self.writer.get_constant_composite(ty, components)
762                } else {
763                    let id = self.gen_id();
764                    block.body.push(Instruction::composite_construct(
765                        result_type_id,
766                        id,
767                        components,
768                    ));
769                    id
770                }
771            }
772            crate::Expression::Access { base, index } => {
773                let base_ty_inner = self.fun_info[base].ty.inner_with(&self.ir_module.types);
774                match *base_ty_inner {
775                    crate::TypeInner::Pointer { .. } | crate::TypeInner::ValuePointer { .. } => {
776                        // When we have a chain of `Access` and `AccessIndex` expressions
777                        // operating on pointers, we want to generate a single
778                        // `OpAccessChain` instruction for the whole chain. Put off
779                        // generating any code for this until we find the `Expression`
780                        // that actually dereferences the pointer.
781                        0
782                    }
783                    _ if self.function.spilled_accesses.contains(base) => {
784                        // As far as Naga IR is concerned, this expression does not yield
785                        // a pointer (we just checked, above), but this backend spilled it
786                        // to a temporary variable, so SPIR-V thinks we're accessing it
787                        // via a pointer.
788
789                        // Since the base expression was spilled, mark this access to it
790                        // as spilled, too.
791                        self.function.spilled_accesses.insert(expr_handle);
792                        self.maybe_access_spilled_composite(expr_handle, block, result_type_id)?
793                    }
794                    crate::TypeInner::Vector { .. } => self.write_vector_access(
795                        result_type_id,
796                        base,
797                        None,
798                        GuardedIndex::Expression(index),
799                        block,
800                    )?,
801                    crate::TypeInner::Array { .. } | crate::TypeInner::Matrix { .. } => {
802                        // See if `index` is known at compile time.
803                        match GuardedIndex::from_expression(
804                            index,
805                            &self.ir_function.expressions,
806                            self.ir_module,
807                        ) {
808                            GuardedIndex::Known(value) => {
809                                // If `index` is known and in bounds, we can just use
810                                // `OpCompositeExtract`.
811                                //
812                                // At the moment, validation rejects programs if this
813                                // index is out of bounds, so we don't need bounds checks.
814                                // However, that rejection is incorrect, since WGSL says
815                                // that `let` bindings are not constant expressions
816                                // (#6396). So eventually we will need to emulate bounds
817                                // checks here.
818                                let id = self.gen_id();
819                                let base_id = self.cached[base];
820                                block.body.push(Instruction::composite_extract(
821                                    result_type_id,
822                                    id,
823                                    base_id,
824                                    &[value],
825                                ));
826                                id
827                            }
828                            GuardedIndex::Expression(_) => {
829                                // We are subscripting an array or matrix that is not
830                                // behind a pointer, using an index computed at runtime.
831                                // SPIR-V has no instructions that do this, so the best we
832                                // can do is spill the value to a new temporary variable,
833                                // at which point we can get a pointer to that and just
834                                // use `OpAccessChain` in the usual way.
835                                self.spill_to_internal_variable(base, block);
836
837                                // Since the base was spilled, mark this access to it as
838                                // spilled, too.
839                                self.function.spilled_accesses.insert(expr_handle);
840                                self.maybe_access_spilled_composite(
841                                    expr_handle,
842                                    block,
843                                    result_type_id,
844                                )?
845                            }
846                        }
847                    }
848                    crate::TypeInner::BindingArray {
849                        base: binding_type, ..
850                    } => {
851                        // Only binding arrays in the `Handle` address space will take
852                        // this path, since we handled the `Pointer` case above.
853                        let result_id = match self.write_access_chain(
854                            expr_handle,
855                            block,
856                            AccessTypeAdjustment::IntroducePointer(
857                                spirv::StorageClass::UniformConstant,
858                            ),
859                        )? {
860                            ExpressionPointer::Ready { pointer_id } => pointer_id,
861                            ExpressionPointer::Conditional { .. } => {
862                                return Err(Error::FeatureNotImplemented(
863                                    "Texture array out-of-bounds handling",
864                                ));
865                            }
866                        };
867
868                        let binding_type_id = self.get_handle_type_id(binding_type);
869
870                        let load_id = self.gen_id();
871                        block.body.push(Instruction::load(
872                            binding_type_id,
873                            load_id,
874                            result_id,
875                            None,
876                        ));
877
878                        // Subsequent image operations require the image/sampler to be decorated as NonUniform
879                        // if the image/sampler binding array was accessed with a non-uniform index
880                        // see VUID-RuntimeSpirv-NonUniform-06274
881                        if self.fun_info[index].uniformity.non_uniform_result.is_some() {
882                            self.writer
883                                .decorate_non_uniform_binding_array_access(load_id)?;
884                        }
885
886                        load_id
887                    }
888                    ref other => {
889                        log::error!(
890                            "Unable to access base {:?} of type {:?}",
891                            self.ir_function.expressions[base],
892                            other
893                        );
894                        return Err(Error::Validation(
895                            "only vectors and arrays may be dynamically indexed by value",
896                        ));
897                    }
898                }
899            }
900            crate::Expression::AccessIndex { base, index } => {
901                match *self.fun_info[base].ty.inner_with(&self.ir_module.types) {
902                    crate::TypeInner::Pointer { .. } | crate::TypeInner::ValuePointer { .. } => {
903                        // When we have a chain of `Access` and `AccessIndex` expressions
904                        // operating on pointers, we want to generate a single
905                        // `OpAccessChain` instruction for the whole chain. Put off
906                        // generating any code for this until we find the `Expression`
907                        // that actually dereferences the pointer.
908                        0
909                    }
910                    _ if self.function.spilled_accesses.contains(base) => {
911                        // As far as Naga IR is concerned, this expression does not yield
912                        // a pointer (we just checked, above), but this backend spilled it
913                        // to a temporary variable, so SPIR-V thinks we're accessing it
914                        // via a pointer.
915
916                        // Since the base expression was spilled, mark this access to it
917                        // as spilled, too.
918                        self.function.spilled_accesses.insert(expr_handle);
919                        self.maybe_access_spilled_composite(expr_handle, block, result_type_id)?
920                    }
921                    crate::TypeInner::Vector { .. }
922                    | crate::TypeInner::Matrix { .. }
923                    | crate::TypeInner::Array { .. }
924                    | crate::TypeInner::Struct { .. } => {
925                        // We never need bounds checks here: dynamically sized arrays can
926                        // only appear behind pointers, and are thus handled by the
927                        // `is_intermediate` case above. Everything else's size is
928                        // statically known and checked in validation.
929                        let id = self.gen_id();
930                        let base_id = self.cached[base];
931                        block.body.push(Instruction::composite_extract(
932                            result_type_id,
933                            id,
934                            base_id,
935                            &[index],
936                        ));
937                        id
938                    }
939                    crate::TypeInner::BindingArray {
940                        base: binding_type, ..
941                    } => {
942                        // Only binding arrays in the `Handle` address space will take
943                        // this path, since we handled the `Pointer` case above.
944                        let result_id = match self.write_access_chain(
945                            expr_handle,
946                            block,
947                            AccessTypeAdjustment::IntroducePointer(
948                                spirv::StorageClass::UniformConstant,
949                            ),
950                        )? {
951                            ExpressionPointer::Ready { pointer_id } => pointer_id,
952                            ExpressionPointer::Conditional { .. } => {
953                                return Err(Error::FeatureNotImplemented(
954                                    "Texture array out-of-bounds handling",
955                                ));
956                            }
957                        };
958
959                        let binding_type_id = self.get_handle_type_id(binding_type);
960
961                        let load_id = self.gen_id();
962                        block.body.push(Instruction::load(
963                            binding_type_id,
964                            load_id,
965                            result_id,
966                            None,
967                        ));
968
969                        load_id
970                    }
971                    ref other => {
972                        log::error!("Unable to access index of {other:?}");
973                        return Err(Error::FeatureNotImplemented("access index for type"));
974                    }
975                }
976            }
977            crate::Expression::GlobalVariable(handle) => {
978                self.writer.global_variables[handle].access_id
979            }
980            crate::Expression::Swizzle {
981                size,
982                vector,
983                pattern,
984            } => {
985                let vector_id = self.cached[vector];
986                self.temp_list.clear();
987                for &sc in pattern[..size as usize].iter() {
988                    self.temp_list.push(sc as Word);
989                }
990                let id = self.gen_id();
991                block.body.push(Instruction::vector_shuffle(
992                    result_type_id,
993                    id,
994                    vector_id,
995                    vector_id,
996                    &self.temp_list,
997                ));
998                id
999            }
1000            crate::Expression::Unary { op, expr } => {
1001                let id = self.gen_id();
1002                let expr_id = self.cached[expr];
1003                let expr_ty_inner = self.fun_info[expr].ty.inner_with(&self.ir_module.types);
1004
1005                let spirv_op = match op {
1006                    crate::UnaryOperator::Negate => match expr_ty_inner.scalar_kind() {
1007                        Some(crate::ScalarKind::Float) => spirv::Op::FNegate,
1008                        Some(crate::ScalarKind::Sint) => spirv::Op::SNegate,
1009                        _ => return Err(Error::Validation("Unexpected kind for negation")),
1010                    },
1011                    crate::UnaryOperator::LogicalNot => spirv::Op::LogicalNot,
1012                    crate::UnaryOperator::BitwiseNot => spirv::Op::Not,
1013                };
1014
1015                block
1016                    .body
1017                    .push(Instruction::unary(spirv_op, result_type_id, id, expr_id));
1018                id
1019            }
1020            crate::Expression::Binary { op, left, right } => {
1021                let id = self.gen_id();
1022                let left_id = self.cached[left];
1023                let right_id = self.cached[right];
1024                let left_type_id = self.get_expression_type_id(&self.fun_info[left].ty);
1025                let right_type_id = self.get_expression_type_id(&self.fun_info[right].ty);
1026
1027                if let Some(function_id) =
1028                    self.writer
1029                        .wrapped_functions
1030                        .get(&WrappedFunction::BinaryOp {
1031                            op,
1032                            left_type_id,
1033                            right_type_id,
1034                        })
1035                {
1036                    block.body.push(Instruction::function_call(
1037                        result_type_id,
1038                        id,
1039                        *function_id,
1040                        &[left_id, right_id],
1041                    ));
1042                } else {
1043                    let left_ty_inner = self.fun_info[left].ty.inner_with(&self.ir_module.types);
1044                    let right_ty_inner = self.fun_info[right].ty.inner_with(&self.ir_module.types);
1045
1046                    let left_dimension = get_dimension(left_ty_inner);
1047                    let right_dimension = get_dimension(right_ty_inner);
1048
1049                    let mut reverse_operands = false;
1050
1051                    let spirv_op = match op {
1052                        crate::BinaryOperator::Add => match *left_ty_inner {
1053                            crate::TypeInner::Scalar(scalar)
1054                            | crate::TypeInner::Vector { scalar, .. } => match scalar.kind {
1055                                crate::ScalarKind::Float => spirv::Op::FAdd,
1056                                _ => spirv::Op::IAdd,
1057                            },
1058                            crate::TypeInner::Matrix {
1059                                columns,
1060                                rows,
1061                                scalar,
1062                            } => {
1063                                //TODO: why not just rely on `Fadd` for matrices?
1064                                self.write_matrix_matrix_column_op(
1065                                    block,
1066                                    id,
1067                                    result_type_id,
1068                                    left_id,
1069                                    right_id,
1070                                    columns,
1071                                    rows,
1072                                    scalar.width,
1073                                    spirv::Op::FAdd,
1074                                );
1075
1076                                self.cached[expr_handle] = id;
1077                                return Ok(());
1078                            }
1079                            crate::TypeInner::CooperativeMatrix { .. } => spirv::Op::FAdd,
1080                            _ => unimplemented!(),
1081                        },
1082                        crate::BinaryOperator::Subtract => match *left_ty_inner {
1083                            crate::TypeInner::Scalar(scalar)
1084                            | crate::TypeInner::Vector { scalar, .. } => match scalar.kind {
1085                                crate::ScalarKind::Float => spirv::Op::FSub,
1086                                _ => spirv::Op::ISub,
1087                            },
1088                            crate::TypeInner::Matrix {
1089                                columns,
1090                                rows,
1091                                scalar,
1092                            } => {
1093                                self.write_matrix_matrix_column_op(
1094                                    block,
1095                                    id,
1096                                    result_type_id,
1097                                    left_id,
1098                                    right_id,
1099                                    columns,
1100                                    rows,
1101                                    scalar.width,
1102                                    spirv::Op::FSub,
1103                                );
1104
1105                                self.cached[expr_handle] = id;
1106                                return Ok(());
1107                            }
1108                            crate::TypeInner::CooperativeMatrix { .. } => spirv::Op::FSub,
1109                            _ => unimplemented!(),
1110                        },
1111                        crate::BinaryOperator::Multiply => {
1112                            match (left_dimension, right_dimension) {
1113                                (Dimension::Scalar, Dimension::Vector) => {
1114                                    self.write_vector_scalar_mult(
1115                                        block,
1116                                        id,
1117                                        result_type_id,
1118                                        right_id,
1119                                        left_id,
1120                                        right_ty_inner,
1121                                    );
1122
1123                                    self.cached[expr_handle] = id;
1124                                    return Ok(());
1125                                }
1126                                (Dimension::Vector, Dimension::Scalar) => {
1127                                    self.write_vector_scalar_mult(
1128                                        block,
1129                                        id,
1130                                        result_type_id,
1131                                        left_id,
1132                                        right_id,
1133                                        left_ty_inner,
1134                                    );
1135
1136                                    self.cached[expr_handle] = id;
1137                                    return Ok(());
1138                                }
1139                                (Dimension::Vector, Dimension::Matrix) => {
1140                                    spirv::Op::VectorTimesMatrix
1141                                }
1142                                (Dimension::Matrix, Dimension::Scalar)
1143                                | (Dimension::CooperativeMatrix, Dimension::Scalar) => {
1144                                    spirv::Op::MatrixTimesScalar
1145                                }
1146                                (Dimension::Scalar, Dimension::Matrix)
1147                                | (Dimension::Scalar, Dimension::CooperativeMatrix) => {
1148                                    reverse_operands = true;
1149                                    spirv::Op::MatrixTimesScalar
1150                                }
1151                                (Dimension::Matrix, Dimension::Vector) => {
1152                                    spirv::Op::MatrixTimesVector
1153                                }
1154                                (Dimension::Matrix, Dimension::Matrix) => {
1155                                    spirv::Op::MatrixTimesMatrix
1156                                }
1157                                (Dimension::Vector, Dimension::Vector)
1158                                | (Dimension::Scalar, Dimension::Scalar)
1159                                    if left_ty_inner.scalar_kind()
1160                                        == Some(crate::ScalarKind::Float) =>
1161                                {
1162                                    spirv::Op::FMul
1163                                }
1164                                (Dimension::Vector, Dimension::Vector)
1165                                | (Dimension::Scalar, Dimension::Scalar) => spirv::Op::IMul,
1166                                (Dimension::CooperativeMatrix, Dimension::CooperativeMatrix)
1167                                //Note: technically can do `FMul` but IR doesn't have matrix per-component multiplication
1168                                | (Dimension::CooperativeMatrix, _)
1169                                | (_, Dimension::CooperativeMatrix) => {
1170                                    unimplemented!()
1171                                }
1172                            }
1173                        }
1174                        crate::BinaryOperator::Divide => match left_ty_inner.scalar_kind() {
1175                            Some(crate::ScalarKind::Sint) => spirv::Op::SDiv,
1176                            Some(crate::ScalarKind::Uint) => spirv::Op::UDiv,
1177                            Some(crate::ScalarKind::Float) => spirv::Op::FDiv,
1178                            _ => unimplemented!(),
1179                        },
1180                        crate::BinaryOperator::Modulo => match left_ty_inner.scalar_kind() {
1181                            Some(crate::ScalarKind::Float) => spirv::Op::FRem,
1182                            Some(crate::ScalarKind::Sint) => {
1183                                // Signed `%` is always lowered to `a - b * (a / b)`
1184                                // through a wrapper function (see
1185                                // `write_wrapped_functions`), because `OpSRem` with a
1186                                // negative operand is poison in the Vulkan SPIR-V
1187                                // environment without `VK_KHR_maintenance8`. So this raw
1188                                // path is never reached for signed modulo.
1189                                unreachable!("signed modulo must be lowered via the wrapped path")
1190                            }
1191                            Some(crate::ScalarKind::Uint) => {
1192                                assert!(!self.writer.emit_int_div_checks);
1193                                spirv::Op::UMod
1194                            }
1195                            _ => unimplemented!(),
1196                        },
1197                        crate::BinaryOperator::Equal => match left_ty_inner.scalar_kind() {
1198                            Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => {
1199                                spirv::Op::IEqual
1200                            }
1201                            Some(crate::ScalarKind::Float) => spirv::Op::FOrdEqual,
1202                            Some(crate::ScalarKind::Bool) => spirv::Op::LogicalEqual,
1203                            _ => unimplemented!(),
1204                        },
1205                        crate::BinaryOperator::NotEqual => match left_ty_inner.scalar_kind() {
1206                            Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => {
1207                                spirv::Op::INotEqual
1208                            }
1209                            Some(crate::ScalarKind::Float) => spirv::Op::FOrdNotEqual,
1210                            Some(crate::ScalarKind::Bool) => spirv::Op::LogicalNotEqual,
1211                            _ => unimplemented!(),
1212                        },
1213                        crate::BinaryOperator::Less => match left_ty_inner.scalar_kind() {
1214                            Some(crate::ScalarKind::Sint) => spirv::Op::SLessThan,
1215                            Some(crate::ScalarKind::Uint) => spirv::Op::ULessThan,
1216                            Some(crate::ScalarKind::Float) => spirv::Op::FOrdLessThan,
1217                            _ => unimplemented!(),
1218                        },
1219                        crate::BinaryOperator::LessEqual => match left_ty_inner.scalar_kind() {
1220                            Some(crate::ScalarKind::Sint) => spirv::Op::SLessThanEqual,
1221                            Some(crate::ScalarKind::Uint) => spirv::Op::ULessThanEqual,
1222                            Some(crate::ScalarKind::Float) => spirv::Op::FOrdLessThanEqual,
1223                            _ => unimplemented!(),
1224                        },
1225                        crate::BinaryOperator::Greater => match left_ty_inner.scalar_kind() {
1226                            Some(crate::ScalarKind::Sint) => spirv::Op::SGreaterThan,
1227                            Some(crate::ScalarKind::Uint) => spirv::Op::UGreaterThan,
1228                            Some(crate::ScalarKind::Float) => spirv::Op::FOrdGreaterThan,
1229                            _ => unimplemented!(),
1230                        },
1231                        crate::BinaryOperator::GreaterEqual => match left_ty_inner.scalar_kind() {
1232                            Some(crate::ScalarKind::Sint) => spirv::Op::SGreaterThanEqual,
1233                            Some(crate::ScalarKind::Uint) => spirv::Op::UGreaterThanEqual,
1234                            Some(crate::ScalarKind::Float) => spirv::Op::FOrdGreaterThanEqual,
1235                            _ => unimplemented!(),
1236                        },
1237                        crate::BinaryOperator::And => match left_ty_inner.scalar_kind() {
1238                            Some(crate::ScalarKind::Bool) => spirv::Op::LogicalAnd,
1239                            _ => spirv::Op::BitwiseAnd,
1240                        },
1241                        crate::BinaryOperator::ExclusiveOr => spirv::Op::BitwiseXor,
1242                        crate::BinaryOperator::InclusiveOr => match left_ty_inner.scalar_kind() {
1243                            Some(crate::ScalarKind::Bool) => spirv::Op::LogicalOr,
1244                            _ => spirv::Op::BitwiseOr,
1245                        },
1246                        crate::BinaryOperator::LogicalAnd => spirv::Op::LogicalAnd,
1247                        crate::BinaryOperator::LogicalOr => spirv::Op::LogicalOr,
1248                        crate::BinaryOperator::ShiftLeft => spirv::Op::ShiftLeftLogical,
1249                        crate::BinaryOperator::ShiftRight => match left_ty_inner.scalar_kind() {
1250                            Some(crate::ScalarKind::Sint) => spirv::Op::ShiftRightArithmetic,
1251                            Some(crate::ScalarKind::Uint) => spirv::Op::ShiftRightLogical,
1252                            _ => unimplemented!(),
1253                        },
1254                    };
1255
1256                    block.body.push(Instruction::binary(
1257                        spirv_op,
1258                        result_type_id,
1259                        id,
1260                        if reverse_operands { right_id } else { left_id },
1261                        if reverse_operands { left_id } else { right_id },
1262                    ));
1263                }
1264                id
1265            }
1266            crate::Expression::Math {
1267                fun,
1268                arg,
1269                arg1,
1270                arg2,
1271                arg3,
1272            } => {
1273                use crate::MathFunction as Mf;
1274                enum MathOp {
1275                    Ext(spirv::GlslStd450Op),
1276                    Custom(Instruction),
1277                }
1278
1279                let arg0_id = self.cached[arg];
1280                let arg_ty = self.fun_info[arg].ty.inner_with(&self.ir_module.types);
1281                let arg_scalar_kind = arg_ty.scalar_kind();
1282                let arg1_id = match arg1 {
1283                    Some(handle) => self.cached[handle],
1284                    None => 0,
1285                };
1286                let arg2_id = match arg2 {
1287                    Some(handle) => self.cached[handle],
1288                    None => 0,
1289                };
1290                let arg3_id = match arg3 {
1291                    Some(handle) => self.cached[handle],
1292                    None => 0,
1293                };
1294
1295                let id = self.gen_id();
1296                let math_op = match fun {
1297                    // comparison
1298                    Mf::Abs => {
1299                        match arg_scalar_kind {
1300                            Some(crate::ScalarKind::Float) => {
1301                                MathOp::Ext(spirv::GlslStd450Op::FAbs)
1302                            }
1303                            Some(crate::ScalarKind::Sint) => MathOp::Ext(spirv::GlslStd450Op::SAbs),
1304                            Some(crate::ScalarKind::Uint) => {
1305                                MathOp::Custom(Instruction::unary(
1306                                    spirv::Op::CopyObject, // do nothing
1307                                    result_type_id,
1308                                    id,
1309                                    arg0_id,
1310                                ))
1311                            }
1312                            other => unimplemented!("Unexpected abs({:?})", other),
1313                        }
1314                    }
1315                    Mf::Min => MathOp::Ext(match arg_scalar_kind {
1316                        Some(crate::ScalarKind::Float) => spirv::GlslStd450Op::FMin,
1317                        Some(crate::ScalarKind::Sint) => spirv::GlslStd450Op::SMin,
1318                        Some(crate::ScalarKind::Uint) => spirv::GlslStd450Op::UMin,
1319                        other => unimplemented!("Unexpected min({:?})", other),
1320                    }),
1321                    Mf::Max => MathOp::Ext(match arg_scalar_kind {
1322                        Some(crate::ScalarKind::Float) => spirv::GlslStd450Op::FMax,
1323                        Some(crate::ScalarKind::Sint) => spirv::GlslStd450Op::SMax,
1324                        Some(crate::ScalarKind::Uint) => spirv::GlslStd450Op::UMax,
1325                        other => unimplemented!("Unexpected max({:?})", other),
1326                    }),
1327                    Mf::Clamp => match arg_scalar_kind {
1328                        // Clamp is undefined if min > max. In practice this means it can use a median-of-three
1329                        // instruction to determine the value. This is fine according to the WGSL spec for float
1330                        // clamp, but integer clamp _must_ use min-max. As such we write out min/max.
1331                        Some(crate::ScalarKind::Float) => MathOp::Ext(spirv::GlslStd450Op::FClamp),
1332                        Some(_) => {
1333                            let (min_op, max_op) = match arg_scalar_kind {
1334                                Some(crate::ScalarKind::Sint) => {
1335                                    (spirv::GlslStd450Op::SMin, spirv::GlslStd450Op::SMax)
1336                                }
1337                                Some(crate::ScalarKind::Uint) => {
1338                                    (spirv::GlslStd450Op::UMin, spirv::GlslStd450Op::UMax)
1339                                }
1340                                _ => unreachable!(),
1341                            };
1342
1343                            let max_id = self.gen_id();
1344                            block.body.push(Instruction::ext_inst_gl_op(
1345                                self.writer.gl450_ext_inst_id,
1346                                max_op,
1347                                result_type_id,
1348                                max_id,
1349                                &[arg0_id, arg1_id],
1350                            ));
1351
1352                            MathOp::Custom(Instruction::ext_inst_gl_op(
1353                                self.writer.gl450_ext_inst_id,
1354                                min_op,
1355                                result_type_id,
1356                                id,
1357                                &[max_id, arg2_id],
1358                            ))
1359                        }
1360                        other => unimplemented!("Unexpected max({:?})", other),
1361                    },
1362                    Mf::Saturate => {
1363                        let (maybe_size, scalar) = match *arg_ty {
1364                            crate::TypeInner::Vector { size, scalar } => (Some(size), scalar),
1365                            crate::TypeInner::Scalar(scalar) => (None, scalar),
1366                            ref other => unimplemented!("Unexpected saturate({:?})", other),
1367                        };
1368                        let scalar = crate::Scalar::float(scalar.width);
1369                        let mut arg1_id = self.writer.get_constant_scalar_with(0, scalar)?;
1370                        let mut arg2_id = self.writer.get_constant_scalar_with(1, scalar)?;
1371
1372                        if let Some(size) = maybe_size {
1373                            let ty =
1374                                LocalType::Numeric(NumericType::Vector { size, scalar }).into();
1375
1376                            self.temp_list.clear();
1377                            self.temp_list.resize(size as _, arg1_id);
1378
1379                            arg1_id = self.writer.get_constant_composite(ty, &self.temp_list);
1380
1381                            self.temp_list.fill(arg2_id);
1382
1383                            arg2_id = self.writer.get_constant_composite(ty, &self.temp_list);
1384                        }
1385
1386                        MathOp::Custom(Instruction::ext_inst_gl_op(
1387                            self.writer.gl450_ext_inst_id,
1388                            spirv::GlslStd450Op::FClamp,
1389                            result_type_id,
1390                            id,
1391                            &[arg0_id, arg1_id, arg2_id],
1392                        ))
1393                    }
1394                    // trigonometry
1395                    Mf::Sin => MathOp::Ext(spirv::GlslStd450Op::Sin),
1396                    Mf::Sinh => MathOp::Ext(spirv::GlslStd450Op::Sinh),
1397                    Mf::Asin => MathOp::Ext(spirv::GlslStd450Op::Asin),
1398                    Mf::Cos => MathOp::Ext(spirv::GlslStd450Op::Cos),
1399                    Mf::Cosh => MathOp::Ext(spirv::GlslStd450Op::Cosh),
1400                    Mf::Acos => MathOp::Ext(spirv::GlslStd450Op::Acos),
1401                    Mf::Tan => MathOp::Ext(spirv::GlslStd450Op::Tan),
1402                    Mf::Tanh => MathOp::Ext(spirv::GlslStd450Op::Tanh),
1403                    Mf::Atan => MathOp::Ext(spirv::GlslStd450Op::Atan),
1404                    Mf::Atan2 => MathOp::Ext(spirv::GlslStd450Op::Atan2),
1405                    Mf::Asinh => MathOp::Ext(spirv::GlslStd450Op::Asinh),
1406                    Mf::Acosh => MathOp::Ext(spirv::GlslStd450Op::Acosh),
1407                    Mf::Atanh => MathOp::Ext(spirv::GlslStd450Op::Atanh),
1408                    Mf::Radians => MathOp::Ext(spirv::GlslStd450Op::Radians),
1409                    Mf::Degrees => MathOp::Ext(spirv::GlslStd450Op::Degrees),
1410                    // decomposition
1411                    Mf::Ceil => MathOp::Ext(spirv::GlslStd450Op::Ceil),
1412                    Mf::Round => MathOp::Ext(spirv::GlslStd450Op::RoundEven),
1413                    Mf::Floor => MathOp::Ext(spirv::GlslStd450Op::Floor),
1414                    Mf::Fract => MathOp::Ext(spirv::GlslStd450Op::Fract),
1415                    Mf::Trunc => MathOp::Ext(spirv::GlslStd450Op::Trunc),
1416                    Mf::Modf => MathOp::Ext(spirv::GlslStd450Op::ModfStruct),
1417                    Mf::Frexp => MathOp::Ext(spirv::GlslStd450Op::FrexpStruct),
1418                    Mf::Ldexp => MathOp::Ext(spirv::GlslStd450Op::Ldexp),
1419                    // geometry
1420                    Mf::Dot => match *self.fun_info[arg].ty.inner_with(&self.ir_module.types) {
1421                        crate::TypeInner::Vector {
1422                            scalar:
1423                                crate::Scalar {
1424                                    kind: crate::ScalarKind::Float,
1425                                    ..
1426                                },
1427                            ..
1428                        } => MathOp::Custom(Instruction::binary(
1429                            spirv::Op::Dot,
1430                            result_type_id,
1431                            id,
1432                            arg0_id,
1433                            arg1_id,
1434                        )),
1435                        // TODO: consider using integer dot product if VK_KHR_shader_integer_dot_product is available
1436                        crate::TypeInner::Vector { size, .. } => {
1437                            self.write_dot_product(
1438                                id,
1439                                result_type_id,
1440                                arg0_id,
1441                                arg1_id,
1442                                size as u32,
1443                                block,
1444                                |result_id, composite_id, index| {
1445                                    Instruction::composite_extract(
1446                                        result_type_id,
1447                                        result_id,
1448                                        composite_id,
1449                                        &[index],
1450                                    )
1451                                },
1452                            );
1453                            self.cached[expr_handle] = id;
1454                            return Ok(());
1455                        }
1456                        _ => unreachable!(
1457                            "Correct TypeInner for dot product should be already validated"
1458                        ),
1459                    },
1460                    fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
1461                        if self
1462                            .writer
1463                            .require_all(&[
1464                                spirv::Capability::DotProduct,
1465                                spirv::Capability::DotProductInput4x8BitPacked,
1466                            ])
1467                            .is_ok()
1468                        {
1469                            // Write optimized code using `PackedVectorFormat4x8Bit`.
1470                            if self.writer.lang_version() < (1, 6) {
1471                                // SPIR-V 1.6 supports the required capabilities natively, so the extension
1472                                // is only required for earlier versions. See right column of
1473                                // <https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpSDot>.
1474                                self.writer.use_extension("SPV_KHR_integer_dot_product");
1475                            }
1476
1477                            let op = match fun {
1478                                Mf::Dot4I8Packed => spirv::Op::SDot,
1479                                Mf::Dot4U8Packed => spirv::Op::UDot,
1480                                _ => unreachable!(),
1481                            };
1482
1483                            block.body.push(Instruction::ternary(
1484                                op,
1485                                result_type_id,
1486                                id,
1487                                arg0_id,
1488                                arg1_id,
1489                                spirv::PackedVectorFormat::PackedVectorFormat4x8Bit as Word,
1490                            ));
1491                        } else {
1492                            // Fall back to a polyfill since `PackedVectorFormat4x8Bit` is not available.
1493                            let (extract_op, arg0_id, arg1_id) = match fun {
1494                                Mf::Dot4U8Packed => (spirv::Op::BitFieldUExtract, arg0_id, arg1_id),
1495                                Mf::Dot4I8Packed => {
1496                                    // Convert both packed arguments to signed integers so that we can apply the
1497                                    // `BitFieldSExtract` operation on them in `write_dot_product` below.
1498                                    let new_arg0_id = self.gen_id();
1499                                    block.body.push(Instruction::unary(
1500                                        spirv::Op::Bitcast,
1501                                        result_type_id,
1502                                        new_arg0_id,
1503                                        arg0_id,
1504                                    ));
1505
1506                                    let new_arg1_id = self.gen_id();
1507                                    block.body.push(Instruction::unary(
1508                                        spirv::Op::Bitcast,
1509                                        result_type_id,
1510                                        new_arg1_id,
1511                                        arg1_id,
1512                                    ));
1513
1514                                    (spirv::Op::BitFieldSExtract, new_arg0_id, new_arg1_id)
1515                                }
1516                                _ => unreachable!(),
1517                            };
1518
1519                            let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
1520
1521                            const VEC_LENGTH: u8 = 4;
1522                            let bit_shifts: [_; VEC_LENGTH as usize] =
1523                                core::array::from_fn(|index| {
1524                                    self.writer
1525                                        .get_constant_scalar(crate::Literal::U32(index as u32 * 8))
1526                                });
1527
1528                            self.write_dot_product(
1529                                id,
1530                                result_type_id,
1531                                arg0_id,
1532                                arg1_id,
1533                                VEC_LENGTH as Word,
1534                                block,
1535                                |result_id, composite_id, index| {
1536                                    Instruction::ternary(
1537                                        extract_op,
1538                                        result_type_id,
1539                                        result_id,
1540                                        composite_id,
1541                                        bit_shifts[index as usize],
1542                                        eight,
1543                                    )
1544                                },
1545                            );
1546                        }
1547
1548                        self.cached[expr_handle] = id;
1549                        return Ok(());
1550                    }
1551                    Mf::Outer => MathOp::Custom(Instruction::binary(
1552                        spirv::Op::OuterProduct,
1553                        result_type_id,
1554                        id,
1555                        arg0_id,
1556                        arg1_id,
1557                    )),
1558                    Mf::Cross => MathOp::Ext(spirv::GlslStd450Op::Cross),
1559                    Mf::Distance => MathOp::Ext(spirv::GlslStd450Op::Distance),
1560                    Mf::Length => MathOp::Ext(spirv::GlslStd450Op::Length),
1561                    Mf::Normalize => MathOp::Ext(spirv::GlslStd450Op::Normalize),
1562                    Mf::FaceForward => MathOp::Ext(spirv::GlslStd450Op::FaceForward),
1563                    Mf::Reflect => MathOp::Ext(spirv::GlslStd450Op::Reflect),
1564                    Mf::Refract => MathOp::Ext(spirv::GlslStd450Op::Refract),
1565                    // exponent
1566                    Mf::Exp => MathOp::Ext(spirv::GlslStd450Op::Exp),
1567                    Mf::Exp2 => MathOp::Ext(spirv::GlslStd450Op::Exp2),
1568                    Mf::Log => MathOp::Ext(spirv::GlslStd450Op::Log),
1569                    Mf::Log2 => MathOp::Ext(spirv::GlslStd450Op::Log2),
1570                    Mf::Pow => MathOp::Ext(spirv::GlslStd450Op::Pow),
1571                    // computational
1572                    Mf::Sign => MathOp::Ext(match arg_scalar_kind {
1573                        Some(crate::ScalarKind::Float) => spirv::GlslStd450Op::FSign,
1574                        Some(crate::ScalarKind::Sint) => spirv::GlslStd450Op::SSign,
1575                        other => unimplemented!("Unexpected sign({:?})", other),
1576                    }),
1577                    Mf::Fma => MathOp::Ext(spirv::GlslStd450Op::Fma),
1578                    Mf::Mix => {
1579                        let selector = arg2.unwrap();
1580                        let selector_ty =
1581                            self.fun_info[selector].ty.inner_with(&self.ir_module.types);
1582                        match (arg_ty, selector_ty) {
1583                            // if the selector is a scalar, we need to splat it
1584                            (
1585                                &crate::TypeInner::Vector { size, .. },
1586                                &crate::TypeInner::Scalar(scalar),
1587                            ) => {
1588                                let selector_type_id =
1589                                    self.get_numeric_type_id(NumericType::Vector { size, scalar });
1590                                self.temp_list.clear();
1591                                self.temp_list.resize(size as usize, arg2_id);
1592
1593                                let selector_id = self.gen_id();
1594                                block.body.push(Instruction::composite_construct(
1595                                    selector_type_id,
1596                                    selector_id,
1597                                    &self.temp_list,
1598                                ));
1599
1600                                MathOp::Custom(Instruction::ext_inst_gl_op(
1601                                    self.writer.gl450_ext_inst_id,
1602                                    spirv::GlslStd450Op::FMix,
1603                                    result_type_id,
1604                                    id,
1605                                    &[arg0_id, arg1_id, selector_id],
1606                                ))
1607                            }
1608                            _ => MathOp::Ext(spirv::GlslStd450Op::FMix),
1609                        }
1610                    }
1611                    Mf::Step => MathOp::Ext(spirv::GlslStd450Op::Step),
1612                    Mf::SmoothStep => MathOp::Ext(spirv::GlslStd450Op::SmoothStep),
1613                    Mf::Sqrt => MathOp::Ext(spirv::GlslStd450Op::Sqrt),
1614                    Mf::InverseSqrt => MathOp::Ext(spirv::GlslStd450Op::InverseSqrt),
1615                    Mf::Inverse => MathOp::Ext(spirv::GlslStd450Op::MatrixInverse),
1616                    Mf::Transpose => MathOp::Custom(Instruction::unary(
1617                        spirv::Op::Transpose,
1618                        result_type_id,
1619                        id,
1620                        arg0_id,
1621                    )),
1622                    Mf::Determinant => MathOp::Ext(spirv::GlslStd450Op::Determinant),
1623                    Mf::QuantizeToF16 => MathOp::Custom(Instruction::unary(
1624                        spirv::Op::QuantizeToF16,
1625                        result_type_id,
1626                        id,
1627                        arg0_id,
1628                    )),
1629                    Mf::ReverseBits => MathOp::Custom(Instruction::unary(
1630                        spirv::Op::BitReverse,
1631                        result_type_id,
1632                        id,
1633                        arg0_id,
1634                    )),
1635                    Mf::CountTrailingZeros => {
1636                        let uint_id = match *arg_ty {
1637                            crate::TypeInner::Vector { size, scalar } => {
1638                                let ty =
1639                                    LocalType::Numeric(NumericType::Vector { size, scalar }).into();
1640
1641                                self.temp_list.clear();
1642                                self.temp_list.resize(
1643                                    size as _,
1644                                    self.writer
1645                                        .get_constant_scalar_with(scalar.width * 8, scalar)?,
1646                                );
1647
1648                                self.writer.get_constant_composite(ty, &self.temp_list)
1649                            }
1650                            crate::TypeInner::Scalar(scalar) => self
1651                                .writer
1652                                .get_constant_scalar_with(scalar.width * 8, scalar)?,
1653                            _ => unreachable!(),
1654                        };
1655
1656                        let lsb_id = self.gen_id();
1657                        block.body.push(Instruction::ext_inst_gl_op(
1658                            self.writer.gl450_ext_inst_id,
1659                            spirv::GlslStd450Op::FindILsb,
1660                            result_type_id,
1661                            lsb_id,
1662                            &[arg0_id],
1663                        ));
1664
1665                        MathOp::Custom(Instruction::ext_inst_gl_op(
1666                            self.writer.gl450_ext_inst_id,
1667                            spirv::GlslStd450Op::UMin,
1668                            result_type_id,
1669                            id,
1670                            &[uint_id, lsb_id],
1671                        ))
1672                    }
1673                    Mf::CountLeadingZeros => {
1674                        let (int_type_id, int_id, width) = match *arg_ty {
1675                            crate::TypeInner::Vector { size, scalar } => {
1676                                let ty =
1677                                    LocalType::Numeric(NumericType::Vector { size, scalar }).into();
1678
1679                                self.temp_list.clear();
1680                                self.temp_list.resize(
1681                                    size as _,
1682                                    self.writer
1683                                        .get_constant_scalar_with(scalar.width * 8 - 1, scalar)?,
1684                                );
1685
1686                                (
1687                                    self.get_type_id(ty),
1688                                    self.writer.get_constant_composite(ty, &self.temp_list),
1689                                    scalar.width,
1690                                )
1691                            }
1692                            crate::TypeInner::Scalar(scalar) => (
1693                                self.get_numeric_type_id(NumericType::Scalar(scalar)),
1694                                self.writer
1695                                    .get_constant_scalar_with(scalar.width * 8 - 1, scalar)?,
1696                                scalar.width,
1697                            ),
1698                            _ => unreachable!(),
1699                        };
1700
1701                        if width != 4 {
1702                            unreachable!("This is validated out until a polyfill is implemented. https://github.com/gfx-rs/wgpu/issues/5276");
1703                        };
1704
1705                        let msb_id = self.gen_id();
1706                        block.body.push(Instruction::ext_inst_gl_op(
1707                            self.writer.gl450_ext_inst_id,
1708                            if width != 4 {
1709                                spirv::GlslStd450Op::FindILsb
1710                            } else {
1711                                spirv::GlslStd450Op::FindUMsb
1712                            },
1713                            int_type_id,
1714                            msb_id,
1715                            &[arg0_id],
1716                        ));
1717
1718                        MathOp::Custom(Instruction::binary(
1719                            spirv::Op::ISub,
1720                            result_type_id,
1721                            id,
1722                            int_id,
1723                            msb_id,
1724                        ))
1725                    }
1726                    Mf::CountOneBits => MathOp::Custom(Instruction::unary(
1727                        spirv::Op::BitCount,
1728                        result_type_id,
1729                        id,
1730                        arg0_id,
1731                    )),
1732                    Mf::ExtractBits => {
1733                        let op = match arg_scalar_kind {
1734                            Some(crate::ScalarKind::Uint) => spirv::Op::BitFieldUExtract,
1735                            Some(crate::ScalarKind::Sint) => spirv::Op::BitFieldSExtract,
1736                            other => unimplemented!("Unexpected sign({:?})", other),
1737                        };
1738
1739                        // The behavior of ExtractBits is undefined when offset + count > bit_width. We need
1740                        // to first sanitize the offset and count first. If we don't do this, AMD and Intel
1741                        // will return out-of-spec values if the extracted range is not within the bit width.
1742                        //
1743                        // This encodes the exact formula specified by the wgsl spec:
1744                        // https://gpuweb.github.io/gpuweb/wgsl/#extractBits-unsigned-builtin
1745                        //
1746                        // w = sizeof(x) * 8
1747                        // o = min(offset, w)
1748                        // tmp = w - o
1749                        // c = min(count, tmp)
1750                        //
1751                        // bitfieldExtract(x, o, c)
1752
1753                        let bit_width = arg_ty.scalar_width().unwrap() * 8;
1754                        let width_constant = self
1755                            .writer
1756                            .get_constant_scalar(crate::Literal::U32(bit_width as u32));
1757
1758                        let u32_type =
1759                            self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
1760
1761                        // o = min(offset, w)
1762                        let offset_id = self.gen_id();
1763                        block.body.push(Instruction::ext_inst_gl_op(
1764                            self.writer.gl450_ext_inst_id,
1765                            spirv::GlslStd450Op::UMin,
1766                            u32_type,
1767                            offset_id,
1768                            &[arg1_id, width_constant],
1769                        ));
1770
1771                        // tmp = w - o
1772                        let max_count_id = self.gen_id();
1773                        block.body.push(Instruction::binary(
1774                            spirv::Op::ISub,
1775                            u32_type,
1776                            max_count_id,
1777                            width_constant,
1778                            offset_id,
1779                        ));
1780
1781                        // c = min(count, tmp)
1782                        let count_id = self.gen_id();
1783                        block.body.push(Instruction::ext_inst_gl_op(
1784                            self.writer.gl450_ext_inst_id,
1785                            spirv::GlslStd450Op::UMin,
1786                            u32_type,
1787                            count_id,
1788                            &[arg2_id, max_count_id],
1789                        ));
1790
1791                        MathOp::Custom(Instruction::ternary(
1792                            op,
1793                            result_type_id,
1794                            id,
1795                            arg0_id,
1796                            offset_id,
1797                            count_id,
1798                        ))
1799                    }
1800                    Mf::InsertBits => {
1801                        // The behavior of InsertBits has the same undefined behavior as ExtractBits.
1802
1803                        let bit_width = arg_ty.scalar_width().unwrap() * 8;
1804                        let width_constant = self
1805                            .writer
1806                            .get_constant_scalar(crate::Literal::U32(bit_width as u32));
1807
1808                        let u32_type =
1809                            self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
1810
1811                        // o = min(offset, w)
1812                        let offset_id = self.gen_id();
1813                        block.body.push(Instruction::ext_inst_gl_op(
1814                            self.writer.gl450_ext_inst_id,
1815                            spirv::GlslStd450Op::UMin,
1816                            u32_type,
1817                            offset_id,
1818                            &[arg2_id, width_constant],
1819                        ));
1820
1821                        // tmp = w - o
1822                        let max_count_id = self.gen_id();
1823                        block.body.push(Instruction::binary(
1824                            spirv::Op::ISub,
1825                            u32_type,
1826                            max_count_id,
1827                            width_constant,
1828                            offset_id,
1829                        ));
1830
1831                        // c = min(count, tmp)
1832                        let count_id = self.gen_id();
1833                        block.body.push(Instruction::ext_inst_gl_op(
1834                            self.writer.gl450_ext_inst_id,
1835                            spirv::GlslStd450Op::UMin,
1836                            u32_type,
1837                            count_id,
1838                            &[arg3_id, max_count_id],
1839                        ));
1840
1841                        MathOp::Custom(Instruction::quaternary(
1842                            spirv::Op::BitFieldInsert,
1843                            result_type_id,
1844                            id,
1845                            arg0_id,
1846                            arg1_id,
1847                            offset_id,
1848                            count_id,
1849                        ))
1850                    }
1851                    Mf::FirstTrailingBit => MathOp::Ext(spirv::GlslStd450Op::FindILsb),
1852                    Mf::FirstLeadingBit => {
1853                        if arg_ty.scalar_width() == Some(4) {
1854                            let thing = match arg_scalar_kind {
1855                                Some(crate::ScalarKind::Uint) => spirv::GlslStd450Op::FindUMsb,
1856                                Some(crate::ScalarKind::Sint) => spirv::GlslStd450Op::FindSMsb,
1857                                other => unimplemented!("Unexpected firstLeadingBit({:?})", other),
1858                            };
1859                            MathOp::Ext(thing)
1860                        } else {
1861                            unreachable!("This is validated out until a polyfill is implemented. https://github.com/gfx-rs/wgpu/issues/5276");
1862                        }
1863                    }
1864                    Mf::Pack4x8unorm => MathOp::Ext(spirv::GlslStd450Op::PackUnorm4x8),
1865                    Mf::Pack4x8snorm => MathOp::Ext(spirv::GlslStd450Op::PackSnorm4x8),
1866                    Mf::Pack2x16float => MathOp::Ext(spirv::GlslStd450Op::PackHalf2x16),
1867                    Mf::Pack2x16unorm => MathOp::Ext(spirv::GlslStd450Op::PackUnorm2x16),
1868                    Mf::Pack2x16snorm => MathOp::Ext(spirv::GlslStd450Op::PackSnorm2x16),
1869                    fun @ (Mf::Pack4xI8 | Mf::Pack4xU8 | Mf::Pack4xI8Clamp | Mf::Pack4xU8Clamp) => {
1870                        let is_signed = matches!(fun, Mf::Pack4xI8 | Mf::Pack4xI8Clamp);
1871                        let should_clamp = matches!(fun, Mf::Pack4xI8Clamp | Mf::Pack4xU8Clamp);
1872
1873                        let last_instruction =
1874                            if self.writer.require_all(&[spirv::Capability::Int8]).is_ok() {
1875                                self.write_pack4x8_optimized(
1876                                    block,
1877                                    result_type_id,
1878                                    arg0_id,
1879                                    id,
1880                                    is_signed,
1881                                    should_clamp,
1882                                )
1883                            } else {
1884                                self.write_pack4x8_polyfill(
1885                                    block,
1886                                    result_type_id,
1887                                    arg0_id,
1888                                    id,
1889                                    is_signed,
1890                                    should_clamp,
1891                                )
1892                            };
1893
1894                        MathOp::Custom(last_instruction)
1895                    }
1896                    Mf::Unpack4x8unorm => MathOp::Ext(spirv::GlslStd450Op::UnpackUnorm4x8),
1897                    Mf::Unpack4x8snorm => MathOp::Ext(spirv::GlslStd450Op::UnpackSnorm4x8),
1898                    Mf::Unpack2x16float => MathOp::Ext(spirv::GlslStd450Op::UnpackHalf2x16),
1899                    Mf::Unpack2x16unorm => MathOp::Ext(spirv::GlslStd450Op::UnpackUnorm2x16),
1900                    Mf::Unpack2x16snorm => MathOp::Ext(spirv::GlslStd450Op::UnpackSnorm2x16),
1901                    fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => {
1902                        let is_signed = matches!(fun, Mf::Unpack4xI8);
1903
1904                        let last_instruction =
1905                            if self.writer.require_all(&[spirv::Capability::Int8]).is_ok() {
1906                                self.write_unpack4x8_optimized(
1907                                    block,
1908                                    result_type_id,
1909                                    arg0_id,
1910                                    id,
1911                                    is_signed,
1912                                )
1913                            } else {
1914                                self.write_unpack4x8_polyfill(
1915                                    block,
1916                                    result_type_id,
1917                                    arg0_id,
1918                                    id,
1919                                    is_signed,
1920                                )
1921                            };
1922
1923                        MathOp::Custom(last_instruction)
1924                    }
1925                };
1926
1927                block.body.push(match math_op {
1928                    MathOp::Ext(op) => Instruction::ext_inst_gl_op(
1929                        self.writer.gl450_ext_inst_id,
1930                        op,
1931                        result_type_id,
1932                        id,
1933                        &[arg0_id, arg1_id, arg2_id, arg3_id][..fun.argument_count()],
1934                    ),
1935                    MathOp::Custom(inst) => inst,
1936                });
1937                id
1938            }
1939            crate::Expression::LocalVariable(variable) => {
1940                if let Some(rq_tracker) = self
1941                    .function
1942                    .ray_query_initialization_tracker_variables
1943                    .get(&variable)
1944                {
1945                    self.ray_query_tracker_expr.insert(
1946                        expr_handle,
1947                        super::RayQueryTrackers {
1948                            initialized_tracker: rq_tracker.id,
1949                            t_max_tracker: self
1950                                .function
1951                                .ray_query_t_max_tracker_variables
1952                                .get(&variable)
1953                                .expect("Both trackers are set at the same time.")
1954                                .id,
1955                        },
1956                    );
1957                }
1958                self.function.variables[&variable].id
1959            }
1960            crate::Expression::Load { pointer } => {
1961                self.write_checked_load(pointer, block, AccessTypeAdjustment::None, result_type_id)?
1962            }
1963            crate::Expression::FunctionArgument(index) => self.function.parameter_id(index),
1964            crate::Expression::CallResult(_)
1965            | crate::Expression::AtomicResult { .. }
1966            | crate::Expression::WorkGroupUniformLoadResult { .. }
1967            | crate::Expression::RayQueryProceedResult
1968            | crate::Expression::SubgroupBallotResult
1969            | crate::Expression::SubgroupOperationResult { .. } => self.cached[expr_handle],
1970            crate::Expression::As {
1971                expr,
1972                kind,
1973                convert,
1974            } => self.write_as_expression(expr, convert, kind, block, result_type_id)?,
1975            crate::Expression::ImageLoad {
1976                image,
1977                coordinate,
1978                array_index,
1979                sample,
1980                level,
1981            } => self.write_image_load(
1982                result_type_id,
1983                image,
1984                coordinate,
1985                array_index,
1986                level,
1987                sample,
1988                block,
1989            )?,
1990            crate::Expression::ImageSample {
1991                image,
1992                sampler,
1993                gather,
1994                coordinate,
1995                array_index,
1996                offset,
1997                level,
1998                depth_ref,
1999                clamp_to_edge,
2000            } => self.write_image_sample(
2001                result_type_id,
2002                image,
2003                sampler,
2004                gather,
2005                coordinate,
2006                array_index,
2007                offset,
2008                level,
2009                depth_ref,
2010                clamp_to_edge,
2011                block,
2012            )?,
2013            crate::Expression::Select {
2014                condition,
2015                accept,
2016                reject,
2017            } => {
2018                let id = self.gen_id();
2019                let mut condition_id = self.cached[condition];
2020                let accept_id = self.cached[accept];
2021                let reject_id = self.cached[reject];
2022
2023                let condition_ty = self.fun_info[condition]
2024                    .ty
2025                    .inner_with(&self.ir_module.types);
2026                let object_ty = self.fun_info[accept].ty.inner_with(&self.ir_module.types);
2027
2028                if let (
2029                    &crate::TypeInner::Scalar(
2030                        condition_scalar @ crate::Scalar {
2031                            kind: crate::ScalarKind::Bool,
2032                            ..
2033                        },
2034                    ),
2035                    &crate::TypeInner::Vector { size, .. },
2036                ) = (condition_ty, object_ty)
2037                {
2038                    self.temp_list.clear();
2039                    self.temp_list.resize(size as usize, condition_id);
2040
2041                    let bool_vector_type_id = self.get_numeric_type_id(NumericType::Vector {
2042                        size,
2043                        scalar: condition_scalar,
2044                    });
2045
2046                    let id = self.gen_id();
2047                    block.body.push(Instruction::composite_construct(
2048                        bool_vector_type_id,
2049                        id,
2050                        &self.temp_list,
2051                    ));
2052                    condition_id = id
2053                }
2054
2055                let instruction =
2056                    Instruction::select(result_type_id, id, condition_id, accept_id, reject_id);
2057                block.body.push(instruction);
2058                id
2059            }
2060            crate::Expression::Derivative { axis, ctrl, expr } => {
2061                use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
2062                match ctrl {
2063                    Ctrl::Coarse | Ctrl::Fine => {
2064                        self.writer.require_any(
2065                            "DerivativeControl",
2066                            &[spirv::Capability::DerivativeControl],
2067                        )?;
2068                    }
2069                    Ctrl::None => {}
2070                }
2071                let id = self.gen_id();
2072                let expr_id = self.cached[expr];
2073                let op = match (axis, ctrl) {
2074                    (Axis::X, Ctrl::Coarse) => spirv::Op::DPdxCoarse,
2075                    (Axis::X, Ctrl::Fine) => spirv::Op::DPdxFine,
2076                    (Axis::X, Ctrl::None) => spirv::Op::DPdx,
2077                    (Axis::Y, Ctrl::Coarse) => spirv::Op::DPdyCoarse,
2078                    (Axis::Y, Ctrl::Fine) => spirv::Op::DPdyFine,
2079                    (Axis::Y, Ctrl::None) => spirv::Op::DPdy,
2080                    (Axis::Width, Ctrl::Coarse) => spirv::Op::FwidthCoarse,
2081                    (Axis::Width, Ctrl::Fine) => spirv::Op::FwidthFine,
2082                    (Axis::Width, Ctrl::None) => spirv::Op::Fwidth,
2083                };
2084                block
2085                    .body
2086                    .push(Instruction::derivative(op, result_type_id, id, expr_id));
2087                id
2088            }
2089            crate::Expression::ImageQuery { image, query } => {
2090                self.write_image_query(result_type_id, image, query, block)?
2091            }
2092            crate::Expression::Relational { fun, argument } => {
2093                use crate::RelationalFunction as Rf;
2094                let arg_id = self.cached[argument];
2095                let op = match fun {
2096                    Rf::All => spirv::Op::All,
2097                    Rf::Any => spirv::Op::Any,
2098                    Rf::IsNan => spirv::Op::IsNan,
2099                    Rf::IsInf => spirv::Op::IsInf,
2100                };
2101                let id = self.gen_id();
2102                block
2103                    .body
2104                    .push(Instruction::relational(op, result_type_id, id, arg_id));
2105                id
2106            }
2107            crate::Expression::ArrayLength(expr) => self.write_runtime_array_length(expr, block)?,
2108            crate::Expression::RayQueryGetIntersection { query, committed } => {
2109                let query_id = self.cached[query];
2110                let init_tracker_id = *self
2111                    .ray_query_tracker_expr
2112                    .get(&query)
2113                    .expect("not a cached ray query");
2114                let func_id = self
2115                    .writer
2116                    .write_ray_query_get_intersection_function(committed, self.ir_module);
2117                let ray_intersection = self.ir_module.special_types.ray_intersection.unwrap();
2118                let intersection_type_id = self.get_handle_type_id(ray_intersection);
2119                let id = self.gen_id();
2120                block.body.push(Instruction::function_call(
2121                    intersection_type_id,
2122                    id,
2123                    func_id,
2124                    &[query_id, init_tracker_id.initialized_tracker],
2125                ));
2126                id
2127            }
2128            crate::Expression::RayQueryVertexPositions { query, committed } => {
2129                self.writer.require_any(
2130                    "RayQueryVertexPositions",
2131                    &[spirv::Capability::RayQueryPositionFetchKHR],
2132                )?;
2133                self.write_ray_query_return_vertex_position(query, block, committed)
2134            }
2135            crate::Expression::CooperativeLoad { ref data, .. } => {
2136                self.writer.require_any(
2137                    "CooperativeMatrix",
2138                    &[spirv::Capability::CooperativeMatrixKHR],
2139                )?;
2140                let layout = if data.row_major {
2141                    spirv::CooperativeMatrixLayout::RowMajorKHR
2142                } else {
2143                    spirv::CooperativeMatrixLayout::ColumnMajorKHR
2144                };
2145                let layout_id = self.get_index_constant(layout as u32);
2146                let stride_id = self.cached[data.stride];
2147                match self.write_access_chain(data.pointer, block, AccessTypeAdjustment::None)? {
2148                    ExpressionPointer::Ready { pointer_id } => {
2149                        let id = self.gen_id();
2150                        block.body.push(Instruction::coop_load(
2151                            result_type_id,
2152                            id,
2153                            pointer_id,
2154                            layout_id,
2155                            stride_id,
2156                        ));
2157                        id
2158                    }
2159                    ExpressionPointer::Conditional { condition, access } => self
2160                        .write_conditional_indexed_load(
2161                            result_type_id,
2162                            condition,
2163                            block,
2164                            |id_gen, block| {
2165                                let pointer_id = access.result_id.unwrap();
2166                                block.body.push(access);
2167                                let id = id_gen.next();
2168                                block.body.push(Instruction::coop_load(
2169                                    result_type_id,
2170                                    id,
2171                                    pointer_id,
2172                                    layout_id,
2173                                    stride_id,
2174                                ));
2175                                id
2176                            },
2177                        ),
2178                }
2179            }
2180            crate::Expression::CooperativeMultiplyAdd { a, b, c } => {
2181                self.writer.require_any(
2182                    "CooperativeMatrix",
2183                    &[spirv::Capability::CooperativeMatrixKHR],
2184                )?;
2185                let a_id = self.cached[a];
2186                let b_id = self.cached[b];
2187                let c_id = self.cached[c];
2188                let id = self.gen_id();
2189                block.body.push(Instruction::coop_mul_add(
2190                    result_type_id,
2191                    id,
2192                    a_id,
2193                    b_id,
2194                    c_id,
2195                ));
2196                id
2197            }
2198        };
2199
2200        self.cached[expr_handle] = id;
2201        Ok(())
2202    }
2203
2204    /// Helper which focuses on generating the `As` expressions and the various conversions
2205    /// that need to happen because of that.
2206    fn write_as_expression(
2207        &mut self,
2208        expr: Handle<crate::Expression>,
2209        convert: Option<u8>,
2210        kind: crate::ScalarKind,
2211
2212        block: &mut Block,
2213        result_type_id: u32,
2214    ) -> Result<u32, Error> {
2215        use crate::ScalarKind as Sk;
2216        let expr_id = self.cached[expr];
2217        let ty = self.fun_info[expr].ty.inner_with(&self.ir_module.types);
2218
2219        // Matrix casts needs special treatment in SPIR-V, as the cast functions
2220        // can take vectors or scalars, but not matrices. In order to cast a matrix
2221        // we need to cast each column of the matrix individually and construct a new
2222        // matrix from the converted columns.
2223        if let crate::TypeInner::Matrix {
2224            columns,
2225            rows,
2226            scalar,
2227        } = *ty
2228        {
2229            let Some(convert) = convert else {
2230                // No conversion needs to be done, passes through.
2231                return Ok(expr_id);
2232            };
2233
2234            if convert == scalar.width {
2235                // No conversion needs to be done, passes through.
2236                return Ok(expr_id);
2237            }
2238
2239            if kind != Sk::Float {
2240                // Only float conversions are supported for matrices.
2241                return Err(Error::Validation("Matrices must be floats"));
2242            }
2243
2244            // Type of each extracted column
2245            let column_src_ty =
2246                self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector {
2247                    size: rows,
2248                    scalar,
2249                })));
2250
2251            // Type of the column after conversion
2252            let column_dst_ty =
2253                self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector {
2254                    size: rows,
2255                    scalar: crate::Scalar {
2256                        kind,
2257                        width: convert,
2258                    },
2259                })));
2260
2261            let mut components = ArrayVec::<Word, 4>::new();
2262
2263            for column in 0..columns as usize {
2264                let column_id = self.gen_id();
2265                block.body.push(Instruction::composite_extract(
2266                    column_src_ty,
2267                    column_id,
2268                    expr_id,
2269                    &[column as u32],
2270                ));
2271
2272                let column_conv_id = self.gen_id();
2273                block.body.push(Instruction::unary(
2274                    spirv::Op::FConvert,
2275                    column_dst_ty,
2276                    column_conv_id,
2277                    column_id,
2278                ));
2279
2280                components.push(column_conv_id);
2281            }
2282
2283            let construct_id = self.gen_id();
2284
2285            block.body.push(Instruction::composite_construct(
2286                result_type_id,
2287                construct_id,
2288                &components,
2289            ));
2290
2291            return Ok(construct_id);
2292        }
2293
2294        let (src_scalar, src_size) = match *ty {
2295            crate::TypeInner::Scalar(scalar) => (scalar, None),
2296            crate::TypeInner::Vector { scalar, size } => (scalar, Some(size)),
2297            ref other => {
2298                log::error!("As source {other:?}");
2299                return Err(Error::Validation("Unexpected Expression::As source"));
2300            }
2301        };
2302
2303        enum Cast {
2304            Identity(Word),
2305            Unary(spirv::Op, Word),
2306            Binary(spirv::Op, Word, Word),
2307            Ternary(spirv::Op, Word, Word, Word),
2308        }
2309        let cast = match (src_scalar.kind, kind, convert) {
2310            // Filter out identity casts. Some Adreno drivers are
2311            // confused by no-op OpBitCast instructions.
2312            (src_kind, kind, convert)
2313                if src_kind == kind
2314                    && convert.filter(|&width| width != src_scalar.width).is_none() =>
2315            {
2316                Cast::Identity(expr_id)
2317            }
2318            (Sk::Bool, Sk::Bool, _) => Cast::Unary(spirv::Op::CopyObject, expr_id),
2319            (_, _, None) => Cast::Unary(spirv::Op::Bitcast, expr_id),
2320            // casting to a bool - generate `OpXxxNotEqual`
2321            (_, Sk::Bool, Some(_)) => {
2322                let op = match src_scalar.kind {
2323                    Sk::Sint | Sk::Uint => spirv::Op::INotEqual,
2324                    Sk::Float => spirv::Op::FUnordNotEqual,
2325                    Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => unreachable!(),
2326                };
2327                let zero_scalar_id = self.writer.get_constant_scalar_with(0, src_scalar)?;
2328                let zero_id = match src_size {
2329                    Some(size) => {
2330                        let ty = LocalType::Numeric(NumericType::Vector {
2331                            size,
2332                            scalar: src_scalar,
2333                        })
2334                        .into();
2335
2336                        self.temp_list.clear();
2337                        self.temp_list.resize(size as _, zero_scalar_id);
2338
2339                        self.writer.get_constant_composite(ty, &self.temp_list)
2340                    }
2341                    None => zero_scalar_id,
2342                };
2343
2344                Cast::Binary(op, expr_id, zero_id)
2345            }
2346            // casting from a bool - generate `OpSelect`
2347            (Sk::Bool, _, Some(dst_width)) => {
2348                let dst_scalar = crate::Scalar {
2349                    kind,
2350                    width: dst_width,
2351                };
2352                let zero_scalar_id = self.writer.get_constant_scalar_with(0, dst_scalar)?;
2353                let one_scalar_id = self.writer.get_constant_scalar_with(1, dst_scalar)?;
2354                let (accept_id, reject_id) = match src_size {
2355                    Some(size) => {
2356                        let ty = LocalType::Numeric(NumericType::Vector {
2357                            size,
2358                            scalar: dst_scalar,
2359                        })
2360                        .into();
2361
2362                        self.temp_list.clear();
2363                        self.temp_list.resize(size as _, zero_scalar_id);
2364
2365                        let vec0_id = self.writer.get_constant_composite(ty, &self.temp_list);
2366
2367                        self.temp_list.fill(one_scalar_id);
2368
2369                        let vec1_id = self.writer.get_constant_composite(ty, &self.temp_list);
2370
2371                        (vec1_id, vec0_id)
2372                    }
2373                    None => (one_scalar_id, zero_scalar_id),
2374                };
2375
2376                Cast::Ternary(spirv::Op::Select, expr_id, accept_id, reject_id)
2377            }
2378            // Avoid undefined behaviour when casting from a float to integer
2379            // when the value is out of range for the target type. Additionally
2380            // ensure we clamp to the correct value as per the WGSL spec.
2381            //
2382            // https://www.w3.org/TR/WGSL/#floating-point-conversion:
2383            // * If X is exactly representable in the target type T, then the
2384            //   result is that value.
2385            // * Otherwise, the result is the value in T closest to
2386            //   truncate(X) and also exactly representable in the original
2387            //   floating point type.
2388            (Sk::Float, Sk::Sint | Sk::Uint, Some(width)) => {
2389                let dst_scalar = crate::Scalar { kind, width };
2390                let (min, max) =
2391                    crate::proc::min_max_float_representable_by(src_scalar, dst_scalar);
2392                let expr_type_id = self.get_expression_type_id(&self.fun_info[expr].ty);
2393
2394                let maybe_splat_const = |writer: &mut Writer, const_id| match src_size {
2395                    None => const_id,
2396                    Some(size) => {
2397                        let constituent_ids = [const_id; crate::VectorSize::MAX];
2398                        writer.get_constant_composite(
2399                            LookupType::Local(LocalType::Numeric(NumericType::Vector {
2400                                size,
2401                                scalar: src_scalar,
2402                            })),
2403                            &constituent_ids[..size as usize],
2404                        )
2405                    }
2406                };
2407                let min_const_id = self.writer.get_constant_scalar(min);
2408                let min_const_id = maybe_splat_const(self.writer, min_const_id);
2409                let max_const_id = self.writer.get_constant_scalar(max);
2410                let max_const_id = maybe_splat_const(self.writer, max_const_id);
2411
2412                let clamp_id = self.gen_id();
2413                block.body.push(Instruction::ext_inst_gl_op(
2414                    self.writer.gl450_ext_inst_id,
2415                    spirv::GlslStd450Op::FClamp,
2416                    expr_type_id,
2417                    clamp_id,
2418                    &[expr_id, min_const_id, max_const_id],
2419                ));
2420
2421                let op = match dst_scalar.kind {
2422                    crate::ScalarKind::Sint => spirv::Op::ConvertFToS,
2423                    crate::ScalarKind::Uint => spirv::Op::ConvertFToU,
2424                    _ => unreachable!(),
2425                };
2426                Cast::Unary(op, clamp_id)
2427            }
2428            (Sk::Float, Sk::Float, Some(dst_width)) if src_scalar.width != dst_width => {
2429                Cast::Unary(spirv::Op::FConvert, expr_id)
2430            }
2431            (Sk::Sint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertSToF, expr_id),
2432            (Sk::Sint, Sk::Sint, Some(dst_width)) if src_scalar.width != dst_width => {
2433                Cast::Unary(spirv::Op::SConvert, expr_id)
2434            }
2435            (Sk::Uint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertUToF, expr_id),
2436            (Sk::Uint, Sk::Uint, Some(dst_width)) if src_scalar.width != dst_width => {
2437                Cast::Unary(spirv::Op::UConvert, expr_id)
2438            }
2439            (Sk::Uint, Sk::Sint, Some(dst_width)) if src_scalar.width != dst_width => {
2440                Cast::Unary(spirv::Op::SConvert, expr_id)
2441            }
2442            (Sk::Sint, Sk::Uint, Some(dst_width)) if src_scalar.width != dst_width => {
2443                Cast::Unary(spirv::Op::UConvert, expr_id)
2444            }
2445            // We assume it's either an identity cast, or int-uint.
2446            _ => Cast::Unary(spirv::Op::Bitcast, expr_id),
2447        };
2448        Ok(match cast {
2449            Cast::Identity(expr) => expr,
2450            Cast::Unary(op, op1) => {
2451                let id = self.gen_id();
2452                block
2453                    .body
2454                    .push(Instruction::unary(op, result_type_id, id, op1));
2455                id
2456            }
2457            Cast::Binary(op, op1, op2) => {
2458                let id = self.gen_id();
2459                block
2460                    .body
2461                    .push(Instruction::binary(op, result_type_id, id, op1, op2));
2462                id
2463            }
2464            Cast::Ternary(op, op1, op2, op3) => {
2465                let id = self.gen_id();
2466                block
2467                    .body
2468                    .push(Instruction::ternary(op, result_type_id, id, op1, op2, op3));
2469                id
2470            }
2471        })
2472    }
2473
2474    /// Build an `OpAccessChain` instruction.
2475    ///
2476    /// Emit any needed bounds-checking expressions to `block`.
2477    ///
2478    /// Give the `OpAccessChain` a result type based on `expr_handle`, adjusted
2479    /// according to `type_adjustment`; see the documentation for
2480    /// [`AccessTypeAdjustment`] for details.
2481    ///
2482    /// On success, the return value is an [`ExpressionPointer`] value; see the
2483    /// documentation for that type.
2484    fn write_access_chain(
2485        &mut self,
2486        mut expr_handle: Handle<crate::Expression>,
2487        block: &mut Block,
2488        type_adjustment: AccessTypeAdjustment,
2489    ) -> Result<ExpressionPointer, Error> {
2490        let result_type_id = {
2491            let resolution = &self.fun_info[expr_handle].ty;
2492            match type_adjustment {
2493                AccessTypeAdjustment::None => self.writer.get_expression_type_id(resolution),
2494                AccessTypeAdjustment::IntroducePointer(class) => {
2495                    self.writer.get_resolution_pointer_id(resolution, class)
2496                }
2497                AccessTypeAdjustment::UseStd140CompatType => {
2498                    match *resolution.inner_with(&self.ir_module.types) {
2499                        crate::TypeInner::Pointer {
2500                            base,
2501                            space: space @ crate::AddressSpace::Uniform,
2502                        } => self.writer.get_pointer_type_id(
2503                            self.writer.std140_compat_uniform_types[&base].type_id,
2504                            map_storage_class(space),
2505                        ),
2506                        _ => unreachable!(
2507                            "`UseStd140CompatType` must only be used with uniform pointer types"
2508                        ),
2509                    }
2510                }
2511            }
2512        };
2513
2514        // The id of the boolean `and` of all dynamic bounds checks up to this point.
2515        //
2516        // See `extend_bounds_check_condition_chain` for a full explanation.
2517        let mut accumulated_checks = None;
2518
2519        // Is true if we are accessing into a binding array with a non-uniform index.
2520        let mut is_non_uniform_binding_array = false;
2521
2522        // The index value if the previously encountered expression was an
2523        // `AccessIndex` of a matrix which has been decomposed into individual
2524        // column vectors directly in the containing struct. The subsequent
2525        // iteration will append the correct index to the list for accessing
2526        // said column from the containing struct.
2527        let mut prev_decomposed_matrix_index = None;
2528
2529        self.temp_list.clear();
2530        let root_id = loop {
2531            // If `expr_handle` was spilled, then the temporary variable has exactly
2532            // the value we want to start from.
2533            if let Some(spilled) = self.function.spilled_composites.get(&expr_handle) {
2534                // The root id of the `OpAccessChain` instruction is the temporary
2535                // variable we spilled the composite to.
2536                break spilled.id;
2537            }
2538
2539            expr_handle = match self.ir_function.expressions[expr_handle] {
2540                crate::Expression::Access { base, index } => {
2541                    is_non_uniform_binding_array |=
2542                        self.is_nonuniform_binding_array_access(base, index);
2543
2544                    let index = GuardedIndex::Expression(index);
2545                    let index_id =
2546                        self.write_access_chain_index(base, index, &mut accumulated_checks, block)?;
2547                    self.temp_list.push(index_id);
2548
2549                    base
2550                }
2551                crate::Expression::AccessIndex { base, index } => {
2552                    // Decide whether we're indexing a struct (bounds checks
2553                    // forbidden) or anything else (bounds checks required).
2554                    let mut base_ty = self.fun_info[base].ty.inner_with(&self.ir_module.types);
2555                    let mut base_ty_handle = self.fun_info[base].ty.handle();
2556                    let mut pointer_space = None;
2557                    if let crate::TypeInner::Pointer { base, space } = *base_ty {
2558                        base_ty = &self.ir_module.types[base].inner;
2559                        base_ty_handle = Some(base);
2560                        pointer_space = Some(space);
2561                    }
2562                    match *base_ty {
2563                        // When indexing a struct bounds checks are forbidden. If accessing the
2564                        // struct through a uniform address space pointer, where the struct has
2565                        // been declared with an alternative std140 compatible layout, we must use
2566                        // the remapped member index. Additionally if the previous iteration was
2567                        // accessing a column of a matrix member which has been decomposed directly
2568                        // into the struct, we must ensure we access the correct column.
2569                        crate::TypeInner::Struct { .. } => {
2570                            let index = match base_ty_handle.and_then(|handle| {
2571                                self.writer.std140_compat_uniform_types.get(&handle)
2572                            }) {
2573                                Some(std140_type_info)
2574                                    if pointer_space == Some(crate::AddressSpace::Uniform) =>
2575                                {
2576                                    std140_type_info.member_indices[index as usize]
2577                                        + prev_decomposed_matrix_index.take().unwrap_or(0)
2578                                }
2579                                _ => index,
2580                            };
2581                            let index_id = self.get_index_constant(index);
2582                            self.temp_list.push(index_id);
2583                        }
2584                        // Bounds checks are not required when indexing a matrix. If indexing a
2585                        // two-row matrix contained within a struct through a uniform address space
2586                        // pointer then the matrix' columns will have been decomposed directly into
2587                        // the containing struct. We skip adding an index to the list on this
2588                        // iteration and instead adjust the index on the next iteration when
2589                        // accessing the struct member.
2590                        _ if is_uniform_matcx2_struct_member_access(
2591                            self.ir_function,
2592                            self.fun_info,
2593                            self.ir_module,
2594                            base,
2595                        ) =>
2596                        {
2597                            assert!(prev_decomposed_matrix_index.is_none());
2598                            prev_decomposed_matrix_index = Some(index);
2599                        }
2600                        _ => {
2601                            // `index` is constant, so this can't possibly require
2602                            // setting `is_nonuniform_binding_array_access`.
2603
2604                            // Even though the index value is statically known, `base`
2605                            // may be a runtime-sized array, so we still need to go
2606                            // through the bounds check process.
2607                            let index_id = self.write_access_chain_index(
2608                                base,
2609                                GuardedIndex::Known(index),
2610                                &mut accumulated_checks,
2611                                block,
2612                            )?;
2613                            self.temp_list.push(index_id);
2614                        }
2615                    }
2616                    base
2617                }
2618                crate::Expression::GlobalVariable(handle) => {
2619                    let gv = &self.writer.global_variables[handle];
2620                    break gv.access_id;
2621                }
2622                crate::Expression::LocalVariable(variable) => {
2623                    let local_var = &self.function.variables[&variable];
2624                    break local_var.id;
2625                }
2626                crate::Expression::FunctionArgument(index) => {
2627                    break self.function.parameter_id(index);
2628                }
2629                ref other => unimplemented!("Unexpected pointer expression {:?}", other),
2630            }
2631        };
2632
2633        let (pointer_id, expr_pointer) = if self.temp_list.is_empty() {
2634            (
2635                root_id,
2636                ExpressionPointer::Ready {
2637                    pointer_id: root_id,
2638                },
2639            )
2640        } else {
2641            self.temp_list.reverse();
2642            let pointer_id = self.gen_id();
2643            let access =
2644                Instruction::access_chain(result_type_id, pointer_id, root_id, &self.temp_list);
2645
2646            // If we generated some bounds checks, we need to leave it to our
2647            // caller to generate the branch, the access, the load or store, and
2648            // the zero value (for loads). Otherwise, we can emit the access
2649            // ourselves, and just hand them the id of the pointer.
2650            let expr_pointer = match accumulated_checks {
2651                Some(condition) => ExpressionPointer::Conditional { condition, access },
2652                None => {
2653                    block.body.push(access);
2654                    ExpressionPointer::Ready { pointer_id }
2655                }
2656            };
2657            (pointer_id, expr_pointer)
2658        };
2659        // Subsequent load, store and atomic operations require the pointer to be decorated as NonUniform
2660        // if the binding array was accessed with a non-uniform index
2661        // see VUID-RuntimeSpirv-NonUniform-06274
2662        if is_non_uniform_binding_array {
2663            self.writer
2664                .decorate_non_uniform_binding_array_access(pointer_id)?;
2665        }
2666
2667        Ok(expr_pointer)
2668    }
2669
2670    fn is_nonuniform_binding_array_access(
2671        &mut self,
2672        base: Handle<crate::Expression>,
2673        index: Handle<crate::Expression>,
2674    ) -> bool {
2675        let crate::Expression::GlobalVariable(var_handle) = self.ir_function.expressions[base]
2676        else {
2677            return false;
2678        };
2679
2680        // The access chain needs to be decorated as NonUniform
2681        // see VUID-RuntimeSpirv-NonUniform-06274
2682        let gvar = &self.ir_module.global_variables[var_handle];
2683        let crate::TypeInner::BindingArray { .. } = self.ir_module.types[gvar.ty].inner else {
2684            return false;
2685        };
2686
2687        self.fun_info[index].uniformity.non_uniform_result.is_some()
2688    }
2689
2690    /// Compute a single index operand to an `OpAccessChain` instruction.
2691    ///
2692    /// Given that we are indexing `base` with `index`, apply the appropriate
2693    /// bounds check policies, emitting code to `block` to clamp `index` or
2694    /// determine whether it's in bounds. Return the SPIR-V instruction id of
2695    /// the index value we should actually use.
2696    ///
2697    /// Extend `accumulated_checks` to include the results of any needed bounds
2698    /// checks. See [`BlockContext::extend_bounds_check_condition_chain`].
2699    fn write_access_chain_index(
2700        &mut self,
2701        base: Handle<crate::Expression>,
2702        index: GuardedIndex,
2703        accumulated_checks: &mut Option<Word>,
2704        block: &mut Block,
2705    ) -> Result<Word, Error> {
2706        match self.write_bounds_check(base, index, block)? {
2707            BoundsCheckResult::KnownInBounds(known_index) => {
2708                // Even if the index is known, `OpAccessChain`
2709                // requires expression operands, not literals.
2710                let scalar = crate::Literal::U32(known_index);
2711                Ok(self.writer.get_constant_scalar(scalar))
2712            }
2713            BoundsCheckResult::Computed(computed_index_id) => Ok(computed_index_id),
2714            BoundsCheckResult::Conditional {
2715                condition_id: condition,
2716                index_id: index,
2717            } => {
2718                self.extend_bounds_check_condition_chain(accumulated_checks, condition, block);
2719
2720                // Use the index from the `Access` expression unchanged.
2721                Ok(index)
2722            }
2723        }
2724    }
2725
2726    /// Add a condition to a chain of bounds checks.
2727    ///
2728    /// As we build an `OpAccessChain` instruction govered by
2729    /// [`BoundsCheckPolicy::ReadZeroSkipWrite`], we accumulate a chain of
2730    /// dynamic bounds checks, one for each index in the chain, which must all
2731    /// be true for that `OpAccessChain`'s execution to be well-defined. This
2732    /// function adds the boolean instruction id `comparison_id` to `chain`.
2733    ///
2734    /// If `chain` is `None`, that means there are no bounds checks in the chain
2735    /// yet. If chain is `Some(id)`, then `id` is the conjunction of all the
2736    /// bounds checks in the chain.
2737    ///
2738    /// When we have multiple bounds checks, we combine them with
2739    /// `OpLogicalAnd`, not a short-circuit branch. This means we might do
2740    /// comparisons we don't need to, but we expect these checks to almost
2741    /// always succeed, and keeping branches to a minimum is essential.
2742    ///
2743    /// [`BoundsCheckPolicy::ReadZeroSkipWrite`]: crate::proc::BoundsCheckPolicy
2744    fn extend_bounds_check_condition_chain(
2745        &mut self,
2746        chain: &mut Option<Word>,
2747        comparison_id: Word,
2748        block: &mut Block,
2749    ) {
2750        match *chain {
2751            Some(ref mut prior_checks) => {
2752                let combined = self.gen_id();
2753                block.body.push(Instruction::binary(
2754                    spirv::Op::LogicalAnd,
2755                    self.writer.get_bool_type_id(),
2756                    combined,
2757                    *prior_checks,
2758                    comparison_id,
2759                ));
2760                *prior_checks = combined;
2761            }
2762            None => {
2763                // Start a fresh chain of checks.
2764                *chain = Some(comparison_id);
2765            }
2766        }
2767    }
2768
2769    fn write_checked_load(
2770        &mut self,
2771        pointer: Handle<crate::Expression>,
2772        block: &mut Block,
2773        access_type_adjustment: AccessTypeAdjustment,
2774        result_type_id: Word,
2775    ) -> Result<Word, Error> {
2776        if let Some(result_id) = self.maybe_write_uniform_matcx2_dynamic_access(pointer, block)? {
2777            Ok(result_id)
2778        } else if let Some(result_id) =
2779            self.maybe_write_load_uniform_matcx2_struct_member(pointer, block)?
2780        {
2781            Ok(result_id)
2782        } else {
2783            // If `pointer` refers to a uniform address space pointer to a type
2784            // which was declared using a std140 compatible type variant (i.e.
2785            // is a two-row matrix, or a struct or array containing such a
2786            // matrix) we must ensure the access chain and the type of the load
2787            // instruction use the std140 compatible type variant.
2788            struct WrappedLoad {
2789                access_type_adjustment: AccessTypeAdjustment,
2790                r#type: Handle<crate::Type>,
2791            }
2792            let mut wrapped_load = None;
2793            if let crate::TypeInner::Pointer {
2794                base: pointer_base_type,
2795                space: crate::AddressSpace::Uniform,
2796            } = *self.fun_info[pointer].ty.inner_with(&self.ir_module.types)
2797            {
2798                if self
2799                    .writer
2800                    .std140_compat_uniform_types
2801                    .contains_key(&pointer_base_type)
2802                {
2803                    wrapped_load = Some(WrappedLoad {
2804                        access_type_adjustment: AccessTypeAdjustment::UseStd140CompatType,
2805                        r#type: pointer_base_type,
2806                    });
2807                };
2808            };
2809
2810            let (load_type_id, access_type_adjustment) = match wrapped_load {
2811                Some(ref wrapped_load) => (
2812                    self.writer.std140_compat_uniform_types[&wrapped_load.r#type].type_id,
2813                    wrapped_load.access_type_adjustment,
2814                ),
2815                None => (result_type_id, access_type_adjustment),
2816            };
2817
2818            let load_id = match self.write_access_chain(pointer, block, access_type_adjustment)? {
2819                ExpressionPointer::Ready { pointer_id } => {
2820                    let id = self.gen_id();
2821                    let atomic_space =
2822                        match *self.fun_info[pointer].ty.inner_with(&self.ir_module.types) {
2823                            crate::TypeInner::Pointer { base, space } => {
2824                                match self.ir_module.types[base].inner {
2825                                    crate::TypeInner::Atomic { .. } => Some(space),
2826                                    _ => None,
2827                                }
2828                            }
2829                            _ => None,
2830                        };
2831                    let instruction = if let Some(space) = atomic_space {
2832                        let (semantics, scope) = space.to_spirv_semantics_and_scope();
2833                        let scope_constant_id = self.get_scope_constant(scope as u32);
2834                        let semantics_id = self.get_index_constant(semantics.bits());
2835                        Instruction::atomic_load(
2836                            result_type_id,
2837                            id,
2838                            pointer_id,
2839                            scope_constant_id,
2840                            semantics_id,
2841                        )
2842                    } else {
2843                        Instruction::load(load_type_id, id, pointer_id, None)
2844                    };
2845                    block.body.push(instruction);
2846                    id
2847                }
2848                ExpressionPointer::Conditional { condition, access } => {
2849                    //TODO: support atomics?
2850                    self.write_conditional_indexed_load(
2851                        load_type_id,
2852                        condition,
2853                        block,
2854                        move |id_gen, block| {
2855                            // The in-bounds path. Perform the access and the load.
2856                            let pointer_id = access.result_id.unwrap();
2857                            let value_id = id_gen.next();
2858                            block.body.push(access);
2859                            block.body.push(Instruction::load(
2860                                load_type_id,
2861                                value_id,
2862                                pointer_id,
2863                                None,
2864                            ));
2865                            value_id
2866                        },
2867                    )
2868                }
2869            };
2870
2871            match wrapped_load {
2872                Some(ref wrapped_load) => {
2873                    // If we loaded a std140 compat type then we must call the
2874                    // function to convert the loaded value to the regular type.
2875                    let result_id = self.gen_id();
2876                    let function_id = self.writer.wrapped_functions
2877                        [&WrappedFunction::ConvertFromStd140CompatType {
2878                            r#type: wrapped_load.r#type,
2879                        }];
2880                    block.body.push(Instruction::function_call(
2881                        result_type_id,
2882                        result_id,
2883                        function_id,
2884                        &[load_id],
2885                    ));
2886                    Ok(result_id)
2887                }
2888                None => Ok(load_id),
2889            }
2890        }
2891    }
2892
2893    fn spill_to_internal_variable(&mut self, base: Handle<crate::Expression>, block: &mut Block) {
2894        use indexmap::map::Entry;
2895
2896        // Make sure we have an internal variable to spill `base` to.
2897        let spill_variable_id = match self.function.spilled_composites.entry(base) {
2898            Entry::Occupied(preexisting) => preexisting.get().id,
2899            Entry::Vacant(vacant) => {
2900                // Generate a new internal variable of the appropriate
2901                // type for `base`.
2902                let pointer_type_id = self.writer.get_resolution_pointer_id(
2903                    &self.fun_info[base].ty,
2904                    spirv::StorageClass::Function,
2905                );
2906                let id = self.writer.id_gen.next();
2907                vacant.insert(super::LocalVariable {
2908                    id,
2909                    instruction: Instruction::variable(
2910                        pointer_type_id,
2911                        id,
2912                        spirv::StorageClass::Function,
2913                        None,
2914                    ),
2915                });
2916                id
2917            }
2918        };
2919
2920        // Perform the store even if we already had a spill variable for `base`.
2921        // Consider this code:
2922        //
2923        // var x = ...;
2924        // var y = ...;
2925        // var z = ...;
2926        // for (i = 0; i<2; i++) {
2927        //     let a = array(i, i, i);
2928        //     if (i == 0) {
2929        //         x += a[y];
2930        //     } else [
2931        //         x += a[z];
2932        //     }
2933        // }
2934        //
2935        // The value of `a` needs to be spilled so we can subscript it with `y` and `z`.
2936        //
2937        // When we generate SPIR-V for `a[y]`, we will create the spill
2938        // variable, and store `a`'s value in it.
2939        //
2940        // When we generate SPIR-V for `a[z]`, we will notice that the spill
2941        // variable for `a` has already been declared, but it is still essential
2942        // that we store `a` into it, so that `a[z]` sees this iteration's value
2943        // of `a`.
2944        let base_id = self.cached[base];
2945        block
2946            .body
2947            .push(Instruction::store(spill_variable_id, base_id, None));
2948    }
2949
2950    /// Generate an access to a spilled temporary, if necessary.
2951    ///
2952    /// Given `access`, an [`Access`] or [`AccessIndex`] expression that refers
2953    /// to a component of a composite value that has been spilled to a temporary
2954    /// variable, determine whether other expressions are going to use
2955    /// `access`'s value:
2956    ///
2957    /// - If so, perform the access and cache that as the value of `access`.
2958    ///
2959    /// - Otherwise, generate no code and cache no value for `access`.
2960    ///
2961    /// Return `Ok(0)` if no value was fetched, or `Ok(id)` if we loaded it into
2962    /// the instruction given by `id`.
2963    ///
2964    /// [`Access`]: crate::Expression::Access
2965    /// [`AccessIndex`]: crate::Expression::AccessIndex
2966    fn maybe_access_spilled_composite(
2967        &mut self,
2968        access: Handle<crate::Expression>,
2969        block: &mut Block,
2970        result_type_id: Word,
2971    ) -> Result<Word, Error> {
2972        let access_uses = self.function.access_uses.get(&access).map_or(0, |r| *r);
2973        if access_uses == self.fun_info[access].ref_count {
2974            // This expression is only used by other `Access` and
2975            // `AccessIndex` expressions, so we don't need to cache a
2976            // value for it yet.
2977            Ok(0)
2978        } else {
2979            // There are other expressions that are going to expect this
2980            // expression's value to be cached, not just other `Access` or
2981            // `AccessIndex` expressions. We must actually perform the
2982            // access on the spill variable now.
2983            self.write_checked_load(
2984                access,
2985                block,
2986                AccessTypeAdjustment::IntroducePointer(spirv::StorageClass::Function),
2987                result_type_id,
2988            )
2989        }
2990    }
2991
2992    /// Build the instructions for matrix - matrix column operations
2993    #[allow(clippy::too_many_arguments)]
2994    fn write_matrix_matrix_column_op(
2995        &mut self,
2996        block: &mut Block,
2997        result_id: Word,
2998        result_type_id: Word,
2999        left_id: Word,
3000        right_id: Word,
3001        columns: crate::VectorSize,
3002        rows: crate::VectorSize,
3003        width: u8,
3004        op: spirv::Op,
3005    ) {
3006        self.temp_list.clear();
3007
3008        let vector_type_id = self.get_numeric_type_id(NumericType::Vector {
3009            size: rows,
3010            scalar: crate::Scalar::float(width),
3011        });
3012
3013        for index in 0..columns as u32 {
3014            let column_id_left = self.gen_id();
3015            let column_id_right = self.gen_id();
3016            let column_id_res = self.gen_id();
3017
3018            block.body.push(Instruction::composite_extract(
3019                vector_type_id,
3020                column_id_left,
3021                left_id,
3022                &[index],
3023            ));
3024            block.body.push(Instruction::composite_extract(
3025                vector_type_id,
3026                column_id_right,
3027                right_id,
3028                &[index],
3029            ));
3030            block.body.push(Instruction::binary(
3031                op,
3032                vector_type_id,
3033                column_id_res,
3034                column_id_left,
3035                column_id_right,
3036            ));
3037
3038            self.temp_list.push(column_id_res);
3039        }
3040
3041        block.body.push(Instruction::composite_construct(
3042            result_type_id,
3043            result_id,
3044            &self.temp_list,
3045        ));
3046    }
3047
3048    /// Build the instructions for vector - scalar multiplication
3049    fn write_vector_scalar_mult(
3050        &mut self,
3051        block: &mut Block,
3052        result_id: Word,
3053        result_type_id: Word,
3054        vector_id: Word,
3055        scalar_id: Word,
3056        vector: &crate::TypeInner,
3057    ) {
3058        let (size, kind) = match *vector {
3059            crate::TypeInner::Vector {
3060                size,
3061                scalar: crate::Scalar { kind, .. },
3062            } => (size, kind),
3063            _ => unreachable!(),
3064        };
3065
3066        let (op, operand_id) = match kind {
3067            crate::ScalarKind::Float => (spirv::Op::VectorTimesScalar, scalar_id),
3068            _ => {
3069                let operand_id = self.gen_id();
3070                self.temp_list.clear();
3071                self.temp_list.resize(size as usize, scalar_id);
3072                block.body.push(Instruction::composite_construct(
3073                    result_type_id,
3074                    operand_id,
3075                    &self.temp_list,
3076                ));
3077                (spirv::Op::IMul, operand_id)
3078            }
3079        };
3080
3081        block.body.push(Instruction::binary(
3082            op,
3083            result_type_id,
3084            result_id,
3085            vector_id,
3086            operand_id,
3087        ));
3088    }
3089
3090    /// Build the instructions for the arithmetic expression of a dot product
3091    ///
3092    /// The argument `extractor` is a function that maps `(result_id,
3093    /// composite_id, index)` to an instruction that extracts the `index`th
3094    /// entry of the value with ID `composite_id` and assigns it to the slot
3095    /// with id `result_id` (which must have type `result_type_id`).
3096    #[expect(clippy::too_many_arguments)]
3097    fn write_dot_product(
3098        &mut self,
3099        result_id: Word,
3100        result_type_id: Word,
3101        arg0_id: Word,
3102        arg1_id: Word,
3103        size: u32,
3104        block: &mut Block,
3105        extractor: impl Fn(Word, Word, Word) -> Instruction,
3106    ) {
3107        let mut partial_sum = self.writer.get_constant_null(result_type_id);
3108        let last_component = size - 1;
3109        for index in 0..=last_component {
3110            // compute the product of the current components
3111            let a_id = self.gen_id();
3112            block.body.push(extractor(a_id, arg0_id, index));
3113            let b_id = self.gen_id();
3114            block.body.push(extractor(b_id, arg1_id, index));
3115            let prod_id = self.gen_id();
3116            block.body.push(Instruction::binary(
3117                spirv::Op::IMul,
3118                result_type_id,
3119                prod_id,
3120                a_id,
3121                b_id,
3122            ));
3123
3124            // choose the id for the next sum, depending on current index
3125            let id = if index == last_component {
3126                result_id
3127            } else {
3128                self.gen_id()
3129            };
3130
3131            // sum the computed product with the partial sum
3132            block.body.push(Instruction::binary(
3133                spirv::Op::IAdd,
3134                result_type_id,
3135                id,
3136                partial_sum,
3137                prod_id,
3138            ));
3139            // set the id of the result as the previous partial sum
3140            partial_sum = id;
3141        }
3142    }
3143
3144    /// Emit code for `pack4x{I,U}8[Clamp]` if capability "Int8" is available.
3145    fn write_pack4x8_optimized(
3146        &mut self,
3147        block: &mut Block,
3148        result_type_id: u32,
3149        arg0_id: u32,
3150        id: u32,
3151        is_signed: bool,
3152        should_clamp: bool,
3153    ) -> Instruction {
3154        let int_type = if is_signed {
3155            crate::ScalarKind::Sint
3156        } else {
3157            crate::ScalarKind::Uint
3158        };
3159        let wide_vector_type = NumericType::Vector {
3160            size: crate::VectorSize::Quad,
3161            scalar: crate::Scalar {
3162                kind: int_type,
3163                width: 4,
3164            },
3165        };
3166        let wide_vector_type_id = self.get_numeric_type_id(wide_vector_type);
3167        let packed_vector_type_id = self.get_numeric_type_id(NumericType::Vector {
3168            size: crate::VectorSize::Quad,
3169            scalar: crate::Scalar {
3170                kind: crate::ScalarKind::Uint,
3171                width: 1,
3172            },
3173        });
3174
3175        let mut wide_vector = arg0_id;
3176        if should_clamp {
3177            let (min, max, clamp_op) = if is_signed {
3178                (
3179                    crate::Literal::I32(-128),
3180                    crate::Literal::I32(127),
3181                    spirv::GlslStd450Op::SClamp,
3182                )
3183            } else {
3184                (
3185                    crate::Literal::U32(0),
3186                    crate::Literal::U32(255),
3187                    spirv::GlslStd450Op::UClamp,
3188                )
3189            };
3190            let [min, max] = [min, max].map(|lit| {
3191                let scalar = self.writer.get_constant_scalar(lit);
3192                self.writer.get_constant_composite(
3193                    LookupType::Local(LocalType::Numeric(wide_vector_type)),
3194                    &[scalar; 4],
3195                )
3196            });
3197
3198            let clamp_id = self.gen_id();
3199            block.body.push(Instruction::ext_inst_gl_op(
3200                self.writer.gl450_ext_inst_id,
3201                clamp_op,
3202                wide_vector_type_id,
3203                clamp_id,
3204                &[wide_vector, min, max],
3205            ));
3206
3207            wide_vector = clamp_id;
3208        }
3209
3210        let packed_vector = self.gen_id();
3211        block.body.push(Instruction::unary(
3212            spirv::Op::UConvert, // We truncate, so `UConvert` and `SConvert` behave identically.
3213            packed_vector_type_id,
3214            packed_vector,
3215            wide_vector,
3216        ));
3217
3218        // The SPIR-V spec [1] defines the bit order for bit casting between a vector
3219        // and a scalar precisely as required by the WGSL spec [2].
3220        // [1]: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpBitcast
3221        // [2]: https://www.w3.org/TR/WGSL/#pack4xI8-builtin
3222        Instruction::unary(spirv::Op::Bitcast, result_type_id, id, packed_vector)
3223    }
3224
3225    /// Emit code for `pack4x{I,U}8[Clamp]` if capability "Int8" is not available.
3226    fn write_pack4x8_polyfill(
3227        &mut self,
3228        block: &mut Block,
3229        result_type_id: u32,
3230        arg0_id: u32,
3231        id: u32,
3232        is_signed: bool,
3233        should_clamp: bool,
3234    ) -> Instruction {
3235        let int_type = if is_signed {
3236            crate::ScalarKind::Sint
3237        } else {
3238            crate::ScalarKind::Uint
3239        };
3240        let uint_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
3241        let int_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar {
3242            kind: int_type,
3243            width: 4,
3244        }));
3245
3246        let mut last_instruction = Instruction::new(spirv::Op::Nop);
3247
3248        let zero = self.writer.get_constant_scalar(crate::Literal::U32(0));
3249        let mut preresult = zero;
3250        block
3251            .body
3252            .reserve(usize::from(VEC_LENGTH) * (2 + usize::from(is_signed)));
3253
3254        let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
3255        const VEC_LENGTH: u8 = 4;
3256        for i in 0..u32::from(VEC_LENGTH) {
3257            let offset = self.writer.get_constant_scalar(crate::Literal::U32(i * 8));
3258            let mut extracted = self.gen_id();
3259            block.body.push(Instruction::binary(
3260                spirv::Op::CompositeExtract,
3261                int_type_id,
3262                extracted,
3263                arg0_id,
3264                i,
3265            ));
3266            if is_signed {
3267                let casted = self.gen_id();
3268                block.body.push(Instruction::unary(
3269                    spirv::Op::Bitcast,
3270                    uint_type_id,
3271                    casted,
3272                    extracted,
3273                ));
3274                extracted = casted;
3275            }
3276            if should_clamp {
3277                let (min, max, clamp_op) = if is_signed {
3278                    (
3279                        crate::Literal::I32(-128),
3280                        crate::Literal::I32(127),
3281                        spirv::GlslStd450Op::SClamp,
3282                    )
3283                } else {
3284                    (
3285                        crate::Literal::U32(0),
3286                        crate::Literal::U32(255),
3287                        spirv::GlslStd450Op::UClamp,
3288                    )
3289                };
3290                let [min, max] = [min, max].map(|lit| self.writer.get_constant_scalar(lit));
3291
3292                let clamp_id = self.gen_id();
3293                block.body.push(Instruction::ext_inst_gl_op(
3294                    self.writer.gl450_ext_inst_id,
3295                    clamp_op,
3296                    result_type_id,
3297                    clamp_id,
3298                    &[extracted, min, max],
3299                ));
3300
3301                extracted = clamp_id;
3302            }
3303            let is_last = i == u32::from(VEC_LENGTH - 1);
3304            if is_last {
3305                last_instruction = Instruction::quaternary(
3306                    spirv::Op::BitFieldInsert,
3307                    result_type_id,
3308                    id,
3309                    preresult,
3310                    extracted,
3311                    offset,
3312                    eight,
3313                )
3314            } else {
3315                let new_preresult = self.gen_id();
3316                block.body.push(Instruction::quaternary(
3317                    spirv::Op::BitFieldInsert,
3318                    result_type_id,
3319                    new_preresult,
3320                    preresult,
3321                    extracted,
3322                    offset,
3323                    eight,
3324                ));
3325                preresult = new_preresult;
3326            }
3327        }
3328        last_instruction
3329    }
3330
3331    /// Emit code for `unpack4x{I,U}8` if capability "Int8" is available.
3332    fn write_unpack4x8_optimized(
3333        &mut self,
3334        block: &mut Block,
3335        result_type_id: u32,
3336        arg0_id: u32,
3337        id: u32,
3338        is_signed: bool,
3339    ) -> Instruction {
3340        let (int_type, convert_op) = if is_signed {
3341            (crate::ScalarKind::Sint, spirv::Op::SConvert)
3342        } else {
3343            (crate::ScalarKind::Uint, spirv::Op::UConvert)
3344        };
3345
3346        let packed_vector_type_id = self.get_numeric_type_id(NumericType::Vector {
3347            size: crate::VectorSize::Quad,
3348            scalar: crate::Scalar {
3349                kind: int_type,
3350                width: 1,
3351            },
3352        });
3353
3354        // The SPIR-V spec [1] defines the bit order for bit casting between a vector
3355        // and a scalar precisely as required by the WGSL spec [2].
3356        // [1]: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpBitcast
3357        // [2]: https://www.w3.org/TR/WGSL/#pack4xI8-builtin
3358        let packed_vector = self.gen_id();
3359        block.body.push(Instruction::unary(
3360            spirv::Op::Bitcast,
3361            packed_vector_type_id,
3362            packed_vector,
3363            arg0_id,
3364        ));
3365
3366        Instruction::unary(convert_op, result_type_id, id, packed_vector)
3367    }
3368
3369    /// Emit code for `unpack4x{I,U}8` if capability "Int8" is not available.
3370    fn write_unpack4x8_polyfill(
3371        &mut self,
3372        block: &mut Block,
3373        result_type_id: u32,
3374        arg0_id: u32,
3375        id: u32,
3376        is_signed: bool,
3377    ) -> Instruction {
3378        let (int_type, extract_op) = if is_signed {
3379            (crate::ScalarKind::Sint, spirv::Op::BitFieldSExtract)
3380        } else {
3381            (crate::ScalarKind::Uint, spirv::Op::BitFieldUExtract)
3382        };
3383
3384        let sint_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::I32));
3385
3386        let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
3387        let int_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar {
3388            kind: int_type,
3389            width: 4,
3390        }));
3391        block
3392            .body
3393            .reserve(usize::from(VEC_LENGTH) * 2 + usize::from(is_signed));
3394        let arg_id = if is_signed {
3395            let new_arg_id = self.gen_id();
3396            block.body.push(Instruction::unary(
3397                spirv::Op::Bitcast,
3398                sint_type_id,
3399                new_arg_id,
3400                arg0_id,
3401            ));
3402            new_arg_id
3403        } else {
3404            arg0_id
3405        };
3406
3407        const VEC_LENGTH: u8 = 4;
3408        let parts: [_; VEC_LENGTH as usize] = core::array::from_fn(|_| self.gen_id());
3409        for (i, part_id) in parts.into_iter().enumerate() {
3410            let index = self
3411                .writer
3412                .get_constant_scalar(crate::Literal::U32(i as u32 * 8));
3413            block.body.push(Instruction::ternary(
3414                extract_op,
3415                int_type_id,
3416                part_id,
3417                arg_id,
3418                index,
3419                eight,
3420            ));
3421        }
3422
3423        Instruction::composite_construct(result_type_id, id, &parts)
3424    }
3425
3426    /// Generate one or more SPIR-V blocks for `naga_block`.
3427    ///
3428    /// Use `label_id` as the label for the SPIR-V entry point block.
3429    ///
3430    /// If control reaches the end of the SPIR-V block, terminate it according
3431    /// to `exit`. This function's return value indicates whether it acted on
3432    /// this parameter or not; see [`BlockExitDisposition`].
3433    ///
3434    /// If the block contains [`Break`] or [`Continue`] statements,
3435    /// `loop_context` supplies the labels of the SPIR-V blocks to jump to. If
3436    /// either of these labels are `None`, then it should have been a Naga
3437    /// validation error for the corresponding statement to occur in this
3438    /// context.
3439    ///
3440    /// [`Break`]: Statement::Break
3441    /// [`Continue`]: Statement::Continue
3442    fn write_block(
3443        &mut self,
3444        label_id: Word,
3445        naga_block: &crate::Block,
3446        exit: BlockExit,
3447        loop_context: LoopContext,
3448        debug_info: Option<&DebugInfoInner>,
3449    ) -> Result<BlockExitDisposition, Error> {
3450        let mut block = Block::new(label_id);
3451        for (statement, span) in naga_block.span_iter() {
3452            if let (Some(debug_info), false) = (
3453                debug_info,
3454                matches!(
3455                    statement,
3456                    &(Statement::Block(..)
3457                        | Statement::Break
3458                        | Statement::Continue
3459                        | Statement::Kill
3460                        | Statement::Return { .. }
3461                        | Statement::Loop { .. })
3462                ),
3463            ) {
3464                let loc: crate::SourceLocation = span.location(debug_info.source_code);
3465                block.body.push(Instruction::line(
3466                    debug_info.source_file_id,
3467                    loc.line_number,
3468                    loc.line_position,
3469                ));
3470            };
3471            match *statement {
3472                Statement::Emit(ref range) => {
3473                    for handle in range.clone() {
3474                        // omit const expressions as we've already cached those
3475                        if !self.expression_constness.is_const(handle) {
3476                            self.cache_expression_value(handle, &mut block)?;
3477                        }
3478                    }
3479                }
3480                Statement::Block(ref block_statements) => {
3481                    let scope_id = self.gen_id();
3482                    self.function.consume(block, Instruction::branch(scope_id));
3483
3484                    let merge_id = self.gen_id();
3485                    let merge_used = self.write_block(
3486                        scope_id,
3487                        block_statements,
3488                        BlockExit::Branch { target: merge_id },
3489                        loop_context,
3490                        debug_info,
3491                    )?;
3492
3493                    match merge_used {
3494                        BlockExitDisposition::Used => {
3495                            block = Block::new(merge_id);
3496                        }
3497                        BlockExitDisposition::Discarded => {
3498                            return Ok(BlockExitDisposition::Discarded);
3499                        }
3500                    }
3501                }
3502                Statement::If {
3503                    condition,
3504                    ref accept,
3505                    ref reject,
3506                } => {
3507                    // In spirv 1.6, in a conditional branch the two block ids
3508                    // of the branches can't have the same label. If `accept`
3509                    // and `reject` are both empty (e.g. in `if (condition) {}`)
3510                    // merge id will be both labels. Because both branches are
3511                    // empty, we can skip the if statement.
3512                    if !(accept.is_empty() && reject.is_empty()) {
3513                        let condition_id = self.cached[condition];
3514
3515                        let merge_id = self.gen_id();
3516                        block.body.push(Instruction::selection_merge(
3517                            merge_id,
3518                            spirv::SelectionControl::NONE,
3519                        ));
3520
3521                        let accept_id = if accept.is_empty() {
3522                            None
3523                        } else {
3524                            Some(self.gen_id())
3525                        };
3526                        let reject_id = if reject.is_empty() {
3527                            None
3528                        } else {
3529                            Some(self.gen_id())
3530                        };
3531
3532                        self.function.consume(
3533                            block,
3534                            Instruction::branch_conditional(
3535                                condition_id,
3536                                accept_id.unwrap_or(merge_id),
3537                                reject_id.unwrap_or(merge_id),
3538                            ),
3539                        );
3540
3541                        if let Some(block_id) = accept_id {
3542                            // We can ignore the `BlockExitDisposition` returned here because,
3543                            // even if `merge_id` is not actually reachable, it is always
3544                            // referred to by the `OpSelectionMerge` instruction we emitted
3545                            // earlier.
3546                            let _ = self.write_block(
3547                                block_id,
3548                                accept,
3549                                BlockExit::Branch { target: merge_id },
3550                                loop_context,
3551                                debug_info,
3552                            )?;
3553                        }
3554                        if let Some(block_id) = reject_id {
3555                            // We can ignore the `BlockExitDisposition` returned here because,
3556                            // even if `merge_id` is not actually reachable, it is always
3557                            // referred to by the `OpSelectionMerge` instruction we emitted
3558                            // earlier.
3559                            let _ = self.write_block(
3560                                block_id,
3561                                reject,
3562                                BlockExit::Branch { target: merge_id },
3563                                loop_context,
3564                                debug_info,
3565                            )?;
3566                        }
3567
3568                        block = Block::new(merge_id);
3569                    }
3570                }
3571                Statement::Switch {
3572                    selector,
3573                    ref cases,
3574                } => {
3575                    let selector_id = self.cached[selector];
3576
3577                    let merge_id = self.gen_id();
3578                    block.body.push(Instruction::selection_merge(
3579                        merge_id,
3580                        spirv::SelectionControl::NONE,
3581                    ));
3582
3583                    let mut default_id = None;
3584                    // id of previous empty fall-through case
3585                    let mut last_id = None;
3586
3587                    let mut raw_cases = Vec::with_capacity(cases.len());
3588                    let mut case_ids = Vec::with_capacity(cases.len());
3589                    for case in cases.iter() {
3590                        // take id of previous empty fall-through case or generate a new one
3591                        let label_id = last_id.take().unwrap_or_else(|| self.gen_id());
3592
3593                        if case.fall_through && case.body.is_empty() {
3594                            last_id = Some(label_id);
3595                        }
3596
3597                        case_ids.push(label_id);
3598
3599                        match case.value {
3600                            crate::SwitchValue::I32(value) => {
3601                                raw_cases.push(super::instructions::Case {
3602                                    value: value as Word,
3603                                    label_id,
3604                                });
3605                            }
3606                            crate::SwitchValue::U32(value) => {
3607                                raw_cases.push(super::instructions::Case { value, label_id });
3608                            }
3609                            crate::SwitchValue::Default => {
3610                                default_id = Some(label_id);
3611                            }
3612                        }
3613                    }
3614
3615                    let default_id = default_id.unwrap();
3616
3617                    self.function.consume(
3618                        block,
3619                        Instruction::switch(selector_id, default_id, &raw_cases),
3620                    );
3621
3622                    let inner_context = LoopContext {
3623                        break_id: Some(merge_id),
3624                        ..loop_context
3625                    };
3626
3627                    for (i, (case, label_id)) in cases
3628                        .iter()
3629                        .zip(case_ids.iter())
3630                        .filter(|&(case, _)| !(case.fall_through && case.body.is_empty()))
3631                        .enumerate()
3632                    {
3633                        let case_finish_id = if case.fall_through {
3634                            case_ids[i + 1]
3635                        } else {
3636                            merge_id
3637                        };
3638                        // We can ignore the `BlockExitDisposition` returned here because
3639                        // `case_finish_id` is always referred to by either:
3640                        //
3641                        // - the `OpSwitch`, if it's the next case's label for a
3642                        //   fall-through, or
3643                        //
3644                        // - the `OpSelectionMerge`, if it's the switch's overall merge
3645                        //   block because there's no fall-through.
3646                        let _ = self.write_block(
3647                            *label_id,
3648                            &case.body,
3649                            BlockExit::Branch {
3650                                target: case_finish_id,
3651                            },
3652                            inner_context,
3653                            debug_info,
3654                        )?;
3655                    }
3656
3657                    block = Block::new(merge_id);
3658                }
3659                Statement::Loop {
3660                    ref body,
3661                    ref continuing,
3662                    break_if,
3663                } => {
3664                    let preamble_id = self.gen_id();
3665                    self.function
3666                        .consume(block, Instruction::branch(preamble_id));
3667
3668                    let merge_id = self.gen_id();
3669                    let body_id = self.gen_id();
3670                    let continuing_id = self.gen_id();
3671
3672                    // SPIR-V requires the continuing to the `OpLoopMerge`,
3673                    // so we have to start a new block with it.
3674                    block = Block::new(preamble_id);
3675                    // HACK the loop statement is begin with branch instruction,
3676                    // so we need to put `OpLine` debug info before merge instruction
3677                    if let Some(debug_info) = debug_info {
3678                        let loc: crate::SourceLocation = span.location(debug_info.source_code);
3679                        block.body.push(Instruction::line(
3680                            debug_info.source_file_id,
3681                            loc.line_number,
3682                            loc.line_position,
3683                        ))
3684                    }
3685                    block.body.push(Instruction::loop_merge(
3686                        merge_id,
3687                        continuing_id,
3688                        spirv::SelectionControl::NONE,
3689                    ));
3690
3691                    if self.force_loop_bounding {
3692                        block = self.write_force_bounded_loop_instructions(block, merge_id);
3693                    }
3694                    self.function.consume(block, Instruction::branch(body_id));
3695
3696                    // We can ignore the `BlockExitDisposition` returned here because,
3697                    // even if `continuing_id` is not actually reachable, it is always
3698                    // referred to by the `OpLoopMerge` instruction we emitted earlier.
3699                    let _ = self.write_block(
3700                        body_id,
3701                        body,
3702                        BlockExit::Branch {
3703                            target: continuing_id,
3704                        },
3705                        LoopContext {
3706                            continuing_id: Some(continuing_id),
3707                            break_id: Some(merge_id),
3708                        },
3709                        debug_info,
3710                    )?;
3711
3712                    let exit = match break_if {
3713                        Some(condition) => BlockExit::BreakIf {
3714                            condition,
3715                            preamble_id,
3716                        },
3717                        None => BlockExit::Branch {
3718                            target: preamble_id,
3719                        },
3720                    };
3721
3722                    // We can ignore the `BlockExitDisposition` returned here because,
3723                    // even if `merge_id` is not actually reachable, it is always referred
3724                    // to by the `OpLoopMerge` instruction we emitted earlier.
3725                    let _ = self.write_block(
3726                        continuing_id,
3727                        continuing,
3728                        exit,
3729                        LoopContext {
3730                            continuing_id: None,
3731                            break_id: Some(merge_id),
3732                        },
3733                        debug_info,
3734                    )?;
3735
3736                    block = Block::new(merge_id);
3737                }
3738                Statement::Break => {
3739                    self.function
3740                        .consume(block, Instruction::branch(loop_context.break_id.unwrap()));
3741                    return Ok(BlockExitDisposition::Discarded);
3742                }
3743                Statement::Continue => {
3744                    self.function.consume(
3745                        block,
3746                        Instruction::branch(loop_context.continuing_id.unwrap()),
3747                    );
3748                    return Ok(BlockExitDisposition::Discarded);
3749                }
3750                Statement::Return { value: Some(value) } => {
3751                    let value_id = self.cached[value];
3752                    let instruction = match self.function.entry_point_context {
3753                        // If this is an entry point, and we need to return anything,
3754                        // let's instead store the output variables and return `void`.
3755                        Some(ref context) => self.writer.write_entry_point_return(
3756                            value_id,
3757                            self.ir_function.result.as_ref().unwrap(),
3758                            &context.results,
3759                            &mut block.body,
3760                        )?,
3761                        None => Instruction::return_value(value_id),
3762                    };
3763                    self.function.consume(block, instruction);
3764                    return Ok(BlockExitDisposition::Discarded);
3765                }
3766                Statement::Return { value: None } => {
3767                    self.function.consume(block, Instruction::return_void());
3768                    return Ok(BlockExitDisposition::Discarded);
3769                }
3770                Statement::Kill => {
3771                    self.function.consume(block, Instruction::kill());
3772                    return Ok(BlockExitDisposition::Discarded);
3773                }
3774                Statement::ControlBarrier(flags) => {
3775                    self.writer.write_control_barrier(flags, &mut block.body);
3776                }
3777                Statement::MemoryBarrier(flags) => {
3778                    self.writer.write_memory_barrier(flags, &mut block);
3779                }
3780                Statement::Store { pointer, value } => {
3781                    let value_id = self.cached[value];
3782                    match self.write_access_chain(
3783                        pointer,
3784                        &mut block,
3785                        AccessTypeAdjustment::None,
3786                    )? {
3787                        ExpressionPointer::Ready { pointer_id } => {
3788                            let atomic_space = match *self.fun_info[pointer]
3789                                .ty
3790                                .inner_with(&self.ir_module.types)
3791                            {
3792                                crate::TypeInner::Pointer { base, space } => {
3793                                    match self.ir_module.types[base].inner {
3794                                        crate::TypeInner::Atomic { .. } => Some(space),
3795                                        _ => None,
3796                                    }
3797                                }
3798                                _ => None,
3799                            };
3800                            let instruction = if let Some(space) = atomic_space {
3801                                let (semantics, scope) = space.to_spirv_semantics_and_scope();
3802                                let scope_constant_id = self.get_scope_constant(scope as u32);
3803                                let semantics_id = self.get_index_constant(semantics.bits());
3804                                Instruction::atomic_store(
3805                                    pointer_id,
3806                                    scope_constant_id,
3807                                    semantics_id,
3808                                    value_id,
3809                                )
3810                            } else {
3811                                Instruction::store(pointer_id, value_id, None)
3812                            };
3813                            block.body.push(instruction);
3814                        }
3815                        ExpressionPointer::Conditional { condition, access } => {
3816                            let mut selection = Selection::start(&mut block, ());
3817                            selection.if_true(self, condition, ());
3818
3819                            // The in-bounds path. Perform the access and the store.
3820                            let pointer_id = access.result_id.unwrap();
3821                            selection.block().body.push(access);
3822                            selection
3823                                .block()
3824                                .body
3825                                .push(Instruction::store(pointer_id, value_id, None));
3826
3827                            // Finish the in-bounds block and start the merge block. This
3828                            // is the block we'll leave current on return.
3829                            selection.finish(self, ());
3830                        }
3831                    };
3832                }
3833                Statement::ImageStore {
3834                    image,
3835                    coordinate,
3836                    array_index,
3837                    value,
3838                } => self.write_image_store(image, coordinate, array_index, value, &mut block)?,
3839                Statement::Call {
3840                    function: local_function,
3841                    ref arguments,
3842                    result,
3843                } => {
3844                    let id = self.gen_id();
3845                    self.temp_list.clear();
3846                    for &argument in arguments {
3847                        self.temp_list.push(self.cached[argument]);
3848                    }
3849
3850                    let type_id = match result {
3851                        Some(expr) => {
3852                            self.cached[expr] = id;
3853                            self.get_expression_type_id(&self.fun_info[expr].ty)
3854                        }
3855                        None => self.writer.void_type,
3856                    };
3857
3858                    block.body.push(Instruction::function_call(
3859                        type_id,
3860                        id,
3861                        self.writer.lookup_function[&local_function],
3862                        &self.temp_list,
3863                    ));
3864                }
3865                Statement::Atomic {
3866                    pointer,
3867                    ref fun,
3868                    value,
3869                    result,
3870                } => {
3871                    let id = self.gen_id();
3872                    // Compare-and-exchange operations produce a struct result,
3873                    // so use `result`'s type if it is available. For no-result
3874                    // operations, fall back to `value`'s type.
3875                    let result_type_id =
3876                        self.get_expression_type_id(&self.fun_info[result.unwrap_or(value)].ty);
3877
3878                    if let Some(result) = result {
3879                        self.cached[result] = id;
3880                    }
3881
3882                    let pointer_id = match self.write_access_chain(
3883                        pointer,
3884                        &mut block,
3885                        AccessTypeAdjustment::None,
3886                    )? {
3887                        ExpressionPointer::Ready { pointer_id } => pointer_id,
3888                        ExpressionPointer::Conditional { .. } => {
3889                            return Err(Error::FeatureNotImplemented(
3890                                "Atomics out-of-bounds handling",
3891                            ));
3892                        }
3893                    };
3894
3895                    let space = self.fun_info[pointer]
3896                        .ty
3897                        .inner_with(&self.ir_module.types)
3898                        .pointer_space()
3899                        .unwrap();
3900                    let (semantics, scope) = space.to_spirv_semantics_and_scope();
3901                    let scope_constant_id = self.get_scope_constant(scope as u32);
3902                    let semantics_id = self.get_index_constant(semantics.bits());
3903                    let value_id = self.cached[value];
3904                    let value_inner = self.fun_info[value].ty.inner_with(&self.ir_module.types);
3905
3906                    let crate::TypeInner::Scalar(scalar) = *value_inner else {
3907                        return Err(Error::FeatureNotImplemented(
3908                            "Atomics with non-scalar values",
3909                        ));
3910                    };
3911
3912                    let instruction = match *fun {
3913                        crate::AtomicFunction::Add => {
3914                            let spirv_op = match scalar.kind {
3915                                crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3916                                    spirv::Op::AtomicIAdd
3917                                }
3918                                crate::ScalarKind::Float => spirv::Op::AtomicFAddEXT,
3919                                _ => unimplemented!(),
3920                            };
3921                            Instruction::atomic_binary(
3922                                spirv_op,
3923                                result_type_id,
3924                                id,
3925                                pointer_id,
3926                                scope_constant_id,
3927                                semantics_id,
3928                                value_id,
3929                            )
3930                        }
3931                        crate::AtomicFunction::Subtract => {
3932                            let (spirv_op, value_id) = match scalar.kind {
3933                                crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3934                                    (spirv::Op::AtomicISub, value_id)
3935                                }
3936                                crate::ScalarKind::Float => {
3937                                    // HACK: SPIR-V doesn't have a atomic subtraction,
3938                                    // so we add the negated value instead.
3939                                    let neg_result_id = self.gen_id();
3940                                    block.body.push(Instruction::unary(
3941                                        spirv::Op::FNegate,
3942                                        result_type_id,
3943                                        neg_result_id,
3944                                        value_id,
3945                                    ));
3946                                    (spirv::Op::AtomicFAddEXT, neg_result_id)
3947                                }
3948                                _ => unimplemented!(),
3949                            };
3950                            Instruction::atomic_binary(
3951                                spirv_op,
3952                                result_type_id,
3953                                id,
3954                                pointer_id,
3955                                scope_constant_id,
3956                                semantics_id,
3957                                value_id,
3958                            )
3959                        }
3960                        crate::AtomicFunction::And => {
3961                            let spirv_op = match scalar.kind {
3962                                crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3963                                    spirv::Op::AtomicAnd
3964                                }
3965                                _ => unimplemented!(),
3966                            };
3967                            Instruction::atomic_binary(
3968                                spirv_op,
3969                                result_type_id,
3970                                id,
3971                                pointer_id,
3972                                scope_constant_id,
3973                                semantics_id,
3974                                value_id,
3975                            )
3976                        }
3977                        crate::AtomicFunction::InclusiveOr => {
3978                            let spirv_op = match scalar.kind {
3979                                crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3980                                    spirv::Op::AtomicOr
3981                                }
3982                                _ => unimplemented!(),
3983                            };
3984                            Instruction::atomic_binary(
3985                                spirv_op,
3986                                result_type_id,
3987                                id,
3988                                pointer_id,
3989                                scope_constant_id,
3990                                semantics_id,
3991                                value_id,
3992                            )
3993                        }
3994                        crate::AtomicFunction::ExclusiveOr => {
3995                            let spirv_op = match scalar.kind {
3996                                crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3997                                    spirv::Op::AtomicXor
3998                                }
3999                                _ => unimplemented!(),
4000                            };
4001                            Instruction::atomic_binary(
4002                                spirv_op,
4003                                result_type_id,
4004                                id,
4005                                pointer_id,
4006                                scope_constant_id,
4007                                semantics_id,
4008                                value_id,
4009                            )
4010                        }
4011                        crate::AtomicFunction::Min => {
4012                            let spirv_op = match scalar.kind {
4013                                crate::ScalarKind::Sint => spirv::Op::AtomicSMin,
4014                                crate::ScalarKind::Uint => spirv::Op::AtomicUMin,
4015                                _ => unimplemented!(),
4016                            };
4017                            Instruction::atomic_binary(
4018                                spirv_op,
4019                                result_type_id,
4020                                id,
4021                                pointer_id,
4022                                scope_constant_id,
4023                                semantics_id,
4024                                value_id,
4025                            )
4026                        }
4027                        crate::AtomicFunction::Max => {
4028                            let spirv_op = match scalar.kind {
4029                                crate::ScalarKind::Sint => spirv::Op::AtomicSMax,
4030                                crate::ScalarKind::Uint => spirv::Op::AtomicUMax,
4031                                _ => unimplemented!(),
4032                            };
4033                            Instruction::atomic_binary(
4034                                spirv_op,
4035                                result_type_id,
4036                                id,
4037                                pointer_id,
4038                                scope_constant_id,
4039                                semantics_id,
4040                                value_id,
4041                            )
4042                        }
4043                        crate::AtomicFunction::Exchange { compare: None } => {
4044                            Instruction::atomic_binary(
4045                                spirv::Op::AtomicExchange,
4046                                result_type_id,
4047                                id,
4048                                pointer_id,
4049                                scope_constant_id,
4050                                semantics_id,
4051                                value_id,
4052                            )
4053                        }
4054                        crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
4055                            let scalar_type_id =
4056                                self.get_numeric_type_id(NumericType::Scalar(scalar));
4057                            let bool_type_id =
4058                                self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::BOOL));
4059
4060                            let cas_result_id = self.gen_id();
4061                            let equality_result_id = self.gen_id();
4062                            let equality_operator = match scalar.kind {
4063                                crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
4064                                    spirv::Op::IEqual
4065                                }
4066                                _ => unimplemented!(),
4067                            };
4068
4069                            let mut cas_instr = Instruction::new(spirv::Op::AtomicCompareExchange);
4070                            cas_instr.set_type(scalar_type_id);
4071                            cas_instr.set_result(cas_result_id);
4072                            cas_instr.add_operand(pointer_id);
4073                            cas_instr.add_operand(scope_constant_id);
4074                            cas_instr.add_operand(semantics_id); // semantics if equal
4075                            cas_instr.add_operand(semantics_id); // semantics if not equal
4076                            cas_instr.add_operand(value_id);
4077                            cas_instr.add_operand(self.cached[cmp]);
4078                            block.body.push(cas_instr);
4079                            block.body.push(Instruction::binary(
4080                                equality_operator,
4081                                bool_type_id,
4082                                equality_result_id,
4083                                cas_result_id,
4084                                self.cached[cmp],
4085                            ));
4086                            Instruction::composite_construct(
4087                                result_type_id,
4088                                id,
4089                                &[cas_result_id, equality_result_id],
4090                            )
4091                        }
4092                    };
4093
4094                    block.body.push(instruction);
4095                }
4096                Statement::ImageAtomic {
4097                    image,
4098                    coordinate,
4099                    array_index,
4100                    fun,
4101                    value,
4102                } => {
4103                    self.write_image_atomic(
4104                        image,
4105                        coordinate,
4106                        array_index,
4107                        fun,
4108                        value,
4109                        &mut block,
4110                    )?;
4111                }
4112                Statement::WorkGroupUniformLoad { pointer, result } => {
4113                    self.writer
4114                        .write_control_barrier(crate::Barrier::WORK_GROUP, &mut block.body);
4115                    let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty);
4116                    // Match `Expression::Load` behavior, including `OpAtomicLoad` when
4117                    // loading from a pointer to `atomic<T>`.
4118                    let id = self.write_checked_load(
4119                        pointer,
4120                        &mut block,
4121                        AccessTypeAdjustment::None,
4122                        result_type_id,
4123                    )?;
4124                    self.cached[result] = id;
4125                    self.writer
4126                        .write_control_barrier(crate::Barrier::WORK_GROUP, &mut block.body);
4127                }
4128                Statement::RayQuery { query, ref fun } => {
4129                    self.write_ray_query_function(query, fun, &mut block);
4130                }
4131                Statement::SubgroupBallot {
4132                    result,
4133                    ref predicate,
4134                } => {
4135                    self.write_subgroup_ballot(predicate, result, &mut block)?;
4136                }
4137                Statement::SubgroupCollectiveOperation {
4138                    ref op,
4139                    ref collective_op,
4140                    argument,
4141                    result,
4142                } => {
4143                    self.write_subgroup_operation(op, collective_op, argument, result, &mut block)?;
4144                }
4145                Statement::SubgroupGather {
4146                    ref mode,
4147                    argument,
4148                    result,
4149                } => {
4150                    self.write_subgroup_gather(mode, argument, result, &mut block)?;
4151                }
4152                Statement::CooperativeStore { target, ref data } => {
4153                    let target_id = self.cached[target];
4154                    let layout = if data.row_major {
4155                        spirv::CooperativeMatrixLayout::RowMajorKHR
4156                    } else {
4157                        spirv::CooperativeMatrixLayout::ColumnMajorKHR
4158                    };
4159                    let layout_id = self.get_index_constant(layout as u32);
4160                    let stride_id = self.cached[data.stride];
4161                    match self.write_access_chain(
4162                        data.pointer,
4163                        &mut block,
4164                        AccessTypeAdjustment::None,
4165                    )? {
4166                        ExpressionPointer::Ready { pointer_id } => {
4167                            block.body.push(Instruction::coop_store(
4168                                target_id, pointer_id, layout_id, stride_id,
4169                            ));
4170                        }
4171                        ExpressionPointer::Conditional { condition, access } => {
4172                            let mut selection = Selection::start(&mut block, ());
4173                            selection.if_true(self, condition, ());
4174
4175                            // The in-bounds path. Perform the access and the store.
4176                            let pointer_id = access.result_id.unwrap();
4177                            selection.block().body.push(access);
4178                            selection.block().body.push(Instruction::coop_store(
4179                                target_id, pointer_id, layout_id, stride_id,
4180                            ));
4181
4182                            // Finish the in-bounds block and start the merge block. This
4183                            // is the block we'll leave current on return.
4184                            selection.finish(self, ());
4185                        }
4186                    };
4187                }
4188                Statement::RayPipelineFunction(ref fun) => {
4189                    self.write_ray_tracing_pipeline_function(fun, &mut block);
4190                }
4191            }
4192        }
4193
4194        let termination = match exit {
4195            // We're generating code for the top-level Block of the function, so we
4196            // need to end it with some kind of return instruction.
4197            BlockExit::Return => match self.ir_function.result {
4198                Some(ref result) if self.function.entry_point_context.is_none() => {
4199                    let type_id = self.get_handle_type_id(result.ty);
4200                    let null_id = self.writer.get_constant_null(type_id);
4201                    Instruction::return_value(null_id)
4202                }
4203                _ => Instruction::return_void(),
4204            },
4205            BlockExit::Branch { target } => Instruction::branch(target),
4206            BlockExit::BreakIf {
4207                condition,
4208                preamble_id,
4209            } => {
4210                let condition_id = self.cached[condition];
4211
4212                Instruction::branch_conditional(
4213                    condition_id,
4214                    loop_context.break_id.unwrap(),
4215                    preamble_id,
4216                )
4217            }
4218        };
4219
4220        self.function.consume(block, termination);
4221        Ok(BlockExitDisposition::Used)
4222    }
4223
4224    pub(super) fn write_function_body(
4225        &mut self,
4226        entry_id: Word,
4227        debug_info: Option<&DebugInfoInner>,
4228    ) -> Result<(), Error> {
4229        // We can ignore the `BlockExitDisposition` returned here because
4230        // `BlockExit::Return` doesn't refer to a block.
4231        let _ = self.write_block(
4232            entry_id,
4233            &self.ir_function.body,
4234            BlockExit::Return,
4235            LoopContext::default(),
4236            debug_info,
4237        )?;
4238
4239        Ok(())
4240    }
4241}