naga/back/spv/
block.rs

1/*!
2Implementations for `BlockContext` methods.
3*/
4
5use alloc::vec::Vec;
6
7use arrayvec::ArrayVec;
8use spirv::Word;
9
10use super::{
11    helpers::map_storage_class, index::BoundsCheckResult, selection::Selection, Block,
12    BlockContext, Dimension, Error, IdGenerator, Instruction, LocalType, LookupType, NumericType,
13    ResultMember, WrappedFunction, Writer, WriterFlags,
14};
15use crate::{
16    arena::Handle, back::spv::helpers::is_uniform_matcx2_struct_member_access,
17    proc::index::GuardedIndex, Statement,
18};
19
20fn get_dimension(type_inner: &crate::TypeInner) -> Dimension {
21    match *type_inner {
22        crate::TypeInner::Scalar(_) => Dimension::Scalar,
23        crate::TypeInner::Vector { .. } => Dimension::Vector,
24        crate::TypeInner::Matrix { .. } => Dimension::Matrix,
25        crate::TypeInner::CooperativeMatrix { .. } => Dimension::CooperativeMatrix,
26        _ => unreachable!(),
27    }
28}
29
30/// How to derive the type of `OpAccessChain` instructions from Naga IR.
31///
32/// Most of the time, we compile Naga IR to SPIR-V instructions whose result
33/// types are simply the direct SPIR-V analog of the Naga IR's. But in some
34/// cases, the Naga IR and SPIR-V types need to diverge.
35///
36/// This enum specifies how [`BlockContext::write_access_chain`] should
37/// choose a SPIR-V result type for the `OpAccessChain` it generates, based on
38/// the type of the given Naga IR [`Expression`] it's generating code for.
39///
40/// [`Expression`]: crate::Expression
41#[derive(Copy, Clone)]
42enum AccessTypeAdjustment {
43    /// No adjustment needed: the SPIR-V type should be the direct
44    /// analog of the Naga IR expression type.
45    ///
46    /// For most access chains, this is the right thing: the Naga IR access
47    /// expression produces a [`Pointer`] to the element / component, and the
48    /// SPIR-V `OpAccessChain` instruction does the same.
49    ///
50    /// [`Pointer`]: crate::TypeInner::Pointer
51    None,
52
53    /// The SPIR-V type should be an `OpPointer` to the direct analog of the
54    /// Naga IR expression's type.
55    ///
56    /// This is necessary for indexing binding arrays in the [`Handle`] address
57    /// space:
58    ///
59    /// - In Naga IR, referencing a binding array [`GlobalVariable`] in the
60    ///   [`Handle`] address space produces a value of type [`BindingArray`],
61    ///   not a pointer to such. And [`Access`] and [`AccessIndex`] expressions
62    ///   operate on handle binding arrays by value, and produce handle values,
63    ///   not pointers.
64    ///
65    /// - In SPIR-V, a binding array `OpVariable` produces a pointer to an
66    ///   array, and `OpAccessChain` instructions operate on pointers,
67    ///   regardless of whether the elements are opaque types or not.
68    ///
69    /// See also the documentation for [`BindingArray`].
70    ///
71    /// [`Handle`]: crate::AddressSpace::Handle
72    /// [`GlobalVariable`]: crate::GlobalVariable
73    /// [`BindingArray`]: crate::TypeInner::BindingArray
74    /// [`Access`]: crate::Expression::Access
75    /// [`AccessIndex`]: crate::Expression::AccessIndex
76    IntroducePointer(spirv::StorageClass),
77
78    /// The SPIR-V type should be an `OpPointer` to the std140 layout
79    /// compatible variant of the Naga IR expression's base type.
80    ///
81    /// This is used when accessing a type through an [`AddressSpace::Uniform`]
82    /// pointer in cases where the original type is incompatible with std140
83    /// layout requirements and we have therefore declared the uniform to be of
84    /// an alternative std140 compliant type.
85    ///
86    /// [`AddressSpace::Uniform`]: crate::AddressSpace::Uniform
87    UseStd140CompatType,
88}
89
90/// The results of emitting code for a left-hand-side expression.
91///
92/// On success, `write_access_chain` returns one of these.
93enum ExpressionPointer {
94    /// The pointer to the expression's value is available, as the value of the
95    /// expression with the given id.
96    Ready { pointer_id: Word },
97
98    /// The access expression must be conditional on the value of `condition`, a boolean
99    /// expression that is true if all indices are in bounds. If `condition` is true, then
100    /// `access` is an `OpAccessChain` instruction that will compute a pointer to the
101    /// expression's value. If `condition` is false, then executing `access` would be
102    /// undefined behavior.
103    Conditional {
104        condition: Word,
105        access: Instruction,
106    },
107}
108
109/// The termination statement to be added to the end of the block
110enum BlockExit {
111    /// Generates an OpReturn (void return)
112    Return,
113    /// Generates an OpBranch to the specified block
114    Branch {
115        /// The branch target block
116        target: Word,
117    },
118    /// Translates a loop `break if` into an `OpBranchConditional` to the
119    /// merge block if true (the merge block is passed through [`LoopContext::break_id`]
120    /// or else to the loop header (passed through [`preamble_id`])
121    ///
122    /// [`preamble_id`]: Self::BreakIf::preamble_id
123    BreakIf {
124        /// The condition of the `break if`
125        condition: Handle<crate::Expression>,
126        /// The loop header block id
127        preamble_id: Word,
128    },
129}
130
131/// What code generation did with a provided [`BlockExit`] value.
132///
133/// A function that accepts a [`BlockExit`] argument should return a value of
134/// this type, to indicate whether the code it generated ended up using the
135/// provided exit, or ignored it and did a non-local exit of some other kind
136/// (say, [`Break`] or [`Continue`]). Some callers must use this information to
137/// decide whether to generate the target block at all.
138///
139/// [`Break`]: Statement::Break
140/// [`Continue`]: Statement::Continue
141#[must_use]
142enum BlockExitDisposition {
143    /// The generated code used the provided `BlockExit` value. If it included a
144    /// block label, the caller should be sure to actually emit the block it
145    /// refers to.
146    Used,
147
148    /// The generated code did not use the provided `BlockExit` value. If it
149    /// included a block label, the caller should not bother to actually emit
150    /// the block it refers to, unless it knows the block is needed for
151    /// something else.
152    Discarded,
153}
154
155#[derive(Clone, Copy, Default)]
156struct LoopContext {
157    continuing_id: Option<Word>,
158    break_id: Option<Word>,
159}
160
161#[derive(Debug)]
162pub(crate) struct DebugInfoInner<'a> {
163    pub source_code: &'a str,
164    pub source_file_id: Word,
165}
166
167impl Writer {
168    // Flip Y coordinate to adjust for coordinate space difference
169    // between SPIR-V and our IR.
170    // The `position_id` argument is a pointer to a `vecN<f32>`,
171    // whose `y` component we will negate.
172    fn write_epilogue_position_y_flip(
173        &mut self,
174        position_id: Word,
175        body: &mut Vec<Instruction>,
176    ) -> Result<(), Error> {
177        let float_ptr_type_id = self.get_f32_pointer_type_id(spirv::StorageClass::Output);
178        let index_y_id = self.get_index_constant(1);
179        let access_id = self.id_gen.next();
180        body.push(Instruction::access_chain(
181            float_ptr_type_id,
182            access_id,
183            position_id,
184            &[index_y_id],
185        ));
186
187        let float_type_id = self.get_f32_type_id();
188        let load_id = self.id_gen.next();
189        body.push(Instruction::load(float_type_id, load_id, access_id, None));
190
191        let neg_id = self.id_gen.next();
192        body.push(Instruction::unary(
193            spirv::Op::FNegate,
194            float_type_id,
195            neg_id,
196            load_id,
197        ));
198
199        body.push(Instruction::store(access_id, neg_id, None));
200        Ok(())
201    }
202
203    // Clamp fragment depth between 0 and 1.
204    fn write_epilogue_frag_depth_clamp(
205        &mut self,
206        frag_depth_id: Word,
207        body: &mut Vec<Instruction>,
208    ) -> Result<(), Error> {
209        let float_type_id = self.get_f32_type_id();
210        let zero_scalar_id = self.get_constant_scalar(crate::Literal::F32(0.0));
211        let one_scalar_id = self.get_constant_scalar(crate::Literal::F32(1.0));
212
213        let original_id = self.id_gen.next();
214        body.push(Instruction::load(
215            float_type_id,
216            original_id,
217            frag_depth_id,
218            None,
219        ));
220
221        let clamp_id = self.id_gen.next();
222        body.push(Instruction::ext_inst_gl_op(
223            self.gl450_ext_inst_id,
224            spirv::GLOp::FClamp,
225            float_type_id,
226            clamp_id,
227            &[original_id, zero_scalar_id, one_scalar_id],
228        ));
229
230        body.push(Instruction::store(frag_depth_id, clamp_id, None));
231        Ok(())
232    }
233
234    fn write_entry_point_return(
235        &mut self,
236        value_id: Word,
237        ir_result: &crate::FunctionResult,
238        result_members: &[ResultMember],
239        body: &mut Vec<Instruction>,
240        task_payload: Option<Word>,
241    ) -> Result<Instruction, Error> {
242        for (index, res_member) in result_members.iter().enumerate() {
243            // This isn't a real builtin, and is handled elsewhere
244            if res_member.built_in == Some(crate::BuiltIn::MeshTaskSize) {
245                continue;
246            }
247            let member_value_id = match ir_result.binding {
248                Some(_) => value_id,
249                None => {
250                    let member_value_id = self.id_gen.next();
251                    body.push(Instruction::composite_extract(
252                        res_member.type_id,
253                        member_value_id,
254                        value_id,
255                        &[index as u32],
256                    ));
257                    member_value_id
258                }
259            };
260
261            self.store_io_with_f16_polyfill(body, res_member.id, member_value_id);
262
263            match res_member.built_in {
264                Some(crate::BuiltIn::Position { .. })
265                    if self.flags.contains(WriterFlags::ADJUST_COORDINATE_SPACE) =>
266                {
267                    self.write_epilogue_position_y_flip(res_member.id, body)?;
268                }
269                Some(crate::BuiltIn::FragDepth)
270                    if self.flags.contains(WriterFlags::CLAMP_FRAG_DEPTH) =>
271                {
272                    self.write_epilogue_frag_depth_clamp(res_member.id, body)?;
273                }
274                _ => {}
275            }
276        }
277        self.try_write_entry_point_task_return(
278            value_id,
279            ir_result,
280            result_members,
281            body,
282            task_payload,
283        )
284    }
285}
286
287impl BlockContext<'_> {
288    /// Generates code to ensure that a loop is bounded. Should be called immediately
289    /// after adding the OpLoopMerge instruction to `block`. This function will
290    /// [`consume()`](crate::back::spv::Function::consume) `block` and append its
291    /// instructions to a new [`Block`], which will be returned to the caller for it to
292    /// consumed prior to writing the loop body.
293    ///
294    /// Additionally this function will populate [`force_loop_bounding_vars`](crate::back::spv::Function::force_loop_bounding_vars),
295    /// ensuring that [`Function::to_words()`](crate::back::spv::Function::to_words) will
296    /// declare the required variables.
297    ///
298    /// See [`crate::back::msl::Writer::gen_force_bounded_loop_statements`] for details
299    /// of why this is required.
300    fn write_force_bounded_loop_instructions(&mut self, mut block: Block, merge_id: Word) -> Block {
301        let uint_type_id = self.writer.get_u32_type_id();
302        let uint2_type_id = self.writer.get_vec2u_type_id();
303        let uint2_ptr_type_id = self
304            .writer
305            .get_vec2u_pointer_type_id(spirv::StorageClass::Function);
306        let bool_type_id = self.writer.get_bool_type_id();
307        let bool2_type_id = self.writer.get_vec2_bool_type_id();
308        let zero_uint_const_id = self.writer.get_constant_scalar(crate::Literal::U32(0));
309        let zero_uint2_const_id = self.writer.get_constant_composite(
310            LookupType::Local(LocalType::Numeric(NumericType::Vector {
311                size: crate::VectorSize::Bi,
312                scalar: crate::Scalar::U32,
313            })),
314            &[zero_uint_const_id, zero_uint_const_id],
315        );
316        let one_uint_const_id = self.writer.get_constant_scalar(crate::Literal::U32(1));
317        let max_uint_const_id = self
318            .writer
319            .get_constant_scalar(crate::Literal::U32(u32::MAX));
320        let max_uint2_const_id = self.writer.get_constant_composite(
321            LookupType::Local(LocalType::Numeric(NumericType::Vector {
322                size: crate::VectorSize::Bi,
323                scalar: crate::Scalar::U32,
324            })),
325            &[max_uint_const_id, max_uint_const_id],
326        );
327
328        let loop_counter_var_id = self.gen_id();
329        if self.writer.flags.contains(WriterFlags::DEBUG) {
330            self.writer
331                .debugs
332                .push(Instruction::name(loop_counter_var_id, "loop_bound"));
333        }
334        let var = super::LocalVariable {
335            id: loop_counter_var_id,
336            instruction: Instruction::variable(
337                uint2_ptr_type_id,
338                loop_counter_var_id,
339                spirv::StorageClass::Function,
340                Some(max_uint2_const_id),
341            ),
342        };
343        self.function.force_loop_bounding_vars.push(var);
344
345        let break_if_block = self.gen_id();
346
347        self.function
348            .consume(block, Instruction::branch(break_if_block));
349        block = Block::new(break_if_block);
350
351        // Load the current loop counter value from its variable. We use a vec2<u32> to
352        // simulate a 64-bit counter.
353        let load_id = self.gen_id();
354        block.body.push(Instruction::load(
355            uint2_type_id,
356            load_id,
357            loop_counter_var_id,
358            None,
359        ));
360
361        // If both the high and low u32s have reached 0 then break. ie
362        // if (all(eq(loop_counter, vec2(0)))) { break; }
363        let eq_id = self.gen_id();
364        block.body.push(Instruction::binary(
365            spirv::Op::IEqual,
366            bool2_type_id,
367            eq_id,
368            zero_uint2_const_id,
369            load_id,
370        ));
371        let all_eq_id = self.gen_id();
372        block.body.push(Instruction::relational(
373            spirv::Op::All,
374            bool_type_id,
375            all_eq_id,
376            eq_id,
377        ));
378
379        let inc_counter_block_id = self.gen_id();
380        block.body.push(Instruction::selection_merge(
381            inc_counter_block_id,
382            spirv::SelectionControl::empty(),
383        ));
384        self.function.consume(
385            block,
386            Instruction::branch_conditional(all_eq_id, merge_id, inc_counter_block_id),
387        );
388        block = Block::new(inc_counter_block_id);
389
390        // To simulate a 64-bit counter we always decrement the low u32, and decrement
391        // the high u32 when the low u32 overflows. ie
392        // counter -= vec2(select(0u, 1u, counter.y == 0), 1u);
393        // Count down from u32::MAX rather than up from 0 to avoid hang on
394        // certain Intel drivers. See <https://github.com/gfx-rs/wgpu/issues/7319>.
395        let low_id = self.gen_id();
396        block.body.push(Instruction::composite_extract(
397            uint_type_id,
398            low_id,
399            load_id,
400            &[1],
401        ));
402        let low_overflow_id = self.gen_id();
403        block.body.push(Instruction::binary(
404            spirv::Op::IEqual,
405            bool_type_id,
406            low_overflow_id,
407            low_id,
408            zero_uint_const_id,
409        ));
410        let carry_bit_id = self.gen_id();
411        block.body.push(Instruction::select(
412            uint_type_id,
413            carry_bit_id,
414            low_overflow_id,
415            one_uint_const_id,
416            zero_uint_const_id,
417        ));
418        let decrement_id = self.gen_id();
419        block.body.push(Instruction::composite_construct(
420            uint2_type_id,
421            decrement_id,
422            &[carry_bit_id, one_uint_const_id],
423        ));
424        let result_id = self.gen_id();
425        block.body.push(Instruction::binary(
426            spirv::Op::ISub,
427            uint2_type_id,
428            result_id,
429            load_id,
430            decrement_id,
431        ));
432        block
433            .body
434            .push(Instruction::store(loop_counter_var_id, result_id, None));
435
436        block
437    }
438
439    /// If `pointer` refers to an access chain that contains a dynamic indexing
440    /// of a two-row matrix in the [`Uniform`] address space, write code to
441    /// access the value returning the ID of the result. Else return None.
442    ///
443    /// Two-row matrices in the uniform address space will have been declared
444    /// using a alternative std140 layout compatible type, where each column is
445    /// a member of a containing struct. As a result, SPIR-V is unable to access
446    /// its columns with a non-constant index. To work around this limitation
447    /// this function will call [`Self::write_checked_load()`] to load the
448    /// matrix itself, which handles conversion from the std140 compatible type
449    /// to the real matrix type. It then calls a [`wrapper function`] to obtain
450    /// the correct column from the matrix, and possibly extracts a component
451    /// from the vector too.
452    ///
453    /// [`Uniform`]: crate::AddressSpace::Uniform
454    /// [`wrapper function`]: super::Writer::write_wrapped_matcx2_get_column
455    fn maybe_write_uniform_matcx2_dynamic_access(
456        &mut self,
457        pointer: Handle<crate::Expression>,
458        block: &mut Block,
459    ) -> Result<Option<Word>, Error> {
460        // If this access chain contains a dynamic matrix access, `pointer` is
461        // either a pointer to a vector (the column) or a scalar (a component
462        // within the column). In either case grab the pointer to the column,
463        // and remember the component index if there is one. If `pointer`
464        // points to any other type we're not interested.
465        let (column_pointer, component_index) = match self.fun_info[pointer]
466            .ty
467            .inner_with(&self.ir_module.types)
468            .pointer_base_type()
469        {
470            Some(resolution) => match *resolution.inner_with(&self.ir_module.types) {
471                crate::TypeInner::Scalar(_) => match self.ir_function.expressions[pointer] {
472                    crate::Expression::Access { base, index } => {
473                        (base, Some(GuardedIndex::Expression(index)))
474                    }
475                    crate::Expression::AccessIndex { base, index } => {
476                        (base, Some(GuardedIndex::Known(index)))
477                    }
478                    _ => return Ok(None),
479                },
480                crate::TypeInner::Vector { .. } => (pointer, None),
481                _ => return Ok(None),
482            },
483            None => return Ok(None),
484        };
485
486        // Ensure the column is accessed with a dynamic index (i.e.
487        // `Expression::Access`), and grab the pointer to the matrix.
488        let crate::Expression::Access {
489            base: matrix_pointer,
490            index: column_index,
491        } = self.ir_function.expressions[column_pointer]
492        else {
493            return Ok(None);
494        };
495
496        // Ensure the matrix pointer is in the uniform address space.
497        let crate::TypeInner::Pointer {
498            base: matrix_pointer_base_type,
499            space: crate::AddressSpace::Uniform,
500        } = *self.fun_info[matrix_pointer]
501            .ty
502            .inner_with(&self.ir_module.types)
503        else {
504            return Ok(None);
505        };
506
507        // Ensure the matrix pointer actually points to a Cx2 matrix.
508        let crate::TypeInner::Matrix {
509            columns,
510            rows: rows @ crate::VectorSize::Bi,
511            scalar,
512        } = self.ir_module.types[matrix_pointer_base_type].inner
513        else {
514            return Ok(None);
515        };
516
517        let matrix_type_id = self.get_numeric_type_id(NumericType::Matrix {
518            columns,
519            rows,
520            scalar,
521        });
522        let column_type_id = self.get_numeric_type_id(NumericType::Vector { size: rows, scalar });
523        let component_type_id = self.get_numeric_type_id(NumericType::Scalar(scalar));
524        let get_column_function_id = self.writer.wrapped_functions
525            [&WrappedFunction::MatCx2GetColumn {
526                r#type: matrix_pointer_base_type,
527            }];
528
529        let matrix_load_id = self.write_checked_load(
530            matrix_pointer,
531            block,
532            AccessTypeAdjustment::None,
533            matrix_type_id,
534        )?;
535
536        // Naga IR allows the index to be either an I32 or U32 but our wrapper
537        // function expects a U32 argument, so convert it if required.
538        let column_index_id = match *self.fun_info[column_index]
539            .ty
540            .inner_with(&self.ir_module.types)
541        {
542            crate::TypeInner::Scalar(crate::Scalar {
543                kind: crate::ScalarKind::Uint,
544                ..
545            }) => self.cached[column_index],
546            crate::TypeInner::Scalar(crate::Scalar {
547                kind: crate::ScalarKind::Sint,
548                ..
549            }) => {
550                let cast_id = self.gen_id();
551                let u32_type_id = self.writer.get_u32_type_id();
552                block.body.push(Instruction::unary(
553                    spirv::Op::Bitcast,
554                    u32_type_id,
555                    cast_id,
556                    self.cached[column_index],
557                ));
558                cast_id
559            }
560            _ => return Err(Error::Validation("Matrix access index must be u32 or i32")),
561        };
562        let column_id = self.gen_id();
563        block.body.push(Instruction::function_call(
564            column_type_id,
565            column_id,
566            get_column_function_id,
567            &[matrix_load_id, column_index_id],
568        ));
569        let result_id = match component_index {
570            Some(index) => self.write_vector_access(
571                component_type_id,
572                column_pointer,
573                Some(column_id),
574                index,
575                block,
576            )?,
577            None => column_id,
578        };
579
580        Ok(Some(result_id))
581    }
582
583    /// If `pointer` refers to two-row matrix that is a member of a struct in
584    /// the [`Uniform`] address space, write code to load the matrix returning
585    /// the ID of the result. Else return None.
586    ///
587    /// Two-row matrices that are struct members in the uniform address space
588    /// will have been decomposed such that the struct contains a separate
589    /// vector member for each column of the matrix. This function will load
590    /// each column separately from the containing struct, then composite them
591    /// into the real matrix type.
592    ///
593    /// [`Uniform`]: crate::AddressSpace::Uniform
594    fn maybe_write_load_uniform_matcx2_struct_member(
595        &mut self,
596        pointer: Handle<crate::Expression>,
597        block: &mut Block,
598    ) -> Result<Option<Word>, Error> {
599        // Check this is a uniform address space pointer to a two-row matrix.
600        let crate::TypeInner::Pointer {
601            base: matrix_type,
602            space: space @ crate::AddressSpace::Uniform,
603        } = *self.fun_info[pointer].ty.inner_with(&self.ir_module.types)
604        else {
605            return Ok(None);
606        };
607
608        let crate::TypeInner::Matrix {
609            columns,
610            rows: rows @ crate::VectorSize::Bi,
611            scalar,
612        } = self.ir_module.types[matrix_type].inner
613        else {
614            return Ok(None);
615        };
616
617        // Check this is a struct member. Note struct members can only be
618        // accessed with `AccessIndex`.
619        let crate::Expression::AccessIndex {
620            base: struct_pointer,
621            index: member_index,
622        } = self.ir_function.expressions[pointer]
623        else {
624            return Ok(None);
625        };
626
627        let crate::TypeInner::Pointer {
628            base: struct_type, ..
629        } = *self.fun_info[struct_pointer]
630            .ty
631            .inner_with(&self.ir_module.types)
632        else {
633            return Ok(None);
634        };
635
636        let crate::TypeInner::Struct { .. } = self.ir_module.types[struct_type].inner else {
637            return Ok(None);
638        };
639
640        let matrix_type_id = self.get_numeric_type_id(NumericType::Matrix {
641            columns,
642            rows,
643            scalar,
644        });
645        let column_type_id = self.get_numeric_type_id(NumericType::Vector { size: rows, scalar });
646        let column_pointer_type_id =
647            self.get_pointer_type_id(column_type_id, map_storage_class(space));
648        let column0_index = self.writer.std140_compat_uniform_types[&struct_type].member_indices
649            [member_index as usize];
650        let column_indices = (0..columns as u32)
651            .map(|c| self.get_index_constant(column0_index + c))
652            .collect::<ArrayVec<_, 4>>();
653
654        // Load each column from the struct, then composite into the real
655        // matrix type.
656        let load_mat_from_struct =
657            |struct_pointer_id: Word, id_gen: &mut IdGenerator, block: &mut Block| -> Word {
658                let mut column_ids: ArrayVec<Word, 4> = ArrayVec::new();
659                for index in &column_indices {
660                    let column_pointer_id = id_gen.next();
661                    block.body.push(Instruction::access_chain(
662                        column_pointer_type_id,
663                        column_pointer_id,
664                        struct_pointer_id,
665                        &[*index],
666                    ));
667                    let column_id = id_gen.next();
668                    block.body.push(Instruction::load(
669                        column_type_id,
670                        column_id,
671                        column_pointer_id,
672                        None,
673                    ));
674                    column_ids.push(column_id);
675                }
676                let result_id = id_gen.next();
677                block.body.push(Instruction::composite_construct(
678                    matrix_type_id,
679                    result_id,
680                    &column_ids,
681                ));
682                result_id
683            };
684
685        let result_id = match self.write_access_chain(
686            struct_pointer,
687            block,
688            AccessTypeAdjustment::UseStd140CompatType,
689        )? {
690            ExpressionPointer::Ready { pointer_id } => {
691                load_mat_from_struct(pointer_id, &mut self.writer.id_gen, block)
692            }
693            ExpressionPointer::Conditional { condition, access } => self
694                .write_conditional_indexed_load(
695                    matrix_type_id,
696                    condition,
697                    block,
698                    |id_gen, block| {
699                        let pointer_id = access.result_id.unwrap();
700                        block.body.push(access);
701                        load_mat_from_struct(pointer_id, id_gen, block)
702                    },
703                ),
704        };
705
706        Ok(Some(result_id))
707    }
708
709    /// Cache an expression for a value.
710    pub(super) fn cache_expression_value(
711        &mut self,
712        expr_handle: Handle<crate::Expression>,
713        block: &mut Block,
714    ) -> Result<(), Error> {
715        let is_named_expression = self
716            .ir_function
717            .named_expressions
718            .contains_key(&expr_handle);
719
720        if self.fun_info[expr_handle].ref_count == 0 && !is_named_expression {
721            return Ok(());
722        }
723
724        let result_type_id = self.get_expression_type_id(&self.fun_info[expr_handle].ty);
725        let id = match self.ir_function.expressions[expr_handle] {
726            crate::Expression::Literal(literal) => self.writer.get_constant_scalar(literal),
727            crate::Expression::Constant(handle) => {
728                let init = self.ir_module.constants[handle].init;
729                self.writer.constant_ids[init]
730            }
731            crate::Expression::Override(_) => return Err(Error::Override),
732            crate::Expression::ZeroValue(_) => self.writer.get_constant_null(result_type_id),
733            crate::Expression::Compose { ty, ref components } => {
734                self.temp_list.clear();
735                if self.expression_constness.is_const(expr_handle) {
736                    self.temp_list.extend(
737                        crate::proc::flatten_compose(
738                            ty,
739                            components,
740                            &self.ir_function.expressions,
741                            &self.ir_module.types,
742                        )
743                        .map(|component| self.cached[component]),
744                    );
745                    self.writer
746                        .get_constant_composite(LookupType::Handle(ty), &self.temp_list)
747                } else {
748                    self.temp_list
749                        .extend(components.iter().map(|&component| self.cached[component]));
750
751                    let id = self.gen_id();
752                    block.body.push(Instruction::composite_construct(
753                        result_type_id,
754                        id,
755                        &self.temp_list,
756                    ));
757                    id
758                }
759            }
760            crate::Expression::Splat { size, value } => {
761                let value_id = self.cached[value];
762                let components = &[value_id; 4][..size as usize];
763
764                if self.expression_constness.is_const(expr_handle) {
765                    let ty = self
766                        .writer
767                        .get_expression_lookup_type(&self.fun_info[expr_handle].ty);
768                    self.writer.get_constant_composite(ty, components)
769                } else {
770                    let id = self.gen_id();
771                    block.body.push(Instruction::composite_construct(
772                        result_type_id,
773                        id,
774                        components,
775                    ));
776                    id
777                }
778            }
779            crate::Expression::Access { base, index } => {
780                let base_ty_inner = self.fun_info[base].ty.inner_with(&self.ir_module.types);
781                match *base_ty_inner {
782                    crate::TypeInner::Pointer { .. } | crate::TypeInner::ValuePointer { .. } => {
783                        // When we have a chain of `Access` and `AccessIndex` expressions
784                        // operating on pointers, we want to generate a single
785                        // `OpAccessChain` instruction for the whole chain. Put off
786                        // generating any code for this until we find the `Expression`
787                        // that actually dereferences the pointer.
788                        0
789                    }
790                    _ if self.function.spilled_accesses.contains(base) => {
791                        // As far as Naga IR is concerned, this expression does not yield
792                        // a pointer (we just checked, above), but this backend spilled it
793                        // to a temporary variable, so SPIR-V thinks we're accessing it
794                        // via a pointer.
795
796                        // Since the base expression was spilled, mark this access to it
797                        // as spilled, too.
798                        self.function.spilled_accesses.insert(expr_handle);
799                        self.maybe_access_spilled_composite(expr_handle, block, result_type_id)?
800                    }
801                    crate::TypeInner::Vector { .. } => self.write_vector_access(
802                        result_type_id,
803                        base,
804                        None,
805                        GuardedIndex::Expression(index),
806                        block,
807                    )?,
808                    crate::TypeInner::Array { .. } | crate::TypeInner::Matrix { .. } => {
809                        // See if `index` is known at compile time.
810                        match GuardedIndex::from_expression(
811                            index,
812                            &self.ir_function.expressions,
813                            self.ir_module,
814                        ) {
815                            GuardedIndex::Known(value) => {
816                                // If `index` is known and in bounds, we can just use
817                                // `OpCompositeExtract`.
818                                //
819                                // At the moment, validation rejects programs if this
820                                // index is out of bounds, so we don't need bounds checks.
821                                // However, that rejection is incorrect, since WGSL says
822                                // that `let` bindings are not constant expressions
823                                // (#6396). So eventually we will need to emulate bounds
824                                // checks here.
825                                let id = self.gen_id();
826                                let base_id = self.cached[base];
827                                block.body.push(Instruction::composite_extract(
828                                    result_type_id,
829                                    id,
830                                    base_id,
831                                    &[value],
832                                ));
833                                id
834                            }
835                            GuardedIndex::Expression(_) => {
836                                // We are subscripting an array or matrix that is not
837                                // behind a pointer, using an index computed at runtime.
838                                // SPIR-V has no instructions that do this, so the best we
839                                // can do is spill the value to a new temporary variable,
840                                // at which point we can get a pointer to that and just
841                                // use `OpAccessChain` in the usual way.
842                                self.spill_to_internal_variable(base, block);
843
844                                // Since the base was spilled, mark this access to it as
845                                // spilled, too.
846                                self.function.spilled_accesses.insert(expr_handle);
847                                self.maybe_access_spilled_composite(
848                                    expr_handle,
849                                    block,
850                                    result_type_id,
851                                )?
852                            }
853                        }
854                    }
855                    crate::TypeInner::BindingArray {
856                        base: binding_type, ..
857                    } => {
858                        // Only binding arrays in the `Handle` address space will take
859                        // this path, since we handled the `Pointer` case above.
860                        let result_id = match self.write_access_chain(
861                            expr_handle,
862                            block,
863                            AccessTypeAdjustment::IntroducePointer(
864                                spirv::StorageClass::UniformConstant,
865                            ),
866                        )? {
867                            ExpressionPointer::Ready { pointer_id } => pointer_id,
868                            ExpressionPointer::Conditional { .. } => {
869                                return Err(Error::FeatureNotImplemented(
870                                    "Texture array out-of-bounds handling",
871                                ));
872                            }
873                        };
874
875                        let binding_type_id = self.get_handle_type_id(binding_type);
876
877                        let load_id = self.gen_id();
878                        block.body.push(Instruction::load(
879                            binding_type_id,
880                            load_id,
881                            result_id,
882                            None,
883                        ));
884
885                        // Subsequent image operations require the image/sampler to be decorated as NonUniform
886                        // if the image/sampler binding array was accessed with a non-uniform index
887                        // see VUID-RuntimeSpirv-NonUniform-06274
888                        if self.fun_info[index].uniformity.non_uniform_result.is_some() {
889                            self.writer
890                                .decorate_non_uniform_binding_array_access(load_id)?;
891                        }
892
893                        load_id
894                    }
895                    ref other => {
896                        log::error!(
897                            "Unable to access base {:?} of type {:?}",
898                            self.ir_function.expressions[base],
899                            other
900                        );
901                        return Err(Error::Validation(
902                            "only vectors and arrays may be dynamically indexed by value",
903                        ));
904                    }
905                }
906            }
907            crate::Expression::AccessIndex { base, index } => {
908                match *self.fun_info[base].ty.inner_with(&self.ir_module.types) {
909                    crate::TypeInner::Pointer { .. } | crate::TypeInner::ValuePointer { .. } => {
910                        // When we have a chain of `Access` and `AccessIndex` expressions
911                        // operating on pointers, we want to generate a single
912                        // `OpAccessChain` instruction for the whole chain. Put off
913                        // generating any code for this until we find the `Expression`
914                        // that actually dereferences the pointer.
915                        0
916                    }
917                    _ if self.function.spilled_accesses.contains(base) => {
918                        // As far as Naga IR is concerned, this expression does not yield
919                        // a pointer (we just checked, above), but this backend spilled it
920                        // to a temporary variable, so SPIR-V thinks we're accessing it
921                        // via a pointer.
922
923                        // Since the base expression was spilled, mark this access to it
924                        // as spilled, too.
925                        self.function.spilled_accesses.insert(expr_handle);
926                        self.maybe_access_spilled_composite(expr_handle, block, result_type_id)?
927                    }
928                    crate::TypeInner::Vector { .. }
929                    | crate::TypeInner::Matrix { .. }
930                    | crate::TypeInner::Array { .. }
931                    | crate::TypeInner::Struct { .. } => {
932                        // We never need bounds checks here: dynamically sized arrays can
933                        // only appear behind pointers, and are thus handled by the
934                        // `is_intermediate` case above. Everything else's size is
935                        // statically known and checked in validation.
936                        let id = self.gen_id();
937                        let base_id = self.cached[base];
938                        block.body.push(Instruction::composite_extract(
939                            result_type_id,
940                            id,
941                            base_id,
942                            &[index],
943                        ));
944                        id
945                    }
946                    crate::TypeInner::BindingArray {
947                        base: binding_type, ..
948                    } => {
949                        // Only binding arrays in the `Handle` address space will take
950                        // this path, since we handled the `Pointer` case above.
951                        let result_id = match self.write_access_chain(
952                            expr_handle,
953                            block,
954                            AccessTypeAdjustment::IntroducePointer(
955                                spirv::StorageClass::UniformConstant,
956                            ),
957                        )? {
958                            ExpressionPointer::Ready { pointer_id } => pointer_id,
959                            ExpressionPointer::Conditional { .. } => {
960                                return Err(Error::FeatureNotImplemented(
961                                    "Texture array out-of-bounds handling",
962                                ));
963                            }
964                        };
965
966                        let binding_type_id = self.get_handle_type_id(binding_type);
967
968                        let load_id = self.gen_id();
969                        block.body.push(Instruction::load(
970                            binding_type_id,
971                            load_id,
972                            result_id,
973                            None,
974                        ));
975
976                        load_id
977                    }
978                    ref other => {
979                        log::error!("Unable to access index of {other:?}");
980                        return Err(Error::FeatureNotImplemented("access index for type"));
981                    }
982                }
983            }
984            crate::Expression::GlobalVariable(handle) => {
985                self.writer.global_variables[handle].access_id
986            }
987            crate::Expression::Swizzle {
988                size,
989                vector,
990                pattern,
991            } => {
992                let vector_id = self.cached[vector];
993                self.temp_list.clear();
994                for &sc in pattern[..size as usize].iter() {
995                    self.temp_list.push(sc as Word);
996                }
997                let id = self.gen_id();
998                block.body.push(Instruction::vector_shuffle(
999                    result_type_id,
1000                    id,
1001                    vector_id,
1002                    vector_id,
1003                    &self.temp_list,
1004                ));
1005                id
1006            }
1007            crate::Expression::Unary { op, expr } => {
1008                let id = self.gen_id();
1009                let expr_id = self.cached[expr];
1010                let expr_ty_inner = self.fun_info[expr].ty.inner_with(&self.ir_module.types);
1011
1012                let spirv_op = match op {
1013                    crate::UnaryOperator::Negate => match expr_ty_inner.scalar_kind() {
1014                        Some(crate::ScalarKind::Float) => spirv::Op::FNegate,
1015                        Some(crate::ScalarKind::Sint) => spirv::Op::SNegate,
1016                        _ => return Err(Error::Validation("Unexpected kind for negation")),
1017                    },
1018                    crate::UnaryOperator::LogicalNot => spirv::Op::LogicalNot,
1019                    crate::UnaryOperator::BitwiseNot => spirv::Op::Not,
1020                };
1021
1022                block
1023                    .body
1024                    .push(Instruction::unary(spirv_op, result_type_id, id, expr_id));
1025                id
1026            }
1027            crate::Expression::Binary { op, left, right } => {
1028                let id = self.gen_id();
1029                let left_id = self.cached[left];
1030                let right_id = self.cached[right];
1031                let left_type_id = self.get_expression_type_id(&self.fun_info[left].ty);
1032                let right_type_id = self.get_expression_type_id(&self.fun_info[right].ty);
1033
1034                if let Some(function_id) =
1035                    self.writer
1036                        .wrapped_functions
1037                        .get(&WrappedFunction::BinaryOp {
1038                            op,
1039                            left_type_id,
1040                            right_type_id,
1041                        })
1042                {
1043                    block.body.push(Instruction::function_call(
1044                        result_type_id,
1045                        id,
1046                        *function_id,
1047                        &[left_id, right_id],
1048                    ));
1049                } else {
1050                    let left_ty_inner = self.fun_info[left].ty.inner_with(&self.ir_module.types);
1051                    let right_ty_inner = self.fun_info[right].ty.inner_with(&self.ir_module.types);
1052
1053                    let left_dimension = get_dimension(left_ty_inner);
1054                    let right_dimension = get_dimension(right_ty_inner);
1055
1056                    let mut reverse_operands = false;
1057
1058                    let spirv_op = match op {
1059                        crate::BinaryOperator::Add => match *left_ty_inner {
1060                            crate::TypeInner::Scalar(scalar)
1061                            | crate::TypeInner::Vector { scalar, .. } => match scalar.kind {
1062                                crate::ScalarKind::Float => spirv::Op::FAdd,
1063                                _ => spirv::Op::IAdd,
1064                            },
1065                            crate::TypeInner::Matrix {
1066                                columns,
1067                                rows,
1068                                scalar,
1069                            } => {
1070                                //TODO: why not just rely on `Fadd` for matrices?
1071                                self.write_matrix_matrix_column_op(
1072                                    block,
1073                                    id,
1074                                    result_type_id,
1075                                    left_id,
1076                                    right_id,
1077                                    columns,
1078                                    rows,
1079                                    scalar.width,
1080                                    spirv::Op::FAdd,
1081                                );
1082
1083                                self.cached[expr_handle] = id;
1084                                return Ok(());
1085                            }
1086                            crate::TypeInner::CooperativeMatrix { .. } => spirv::Op::FAdd,
1087                            _ => unimplemented!(),
1088                        },
1089                        crate::BinaryOperator::Subtract => match *left_ty_inner {
1090                            crate::TypeInner::Scalar(scalar)
1091                            | crate::TypeInner::Vector { scalar, .. } => match scalar.kind {
1092                                crate::ScalarKind::Float => spirv::Op::FSub,
1093                                _ => spirv::Op::ISub,
1094                            },
1095                            crate::TypeInner::Matrix {
1096                                columns,
1097                                rows,
1098                                scalar,
1099                            } => {
1100                                self.write_matrix_matrix_column_op(
1101                                    block,
1102                                    id,
1103                                    result_type_id,
1104                                    left_id,
1105                                    right_id,
1106                                    columns,
1107                                    rows,
1108                                    scalar.width,
1109                                    spirv::Op::FSub,
1110                                );
1111
1112                                self.cached[expr_handle] = id;
1113                                return Ok(());
1114                            }
1115                            crate::TypeInner::CooperativeMatrix { .. } => spirv::Op::FSub,
1116                            _ => unimplemented!(),
1117                        },
1118                        crate::BinaryOperator::Multiply => {
1119                            match (left_dimension, right_dimension) {
1120                                (Dimension::Scalar, Dimension::Vector) => {
1121                                    self.write_vector_scalar_mult(
1122                                        block,
1123                                        id,
1124                                        result_type_id,
1125                                        right_id,
1126                                        left_id,
1127                                        right_ty_inner,
1128                                    );
1129
1130                                    self.cached[expr_handle] = id;
1131                                    return Ok(());
1132                                }
1133                                (Dimension::Vector, Dimension::Scalar) => {
1134                                    self.write_vector_scalar_mult(
1135                                        block,
1136                                        id,
1137                                        result_type_id,
1138                                        left_id,
1139                                        right_id,
1140                                        left_ty_inner,
1141                                    );
1142
1143                                    self.cached[expr_handle] = id;
1144                                    return Ok(());
1145                                }
1146                                (Dimension::Vector, Dimension::Matrix) => {
1147                                    spirv::Op::VectorTimesMatrix
1148                                }
1149                                (Dimension::Matrix, Dimension::Scalar)
1150                                | (Dimension::CooperativeMatrix, Dimension::Scalar) => {
1151                                    spirv::Op::MatrixTimesScalar
1152                                }
1153                                (Dimension::Scalar, Dimension::Matrix)
1154                                | (Dimension::Scalar, Dimension::CooperativeMatrix) => {
1155                                    reverse_operands = true;
1156                                    spirv::Op::MatrixTimesScalar
1157                                }
1158                                (Dimension::Matrix, Dimension::Vector) => {
1159                                    spirv::Op::MatrixTimesVector
1160                                }
1161                                (Dimension::Matrix, Dimension::Matrix) => {
1162                                    spirv::Op::MatrixTimesMatrix
1163                                }
1164                                (Dimension::Vector, Dimension::Vector)
1165                                | (Dimension::Scalar, Dimension::Scalar)
1166                                    if left_ty_inner.scalar_kind()
1167                                        == Some(crate::ScalarKind::Float) =>
1168                                {
1169                                    spirv::Op::FMul
1170                                }
1171                                (Dimension::Vector, Dimension::Vector)
1172                                | (Dimension::Scalar, Dimension::Scalar) => spirv::Op::IMul,
1173                                (Dimension::CooperativeMatrix, Dimension::CooperativeMatrix)
1174                                //Note: technically can do `FMul` but IR doesn't have matrix per-component multiplication
1175                                | (Dimension::CooperativeMatrix, _)
1176                                | (_, Dimension::CooperativeMatrix) => {
1177                                    unimplemented!()
1178                                }
1179                            }
1180                        }
1181                        crate::BinaryOperator::Divide => match left_ty_inner.scalar_kind() {
1182                            Some(crate::ScalarKind::Sint) => spirv::Op::SDiv,
1183                            Some(crate::ScalarKind::Uint) => spirv::Op::UDiv,
1184                            Some(crate::ScalarKind::Float) => spirv::Op::FDiv,
1185                            _ => unimplemented!(),
1186                        },
1187                        crate::BinaryOperator::Modulo => match left_ty_inner.scalar_kind() {
1188                            // TODO: handle undefined behavior
1189                            // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798
1190                            Some(crate::ScalarKind::Float) => spirv::Op::FRem,
1191                            Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => {
1192                                unreachable!("Should have been handled by wrapped function")
1193                            }
1194                            _ => unimplemented!(),
1195                        },
1196                        crate::BinaryOperator::Equal => match left_ty_inner.scalar_kind() {
1197                            Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => {
1198                                spirv::Op::IEqual
1199                            }
1200                            Some(crate::ScalarKind::Float) => spirv::Op::FOrdEqual,
1201                            Some(crate::ScalarKind::Bool) => spirv::Op::LogicalEqual,
1202                            _ => unimplemented!(),
1203                        },
1204                        crate::BinaryOperator::NotEqual => match left_ty_inner.scalar_kind() {
1205                            Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => {
1206                                spirv::Op::INotEqual
1207                            }
1208                            Some(crate::ScalarKind::Float) => spirv::Op::FOrdNotEqual,
1209                            Some(crate::ScalarKind::Bool) => spirv::Op::LogicalNotEqual,
1210                            _ => unimplemented!(),
1211                        },
1212                        crate::BinaryOperator::Less => match left_ty_inner.scalar_kind() {
1213                            Some(crate::ScalarKind::Sint) => spirv::Op::SLessThan,
1214                            Some(crate::ScalarKind::Uint) => spirv::Op::ULessThan,
1215                            Some(crate::ScalarKind::Float) => spirv::Op::FOrdLessThan,
1216                            _ => unimplemented!(),
1217                        },
1218                        crate::BinaryOperator::LessEqual => match left_ty_inner.scalar_kind() {
1219                            Some(crate::ScalarKind::Sint) => spirv::Op::SLessThanEqual,
1220                            Some(crate::ScalarKind::Uint) => spirv::Op::ULessThanEqual,
1221                            Some(crate::ScalarKind::Float) => spirv::Op::FOrdLessThanEqual,
1222                            _ => unimplemented!(),
1223                        },
1224                        crate::BinaryOperator::Greater => match left_ty_inner.scalar_kind() {
1225                            Some(crate::ScalarKind::Sint) => spirv::Op::SGreaterThan,
1226                            Some(crate::ScalarKind::Uint) => spirv::Op::UGreaterThan,
1227                            Some(crate::ScalarKind::Float) => spirv::Op::FOrdGreaterThan,
1228                            _ => unimplemented!(),
1229                        },
1230                        crate::BinaryOperator::GreaterEqual => match left_ty_inner.scalar_kind() {
1231                            Some(crate::ScalarKind::Sint) => spirv::Op::SGreaterThanEqual,
1232                            Some(crate::ScalarKind::Uint) => spirv::Op::UGreaterThanEqual,
1233                            Some(crate::ScalarKind::Float) => spirv::Op::FOrdGreaterThanEqual,
1234                            _ => unimplemented!(),
1235                        },
1236                        crate::BinaryOperator::And => match left_ty_inner.scalar_kind() {
1237                            Some(crate::ScalarKind::Bool) => spirv::Op::LogicalAnd,
1238                            _ => spirv::Op::BitwiseAnd,
1239                        },
1240                        crate::BinaryOperator::ExclusiveOr => spirv::Op::BitwiseXor,
1241                        crate::BinaryOperator::InclusiveOr => match left_ty_inner.scalar_kind() {
1242                            Some(crate::ScalarKind::Bool) => spirv::Op::LogicalOr,
1243                            _ => spirv::Op::BitwiseOr,
1244                        },
1245                        crate::BinaryOperator::LogicalAnd => spirv::Op::LogicalAnd,
1246                        crate::BinaryOperator::LogicalOr => spirv::Op::LogicalOr,
1247                        crate::BinaryOperator::ShiftLeft => spirv::Op::ShiftLeftLogical,
1248                        crate::BinaryOperator::ShiftRight => match left_ty_inner.scalar_kind() {
1249                            Some(crate::ScalarKind::Sint) => spirv::Op::ShiftRightArithmetic,
1250                            Some(crate::ScalarKind::Uint) => spirv::Op::ShiftRightLogical,
1251                            _ => unimplemented!(),
1252                        },
1253                    };
1254
1255                    block.body.push(Instruction::binary(
1256                        spirv_op,
1257                        result_type_id,
1258                        id,
1259                        if reverse_operands { right_id } else { left_id },
1260                        if reverse_operands { left_id } else { right_id },
1261                    ));
1262                }
1263                id
1264            }
1265            crate::Expression::Math {
1266                fun,
1267                arg,
1268                arg1,
1269                arg2,
1270                arg3,
1271            } => {
1272                use crate::MathFunction as Mf;
1273                enum MathOp {
1274                    Ext(spirv::GLOp),
1275                    Custom(Instruction),
1276                }
1277
1278                let arg0_id = self.cached[arg];
1279                let arg_ty = self.fun_info[arg].ty.inner_with(&self.ir_module.types);
1280                let arg_scalar_kind = arg_ty.scalar_kind();
1281                let arg1_id = match arg1 {
1282                    Some(handle) => self.cached[handle],
1283                    None => 0,
1284                };
1285                let arg2_id = match arg2 {
1286                    Some(handle) => self.cached[handle],
1287                    None => 0,
1288                };
1289                let arg3_id = match arg3 {
1290                    Some(handle) => self.cached[handle],
1291                    None => 0,
1292                };
1293
1294                let id = self.gen_id();
1295                let math_op = match fun {
1296                    // comparison
1297                    Mf::Abs => {
1298                        match arg_scalar_kind {
1299                            Some(crate::ScalarKind::Float) => MathOp::Ext(spirv::GLOp::FAbs),
1300                            Some(crate::ScalarKind::Sint) => MathOp::Ext(spirv::GLOp::SAbs),
1301                            Some(crate::ScalarKind::Uint) => {
1302                                MathOp::Custom(Instruction::unary(
1303                                    spirv::Op::CopyObject, // do nothing
1304                                    result_type_id,
1305                                    id,
1306                                    arg0_id,
1307                                ))
1308                            }
1309                            other => unimplemented!("Unexpected abs({:?})", other),
1310                        }
1311                    }
1312                    Mf::Min => MathOp::Ext(match arg_scalar_kind {
1313                        Some(crate::ScalarKind::Float) => spirv::GLOp::FMin,
1314                        Some(crate::ScalarKind::Sint) => spirv::GLOp::SMin,
1315                        Some(crate::ScalarKind::Uint) => spirv::GLOp::UMin,
1316                        other => unimplemented!("Unexpected min({:?})", other),
1317                    }),
1318                    Mf::Max => MathOp::Ext(match arg_scalar_kind {
1319                        Some(crate::ScalarKind::Float) => spirv::GLOp::FMax,
1320                        Some(crate::ScalarKind::Sint) => spirv::GLOp::SMax,
1321                        Some(crate::ScalarKind::Uint) => spirv::GLOp::UMax,
1322                        other => unimplemented!("Unexpected max({:?})", other),
1323                    }),
1324                    Mf::Clamp => match arg_scalar_kind {
1325                        // Clamp is undefined if min > max. In practice this means it can use a median-of-three
1326                        // instruction to determine the value. This is fine according to the WGSL spec for float
1327                        // clamp, but integer clamp _must_ use min-max. As such we write out min/max.
1328                        Some(crate::ScalarKind::Float) => MathOp::Ext(spirv::GLOp::FClamp),
1329                        Some(_) => {
1330                            let (min_op, max_op) = match arg_scalar_kind {
1331                                Some(crate::ScalarKind::Sint) => {
1332                                    (spirv::GLOp::SMin, spirv::GLOp::SMax)
1333                                }
1334                                Some(crate::ScalarKind::Uint) => {
1335                                    (spirv::GLOp::UMin, spirv::GLOp::UMax)
1336                                }
1337                                _ => unreachable!(),
1338                            };
1339
1340                            let max_id = self.gen_id();
1341                            block.body.push(Instruction::ext_inst_gl_op(
1342                                self.writer.gl450_ext_inst_id,
1343                                max_op,
1344                                result_type_id,
1345                                max_id,
1346                                &[arg0_id, arg1_id],
1347                            ));
1348
1349                            MathOp::Custom(Instruction::ext_inst_gl_op(
1350                                self.writer.gl450_ext_inst_id,
1351                                min_op,
1352                                result_type_id,
1353                                id,
1354                                &[max_id, arg2_id],
1355                            ))
1356                        }
1357                        other => unimplemented!("Unexpected max({:?})", other),
1358                    },
1359                    Mf::Saturate => {
1360                        let (maybe_size, scalar) = match *arg_ty {
1361                            crate::TypeInner::Vector { size, scalar } => (Some(size), scalar),
1362                            crate::TypeInner::Scalar(scalar) => (None, scalar),
1363                            ref other => unimplemented!("Unexpected saturate({:?})", other),
1364                        };
1365                        let scalar = crate::Scalar::float(scalar.width);
1366                        let mut arg1_id = self.writer.get_constant_scalar_with(0, scalar)?;
1367                        let mut arg2_id = self.writer.get_constant_scalar_with(1, scalar)?;
1368
1369                        if let Some(size) = maybe_size {
1370                            let ty =
1371                                LocalType::Numeric(NumericType::Vector { size, scalar }).into();
1372
1373                            self.temp_list.clear();
1374                            self.temp_list.resize(size as _, arg1_id);
1375
1376                            arg1_id = self.writer.get_constant_composite(ty, &self.temp_list);
1377
1378                            self.temp_list.fill(arg2_id);
1379
1380                            arg2_id = self.writer.get_constant_composite(ty, &self.temp_list);
1381                        }
1382
1383                        MathOp::Custom(Instruction::ext_inst_gl_op(
1384                            self.writer.gl450_ext_inst_id,
1385                            spirv::GLOp::FClamp,
1386                            result_type_id,
1387                            id,
1388                            &[arg0_id, arg1_id, arg2_id],
1389                        ))
1390                    }
1391                    // trigonometry
1392                    Mf::Sin => MathOp::Ext(spirv::GLOp::Sin),
1393                    Mf::Sinh => MathOp::Ext(spirv::GLOp::Sinh),
1394                    Mf::Asin => MathOp::Ext(spirv::GLOp::Asin),
1395                    Mf::Cos => MathOp::Ext(spirv::GLOp::Cos),
1396                    Mf::Cosh => MathOp::Ext(spirv::GLOp::Cosh),
1397                    Mf::Acos => MathOp::Ext(spirv::GLOp::Acos),
1398                    Mf::Tan => MathOp::Ext(spirv::GLOp::Tan),
1399                    Mf::Tanh => MathOp::Ext(spirv::GLOp::Tanh),
1400                    Mf::Atan => MathOp::Ext(spirv::GLOp::Atan),
1401                    Mf::Atan2 => MathOp::Ext(spirv::GLOp::Atan2),
1402                    Mf::Asinh => MathOp::Ext(spirv::GLOp::Asinh),
1403                    Mf::Acosh => MathOp::Ext(spirv::GLOp::Acosh),
1404                    Mf::Atanh => MathOp::Ext(spirv::GLOp::Atanh),
1405                    Mf::Radians => MathOp::Ext(spirv::GLOp::Radians),
1406                    Mf::Degrees => MathOp::Ext(spirv::GLOp::Degrees),
1407                    // decomposition
1408                    Mf::Ceil => MathOp::Ext(spirv::GLOp::Ceil),
1409                    Mf::Round => MathOp::Ext(spirv::GLOp::RoundEven),
1410                    Mf::Floor => MathOp::Ext(spirv::GLOp::Floor),
1411                    Mf::Fract => MathOp::Ext(spirv::GLOp::Fract),
1412                    Mf::Trunc => MathOp::Ext(spirv::GLOp::Trunc),
1413                    Mf::Modf => MathOp::Ext(spirv::GLOp::ModfStruct),
1414                    Mf::Frexp => MathOp::Ext(spirv::GLOp::FrexpStruct),
1415                    Mf::Ldexp => MathOp::Ext(spirv::GLOp::Ldexp),
1416                    // geometry
1417                    Mf::Dot => match *self.fun_info[arg].ty.inner_with(&self.ir_module.types) {
1418                        crate::TypeInner::Vector {
1419                            scalar:
1420                                crate::Scalar {
1421                                    kind: crate::ScalarKind::Float,
1422                                    ..
1423                                },
1424                            ..
1425                        } => MathOp::Custom(Instruction::binary(
1426                            spirv::Op::Dot,
1427                            result_type_id,
1428                            id,
1429                            arg0_id,
1430                            arg1_id,
1431                        )),
1432                        // TODO: consider using integer dot product if VK_KHR_shader_integer_dot_product is available
1433                        crate::TypeInner::Vector { size, .. } => {
1434                            self.write_dot_product(
1435                                id,
1436                                result_type_id,
1437                                arg0_id,
1438                                arg1_id,
1439                                size as u32,
1440                                block,
1441                                |result_id, composite_id, index| {
1442                                    Instruction::composite_extract(
1443                                        result_type_id,
1444                                        result_id,
1445                                        composite_id,
1446                                        &[index],
1447                                    )
1448                                },
1449                            );
1450                            self.cached[expr_handle] = id;
1451                            return Ok(());
1452                        }
1453                        _ => unreachable!(
1454                            "Correct TypeInner for dot product should be already validated"
1455                        ),
1456                    },
1457                    fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
1458                        if self
1459                            .writer
1460                            .require_all(&[
1461                                spirv::Capability::DotProduct,
1462                                spirv::Capability::DotProductInput4x8BitPacked,
1463                            ])
1464                            .is_ok()
1465                        {
1466                            // Write optimized code using `PackedVectorFormat4x8Bit`.
1467                            if self.writer.lang_version() < (1, 6) {
1468                                // SPIR-V 1.6 supports the required capabilities natively, so the extension
1469                                // is only required for earlier versions. See right column of
1470                                // <https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpSDot>.
1471                                self.writer.use_extension("SPV_KHR_integer_dot_product");
1472                            }
1473
1474                            let op = match fun {
1475                                Mf::Dot4I8Packed => spirv::Op::SDot,
1476                                Mf::Dot4U8Packed => spirv::Op::UDot,
1477                                _ => unreachable!(),
1478                            };
1479
1480                            block.body.push(Instruction::ternary(
1481                                op,
1482                                result_type_id,
1483                                id,
1484                                arg0_id,
1485                                arg1_id,
1486                                spirv::PackedVectorFormat::PackedVectorFormat4x8Bit as Word,
1487                            ));
1488                        } else {
1489                            // Fall back to a polyfill since `PackedVectorFormat4x8Bit` is not available.
1490                            let (extract_op, arg0_id, arg1_id) = match fun {
1491                                Mf::Dot4U8Packed => (spirv::Op::BitFieldUExtract, arg0_id, arg1_id),
1492                                Mf::Dot4I8Packed => {
1493                                    // Convert both packed arguments to signed integers so that we can apply the
1494                                    // `BitFieldSExtract` operation on them in `write_dot_product` below.
1495                                    let new_arg0_id = self.gen_id();
1496                                    block.body.push(Instruction::unary(
1497                                        spirv::Op::Bitcast,
1498                                        result_type_id,
1499                                        new_arg0_id,
1500                                        arg0_id,
1501                                    ));
1502
1503                                    let new_arg1_id = self.gen_id();
1504                                    block.body.push(Instruction::unary(
1505                                        spirv::Op::Bitcast,
1506                                        result_type_id,
1507                                        new_arg1_id,
1508                                        arg1_id,
1509                                    ));
1510
1511                                    (spirv::Op::BitFieldSExtract, new_arg0_id, new_arg1_id)
1512                                }
1513                                _ => unreachable!(),
1514                            };
1515
1516                            let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
1517
1518                            const VEC_LENGTH: u8 = 4;
1519                            let bit_shifts: [_; VEC_LENGTH as usize] =
1520                                core::array::from_fn(|index| {
1521                                    self.writer
1522                                        .get_constant_scalar(crate::Literal::U32(index as u32 * 8))
1523                                });
1524
1525                            self.write_dot_product(
1526                                id,
1527                                result_type_id,
1528                                arg0_id,
1529                                arg1_id,
1530                                VEC_LENGTH as Word,
1531                                block,
1532                                |result_id, composite_id, index| {
1533                                    Instruction::ternary(
1534                                        extract_op,
1535                                        result_type_id,
1536                                        result_id,
1537                                        composite_id,
1538                                        bit_shifts[index as usize],
1539                                        eight,
1540                                    )
1541                                },
1542                            );
1543                        }
1544
1545                        self.cached[expr_handle] = id;
1546                        return Ok(());
1547                    }
1548                    Mf::Outer => MathOp::Custom(Instruction::binary(
1549                        spirv::Op::OuterProduct,
1550                        result_type_id,
1551                        id,
1552                        arg0_id,
1553                        arg1_id,
1554                    )),
1555                    Mf::Cross => MathOp::Ext(spirv::GLOp::Cross),
1556                    Mf::Distance => MathOp::Ext(spirv::GLOp::Distance),
1557                    Mf::Length => MathOp::Ext(spirv::GLOp::Length),
1558                    Mf::Normalize => MathOp::Ext(spirv::GLOp::Normalize),
1559                    Mf::FaceForward => MathOp::Ext(spirv::GLOp::FaceForward),
1560                    Mf::Reflect => MathOp::Ext(spirv::GLOp::Reflect),
1561                    Mf::Refract => MathOp::Ext(spirv::GLOp::Refract),
1562                    // exponent
1563                    Mf::Exp => MathOp::Ext(spirv::GLOp::Exp),
1564                    Mf::Exp2 => MathOp::Ext(spirv::GLOp::Exp2),
1565                    Mf::Log => MathOp::Ext(spirv::GLOp::Log),
1566                    Mf::Log2 => MathOp::Ext(spirv::GLOp::Log2),
1567                    Mf::Pow => MathOp::Ext(spirv::GLOp::Pow),
1568                    // computational
1569                    Mf::Sign => MathOp::Ext(match arg_scalar_kind {
1570                        Some(crate::ScalarKind::Float) => spirv::GLOp::FSign,
1571                        Some(crate::ScalarKind::Sint) => spirv::GLOp::SSign,
1572                        other => unimplemented!("Unexpected sign({:?})", other),
1573                    }),
1574                    Mf::Fma => MathOp::Ext(spirv::GLOp::Fma),
1575                    Mf::Mix => {
1576                        let selector = arg2.unwrap();
1577                        let selector_ty =
1578                            self.fun_info[selector].ty.inner_with(&self.ir_module.types);
1579                        match (arg_ty, selector_ty) {
1580                            // if the selector is a scalar, we need to splat it
1581                            (
1582                                &crate::TypeInner::Vector { size, .. },
1583                                &crate::TypeInner::Scalar(scalar),
1584                            ) => {
1585                                let selector_type_id =
1586                                    self.get_numeric_type_id(NumericType::Vector { size, scalar });
1587                                self.temp_list.clear();
1588                                self.temp_list.resize(size as usize, arg2_id);
1589
1590                                let selector_id = self.gen_id();
1591                                block.body.push(Instruction::composite_construct(
1592                                    selector_type_id,
1593                                    selector_id,
1594                                    &self.temp_list,
1595                                ));
1596
1597                                MathOp::Custom(Instruction::ext_inst_gl_op(
1598                                    self.writer.gl450_ext_inst_id,
1599                                    spirv::GLOp::FMix,
1600                                    result_type_id,
1601                                    id,
1602                                    &[arg0_id, arg1_id, selector_id],
1603                                ))
1604                            }
1605                            _ => MathOp::Ext(spirv::GLOp::FMix),
1606                        }
1607                    }
1608                    Mf::Step => MathOp::Ext(spirv::GLOp::Step),
1609                    Mf::SmoothStep => MathOp::Ext(spirv::GLOp::SmoothStep),
1610                    Mf::Sqrt => MathOp::Ext(spirv::GLOp::Sqrt),
1611                    Mf::InverseSqrt => MathOp::Ext(spirv::GLOp::InverseSqrt),
1612                    Mf::Inverse => MathOp::Ext(spirv::GLOp::MatrixInverse),
1613                    Mf::Transpose => MathOp::Custom(Instruction::unary(
1614                        spirv::Op::Transpose,
1615                        result_type_id,
1616                        id,
1617                        arg0_id,
1618                    )),
1619                    Mf::Determinant => MathOp::Ext(spirv::GLOp::Determinant),
1620                    Mf::QuantizeToF16 => MathOp::Custom(Instruction::unary(
1621                        spirv::Op::QuantizeToF16,
1622                        result_type_id,
1623                        id,
1624                        arg0_id,
1625                    )),
1626                    Mf::ReverseBits => MathOp::Custom(Instruction::unary(
1627                        spirv::Op::BitReverse,
1628                        result_type_id,
1629                        id,
1630                        arg0_id,
1631                    )),
1632                    Mf::CountTrailingZeros => {
1633                        let uint_id = match *arg_ty {
1634                            crate::TypeInner::Vector { size, scalar } => {
1635                                let ty =
1636                                    LocalType::Numeric(NumericType::Vector { size, scalar }).into();
1637
1638                                self.temp_list.clear();
1639                                self.temp_list.resize(
1640                                    size as _,
1641                                    self.writer
1642                                        .get_constant_scalar_with(scalar.width * 8, scalar)?,
1643                                );
1644
1645                                self.writer.get_constant_composite(ty, &self.temp_list)
1646                            }
1647                            crate::TypeInner::Scalar(scalar) => self
1648                                .writer
1649                                .get_constant_scalar_with(scalar.width * 8, scalar)?,
1650                            _ => unreachable!(),
1651                        };
1652
1653                        let lsb_id = self.gen_id();
1654                        block.body.push(Instruction::ext_inst_gl_op(
1655                            self.writer.gl450_ext_inst_id,
1656                            spirv::GLOp::FindILsb,
1657                            result_type_id,
1658                            lsb_id,
1659                            &[arg0_id],
1660                        ));
1661
1662                        MathOp::Custom(Instruction::ext_inst_gl_op(
1663                            self.writer.gl450_ext_inst_id,
1664                            spirv::GLOp::UMin,
1665                            result_type_id,
1666                            id,
1667                            &[uint_id, lsb_id],
1668                        ))
1669                    }
1670                    Mf::CountLeadingZeros => {
1671                        let (int_type_id, int_id, width) = match *arg_ty {
1672                            crate::TypeInner::Vector { size, scalar } => {
1673                                let ty =
1674                                    LocalType::Numeric(NumericType::Vector { size, scalar }).into();
1675
1676                                self.temp_list.clear();
1677                                self.temp_list.resize(
1678                                    size as _,
1679                                    self.writer
1680                                        .get_constant_scalar_with(scalar.width * 8 - 1, scalar)?,
1681                                );
1682
1683                                (
1684                                    self.get_type_id(ty),
1685                                    self.writer.get_constant_composite(ty, &self.temp_list),
1686                                    scalar.width,
1687                                )
1688                            }
1689                            crate::TypeInner::Scalar(scalar) => (
1690                                self.get_numeric_type_id(NumericType::Scalar(scalar)),
1691                                self.writer
1692                                    .get_constant_scalar_with(scalar.width * 8 - 1, scalar)?,
1693                                scalar.width,
1694                            ),
1695                            _ => unreachable!(),
1696                        };
1697
1698                        if width != 4 {
1699                            unreachable!("This is validated out until a polyfill is implemented. https://github.com/gfx-rs/wgpu/issues/5276");
1700                        };
1701
1702                        let msb_id = self.gen_id();
1703                        block.body.push(Instruction::ext_inst_gl_op(
1704                            self.writer.gl450_ext_inst_id,
1705                            if width != 4 {
1706                                spirv::GLOp::FindILsb
1707                            } else {
1708                                spirv::GLOp::FindUMsb
1709                            },
1710                            int_type_id,
1711                            msb_id,
1712                            &[arg0_id],
1713                        ));
1714
1715                        MathOp::Custom(Instruction::binary(
1716                            spirv::Op::ISub,
1717                            result_type_id,
1718                            id,
1719                            int_id,
1720                            msb_id,
1721                        ))
1722                    }
1723                    Mf::CountOneBits => MathOp::Custom(Instruction::unary(
1724                        spirv::Op::BitCount,
1725                        result_type_id,
1726                        id,
1727                        arg0_id,
1728                    )),
1729                    Mf::ExtractBits => {
1730                        let op = match arg_scalar_kind {
1731                            Some(crate::ScalarKind::Uint) => spirv::Op::BitFieldUExtract,
1732                            Some(crate::ScalarKind::Sint) => spirv::Op::BitFieldSExtract,
1733                            other => unimplemented!("Unexpected sign({:?})", other),
1734                        };
1735
1736                        // The behavior of ExtractBits is undefined when offset + count > bit_width. We need
1737                        // to first sanitize the offset and count first. If we don't do this, AMD and Intel
1738                        // will return out-of-spec values if the extracted range is not within the bit width.
1739                        //
1740                        // This encodes the exact formula specified by the wgsl spec:
1741                        // https://gpuweb.github.io/gpuweb/wgsl/#extractBits-unsigned-builtin
1742                        //
1743                        // w = sizeof(x) * 8
1744                        // o = min(offset, w)
1745                        // tmp = w - o
1746                        // c = min(count, tmp)
1747                        //
1748                        // bitfieldExtract(x, o, c)
1749
1750                        let bit_width = arg_ty.scalar_width().unwrap() * 8;
1751                        let width_constant = self
1752                            .writer
1753                            .get_constant_scalar(crate::Literal::U32(bit_width as u32));
1754
1755                        let u32_type =
1756                            self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
1757
1758                        // o = min(offset, w)
1759                        let offset_id = self.gen_id();
1760                        block.body.push(Instruction::ext_inst_gl_op(
1761                            self.writer.gl450_ext_inst_id,
1762                            spirv::GLOp::UMin,
1763                            u32_type,
1764                            offset_id,
1765                            &[arg1_id, width_constant],
1766                        ));
1767
1768                        // tmp = w - o
1769                        let max_count_id = self.gen_id();
1770                        block.body.push(Instruction::binary(
1771                            spirv::Op::ISub,
1772                            u32_type,
1773                            max_count_id,
1774                            width_constant,
1775                            offset_id,
1776                        ));
1777
1778                        // c = min(count, tmp)
1779                        let count_id = self.gen_id();
1780                        block.body.push(Instruction::ext_inst_gl_op(
1781                            self.writer.gl450_ext_inst_id,
1782                            spirv::GLOp::UMin,
1783                            u32_type,
1784                            count_id,
1785                            &[arg2_id, max_count_id],
1786                        ));
1787
1788                        MathOp::Custom(Instruction::ternary(
1789                            op,
1790                            result_type_id,
1791                            id,
1792                            arg0_id,
1793                            offset_id,
1794                            count_id,
1795                        ))
1796                    }
1797                    Mf::InsertBits => {
1798                        // The behavior of InsertBits has the same undefined behavior as ExtractBits.
1799
1800                        let bit_width = arg_ty.scalar_width().unwrap() * 8;
1801                        let width_constant = self
1802                            .writer
1803                            .get_constant_scalar(crate::Literal::U32(bit_width as u32));
1804
1805                        let u32_type =
1806                            self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
1807
1808                        // o = min(offset, w)
1809                        let offset_id = self.gen_id();
1810                        block.body.push(Instruction::ext_inst_gl_op(
1811                            self.writer.gl450_ext_inst_id,
1812                            spirv::GLOp::UMin,
1813                            u32_type,
1814                            offset_id,
1815                            &[arg2_id, width_constant],
1816                        ));
1817
1818                        // tmp = w - o
1819                        let max_count_id = self.gen_id();
1820                        block.body.push(Instruction::binary(
1821                            spirv::Op::ISub,
1822                            u32_type,
1823                            max_count_id,
1824                            width_constant,
1825                            offset_id,
1826                        ));
1827
1828                        // c = min(count, tmp)
1829                        let count_id = self.gen_id();
1830                        block.body.push(Instruction::ext_inst_gl_op(
1831                            self.writer.gl450_ext_inst_id,
1832                            spirv::GLOp::UMin,
1833                            u32_type,
1834                            count_id,
1835                            &[arg3_id, max_count_id],
1836                        ));
1837
1838                        MathOp::Custom(Instruction::quaternary(
1839                            spirv::Op::BitFieldInsert,
1840                            result_type_id,
1841                            id,
1842                            arg0_id,
1843                            arg1_id,
1844                            offset_id,
1845                            count_id,
1846                        ))
1847                    }
1848                    Mf::FirstTrailingBit => MathOp::Ext(spirv::GLOp::FindILsb),
1849                    Mf::FirstLeadingBit => {
1850                        if arg_ty.scalar_width() == Some(4) {
1851                            let thing = match arg_scalar_kind {
1852                                Some(crate::ScalarKind::Uint) => spirv::GLOp::FindUMsb,
1853                                Some(crate::ScalarKind::Sint) => spirv::GLOp::FindSMsb,
1854                                other => unimplemented!("Unexpected firstLeadingBit({:?})", other),
1855                            };
1856                            MathOp::Ext(thing)
1857                        } else {
1858                            unreachable!("This is validated out until a polyfill is implemented. https://github.com/gfx-rs/wgpu/issues/5276");
1859                        }
1860                    }
1861                    Mf::Pack4x8unorm => MathOp::Ext(spirv::GLOp::PackUnorm4x8),
1862                    Mf::Pack4x8snorm => MathOp::Ext(spirv::GLOp::PackSnorm4x8),
1863                    Mf::Pack2x16float => MathOp::Ext(spirv::GLOp::PackHalf2x16),
1864                    Mf::Pack2x16unorm => MathOp::Ext(spirv::GLOp::PackUnorm2x16),
1865                    Mf::Pack2x16snorm => MathOp::Ext(spirv::GLOp::PackSnorm2x16),
1866                    fun @ (Mf::Pack4xI8 | Mf::Pack4xU8 | Mf::Pack4xI8Clamp | Mf::Pack4xU8Clamp) => {
1867                        let is_signed = matches!(fun, Mf::Pack4xI8 | Mf::Pack4xI8Clamp);
1868                        let should_clamp = matches!(fun, Mf::Pack4xI8Clamp | Mf::Pack4xU8Clamp);
1869
1870                        let last_instruction =
1871                            if self.writer.require_all(&[spirv::Capability::Int8]).is_ok() {
1872                                self.write_pack4x8_optimized(
1873                                    block,
1874                                    result_type_id,
1875                                    arg0_id,
1876                                    id,
1877                                    is_signed,
1878                                    should_clamp,
1879                                )
1880                            } else {
1881                                self.write_pack4x8_polyfill(
1882                                    block,
1883                                    result_type_id,
1884                                    arg0_id,
1885                                    id,
1886                                    is_signed,
1887                                    should_clamp,
1888                                )
1889                            };
1890
1891                        MathOp::Custom(last_instruction)
1892                    }
1893                    Mf::Unpack4x8unorm => MathOp::Ext(spirv::GLOp::UnpackUnorm4x8),
1894                    Mf::Unpack4x8snorm => MathOp::Ext(spirv::GLOp::UnpackSnorm4x8),
1895                    Mf::Unpack2x16float => MathOp::Ext(spirv::GLOp::UnpackHalf2x16),
1896                    Mf::Unpack2x16unorm => MathOp::Ext(spirv::GLOp::UnpackUnorm2x16),
1897                    Mf::Unpack2x16snorm => MathOp::Ext(spirv::GLOp::UnpackSnorm2x16),
1898                    fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => {
1899                        let is_signed = matches!(fun, Mf::Unpack4xI8);
1900
1901                        let last_instruction =
1902                            if self.writer.require_all(&[spirv::Capability::Int8]).is_ok() {
1903                                self.write_unpack4x8_optimized(
1904                                    block,
1905                                    result_type_id,
1906                                    arg0_id,
1907                                    id,
1908                                    is_signed,
1909                                )
1910                            } else {
1911                                self.write_unpack4x8_polyfill(
1912                                    block,
1913                                    result_type_id,
1914                                    arg0_id,
1915                                    id,
1916                                    is_signed,
1917                                )
1918                            };
1919
1920                        MathOp::Custom(last_instruction)
1921                    }
1922                };
1923
1924                block.body.push(match math_op {
1925                    MathOp::Ext(op) => Instruction::ext_inst_gl_op(
1926                        self.writer.gl450_ext_inst_id,
1927                        op,
1928                        result_type_id,
1929                        id,
1930                        &[arg0_id, arg1_id, arg2_id, arg3_id][..fun.argument_count()],
1931                    ),
1932                    MathOp::Custom(inst) => inst,
1933                });
1934                id
1935            }
1936            crate::Expression::LocalVariable(variable) => {
1937                if let Some(rq_tracker) = self
1938                    .function
1939                    .ray_query_initialization_tracker_variables
1940                    .get(&variable)
1941                {
1942                    self.ray_query_tracker_expr.insert(
1943                        expr_handle,
1944                        super::RayQueryTrackers {
1945                            initialized_tracker: rq_tracker.id,
1946                            t_max_tracker: self
1947                                .function
1948                                .ray_query_t_max_tracker_variables
1949                                .get(&variable)
1950                                .expect("Both trackers are set at the same time.")
1951                                .id,
1952                        },
1953                    );
1954                }
1955                self.function.variables[&variable].id
1956            }
1957            crate::Expression::Load { pointer } => {
1958                self.write_checked_load(pointer, block, AccessTypeAdjustment::None, result_type_id)?
1959            }
1960            crate::Expression::FunctionArgument(index) => self.function.parameter_id(index),
1961            crate::Expression::CallResult(_)
1962            | crate::Expression::AtomicResult { .. }
1963            | crate::Expression::WorkGroupUniformLoadResult { .. }
1964            | crate::Expression::RayQueryProceedResult
1965            | crate::Expression::SubgroupBallotResult
1966            | crate::Expression::SubgroupOperationResult { .. } => self.cached[expr_handle],
1967            crate::Expression::As {
1968                expr,
1969                kind,
1970                convert,
1971            } => self.write_as_expression(expr, convert, kind, block, result_type_id)?,
1972            crate::Expression::ImageLoad {
1973                image,
1974                coordinate,
1975                array_index,
1976                sample,
1977                level,
1978            } => self.write_image_load(
1979                result_type_id,
1980                image,
1981                coordinate,
1982                array_index,
1983                level,
1984                sample,
1985                block,
1986            )?,
1987            crate::Expression::ImageSample {
1988                image,
1989                sampler,
1990                gather,
1991                coordinate,
1992                array_index,
1993                offset,
1994                level,
1995                depth_ref,
1996                clamp_to_edge,
1997            } => self.write_image_sample(
1998                result_type_id,
1999                image,
2000                sampler,
2001                gather,
2002                coordinate,
2003                array_index,
2004                offset,
2005                level,
2006                depth_ref,
2007                clamp_to_edge,
2008                block,
2009            )?,
2010            crate::Expression::Select {
2011                condition,
2012                accept,
2013                reject,
2014            } => {
2015                let id = self.gen_id();
2016                let mut condition_id = self.cached[condition];
2017                let accept_id = self.cached[accept];
2018                let reject_id = self.cached[reject];
2019
2020                let condition_ty = self.fun_info[condition]
2021                    .ty
2022                    .inner_with(&self.ir_module.types);
2023                let object_ty = self.fun_info[accept].ty.inner_with(&self.ir_module.types);
2024
2025                if let (
2026                    &crate::TypeInner::Scalar(
2027                        condition_scalar @ crate::Scalar {
2028                            kind: crate::ScalarKind::Bool,
2029                            ..
2030                        },
2031                    ),
2032                    &crate::TypeInner::Vector { size, .. },
2033                ) = (condition_ty, object_ty)
2034                {
2035                    self.temp_list.clear();
2036                    self.temp_list.resize(size as usize, condition_id);
2037
2038                    let bool_vector_type_id = self.get_numeric_type_id(NumericType::Vector {
2039                        size,
2040                        scalar: condition_scalar,
2041                    });
2042
2043                    let id = self.gen_id();
2044                    block.body.push(Instruction::composite_construct(
2045                        bool_vector_type_id,
2046                        id,
2047                        &self.temp_list,
2048                    ));
2049                    condition_id = id
2050                }
2051
2052                let instruction =
2053                    Instruction::select(result_type_id, id, condition_id, accept_id, reject_id);
2054                block.body.push(instruction);
2055                id
2056            }
2057            crate::Expression::Derivative { axis, ctrl, expr } => {
2058                use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
2059                match ctrl {
2060                    Ctrl::Coarse | Ctrl::Fine => {
2061                        self.writer.require_any(
2062                            "DerivativeControl",
2063                            &[spirv::Capability::DerivativeControl],
2064                        )?;
2065                    }
2066                    Ctrl::None => {}
2067                }
2068                let id = self.gen_id();
2069                let expr_id = self.cached[expr];
2070                let op = match (axis, ctrl) {
2071                    (Axis::X, Ctrl::Coarse) => spirv::Op::DPdxCoarse,
2072                    (Axis::X, Ctrl::Fine) => spirv::Op::DPdxFine,
2073                    (Axis::X, Ctrl::None) => spirv::Op::DPdx,
2074                    (Axis::Y, Ctrl::Coarse) => spirv::Op::DPdyCoarse,
2075                    (Axis::Y, Ctrl::Fine) => spirv::Op::DPdyFine,
2076                    (Axis::Y, Ctrl::None) => spirv::Op::DPdy,
2077                    (Axis::Width, Ctrl::Coarse) => spirv::Op::FwidthCoarse,
2078                    (Axis::Width, Ctrl::Fine) => spirv::Op::FwidthFine,
2079                    (Axis::Width, Ctrl::None) => spirv::Op::Fwidth,
2080                };
2081                block
2082                    .body
2083                    .push(Instruction::derivative(op, result_type_id, id, expr_id));
2084                id
2085            }
2086            crate::Expression::ImageQuery { image, query } => {
2087                self.write_image_query(result_type_id, image, query, block)?
2088            }
2089            crate::Expression::Relational { fun, argument } => {
2090                use crate::RelationalFunction as Rf;
2091                let arg_id = self.cached[argument];
2092                let op = match fun {
2093                    Rf::All => spirv::Op::All,
2094                    Rf::Any => spirv::Op::Any,
2095                    Rf::IsNan => spirv::Op::IsNan,
2096                    Rf::IsInf => spirv::Op::IsInf,
2097                };
2098                let id = self.gen_id();
2099                block
2100                    .body
2101                    .push(Instruction::relational(op, result_type_id, id, arg_id));
2102                id
2103            }
2104            crate::Expression::ArrayLength(expr) => self.write_runtime_array_length(expr, block)?,
2105            crate::Expression::RayQueryGetIntersection { query, committed } => {
2106                let query_id = self.cached[query];
2107                let init_tracker_id = *self
2108                    .ray_query_tracker_expr
2109                    .get(&query)
2110                    .expect("not a cached ray query");
2111                let func_id = self
2112                    .writer
2113                    .write_ray_query_get_intersection_function(committed, self.ir_module);
2114                let ray_intersection = self.ir_module.special_types.ray_intersection.unwrap();
2115                let intersection_type_id = self.get_handle_type_id(ray_intersection);
2116                let id = self.gen_id();
2117                block.body.push(Instruction::function_call(
2118                    intersection_type_id,
2119                    id,
2120                    func_id,
2121                    &[query_id, init_tracker_id.initialized_tracker],
2122                ));
2123                id
2124            }
2125            crate::Expression::RayQueryVertexPositions { query, committed } => {
2126                self.writer.require_any(
2127                    "RayQueryVertexPositions",
2128                    &[spirv::Capability::RayQueryPositionFetchKHR],
2129                )?;
2130                self.write_ray_query_return_vertex_position(query, block, committed)
2131            }
2132            crate::Expression::CooperativeLoad { ref data, .. } => {
2133                self.writer.require_any(
2134                    "CooperativeMatrix",
2135                    &[spirv::Capability::CooperativeMatrixKHR],
2136                )?;
2137                let layout = if data.row_major {
2138                    spirv::CooperativeMatrixLayout::RowMajorKHR
2139                } else {
2140                    spirv::CooperativeMatrixLayout::ColumnMajorKHR
2141                };
2142                let layout_id = self.get_index_constant(layout as u32);
2143                let stride_id = self.cached[data.stride];
2144                match self.write_access_chain(data.pointer, block, AccessTypeAdjustment::None)? {
2145                    ExpressionPointer::Ready { pointer_id } => {
2146                        let id = self.gen_id();
2147                        block.body.push(Instruction::coop_load(
2148                            result_type_id,
2149                            id,
2150                            pointer_id,
2151                            layout_id,
2152                            stride_id,
2153                        ));
2154                        id
2155                    }
2156                    ExpressionPointer::Conditional { condition, access } => self
2157                        .write_conditional_indexed_load(
2158                            result_type_id,
2159                            condition,
2160                            block,
2161                            |id_gen, block| {
2162                                let pointer_id = access.result_id.unwrap();
2163                                block.body.push(access);
2164                                let id = id_gen.next();
2165                                block.body.push(Instruction::coop_load(
2166                                    result_type_id,
2167                                    id,
2168                                    pointer_id,
2169                                    layout_id,
2170                                    stride_id,
2171                                ));
2172                                id
2173                            },
2174                        ),
2175                }
2176            }
2177            crate::Expression::CooperativeMultiplyAdd { a, b, c } => {
2178                self.writer.require_any(
2179                    "CooperativeMatrix",
2180                    &[spirv::Capability::CooperativeMatrixKHR],
2181                )?;
2182                let a_id = self.cached[a];
2183                let b_id = self.cached[b];
2184                let c_id = self.cached[c];
2185                let id = self.gen_id();
2186                block.body.push(Instruction::coop_mul_add(
2187                    result_type_id,
2188                    id,
2189                    a_id,
2190                    b_id,
2191                    c_id,
2192                ));
2193                id
2194            }
2195        };
2196
2197        self.cached[expr_handle] = id;
2198        Ok(())
2199    }
2200
2201    /// Helper which focuses on generating the `As` expressions and the various conversions
2202    /// that need to happen because of that.
2203    fn write_as_expression(
2204        &mut self,
2205        expr: Handle<crate::Expression>,
2206        convert: Option<u8>,
2207        kind: crate::ScalarKind,
2208
2209        block: &mut Block,
2210        result_type_id: u32,
2211    ) -> Result<u32, Error> {
2212        use crate::ScalarKind as Sk;
2213        let expr_id = self.cached[expr];
2214        let ty = self.fun_info[expr].ty.inner_with(&self.ir_module.types);
2215
2216        // Matrix casts needs special treatment in SPIR-V, as the cast functions
2217        // can take vectors or scalars, but not matrices. In order to cast a matrix
2218        // we need to cast each column of the matrix individually and construct a new
2219        // matrix from the converted columns.
2220        if let crate::TypeInner::Matrix {
2221            columns,
2222            rows,
2223            scalar,
2224        } = *ty
2225        {
2226            let Some(convert) = convert else {
2227                // No conversion needs to be done, passes through.
2228                return Ok(expr_id);
2229            };
2230
2231            if convert == scalar.width {
2232                // No conversion needs to be done, passes through.
2233                return Ok(expr_id);
2234            }
2235
2236            if kind != Sk::Float {
2237                // Only float conversions are supported for matrices.
2238                return Err(Error::Validation("Matrices must be floats"));
2239            }
2240
2241            // Type of each extracted column
2242            let column_src_ty =
2243                self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector {
2244                    size: rows,
2245                    scalar,
2246                })));
2247
2248            // Type of the column after conversion
2249            let column_dst_ty =
2250                self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector {
2251                    size: rows,
2252                    scalar: crate::Scalar {
2253                        kind,
2254                        width: convert,
2255                    },
2256                })));
2257
2258            let mut components = ArrayVec::<Word, 4>::new();
2259
2260            for column in 0..columns as usize {
2261                let column_id = self.gen_id();
2262                block.body.push(Instruction::composite_extract(
2263                    column_src_ty,
2264                    column_id,
2265                    expr_id,
2266                    &[column as u32],
2267                ));
2268
2269                let column_conv_id = self.gen_id();
2270                block.body.push(Instruction::unary(
2271                    spirv::Op::FConvert,
2272                    column_dst_ty,
2273                    column_conv_id,
2274                    column_id,
2275                ));
2276
2277                components.push(column_conv_id);
2278            }
2279
2280            let construct_id = self.gen_id();
2281
2282            block.body.push(Instruction::composite_construct(
2283                result_type_id,
2284                construct_id,
2285                &components,
2286            ));
2287
2288            return Ok(construct_id);
2289        }
2290
2291        let (src_scalar, src_size) = match *ty {
2292            crate::TypeInner::Scalar(scalar) => (scalar, None),
2293            crate::TypeInner::Vector { scalar, size } => (scalar, Some(size)),
2294            ref other => {
2295                log::error!("As source {other:?}");
2296                return Err(Error::Validation("Unexpected Expression::As source"));
2297            }
2298        };
2299
2300        enum Cast {
2301            Identity(Word),
2302            Unary(spirv::Op, Word),
2303            Binary(spirv::Op, Word, Word),
2304            Ternary(spirv::Op, Word, Word, Word),
2305        }
2306        let cast = match (src_scalar.kind, kind, convert) {
2307            // Filter out identity casts. Some Adreno drivers are
2308            // confused by no-op OpBitCast instructions.
2309            (src_kind, kind, convert)
2310                if src_kind == kind
2311                    && convert.filter(|&width| width != src_scalar.width).is_none() =>
2312            {
2313                Cast::Identity(expr_id)
2314            }
2315            (Sk::Bool, Sk::Bool, _) => Cast::Unary(spirv::Op::CopyObject, expr_id),
2316            (_, _, None) => Cast::Unary(spirv::Op::Bitcast, expr_id),
2317            // casting to a bool - generate `OpXxxNotEqual`
2318            (_, Sk::Bool, Some(_)) => {
2319                let op = match src_scalar.kind {
2320                    Sk::Sint | Sk::Uint => spirv::Op::INotEqual,
2321                    Sk::Float => spirv::Op::FUnordNotEqual,
2322                    Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => unreachable!(),
2323                };
2324                let zero_scalar_id = self.writer.get_constant_scalar_with(0, src_scalar)?;
2325                let zero_id = match src_size {
2326                    Some(size) => {
2327                        let ty = LocalType::Numeric(NumericType::Vector {
2328                            size,
2329                            scalar: src_scalar,
2330                        })
2331                        .into();
2332
2333                        self.temp_list.clear();
2334                        self.temp_list.resize(size as _, zero_scalar_id);
2335
2336                        self.writer.get_constant_composite(ty, &self.temp_list)
2337                    }
2338                    None => zero_scalar_id,
2339                };
2340
2341                Cast::Binary(op, expr_id, zero_id)
2342            }
2343            // casting from a bool - generate `OpSelect`
2344            (Sk::Bool, _, Some(dst_width)) => {
2345                let dst_scalar = crate::Scalar {
2346                    kind,
2347                    width: dst_width,
2348                };
2349                let zero_scalar_id = self.writer.get_constant_scalar_with(0, dst_scalar)?;
2350                let one_scalar_id = self.writer.get_constant_scalar_with(1, dst_scalar)?;
2351                let (accept_id, reject_id) = match src_size {
2352                    Some(size) => {
2353                        let ty = LocalType::Numeric(NumericType::Vector {
2354                            size,
2355                            scalar: dst_scalar,
2356                        })
2357                        .into();
2358
2359                        self.temp_list.clear();
2360                        self.temp_list.resize(size as _, zero_scalar_id);
2361
2362                        let vec0_id = self.writer.get_constant_composite(ty, &self.temp_list);
2363
2364                        self.temp_list.fill(one_scalar_id);
2365
2366                        let vec1_id = self.writer.get_constant_composite(ty, &self.temp_list);
2367
2368                        (vec1_id, vec0_id)
2369                    }
2370                    None => (one_scalar_id, zero_scalar_id),
2371                };
2372
2373                Cast::Ternary(spirv::Op::Select, expr_id, accept_id, reject_id)
2374            }
2375            // Avoid undefined behaviour when casting from a float to integer
2376            // when the value is out of range for the target type. Additionally
2377            // ensure we clamp to the correct value as per the WGSL spec.
2378            //
2379            // https://www.w3.org/TR/WGSL/#floating-point-conversion:
2380            // * If X is exactly representable in the target type T, then the
2381            //   result is that value.
2382            // * Otherwise, the result is the value in T closest to
2383            //   truncate(X) and also exactly representable in the original
2384            //   floating point type.
2385            (Sk::Float, Sk::Sint | Sk::Uint, Some(width)) => {
2386                let dst_scalar = crate::Scalar { kind, width };
2387                let (min, max) =
2388                    crate::proc::min_max_float_representable_by(src_scalar, dst_scalar);
2389                let expr_type_id = self.get_expression_type_id(&self.fun_info[expr].ty);
2390
2391                let maybe_splat_const = |writer: &mut Writer, const_id| match src_size {
2392                    None => const_id,
2393                    Some(size) => {
2394                        let constituent_ids = [const_id; crate::VectorSize::MAX];
2395                        writer.get_constant_composite(
2396                            LookupType::Local(LocalType::Numeric(NumericType::Vector {
2397                                size,
2398                                scalar: src_scalar,
2399                            })),
2400                            &constituent_ids[..size as usize],
2401                        )
2402                    }
2403                };
2404                let min_const_id = self.writer.get_constant_scalar(min);
2405                let min_const_id = maybe_splat_const(self.writer, min_const_id);
2406                let max_const_id = self.writer.get_constant_scalar(max);
2407                let max_const_id = maybe_splat_const(self.writer, max_const_id);
2408
2409                let clamp_id = self.gen_id();
2410                block.body.push(Instruction::ext_inst_gl_op(
2411                    self.writer.gl450_ext_inst_id,
2412                    spirv::GLOp::FClamp,
2413                    expr_type_id,
2414                    clamp_id,
2415                    &[expr_id, min_const_id, max_const_id],
2416                ));
2417
2418                let op = match dst_scalar.kind {
2419                    crate::ScalarKind::Sint => spirv::Op::ConvertFToS,
2420                    crate::ScalarKind::Uint => spirv::Op::ConvertFToU,
2421                    _ => unreachable!(),
2422                };
2423                Cast::Unary(op, clamp_id)
2424            }
2425            (Sk::Float, Sk::Float, Some(dst_width)) if src_scalar.width != dst_width => {
2426                Cast::Unary(spirv::Op::FConvert, expr_id)
2427            }
2428            (Sk::Sint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertSToF, expr_id),
2429            (Sk::Sint, Sk::Sint, Some(dst_width)) if src_scalar.width != dst_width => {
2430                Cast::Unary(spirv::Op::SConvert, expr_id)
2431            }
2432            (Sk::Uint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertUToF, expr_id),
2433            (Sk::Uint, Sk::Uint, Some(dst_width)) if src_scalar.width != dst_width => {
2434                Cast::Unary(spirv::Op::UConvert, expr_id)
2435            }
2436            (Sk::Uint, Sk::Sint, Some(dst_width)) if src_scalar.width != dst_width => {
2437                Cast::Unary(spirv::Op::SConvert, expr_id)
2438            }
2439            (Sk::Sint, Sk::Uint, Some(dst_width)) if src_scalar.width != dst_width => {
2440                Cast::Unary(spirv::Op::UConvert, expr_id)
2441            }
2442            // We assume it's either an identity cast, or int-uint.
2443            _ => Cast::Unary(spirv::Op::Bitcast, expr_id),
2444        };
2445        Ok(match cast {
2446            Cast::Identity(expr) => expr,
2447            Cast::Unary(op, op1) => {
2448                let id = self.gen_id();
2449                block
2450                    .body
2451                    .push(Instruction::unary(op, result_type_id, id, op1));
2452                id
2453            }
2454            Cast::Binary(op, op1, op2) => {
2455                let id = self.gen_id();
2456                block
2457                    .body
2458                    .push(Instruction::binary(op, result_type_id, id, op1, op2));
2459                id
2460            }
2461            Cast::Ternary(op, op1, op2, op3) => {
2462                let id = self.gen_id();
2463                block
2464                    .body
2465                    .push(Instruction::ternary(op, result_type_id, id, op1, op2, op3));
2466                id
2467            }
2468        })
2469    }
2470
2471    /// Build an `OpAccessChain` instruction.
2472    ///
2473    /// Emit any needed bounds-checking expressions to `block`.
2474    ///
2475    /// Give the `OpAccessChain` a result type based on `expr_handle`, adjusted
2476    /// according to `type_adjustment`; see the documentation for
2477    /// [`AccessTypeAdjustment`] for details.
2478    ///
2479    /// On success, the return value is an [`ExpressionPointer`] value; see the
2480    /// documentation for that type.
2481    fn write_access_chain(
2482        &mut self,
2483        mut expr_handle: Handle<crate::Expression>,
2484        block: &mut Block,
2485        type_adjustment: AccessTypeAdjustment,
2486    ) -> Result<ExpressionPointer, Error> {
2487        let result_type_id = {
2488            let resolution = &self.fun_info[expr_handle].ty;
2489            match type_adjustment {
2490                AccessTypeAdjustment::None => self.writer.get_expression_type_id(resolution),
2491                AccessTypeAdjustment::IntroducePointer(class) => {
2492                    self.writer.get_resolution_pointer_id(resolution, class)
2493                }
2494                AccessTypeAdjustment::UseStd140CompatType => {
2495                    match *resolution.inner_with(&self.ir_module.types) {
2496                        crate::TypeInner::Pointer {
2497                            base,
2498                            space: space @ crate::AddressSpace::Uniform,
2499                        } => self.writer.get_pointer_type_id(
2500                            self.writer.std140_compat_uniform_types[&base].type_id,
2501                            map_storage_class(space),
2502                        ),
2503                        _ => unreachable!(
2504                            "`UseStd140CompatType` must only be used with uniform pointer types"
2505                        ),
2506                    }
2507                }
2508            }
2509        };
2510
2511        // The id of the boolean `and` of all dynamic bounds checks up to this point.
2512        //
2513        // See `extend_bounds_check_condition_chain` for a full explanation.
2514        let mut accumulated_checks = None;
2515
2516        // Is true if we are accessing into a binding array with a non-uniform index.
2517        let mut is_non_uniform_binding_array = false;
2518
2519        // The index value if the previously encountered expression was an
2520        // `AccessIndex` of a matrix which has been decomposed into individual
2521        // column vectors directly in the containing struct. The subsequent
2522        // iteration will append the correct index to the list for accessing
2523        // said column from the containing struct.
2524        let mut prev_decomposed_matrix_index = None;
2525
2526        self.temp_list.clear();
2527        let root_id = loop {
2528            // If `expr_handle` was spilled, then the temporary variable has exactly
2529            // the value we want to start from.
2530            if let Some(spilled) = self.function.spilled_composites.get(&expr_handle) {
2531                // The root id of the `OpAccessChain` instruction is the temporary
2532                // variable we spilled the composite to.
2533                break spilled.id;
2534            }
2535
2536            expr_handle = match self.ir_function.expressions[expr_handle] {
2537                crate::Expression::Access { base, index } => {
2538                    is_non_uniform_binding_array |=
2539                        self.is_nonuniform_binding_array_access(base, index);
2540
2541                    let index = GuardedIndex::Expression(index);
2542                    let index_id =
2543                        self.write_access_chain_index(base, index, &mut accumulated_checks, block)?;
2544                    self.temp_list.push(index_id);
2545
2546                    base
2547                }
2548                crate::Expression::AccessIndex { base, index } => {
2549                    // Decide whether we're indexing a struct (bounds checks
2550                    // forbidden) or anything else (bounds checks required).
2551                    let mut base_ty = self.fun_info[base].ty.inner_with(&self.ir_module.types);
2552                    let mut base_ty_handle = self.fun_info[base].ty.handle();
2553                    let mut pointer_space = None;
2554                    if let crate::TypeInner::Pointer { base, space } = *base_ty {
2555                        base_ty = &self.ir_module.types[base].inner;
2556                        base_ty_handle = Some(base);
2557                        pointer_space = Some(space);
2558                    }
2559                    match *base_ty {
2560                        // When indexing a struct bounds checks are forbidden. If accessing the
2561                        // struct through a uniform address space pointer, where the struct has
2562                        // been declared with an alternative std140 compatible layout, we must use
2563                        // the remapped member index. Additionally if the previous iteration was
2564                        // accessing a column of a matrix member which has been decomposed directly
2565                        // into the struct, we must ensure we access the correct column.
2566                        crate::TypeInner::Struct { .. } => {
2567                            let index = match base_ty_handle.and_then(|handle| {
2568                                self.writer.std140_compat_uniform_types.get(&handle)
2569                            }) {
2570                                Some(std140_type_info)
2571                                    if pointer_space == Some(crate::AddressSpace::Uniform) =>
2572                                {
2573                                    std140_type_info.member_indices[index as usize]
2574                                        + prev_decomposed_matrix_index.take().unwrap_or(0)
2575                                }
2576                                _ => index,
2577                            };
2578                            let index_id = self.get_index_constant(index);
2579                            self.temp_list.push(index_id);
2580                        }
2581                        // Bounds checks are not required when indexing a matrix. If indexing a
2582                        // two-row matrix contained within a struct through a uniform address space
2583                        // pointer then the matrix' columns will have been decomposed directly into
2584                        // the containing struct. We skip adding an index to the list on this
2585                        // iteration and instead adjust the index on the next iteration when
2586                        // accessing the struct member.
2587                        _ if is_uniform_matcx2_struct_member_access(
2588                            self.ir_function,
2589                            self.fun_info,
2590                            self.ir_module,
2591                            base,
2592                        ) =>
2593                        {
2594                            assert!(prev_decomposed_matrix_index.is_none());
2595                            prev_decomposed_matrix_index = Some(index);
2596                        }
2597                        _ => {
2598                            // `index` is constant, so this can't possibly require
2599                            // setting `is_nonuniform_binding_array_access`.
2600
2601                            // Even though the index value is statically known, `base`
2602                            // may be a runtime-sized array, so we still need to go
2603                            // through the bounds check process.
2604                            let index_id = self.write_access_chain_index(
2605                                base,
2606                                GuardedIndex::Known(index),
2607                                &mut accumulated_checks,
2608                                block,
2609                            )?;
2610                            self.temp_list.push(index_id);
2611                        }
2612                    }
2613                    base
2614                }
2615                crate::Expression::GlobalVariable(handle) => {
2616                    let gv = &self.writer.global_variables[handle];
2617                    break gv.access_id;
2618                }
2619                crate::Expression::LocalVariable(variable) => {
2620                    let local_var = &self.function.variables[&variable];
2621                    break local_var.id;
2622                }
2623                crate::Expression::FunctionArgument(index) => {
2624                    break self.function.parameter_id(index);
2625                }
2626                ref other => unimplemented!("Unexpected pointer expression {:?}", other),
2627            }
2628        };
2629
2630        let (pointer_id, expr_pointer) = if self.temp_list.is_empty() {
2631            (
2632                root_id,
2633                ExpressionPointer::Ready {
2634                    pointer_id: root_id,
2635                },
2636            )
2637        } else {
2638            self.temp_list.reverse();
2639            let pointer_id = self.gen_id();
2640            let access =
2641                Instruction::access_chain(result_type_id, pointer_id, root_id, &self.temp_list);
2642
2643            // If we generated some bounds checks, we need to leave it to our
2644            // caller to generate the branch, the access, the load or store, and
2645            // the zero value (for loads). Otherwise, we can emit the access
2646            // ourselves, and just hand them the id of the pointer.
2647            let expr_pointer = match accumulated_checks {
2648                Some(condition) => ExpressionPointer::Conditional { condition, access },
2649                None => {
2650                    block.body.push(access);
2651                    ExpressionPointer::Ready { pointer_id }
2652                }
2653            };
2654            (pointer_id, expr_pointer)
2655        };
2656        // Subsequent load, store and atomic operations require the pointer to be decorated as NonUniform
2657        // if the binding array was accessed with a non-uniform index
2658        // see VUID-RuntimeSpirv-NonUniform-06274
2659        if is_non_uniform_binding_array {
2660            self.writer
2661                .decorate_non_uniform_binding_array_access(pointer_id)?;
2662        }
2663
2664        Ok(expr_pointer)
2665    }
2666
2667    fn is_nonuniform_binding_array_access(
2668        &mut self,
2669        base: Handle<crate::Expression>,
2670        index: Handle<crate::Expression>,
2671    ) -> bool {
2672        let crate::Expression::GlobalVariable(var_handle) = self.ir_function.expressions[base]
2673        else {
2674            return false;
2675        };
2676
2677        // The access chain needs to be decorated as NonUniform
2678        // see VUID-RuntimeSpirv-NonUniform-06274
2679        let gvar = &self.ir_module.global_variables[var_handle];
2680        let crate::TypeInner::BindingArray { .. } = self.ir_module.types[gvar.ty].inner else {
2681            return false;
2682        };
2683
2684        self.fun_info[index].uniformity.non_uniform_result.is_some()
2685    }
2686
2687    /// Compute a single index operand to an `OpAccessChain` instruction.
2688    ///
2689    /// Given that we are indexing `base` with `index`, apply the appropriate
2690    /// bounds check policies, emitting code to `block` to clamp `index` or
2691    /// determine whether it's in bounds. Return the SPIR-V instruction id of
2692    /// the index value we should actually use.
2693    ///
2694    /// Extend `accumulated_checks` to include the results of any needed bounds
2695    /// checks. See [`BlockContext::extend_bounds_check_condition_chain`].
2696    fn write_access_chain_index(
2697        &mut self,
2698        base: Handle<crate::Expression>,
2699        index: GuardedIndex,
2700        accumulated_checks: &mut Option<Word>,
2701        block: &mut Block,
2702    ) -> Result<Word, Error> {
2703        match self.write_bounds_check(base, index, block)? {
2704            BoundsCheckResult::KnownInBounds(known_index) => {
2705                // Even if the index is known, `OpAccessChain`
2706                // requires expression operands, not literals.
2707                let scalar = crate::Literal::U32(known_index);
2708                Ok(self.writer.get_constant_scalar(scalar))
2709            }
2710            BoundsCheckResult::Computed(computed_index_id) => Ok(computed_index_id),
2711            BoundsCheckResult::Conditional {
2712                condition_id: condition,
2713                index_id: index,
2714            } => {
2715                self.extend_bounds_check_condition_chain(accumulated_checks, condition, block);
2716
2717                // Use the index from the `Access` expression unchanged.
2718                Ok(index)
2719            }
2720        }
2721    }
2722
2723    /// Add a condition to a chain of bounds checks.
2724    ///
2725    /// As we build an `OpAccessChain` instruction govered by
2726    /// [`BoundsCheckPolicy::ReadZeroSkipWrite`], we accumulate a chain of
2727    /// dynamic bounds checks, one for each index in the chain, which must all
2728    /// be true for that `OpAccessChain`'s execution to be well-defined. This
2729    /// function adds the boolean instruction id `comparison_id` to `chain`.
2730    ///
2731    /// If `chain` is `None`, that means there are no bounds checks in the chain
2732    /// yet. If chain is `Some(id)`, then `id` is the conjunction of all the
2733    /// bounds checks in the chain.
2734    ///
2735    /// When we have multiple bounds checks, we combine them with
2736    /// `OpLogicalAnd`, not a short-circuit branch. This means we might do
2737    /// comparisons we don't need to, but we expect these checks to almost
2738    /// always succeed, and keeping branches to a minimum is essential.
2739    ///
2740    /// [`BoundsCheckPolicy::ReadZeroSkipWrite`]: crate::proc::BoundsCheckPolicy
2741    fn extend_bounds_check_condition_chain(
2742        &mut self,
2743        chain: &mut Option<Word>,
2744        comparison_id: Word,
2745        block: &mut Block,
2746    ) {
2747        match *chain {
2748            Some(ref mut prior_checks) => {
2749                let combined = self.gen_id();
2750                block.body.push(Instruction::binary(
2751                    spirv::Op::LogicalAnd,
2752                    self.writer.get_bool_type_id(),
2753                    combined,
2754                    *prior_checks,
2755                    comparison_id,
2756                ));
2757                *prior_checks = combined;
2758            }
2759            None => {
2760                // Start a fresh chain of checks.
2761                *chain = Some(comparison_id);
2762            }
2763        }
2764    }
2765
2766    fn write_checked_load(
2767        &mut self,
2768        pointer: Handle<crate::Expression>,
2769        block: &mut Block,
2770        access_type_adjustment: AccessTypeAdjustment,
2771        result_type_id: Word,
2772    ) -> Result<Word, Error> {
2773        if let Some(result_id) = self.maybe_write_uniform_matcx2_dynamic_access(pointer, block)? {
2774            Ok(result_id)
2775        } else if let Some(result_id) =
2776            self.maybe_write_load_uniform_matcx2_struct_member(pointer, block)?
2777        {
2778            Ok(result_id)
2779        } else {
2780            // If `pointer` refers to a uniform address space pointer to a type
2781            // which was declared using a std140 compatible type variant (i.e.
2782            // is a two-row matrix, or a struct or array containing such a
2783            // matrix) we must ensure the access chain and the type of the load
2784            // instruction use the std140 compatible type variant.
2785            struct WrappedLoad {
2786                access_type_adjustment: AccessTypeAdjustment,
2787                r#type: Handle<crate::Type>,
2788            }
2789            let mut wrapped_load = None;
2790            if let crate::TypeInner::Pointer {
2791                base: pointer_base_type,
2792                space: crate::AddressSpace::Uniform,
2793            } = *self.fun_info[pointer].ty.inner_with(&self.ir_module.types)
2794            {
2795                if self
2796                    .writer
2797                    .std140_compat_uniform_types
2798                    .contains_key(&pointer_base_type)
2799                {
2800                    wrapped_load = Some(WrappedLoad {
2801                        access_type_adjustment: AccessTypeAdjustment::UseStd140CompatType,
2802                        r#type: pointer_base_type,
2803                    });
2804                };
2805            };
2806
2807            let (load_type_id, access_type_adjustment) = match wrapped_load {
2808                Some(ref wrapped_load) => (
2809                    self.writer.std140_compat_uniform_types[&wrapped_load.r#type].type_id,
2810                    wrapped_load.access_type_adjustment,
2811                ),
2812                None => (result_type_id, access_type_adjustment),
2813            };
2814
2815            let load_id = match self.write_access_chain(pointer, block, access_type_adjustment)? {
2816                ExpressionPointer::Ready { pointer_id } => {
2817                    let id = self.gen_id();
2818                    let atomic_space =
2819                        match *self.fun_info[pointer].ty.inner_with(&self.ir_module.types) {
2820                            crate::TypeInner::Pointer { base, space } => {
2821                                match self.ir_module.types[base].inner {
2822                                    crate::TypeInner::Atomic { .. } => Some(space),
2823                                    _ => None,
2824                                }
2825                            }
2826                            _ => None,
2827                        };
2828                    let instruction = if let Some(space) = atomic_space {
2829                        let (semantics, scope) = space.to_spirv_semantics_and_scope();
2830                        let scope_constant_id = self.get_scope_constant(scope as u32);
2831                        let semantics_id = self.get_index_constant(semantics.bits());
2832                        Instruction::atomic_load(
2833                            result_type_id,
2834                            id,
2835                            pointer_id,
2836                            scope_constant_id,
2837                            semantics_id,
2838                        )
2839                    } else {
2840                        Instruction::load(load_type_id, id, pointer_id, None)
2841                    };
2842                    block.body.push(instruction);
2843                    id
2844                }
2845                ExpressionPointer::Conditional { condition, access } => {
2846                    //TODO: support atomics?
2847                    self.write_conditional_indexed_load(
2848                        load_type_id,
2849                        condition,
2850                        block,
2851                        move |id_gen, block| {
2852                            // The in-bounds path. Perform the access and the load.
2853                            let pointer_id = access.result_id.unwrap();
2854                            let value_id = id_gen.next();
2855                            block.body.push(access);
2856                            block.body.push(Instruction::load(
2857                                load_type_id,
2858                                value_id,
2859                                pointer_id,
2860                                None,
2861                            ));
2862                            value_id
2863                        },
2864                    )
2865                }
2866            };
2867
2868            match wrapped_load {
2869                Some(ref wrapped_load) => {
2870                    // If we loaded a std140 compat type then we must call the
2871                    // function to convert the loaded value to the regular type.
2872                    let result_id = self.gen_id();
2873                    let function_id = self.writer.wrapped_functions
2874                        [&WrappedFunction::ConvertFromStd140CompatType {
2875                            r#type: wrapped_load.r#type,
2876                        }];
2877                    block.body.push(Instruction::function_call(
2878                        result_type_id,
2879                        result_id,
2880                        function_id,
2881                        &[load_id],
2882                    ));
2883                    Ok(result_id)
2884                }
2885                None => Ok(load_id),
2886            }
2887        }
2888    }
2889
2890    fn spill_to_internal_variable(&mut self, base: Handle<crate::Expression>, block: &mut Block) {
2891        use indexmap::map::Entry;
2892
2893        // Make sure we have an internal variable to spill `base` to.
2894        let spill_variable_id = match self.function.spilled_composites.entry(base) {
2895            Entry::Occupied(preexisting) => preexisting.get().id,
2896            Entry::Vacant(vacant) => {
2897                // Generate a new internal variable of the appropriate
2898                // type for `base`.
2899                let pointer_type_id = self.writer.get_resolution_pointer_id(
2900                    &self.fun_info[base].ty,
2901                    spirv::StorageClass::Function,
2902                );
2903                let id = self.writer.id_gen.next();
2904                vacant.insert(super::LocalVariable {
2905                    id,
2906                    instruction: Instruction::variable(
2907                        pointer_type_id,
2908                        id,
2909                        spirv::StorageClass::Function,
2910                        None,
2911                    ),
2912                });
2913                id
2914            }
2915        };
2916
2917        // Perform the store even if we already had a spill variable for `base`.
2918        // Consider this code:
2919        //
2920        // var x = ...;
2921        // var y = ...;
2922        // var z = ...;
2923        // for (i = 0; i<2; i++) {
2924        //     let a = array(i, i, i);
2925        //     if (i == 0) {
2926        //         x += a[y];
2927        //     } else [
2928        //         x += a[z];
2929        //     }
2930        // }
2931        //
2932        // The value of `a` needs to be spilled so we can subscript it with `y` and `z`.
2933        //
2934        // When we generate SPIR-V for `a[y]`, we will create the spill
2935        // variable, and store `a`'s value in it.
2936        //
2937        // When we generate SPIR-V for `a[z]`, we will notice that the spill
2938        // variable for `a` has already been declared, but it is still essential
2939        // that we store `a` into it, so that `a[z]` sees this iteration's value
2940        // of `a`.
2941        let base_id = self.cached[base];
2942        block
2943            .body
2944            .push(Instruction::store(spill_variable_id, base_id, None));
2945    }
2946
2947    /// Generate an access to a spilled temporary, if necessary.
2948    ///
2949    /// Given `access`, an [`Access`] or [`AccessIndex`] expression that refers
2950    /// to a component of a composite value that has been spilled to a temporary
2951    /// variable, determine whether other expressions are going to use
2952    /// `access`'s value:
2953    ///
2954    /// - If so, perform the access and cache that as the value of `access`.
2955    ///
2956    /// - Otherwise, generate no code and cache no value for `access`.
2957    ///
2958    /// Return `Ok(0)` if no value was fetched, or `Ok(id)` if we loaded it into
2959    /// the instruction given by `id`.
2960    ///
2961    /// [`Access`]: crate::Expression::Access
2962    /// [`AccessIndex`]: crate::Expression::AccessIndex
2963    fn maybe_access_spilled_composite(
2964        &mut self,
2965        access: Handle<crate::Expression>,
2966        block: &mut Block,
2967        result_type_id: Word,
2968    ) -> Result<Word, Error> {
2969        let access_uses = self.function.access_uses.get(&access).map_or(0, |r| *r);
2970        if access_uses == self.fun_info[access].ref_count {
2971            // This expression is only used by other `Access` and
2972            // `AccessIndex` expressions, so we don't need to cache a
2973            // value for it yet.
2974            Ok(0)
2975        } else {
2976            // There are other expressions that are going to expect this
2977            // expression's value to be cached, not just other `Access` or
2978            // `AccessIndex` expressions. We must actually perform the
2979            // access on the spill variable now.
2980            self.write_checked_load(
2981                access,
2982                block,
2983                AccessTypeAdjustment::IntroducePointer(spirv::StorageClass::Function),
2984                result_type_id,
2985            )
2986        }
2987    }
2988
2989    /// Build the instructions for matrix - matrix column operations
2990    #[allow(clippy::too_many_arguments)]
2991    fn write_matrix_matrix_column_op(
2992        &mut self,
2993        block: &mut Block,
2994        result_id: Word,
2995        result_type_id: Word,
2996        left_id: Word,
2997        right_id: Word,
2998        columns: crate::VectorSize,
2999        rows: crate::VectorSize,
3000        width: u8,
3001        op: spirv::Op,
3002    ) {
3003        self.temp_list.clear();
3004
3005        let vector_type_id = self.get_numeric_type_id(NumericType::Vector {
3006            size: rows,
3007            scalar: crate::Scalar::float(width),
3008        });
3009
3010        for index in 0..columns as u32 {
3011            let column_id_left = self.gen_id();
3012            let column_id_right = self.gen_id();
3013            let column_id_res = self.gen_id();
3014
3015            block.body.push(Instruction::composite_extract(
3016                vector_type_id,
3017                column_id_left,
3018                left_id,
3019                &[index],
3020            ));
3021            block.body.push(Instruction::composite_extract(
3022                vector_type_id,
3023                column_id_right,
3024                right_id,
3025                &[index],
3026            ));
3027            block.body.push(Instruction::binary(
3028                op,
3029                vector_type_id,
3030                column_id_res,
3031                column_id_left,
3032                column_id_right,
3033            ));
3034
3035            self.temp_list.push(column_id_res);
3036        }
3037
3038        block.body.push(Instruction::composite_construct(
3039            result_type_id,
3040            result_id,
3041            &self.temp_list,
3042        ));
3043    }
3044
3045    /// Build the instructions for vector - scalar multiplication
3046    fn write_vector_scalar_mult(
3047        &mut self,
3048        block: &mut Block,
3049        result_id: Word,
3050        result_type_id: Word,
3051        vector_id: Word,
3052        scalar_id: Word,
3053        vector: &crate::TypeInner,
3054    ) {
3055        let (size, kind) = match *vector {
3056            crate::TypeInner::Vector {
3057                size,
3058                scalar: crate::Scalar { kind, .. },
3059            } => (size, kind),
3060            _ => unreachable!(),
3061        };
3062
3063        let (op, operand_id) = match kind {
3064            crate::ScalarKind::Float => (spirv::Op::VectorTimesScalar, scalar_id),
3065            _ => {
3066                let operand_id = self.gen_id();
3067                self.temp_list.clear();
3068                self.temp_list.resize(size as usize, scalar_id);
3069                block.body.push(Instruction::composite_construct(
3070                    result_type_id,
3071                    operand_id,
3072                    &self.temp_list,
3073                ));
3074                (spirv::Op::IMul, operand_id)
3075            }
3076        };
3077
3078        block.body.push(Instruction::binary(
3079            op,
3080            result_type_id,
3081            result_id,
3082            vector_id,
3083            operand_id,
3084        ));
3085    }
3086
3087    /// Build the instructions for the arithmetic expression of a dot product
3088    ///
3089    /// The argument `extractor` is a function that maps `(result_id,
3090    /// composite_id, index)` to an instruction that extracts the `index`th
3091    /// entry of the value with ID `composite_id` and assigns it to the slot
3092    /// with id `result_id` (which must have type `result_type_id`).
3093    #[expect(clippy::too_many_arguments)]
3094    fn write_dot_product(
3095        &mut self,
3096        result_id: Word,
3097        result_type_id: Word,
3098        arg0_id: Word,
3099        arg1_id: Word,
3100        size: u32,
3101        block: &mut Block,
3102        extractor: impl Fn(Word, Word, Word) -> Instruction,
3103    ) {
3104        let mut partial_sum = self.writer.get_constant_null(result_type_id);
3105        let last_component = size - 1;
3106        for index in 0..=last_component {
3107            // compute the product of the current components
3108            let a_id = self.gen_id();
3109            block.body.push(extractor(a_id, arg0_id, index));
3110            let b_id = self.gen_id();
3111            block.body.push(extractor(b_id, arg1_id, index));
3112            let prod_id = self.gen_id();
3113            block.body.push(Instruction::binary(
3114                spirv::Op::IMul,
3115                result_type_id,
3116                prod_id,
3117                a_id,
3118                b_id,
3119            ));
3120
3121            // choose the id for the next sum, depending on current index
3122            let id = if index == last_component {
3123                result_id
3124            } else {
3125                self.gen_id()
3126            };
3127
3128            // sum the computed product with the partial sum
3129            block.body.push(Instruction::binary(
3130                spirv::Op::IAdd,
3131                result_type_id,
3132                id,
3133                partial_sum,
3134                prod_id,
3135            ));
3136            // set the id of the result as the previous partial sum
3137            partial_sum = id;
3138        }
3139    }
3140
3141    /// Emit code for `pack4x{I,U}8[Clamp]` if capability "Int8" is available.
3142    fn write_pack4x8_optimized(
3143        &mut self,
3144        block: &mut Block,
3145        result_type_id: u32,
3146        arg0_id: u32,
3147        id: u32,
3148        is_signed: bool,
3149        should_clamp: bool,
3150    ) -> Instruction {
3151        let int_type = if is_signed {
3152            crate::ScalarKind::Sint
3153        } else {
3154            crate::ScalarKind::Uint
3155        };
3156        let wide_vector_type = NumericType::Vector {
3157            size: crate::VectorSize::Quad,
3158            scalar: crate::Scalar {
3159                kind: int_type,
3160                width: 4,
3161            },
3162        };
3163        let wide_vector_type_id = self.get_numeric_type_id(wide_vector_type);
3164        let packed_vector_type_id = self.get_numeric_type_id(NumericType::Vector {
3165            size: crate::VectorSize::Quad,
3166            scalar: crate::Scalar {
3167                kind: crate::ScalarKind::Uint,
3168                width: 1,
3169            },
3170        });
3171
3172        let mut wide_vector = arg0_id;
3173        if should_clamp {
3174            let (min, max, clamp_op) = if is_signed {
3175                (
3176                    crate::Literal::I32(-128),
3177                    crate::Literal::I32(127),
3178                    spirv::GLOp::SClamp,
3179                )
3180            } else {
3181                (
3182                    crate::Literal::U32(0),
3183                    crate::Literal::U32(255),
3184                    spirv::GLOp::UClamp,
3185                )
3186            };
3187            let [min, max] = [min, max].map(|lit| {
3188                let scalar = self.writer.get_constant_scalar(lit);
3189                self.writer.get_constant_composite(
3190                    LookupType::Local(LocalType::Numeric(wide_vector_type)),
3191                    &[scalar; 4],
3192                )
3193            });
3194
3195            let clamp_id = self.gen_id();
3196            block.body.push(Instruction::ext_inst_gl_op(
3197                self.writer.gl450_ext_inst_id,
3198                clamp_op,
3199                wide_vector_type_id,
3200                clamp_id,
3201                &[wide_vector, min, max],
3202            ));
3203
3204            wide_vector = clamp_id;
3205        }
3206
3207        let packed_vector = self.gen_id();
3208        block.body.push(Instruction::unary(
3209            spirv::Op::UConvert, // We truncate, so `UConvert` and `SConvert` behave identically.
3210            packed_vector_type_id,
3211            packed_vector,
3212            wide_vector,
3213        ));
3214
3215        // The SPIR-V spec [1] defines the bit order for bit casting between a vector
3216        // and a scalar precisely as required by the WGSL spec [2].
3217        // [1]: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpBitcast
3218        // [2]: https://www.w3.org/TR/WGSL/#pack4xI8-builtin
3219        Instruction::unary(spirv::Op::Bitcast, result_type_id, id, packed_vector)
3220    }
3221
3222    /// Emit code for `pack4x{I,U}8[Clamp]` if capability "Int8" is not available.
3223    fn write_pack4x8_polyfill(
3224        &mut self,
3225        block: &mut Block,
3226        result_type_id: u32,
3227        arg0_id: u32,
3228        id: u32,
3229        is_signed: bool,
3230        should_clamp: bool,
3231    ) -> Instruction {
3232        let int_type = if is_signed {
3233            crate::ScalarKind::Sint
3234        } else {
3235            crate::ScalarKind::Uint
3236        };
3237        let uint_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
3238        let int_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar {
3239            kind: int_type,
3240            width: 4,
3241        }));
3242
3243        let mut last_instruction = Instruction::new(spirv::Op::Nop);
3244
3245        let zero = self.writer.get_constant_scalar(crate::Literal::U32(0));
3246        let mut preresult = zero;
3247        block
3248            .body
3249            .reserve(usize::from(VEC_LENGTH) * (2 + usize::from(is_signed)));
3250
3251        let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
3252        const VEC_LENGTH: u8 = 4;
3253        for i in 0..u32::from(VEC_LENGTH) {
3254            let offset = self.writer.get_constant_scalar(crate::Literal::U32(i * 8));
3255            let mut extracted = self.gen_id();
3256            block.body.push(Instruction::binary(
3257                spirv::Op::CompositeExtract,
3258                int_type_id,
3259                extracted,
3260                arg0_id,
3261                i,
3262            ));
3263            if is_signed {
3264                let casted = self.gen_id();
3265                block.body.push(Instruction::unary(
3266                    spirv::Op::Bitcast,
3267                    uint_type_id,
3268                    casted,
3269                    extracted,
3270                ));
3271                extracted = casted;
3272            }
3273            if should_clamp {
3274                let (min, max, clamp_op) = if is_signed {
3275                    (
3276                        crate::Literal::I32(-128),
3277                        crate::Literal::I32(127),
3278                        spirv::GLOp::SClamp,
3279                    )
3280                } else {
3281                    (
3282                        crate::Literal::U32(0),
3283                        crate::Literal::U32(255),
3284                        spirv::GLOp::UClamp,
3285                    )
3286                };
3287                let [min, max] = [min, max].map(|lit| self.writer.get_constant_scalar(lit));
3288
3289                let clamp_id = self.gen_id();
3290                block.body.push(Instruction::ext_inst_gl_op(
3291                    self.writer.gl450_ext_inst_id,
3292                    clamp_op,
3293                    result_type_id,
3294                    clamp_id,
3295                    &[extracted, min, max],
3296                ));
3297
3298                extracted = clamp_id;
3299            }
3300            let is_last = i == u32::from(VEC_LENGTH - 1);
3301            if is_last {
3302                last_instruction = Instruction::quaternary(
3303                    spirv::Op::BitFieldInsert,
3304                    result_type_id,
3305                    id,
3306                    preresult,
3307                    extracted,
3308                    offset,
3309                    eight,
3310                )
3311            } else {
3312                let new_preresult = self.gen_id();
3313                block.body.push(Instruction::quaternary(
3314                    spirv::Op::BitFieldInsert,
3315                    result_type_id,
3316                    new_preresult,
3317                    preresult,
3318                    extracted,
3319                    offset,
3320                    eight,
3321                ));
3322                preresult = new_preresult;
3323            }
3324        }
3325        last_instruction
3326    }
3327
3328    /// Emit code for `unpack4x{I,U}8` if capability "Int8" is available.
3329    fn write_unpack4x8_optimized(
3330        &mut self,
3331        block: &mut Block,
3332        result_type_id: u32,
3333        arg0_id: u32,
3334        id: u32,
3335        is_signed: bool,
3336    ) -> Instruction {
3337        let (int_type, convert_op) = if is_signed {
3338            (crate::ScalarKind::Sint, spirv::Op::SConvert)
3339        } else {
3340            (crate::ScalarKind::Uint, spirv::Op::UConvert)
3341        };
3342
3343        let packed_vector_type_id = self.get_numeric_type_id(NumericType::Vector {
3344            size: crate::VectorSize::Quad,
3345            scalar: crate::Scalar {
3346                kind: int_type,
3347                width: 1,
3348            },
3349        });
3350
3351        // The SPIR-V spec [1] defines the bit order for bit casting between a vector
3352        // and a scalar precisely as required by the WGSL spec [2].
3353        // [1]: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpBitcast
3354        // [2]: https://www.w3.org/TR/WGSL/#pack4xI8-builtin
3355        let packed_vector = self.gen_id();
3356        block.body.push(Instruction::unary(
3357            spirv::Op::Bitcast,
3358            packed_vector_type_id,
3359            packed_vector,
3360            arg0_id,
3361        ));
3362
3363        Instruction::unary(convert_op, result_type_id, id, packed_vector)
3364    }
3365
3366    /// Emit code for `unpack4x{I,U}8` if capability "Int8" is not available.
3367    fn write_unpack4x8_polyfill(
3368        &mut self,
3369        block: &mut Block,
3370        result_type_id: u32,
3371        arg0_id: u32,
3372        id: u32,
3373        is_signed: bool,
3374    ) -> Instruction {
3375        let (int_type, extract_op) = if is_signed {
3376            (crate::ScalarKind::Sint, spirv::Op::BitFieldSExtract)
3377        } else {
3378            (crate::ScalarKind::Uint, spirv::Op::BitFieldUExtract)
3379        };
3380
3381        let sint_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::I32));
3382
3383        let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
3384        let int_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar {
3385            kind: int_type,
3386            width: 4,
3387        }));
3388        block
3389            .body
3390            .reserve(usize::from(VEC_LENGTH) * 2 + usize::from(is_signed));
3391        let arg_id = if is_signed {
3392            let new_arg_id = self.gen_id();
3393            block.body.push(Instruction::unary(
3394                spirv::Op::Bitcast,
3395                sint_type_id,
3396                new_arg_id,
3397                arg0_id,
3398            ));
3399            new_arg_id
3400        } else {
3401            arg0_id
3402        };
3403
3404        const VEC_LENGTH: u8 = 4;
3405        let parts: [_; VEC_LENGTH as usize] = core::array::from_fn(|_| self.gen_id());
3406        for (i, part_id) in parts.into_iter().enumerate() {
3407            let index = self
3408                .writer
3409                .get_constant_scalar(crate::Literal::U32(i as u32 * 8));
3410            block.body.push(Instruction::ternary(
3411                extract_op,
3412                int_type_id,
3413                part_id,
3414                arg_id,
3415                index,
3416                eight,
3417            ));
3418        }
3419
3420        Instruction::composite_construct(result_type_id, id, &parts)
3421    }
3422
3423    /// Generate one or more SPIR-V blocks for `naga_block`.
3424    ///
3425    /// Use `label_id` as the label for the SPIR-V entry point block.
3426    ///
3427    /// If control reaches the end of the SPIR-V block, terminate it according
3428    /// to `exit`. This function's return value indicates whether it acted on
3429    /// this parameter or not; see [`BlockExitDisposition`].
3430    ///
3431    /// If the block contains [`Break`] or [`Continue`] statements,
3432    /// `loop_context` supplies the labels of the SPIR-V blocks to jump to. If
3433    /// either of these labels are `None`, then it should have been a Naga
3434    /// validation error for the corresponding statement to occur in this
3435    /// context.
3436    ///
3437    /// [`Break`]: Statement::Break
3438    /// [`Continue`]: Statement::Continue
3439    fn write_block(
3440        &mut self,
3441        label_id: Word,
3442        naga_block: &crate::Block,
3443        exit: BlockExit,
3444        loop_context: LoopContext,
3445        debug_info: Option<&DebugInfoInner>,
3446    ) -> Result<BlockExitDisposition, Error> {
3447        let mut block = Block::new(label_id);
3448        for (statement, span) in naga_block.span_iter() {
3449            if let (Some(debug_info), false) = (
3450                debug_info,
3451                matches!(
3452                    statement,
3453                    &(Statement::Block(..)
3454                        | Statement::Break
3455                        | Statement::Continue
3456                        | Statement::Kill
3457                        | Statement::Return { .. }
3458                        | Statement::Loop { .. })
3459                ),
3460            ) {
3461                let loc: crate::SourceLocation = span.location(debug_info.source_code);
3462                block.body.push(Instruction::line(
3463                    debug_info.source_file_id,
3464                    loc.line_number,
3465                    loc.line_position,
3466                ));
3467            };
3468            match *statement {
3469                Statement::Emit(ref range) => {
3470                    for handle in range.clone() {
3471                        // omit const expressions as we've already cached those
3472                        if !self.expression_constness.is_const(handle) {
3473                            self.cache_expression_value(handle, &mut block)?;
3474                        }
3475                    }
3476                }
3477                Statement::Block(ref block_statements) => {
3478                    let scope_id = self.gen_id();
3479                    self.function.consume(block, Instruction::branch(scope_id));
3480
3481                    let merge_id = self.gen_id();
3482                    let merge_used = self.write_block(
3483                        scope_id,
3484                        block_statements,
3485                        BlockExit::Branch { target: merge_id },
3486                        loop_context,
3487                        debug_info,
3488                    )?;
3489
3490                    match merge_used {
3491                        BlockExitDisposition::Used => {
3492                            block = Block::new(merge_id);
3493                        }
3494                        BlockExitDisposition::Discarded => {
3495                            return Ok(BlockExitDisposition::Discarded);
3496                        }
3497                    }
3498                }
3499                Statement::If {
3500                    condition,
3501                    ref accept,
3502                    ref reject,
3503                } => {
3504                    // In spirv 1.6, in a conditional branch the two block ids
3505                    // of the branches can't have the same label. If `accept`
3506                    // and `reject` are both empty (e.g. in `if (condition) {}`)
3507                    // merge id will be both labels. Because both branches are
3508                    // empty, we can skip the if statement.
3509                    if !(accept.is_empty() && reject.is_empty()) {
3510                        let condition_id = self.cached[condition];
3511
3512                        let merge_id = self.gen_id();
3513                        block.body.push(Instruction::selection_merge(
3514                            merge_id,
3515                            spirv::SelectionControl::NONE,
3516                        ));
3517
3518                        let accept_id = if accept.is_empty() {
3519                            None
3520                        } else {
3521                            Some(self.gen_id())
3522                        };
3523                        let reject_id = if reject.is_empty() {
3524                            None
3525                        } else {
3526                            Some(self.gen_id())
3527                        };
3528
3529                        self.function.consume(
3530                            block,
3531                            Instruction::branch_conditional(
3532                                condition_id,
3533                                accept_id.unwrap_or(merge_id),
3534                                reject_id.unwrap_or(merge_id),
3535                            ),
3536                        );
3537
3538                        if let Some(block_id) = accept_id {
3539                            // We can ignore the `BlockExitDisposition` returned here because,
3540                            // even if `merge_id` is not actually reachable, it is always
3541                            // referred to by the `OpSelectionMerge` instruction we emitted
3542                            // earlier.
3543                            let _ = self.write_block(
3544                                block_id,
3545                                accept,
3546                                BlockExit::Branch { target: merge_id },
3547                                loop_context,
3548                                debug_info,
3549                            )?;
3550                        }
3551                        if let Some(block_id) = reject_id {
3552                            // We can ignore the `BlockExitDisposition` returned here because,
3553                            // even if `merge_id` is not actually reachable, it is always
3554                            // referred to by the `OpSelectionMerge` instruction we emitted
3555                            // earlier.
3556                            let _ = self.write_block(
3557                                block_id,
3558                                reject,
3559                                BlockExit::Branch { target: merge_id },
3560                                loop_context,
3561                                debug_info,
3562                            )?;
3563                        }
3564
3565                        block = Block::new(merge_id);
3566                    }
3567                }
3568                Statement::Switch {
3569                    selector,
3570                    ref cases,
3571                } => {
3572                    let selector_id = self.cached[selector];
3573
3574                    let merge_id = self.gen_id();
3575                    block.body.push(Instruction::selection_merge(
3576                        merge_id,
3577                        spirv::SelectionControl::NONE,
3578                    ));
3579
3580                    let mut default_id = None;
3581                    // id of previous empty fall-through case
3582                    let mut last_id = None;
3583
3584                    let mut raw_cases = Vec::with_capacity(cases.len());
3585                    let mut case_ids = Vec::with_capacity(cases.len());
3586                    for case in cases.iter() {
3587                        // take id of previous empty fall-through case or generate a new one
3588                        let label_id = last_id.take().unwrap_or_else(|| self.gen_id());
3589
3590                        if case.fall_through && case.body.is_empty() {
3591                            last_id = Some(label_id);
3592                        }
3593
3594                        case_ids.push(label_id);
3595
3596                        match case.value {
3597                            crate::SwitchValue::I32(value) => {
3598                                raw_cases.push(super::instructions::Case {
3599                                    value: value as Word,
3600                                    label_id,
3601                                });
3602                            }
3603                            crate::SwitchValue::U32(value) => {
3604                                raw_cases.push(super::instructions::Case { value, label_id });
3605                            }
3606                            crate::SwitchValue::Default => {
3607                                default_id = Some(label_id);
3608                            }
3609                        }
3610                    }
3611
3612                    let default_id = default_id.unwrap();
3613
3614                    self.function.consume(
3615                        block,
3616                        Instruction::switch(selector_id, default_id, &raw_cases),
3617                    );
3618
3619                    let inner_context = LoopContext {
3620                        break_id: Some(merge_id),
3621                        ..loop_context
3622                    };
3623
3624                    for (i, (case, label_id)) in cases
3625                        .iter()
3626                        .zip(case_ids.iter())
3627                        .filter(|&(case, _)| !(case.fall_through && case.body.is_empty()))
3628                        .enumerate()
3629                    {
3630                        let case_finish_id = if case.fall_through {
3631                            case_ids[i + 1]
3632                        } else {
3633                            merge_id
3634                        };
3635                        // We can ignore the `BlockExitDisposition` returned here because
3636                        // `case_finish_id` is always referred to by either:
3637                        //
3638                        // - the `OpSwitch`, if it's the next case's label for a
3639                        //   fall-through, or
3640                        //
3641                        // - the `OpSelectionMerge`, if it's the switch's overall merge
3642                        //   block because there's no fall-through.
3643                        let _ = self.write_block(
3644                            *label_id,
3645                            &case.body,
3646                            BlockExit::Branch {
3647                                target: case_finish_id,
3648                            },
3649                            inner_context,
3650                            debug_info,
3651                        )?;
3652                    }
3653
3654                    block = Block::new(merge_id);
3655                }
3656                Statement::Loop {
3657                    ref body,
3658                    ref continuing,
3659                    break_if,
3660                } => {
3661                    let preamble_id = self.gen_id();
3662                    self.function
3663                        .consume(block, Instruction::branch(preamble_id));
3664
3665                    let merge_id = self.gen_id();
3666                    let body_id = self.gen_id();
3667                    let continuing_id = self.gen_id();
3668
3669                    // SPIR-V requires the continuing to the `OpLoopMerge`,
3670                    // so we have to start a new block with it.
3671                    block = Block::new(preamble_id);
3672                    // HACK the loop statement is begin with branch instruction,
3673                    // so we need to put `OpLine` debug info before merge instruction
3674                    if let Some(debug_info) = debug_info {
3675                        let loc: crate::SourceLocation = span.location(debug_info.source_code);
3676                        block.body.push(Instruction::line(
3677                            debug_info.source_file_id,
3678                            loc.line_number,
3679                            loc.line_position,
3680                        ))
3681                    }
3682                    block.body.push(Instruction::loop_merge(
3683                        merge_id,
3684                        continuing_id,
3685                        spirv::SelectionControl::NONE,
3686                    ));
3687
3688                    if self.force_loop_bounding {
3689                        block = self.write_force_bounded_loop_instructions(block, merge_id);
3690                    }
3691                    self.function.consume(block, Instruction::branch(body_id));
3692
3693                    // We can ignore the `BlockExitDisposition` returned here because,
3694                    // even if `continuing_id` is not actually reachable, it is always
3695                    // referred to by the `OpLoopMerge` instruction we emitted earlier.
3696                    let _ = self.write_block(
3697                        body_id,
3698                        body,
3699                        BlockExit::Branch {
3700                            target: continuing_id,
3701                        },
3702                        LoopContext {
3703                            continuing_id: Some(continuing_id),
3704                            break_id: Some(merge_id),
3705                        },
3706                        debug_info,
3707                    )?;
3708
3709                    let exit = match break_if {
3710                        Some(condition) => BlockExit::BreakIf {
3711                            condition,
3712                            preamble_id,
3713                        },
3714                        None => BlockExit::Branch {
3715                            target: preamble_id,
3716                        },
3717                    };
3718
3719                    // We can ignore the `BlockExitDisposition` returned here because,
3720                    // even if `merge_id` is not actually reachable, it is always referred
3721                    // to by the `OpLoopMerge` instruction we emitted earlier.
3722                    let _ = self.write_block(
3723                        continuing_id,
3724                        continuing,
3725                        exit,
3726                        LoopContext {
3727                            continuing_id: None,
3728                            break_id: Some(merge_id),
3729                        },
3730                        debug_info,
3731                    )?;
3732
3733                    block = Block::new(merge_id);
3734                }
3735                Statement::Break => {
3736                    self.function
3737                        .consume(block, Instruction::branch(loop_context.break_id.unwrap()));
3738                    return Ok(BlockExitDisposition::Discarded);
3739                }
3740                Statement::Continue => {
3741                    self.function.consume(
3742                        block,
3743                        Instruction::branch(loop_context.continuing_id.unwrap()),
3744                    );
3745                    return Ok(BlockExitDisposition::Discarded);
3746                }
3747                Statement::Return { value: Some(value) } => {
3748                    let value_id = self.cached[value];
3749                    let instruction = match self.function.entry_point_context {
3750                        // If this is an entry point, and we need to return anything,
3751                        // let's instead store the output variables and return `void`.
3752                        Some(ref context) => self.writer.write_entry_point_return(
3753                            value_id,
3754                            self.ir_function.result.as_ref().unwrap(),
3755                            &context.results,
3756                            &mut block.body,
3757                            context.task_payload_variable_id,
3758                        )?,
3759                        None => Instruction::return_value(value_id),
3760                    };
3761                    self.function.consume(block, instruction);
3762                    return Ok(BlockExitDisposition::Discarded);
3763                }
3764                Statement::Return { value: None } => {
3765                    if let Some(super::EntryPointContext {
3766                        mesh_state: Some(ref mesh_state),
3767                        ..
3768                    }) = self.function.entry_point_context
3769                    {
3770                        self.function.consume(
3771                            block,
3772                            Instruction::branch(mesh_state.entry_point_epilogue_id),
3773                        );
3774                    } else {
3775                        self.function.consume(block, Instruction::return_void());
3776                    }
3777                    return Ok(BlockExitDisposition::Discarded);
3778                }
3779                Statement::Kill => {
3780                    self.function.consume(block, Instruction::kill());
3781                    return Ok(BlockExitDisposition::Discarded);
3782                }
3783                Statement::ControlBarrier(flags) => {
3784                    self.writer.write_control_barrier(flags, &mut block.body);
3785                }
3786                Statement::MemoryBarrier(flags) => {
3787                    self.writer.write_memory_barrier(flags, &mut block);
3788                }
3789                Statement::Store { pointer, value } => {
3790                    let value_id = self.cached[value];
3791                    match self.write_access_chain(
3792                        pointer,
3793                        &mut block,
3794                        AccessTypeAdjustment::None,
3795                    )? {
3796                        ExpressionPointer::Ready { pointer_id } => {
3797                            let atomic_space = match *self.fun_info[pointer]
3798                                .ty
3799                                .inner_with(&self.ir_module.types)
3800                            {
3801                                crate::TypeInner::Pointer { base, space } => {
3802                                    match self.ir_module.types[base].inner {
3803                                        crate::TypeInner::Atomic { .. } => Some(space),
3804                                        _ => None,
3805                                    }
3806                                }
3807                                _ => None,
3808                            };
3809                            let instruction = if let Some(space) = atomic_space {
3810                                let (semantics, scope) = space.to_spirv_semantics_and_scope();
3811                                let scope_constant_id = self.get_scope_constant(scope as u32);
3812                                let semantics_id = self.get_index_constant(semantics.bits());
3813                                Instruction::atomic_store(
3814                                    pointer_id,
3815                                    scope_constant_id,
3816                                    semantics_id,
3817                                    value_id,
3818                                )
3819                            } else {
3820                                Instruction::store(pointer_id, value_id, None)
3821                            };
3822                            block.body.push(instruction);
3823                        }
3824                        ExpressionPointer::Conditional { condition, access } => {
3825                            let mut selection = Selection::start(&mut block, ());
3826                            selection.if_true(self, condition, ());
3827
3828                            // The in-bounds path. Perform the access and the store.
3829                            let pointer_id = access.result_id.unwrap();
3830                            selection.block().body.push(access);
3831                            selection
3832                                .block()
3833                                .body
3834                                .push(Instruction::store(pointer_id, value_id, None));
3835
3836                            // Finish the in-bounds block and start the merge block. This
3837                            // is the block we'll leave current on return.
3838                            selection.finish(self, ());
3839                        }
3840                    };
3841                }
3842                Statement::ImageStore {
3843                    image,
3844                    coordinate,
3845                    array_index,
3846                    value,
3847                } => self.write_image_store(image, coordinate, array_index, value, &mut block)?,
3848                Statement::Call {
3849                    function: local_function,
3850                    ref arguments,
3851                    result,
3852                } => {
3853                    let id = self.gen_id();
3854                    self.temp_list.clear();
3855                    for &argument in arguments {
3856                        self.temp_list.push(self.cached[argument]);
3857                    }
3858
3859                    let type_id = match result {
3860                        Some(expr) => {
3861                            self.cached[expr] = id;
3862                            self.get_expression_type_id(&self.fun_info[expr].ty)
3863                        }
3864                        None => self.writer.void_type,
3865                    };
3866
3867                    block.body.push(Instruction::function_call(
3868                        type_id,
3869                        id,
3870                        self.writer.lookup_function[&local_function],
3871                        &self.temp_list,
3872                    ));
3873                }
3874                Statement::Atomic {
3875                    pointer,
3876                    ref fun,
3877                    value,
3878                    result,
3879                } => {
3880                    let id = self.gen_id();
3881                    // Compare-and-exchange operations produce a struct result,
3882                    // so use `result`'s type if it is available. For no-result
3883                    // operations, fall back to `value`'s type.
3884                    let result_type_id =
3885                        self.get_expression_type_id(&self.fun_info[result.unwrap_or(value)].ty);
3886
3887                    if let Some(result) = result {
3888                        self.cached[result] = id;
3889                    }
3890
3891                    let pointer_id = match self.write_access_chain(
3892                        pointer,
3893                        &mut block,
3894                        AccessTypeAdjustment::None,
3895                    )? {
3896                        ExpressionPointer::Ready { pointer_id } => pointer_id,
3897                        ExpressionPointer::Conditional { .. } => {
3898                            return Err(Error::FeatureNotImplemented(
3899                                "Atomics out-of-bounds handling",
3900                            ));
3901                        }
3902                    };
3903
3904                    let space = self.fun_info[pointer]
3905                        .ty
3906                        .inner_with(&self.ir_module.types)
3907                        .pointer_space()
3908                        .unwrap();
3909                    let (semantics, scope) = space.to_spirv_semantics_and_scope();
3910                    let scope_constant_id = self.get_scope_constant(scope as u32);
3911                    let semantics_id = self.get_index_constant(semantics.bits());
3912                    let value_id = self.cached[value];
3913                    let value_inner = self.fun_info[value].ty.inner_with(&self.ir_module.types);
3914
3915                    let crate::TypeInner::Scalar(scalar) = *value_inner else {
3916                        return Err(Error::FeatureNotImplemented(
3917                            "Atomics with non-scalar values",
3918                        ));
3919                    };
3920
3921                    let instruction = match *fun {
3922                        crate::AtomicFunction::Add => {
3923                            let spirv_op = match scalar.kind {
3924                                crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3925                                    spirv::Op::AtomicIAdd
3926                                }
3927                                crate::ScalarKind::Float => spirv::Op::AtomicFAddEXT,
3928                                _ => unimplemented!(),
3929                            };
3930                            Instruction::atomic_binary(
3931                                spirv_op,
3932                                result_type_id,
3933                                id,
3934                                pointer_id,
3935                                scope_constant_id,
3936                                semantics_id,
3937                                value_id,
3938                            )
3939                        }
3940                        crate::AtomicFunction::Subtract => {
3941                            let (spirv_op, value_id) = match scalar.kind {
3942                                crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3943                                    (spirv::Op::AtomicISub, value_id)
3944                                }
3945                                crate::ScalarKind::Float => {
3946                                    // HACK: SPIR-V doesn't have a atomic subtraction,
3947                                    // so we add the negated value instead.
3948                                    let neg_result_id = self.gen_id();
3949                                    block.body.push(Instruction::unary(
3950                                        spirv::Op::FNegate,
3951                                        result_type_id,
3952                                        neg_result_id,
3953                                        value_id,
3954                                    ));
3955                                    (spirv::Op::AtomicFAddEXT, neg_result_id)
3956                                }
3957                                _ => unimplemented!(),
3958                            };
3959                            Instruction::atomic_binary(
3960                                spirv_op,
3961                                result_type_id,
3962                                id,
3963                                pointer_id,
3964                                scope_constant_id,
3965                                semantics_id,
3966                                value_id,
3967                            )
3968                        }
3969                        crate::AtomicFunction::And => {
3970                            let spirv_op = match scalar.kind {
3971                                crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3972                                    spirv::Op::AtomicAnd
3973                                }
3974                                _ => unimplemented!(),
3975                            };
3976                            Instruction::atomic_binary(
3977                                spirv_op,
3978                                result_type_id,
3979                                id,
3980                                pointer_id,
3981                                scope_constant_id,
3982                                semantics_id,
3983                                value_id,
3984                            )
3985                        }
3986                        crate::AtomicFunction::InclusiveOr => {
3987                            let spirv_op = match scalar.kind {
3988                                crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3989                                    spirv::Op::AtomicOr
3990                                }
3991                                _ => unimplemented!(),
3992                            };
3993                            Instruction::atomic_binary(
3994                                spirv_op,
3995                                result_type_id,
3996                                id,
3997                                pointer_id,
3998                                scope_constant_id,
3999                                semantics_id,
4000                                value_id,
4001                            )
4002                        }
4003                        crate::AtomicFunction::ExclusiveOr => {
4004                            let spirv_op = match scalar.kind {
4005                                crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
4006                                    spirv::Op::AtomicXor
4007                                }
4008                                _ => unimplemented!(),
4009                            };
4010                            Instruction::atomic_binary(
4011                                spirv_op,
4012                                result_type_id,
4013                                id,
4014                                pointer_id,
4015                                scope_constant_id,
4016                                semantics_id,
4017                                value_id,
4018                            )
4019                        }
4020                        crate::AtomicFunction::Min => {
4021                            let spirv_op = match scalar.kind {
4022                                crate::ScalarKind::Sint => spirv::Op::AtomicSMin,
4023                                crate::ScalarKind::Uint => spirv::Op::AtomicUMin,
4024                                _ => unimplemented!(),
4025                            };
4026                            Instruction::atomic_binary(
4027                                spirv_op,
4028                                result_type_id,
4029                                id,
4030                                pointer_id,
4031                                scope_constant_id,
4032                                semantics_id,
4033                                value_id,
4034                            )
4035                        }
4036                        crate::AtomicFunction::Max => {
4037                            let spirv_op = match scalar.kind {
4038                                crate::ScalarKind::Sint => spirv::Op::AtomicSMax,
4039                                crate::ScalarKind::Uint => spirv::Op::AtomicUMax,
4040                                _ => unimplemented!(),
4041                            };
4042                            Instruction::atomic_binary(
4043                                spirv_op,
4044                                result_type_id,
4045                                id,
4046                                pointer_id,
4047                                scope_constant_id,
4048                                semantics_id,
4049                                value_id,
4050                            )
4051                        }
4052                        crate::AtomicFunction::Exchange { compare: None } => {
4053                            Instruction::atomic_binary(
4054                                spirv::Op::AtomicExchange,
4055                                result_type_id,
4056                                id,
4057                                pointer_id,
4058                                scope_constant_id,
4059                                semantics_id,
4060                                value_id,
4061                            )
4062                        }
4063                        crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
4064                            let scalar_type_id =
4065                                self.get_numeric_type_id(NumericType::Scalar(scalar));
4066                            let bool_type_id =
4067                                self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::BOOL));
4068
4069                            let cas_result_id = self.gen_id();
4070                            let equality_result_id = self.gen_id();
4071                            let equality_operator = match scalar.kind {
4072                                crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
4073                                    spirv::Op::IEqual
4074                                }
4075                                _ => unimplemented!(),
4076                            };
4077
4078                            let mut cas_instr = Instruction::new(spirv::Op::AtomicCompareExchange);
4079                            cas_instr.set_type(scalar_type_id);
4080                            cas_instr.set_result(cas_result_id);
4081                            cas_instr.add_operand(pointer_id);
4082                            cas_instr.add_operand(scope_constant_id);
4083                            cas_instr.add_operand(semantics_id); // semantics if equal
4084                            cas_instr.add_operand(semantics_id); // semantics if not equal
4085                            cas_instr.add_operand(value_id);
4086                            cas_instr.add_operand(self.cached[cmp]);
4087                            block.body.push(cas_instr);
4088                            block.body.push(Instruction::binary(
4089                                equality_operator,
4090                                bool_type_id,
4091                                equality_result_id,
4092                                cas_result_id,
4093                                self.cached[cmp],
4094                            ));
4095                            Instruction::composite_construct(
4096                                result_type_id,
4097                                id,
4098                                &[cas_result_id, equality_result_id],
4099                            )
4100                        }
4101                    };
4102
4103                    block.body.push(instruction);
4104                }
4105                Statement::ImageAtomic {
4106                    image,
4107                    coordinate,
4108                    array_index,
4109                    fun,
4110                    value,
4111                } => {
4112                    self.write_image_atomic(
4113                        image,
4114                        coordinate,
4115                        array_index,
4116                        fun,
4117                        value,
4118                        &mut block,
4119                    )?;
4120                }
4121                Statement::WorkGroupUniformLoad { pointer, result } => {
4122                    self.writer
4123                        .write_control_barrier(crate::Barrier::WORK_GROUP, &mut block.body);
4124                    let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty);
4125                    // Match `Expression::Load` behavior, including `OpAtomicLoad` when
4126                    // loading from a pointer to `atomic<T>`.
4127                    let id = self.write_checked_load(
4128                        pointer,
4129                        &mut block,
4130                        AccessTypeAdjustment::None,
4131                        result_type_id,
4132                    )?;
4133                    self.cached[result] = id;
4134                    self.writer
4135                        .write_control_barrier(crate::Barrier::WORK_GROUP, &mut block.body);
4136                }
4137                Statement::RayQuery { query, ref fun } => {
4138                    self.write_ray_query_function(query, fun, &mut block);
4139                }
4140                Statement::SubgroupBallot {
4141                    result,
4142                    ref predicate,
4143                } => {
4144                    self.write_subgroup_ballot(predicate, result, &mut block)?;
4145                }
4146                Statement::SubgroupCollectiveOperation {
4147                    ref op,
4148                    ref collective_op,
4149                    argument,
4150                    result,
4151                } => {
4152                    self.write_subgroup_operation(op, collective_op, argument, result, &mut block)?;
4153                }
4154                Statement::SubgroupGather {
4155                    ref mode,
4156                    argument,
4157                    result,
4158                } => {
4159                    self.write_subgroup_gather(mode, argument, result, &mut block)?;
4160                }
4161                Statement::CooperativeStore { target, ref data } => {
4162                    let target_id = self.cached[target];
4163                    let layout = if data.row_major {
4164                        spirv::CooperativeMatrixLayout::RowMajorKHR
4165                    } else {
4166                        spirv::CooperativeMatrixLayout::ColumnMajorKHR
4167                    };
4168                    let layout_id = self.get_index_constant(layout as u32);
4169                    let stride_id = self.cached[data.stride];
4170                    match self.write_access_chain(
4171                        data.pointer,
4172                        &mut block,
4173                        AccessTypeAdjustment::None,
4174                    )? {
4175                        ExpressionPointer::Ready { pointer_id } => {
4176                            block.body.push(Instruction::coop_store(
4177                                target_id, pointer_id, layout_id, stride_id,
4178                            ));
4179                        }
4180                        ExpressionPointer::Conditional { condition, access } => {
4181                            let mut selection = Selection::start(&mut block, ());
4182                            selection.if_true(self, condition, ());
4183
4184                            // The in-bounds path. Perform the access and the store.
4185                            let pointer_id = access.result_id.unwrap();
4186                            selection.block().body.push(access);
4187                            selection.block().body.push(Instruction::coop_store(
4188                                target_id, pointer_id, layout_id, stride_id,
4189                            ));
4190
4191                            // Finish the in-bounds block and start the merge block. This
4192                            // is the block we'll leave current on return.
4193                            selection.finish(self, ());
4194                        }
4195                    };
4196                }
4197            }
4198        }
4199
4200        let termination = match exit {
4201            // We're generating code for the top-level Block of the function, so we
4202            // need to end it with some kind of return instruction.
4203            BlockExit::Return => match self.ir_function.result {
4204                Some(ref result) if self.function.entry_point_context.is_none() => {
4205                    let type_id = self.get_handle_type_id(result.ty);
4206                    let null_id = self.writer.get_constant_null(type_id);
4207                    Instruction::return_value(null_id)
4208                }
4209                _ => Instruction::return_void(),
4210            },
4211            BlockExit::Branch { target } => Instruction::branch(target),
4212            BlockExit::BreakIf {
4213                condition,
4214                preamble_id,
4215            } => {
4216                let condition_id = self.cached[condition];
4217
4218                Instruction::branch_conditional(
4219                    condition_id,
4220                    loop_context.break_id.unwrap(),
4221                    preamble_id,
4222                )
4223            }
4224        };
4225
4226        self.function.consume(block, termination);
4227        Ok(BlockExitDisposition::Used)
4228    }
4229
4230    pub(super) fn write_function_body(
4231        &mut self,
4232        entry_id: Word,
4233        debug_info: Option<&DebugInfoInner>,
4234    ) -> Result<(), Error> {
4235        // We can ignore the `BlockExitDisposition` returned here because
4236        // `BlockExit::Return` doesn't refer to a block.
4237        let _ = self.write_block(
4238            entry_id,
4239            &self.ir_function.body,
4240            BlockExit::Return,
4241            LoopContext::default(),
4242            debug_info,
4243        )?;
4244        if let Some(super::EntryPointContext {
4245            mesh_state: Some(ref mesh_state),
4246            ..
4247        }) = self.function.entry_point_context
4248        {
4249            let mut block = Block::new(mesh_state.entry_point_epilogue_id);
4250            self.writer
4251                .write_mesh_shader_return(mesh_state, &mut block)?;
4252            self.function.consume(block, Instruction::return_void());
4253        }
4254
4255        Ok(())
4256    }
4257}