naga/back/spv/
block.rs

1/*!
2Implementations for `BlockContext` methods.
3*/
4
5use alloc::vec::Vec;
6
7use arrayvec::ArrayVec;
8use spirv::Word;
9
10use super::{
11    index::BoundsCheckResult, selection::Selection, Block, BlockContext, Dimension, Error,
12    Instruction, LocalType, LookupType, NumericType, ResultMember, WrappedFunction, Writer,
13    WriterFlags,
14};
15use crate::{arena::Handle, proc::index::GuardedIndex, Statement};
16
17fn get_dimension(type_inner: &crate::TypeInner) -> Dimension {
18    match *type_inner {
19        crate::TypeInner::Scalar(_) => Dimension::Scalar,
20        crate::TypeInner::Vector { .. } => Dimension::Vector,
21        crate::TypeInner::Matrix { .. } => Dimension::Matrix,
22        _ => unreachable!(),
23    }
24}
25
26/// How to derive the type of `OpAccessChain` instructions from Naga IR.
27///
28/// Most of the time, we compile Naga IR to SPIR-V instructions whose result
29/// types are simply the direct SPIR-V analog of the Naga IR's. But in some
30/// cases, the Naga IR and SPIR-V types need to diverge.
31///
32/// This enum specifies how [`BlockContext::write_access_chain`] should
33/// choose a SPIR-V result type for the `OpAccessChain` it generates, based on
34/// the type of the given Naga IR [`Expression`] it's generating code for.
35///
36/// [`Expression`]: crate::Expression
37enum AccessTypeAdjustment {
38    /// No adjustment needed: the SPIR-V type should be the direct
39    /// analog of the Naga IR expression type.
40    ///
41    /// For most access chains, this is the right thing: the Naga IR access
42    /// expression produces a [`Pointer`] to the element / component, and the
43    /// SPIR-V `OpAccessChain` instruction does the same.
44    ///
45    /// [`Pointer`]: crate::TypeInner::Pointer
46    None,
47
48    /// The SPIR-V type should be an `OpPointer` to the direct analog of the
49    /// Naga IR expression's type.
50    ///
51    /// This is necessary for indexing binding arrays in the [`Handle`] address
52    /// space:
53    ///
54    /// - In Naga IR, referencing a binding array [`GlobalVariable`] in the
55    ///   [`Handle`] address space produces a value of type [`BindingArray`],
56    ///   not a pointer to such. And [`Access`] and [`AccessIndex`] expressions
57    ///   operate on handle binding arrays by value, and produce handle values,
58    ///   not pointers.
59    ///
60    /// - In SPIR-V, a binding array `OpVariable` produces a pointer to an
61    ///   array, and `OpAccessChain` instructions operate on pointers,
62    ///   regardless of whether the elements are opaque types or not.
63    ///
64    /// See also the documentation for [`BindingArray`].
65    ///
66    /// [`Handle`]: crate::AddressSpace::Handle
67    /// [`GlobalVariable`]: crate::GlobalVariable
68    /// [`BindingArray`]: crate::TypeInner::BindingArray
69    /// [`Access`]: crate::Expression::Access
70    /// [`AccessIndex`]: crate::Expression::AccessIndex
71    IntroducePointer(spirv::StorageClass),
72}
73
74/// The results of emitting code for a left-hand-side expression.
75///
76/// On success, `write_access_chain` returns one of these.
77enum ExpressionPointer {
78    /// The pointer to the expression's value is available, as the value of the
79    /// expression with the given id.
80    Ready { pointer_id: Word },
81
82    /// The access expression must be conditional on the value of `condition`, a boolean
83    /// expression that is true if all indices are in bounds. If `condition` is true, then
84    /// `access` is an `OpAccessChain` instruction that will compute a pointer to the
85    /// expression's value. If `condition` is false, then executing `access` would be
86    /// undefined behavior.
87    Conditional {
88        condition: Word,
89        access: Instruction,
90    },
91}
92
93/// The termination statement to be added to the end of the block
94enum BlockExit {
95    /// Generates an OpReturn (void return)
96    Return,
97    /// Generates an OpBranch to the specified block
98    Branch {
99        /// The branch target block
100        target: Word,
101    },
102    /// Translates a loop `break if` into an `OpBranchConditional` to the
103    /// merge block if true (the merge block is passed through [`LoopContext::break_id`]
104    /// or else to the loop header (passed through [`preamble_id`])
105    ///
106    /// [`preamble_id`]: Self::BreakIf::preamble_id
107    BreakIf {
108        /// The condition of the `break if`
109        condition: Handle<crate::Expression>,
110        /// The loop header block id
111        preamble_id: Word,
112    },
113}
114
115/// What code generation did with a provided [`BlockExit`] value.
116///
117/// A function that accepts a [`BlockExit`] argument should return a value of
118/// this type, to indicate whether the code it generated ended up using the
119/// provided exit, or ignored it and did a non-local exit of some other kind
120/// (say, [`Break`] or [`Continue`]). Some callers must use this information to
121/// decide whether to generate the target block at all.
122///
123/// [`Break`]: Statement::Break
124/// [`Continue`]: Statement::Continue
125#[must_use]
126enum BlockExitDisposition {
127    /// The generated code used the provided `BlockExit` value. If it included a
128    /// block label, the caller should be sure to actually emit the block it
129    /// refers to.
130    Used,
131
132    /// The generated code did not use the provided `BlockExit` value. If it
133    /// included a block label, the caller should not bother to actually emit
134    /// the block it refers to, unless it knows the block is needed for
135    /// something else.
136    Discarded,
137}
138
139#[derive(Clone, Copy, Default)]
140struct LoopContext {
141    continuing_id: Option<Word>,
142    break_id: Option<Word>,
143}
144
145#[derive(Debug)]
146pub(crate) struct DebugInfoInner<'a> {
147    pub source_code: &'a str,
148    pub source_file_id: Word,
149}
150
151impl Writer {
152    // Flip Y coordinate to adjust for coordinate space difference
153    // between SPIR-V and our IR.
154    // The `position_id` argument is a pointer to a `vecN<f32>`,
155    // whose `y` component we will negate.
156    fn write_epilogue_position_y_flip(
157        &mut self,
158        position_id: Word,
159        body: &mut Vec<Instruction>,
160    ) -> Result<(), Error> {
161        let float_ptr_type_id = self.get_f32_pointer_type_id(spirv::StorageClass::Output);
162        let index_y_id = self.get_index_constant(1);
163        let access_id = self.id_gen.next();
164        body.push(Instruction::access_chain(
165            float_ptr_type_id,
166            access_id,
167            position_id,
168            &[index_y_id],
169        ));
170
171        let float_type_id = self.get_f32_type_id();
172        let load_id = self.id_gen.next();
173        body.push(Instruction::load(float_type_id, load_id, access_id, None));
174
175        let neg_id = self.id_gen.next();
176        body.push(Instruction::unary(
177            spirv::Op::FNegate,
178            float_type_id,
179            neg_id,
180            load_id,
181        ));
182
183        body.push(Instruction::store(access_id, neg_id, None));
184        Ok(())
185    }
186
187    // Clamp fragment depth between 0 and 1.
188    fn write_epilogue_frag_depth_clamp(
189        &mut self,
190        frag_depth_id: Word,
191        body: &mut Vec<Instruction>,
192    ) -> Result<(), Error> {
193        let float_type_id = self.get_f32_type_id();
194        let zero_scalar_id = self.get_constant_scalar(crate::Literal::F32(0.0));
195        let one_scalar_id = self.get_constant_scalar(crate::Literal::F32(1.0));
196
197        let original_id = self.id_gen.next();
198        body.push(Instruction::load(
199            float_type_id,
200            original_id,
201            frag_depth_id,
202            None,
203        ));
204
205        let clamp_id = self.id_gen.next();
206        body.push(Instruction::ext_inst_gl_op(
207            self.gl450_ext_inst_id,
208            spirv::GLOp::FClamp,
209            float_type_id,
210            clamp_id,
211            &[original_id, zero_scalar_id, one_scalar_id],
212        ));
213
214        body.push(Instruction::store(frag_depth_id, clamp_id, None));
215        Ok(())
216    }
217
218    fn write_entry_point_return(
219        &mut self,
220        value_id: Word,
221        ir_result: &crate::FunctionResult,
222        result_members: &[ResultMember],
223        body: &mut Vec<Instruction>,
224    ) -> Result<(), Error> {
225        for (index, res_member) in result_members.iter().enumerate() {
226            let member_value_id = match ir_result.binding {
227                Some(_) => value_id,
228                None => {
229                    let member_value_id = self.id_gen.next();
230                    body.push(Instruction::composite_extract(
231                        res_member.type_id,
232                        member_value_id,
233                        value_id,
234                        &[index as u32],
235                    ));
236                    member_value_id
237                }
238            };
239
240            self.store_io_with_f16_polyfill(body, res_member.id, member_value_id);
241
242            match res_member.built_in {
243                Some(crate::BuiltIn::Position { .. })
244                    if self.flags.contains(WriterFlags::ADJUST_COORDINATE_SPACE) =>
245                {
246                    self.write_epilogue_position_y_flip(res_member.id, body)?;
247                }
248                Some(crate::BuiltIn::FragDepth)
249                    if self.flags.contains(WriterFlags::CLAMP_FRAG_DEPTH) =>
250                {
251                    self.write_epilogue_frag_depth_clamp(res_member.id, body)?;
252                }
253                _ => {}
254            }
255        }
256        Ok(())
257    }
258}
259
260impl BlockContext<'_> {
261    /// Generates code to ensure that a loop is bounded. Should be called immediately
262    /// after adding the OpLoopMerge instruction to `block`. This function will
263    /// [`consume()`](crate::back::spv::Function::consume) `block` and append its
264    /// instructions to a new [`Block`], which will be returned to the caller for it to
265    /// consumed prior to writing the loop body.
266    ///
267    /// Additionally this function will populate [`force_loop_bounding_vars`](crate::back::spv::Function::force_loop_bounding_vars),
268    /// ensuring that [`Function::to_words()`](crate::back::spv::Function::to_words) will
269    /// declare the required variables.
270    ///
271    /// See [`crate::back::msl::Writer::gen_force_bounded_loop_statements`] for details
272    /// of why this is required.
273    fn write_force_bounded_loop_instructions(&mut self, mut block: Block, merge_id: Word) -> Block {
274        let uint_type_id = self.writer.get_u32_type_id();
275        let uint2_type_id = self.writer.get_vec2u_type_id();
276        let uint2_ptr_type_id = self
277            .writer
278            .get_vec2u_pointer_type_id(spirv::StorageClass::Function);
279        let bool_type_id = self.writer.get_bool_type_id();
280        let bool2_type_id = self.writer.get_vec2_bool_type_id();
281        let zero_uint_const_id = self.writer.get_constant_scalar(crate::Literal::U32(0));
282        let zero_uint2_const_id = self.writer.get_constant_composite(
283            LookupType::Local(LocalType::Numeric(NumericType::Vector {
284                size: crate::VectorSize::Bi,
285                scalar: crate::Scalar::U32,
286            })),
287            &[zero_uint_const_id, zero_uint_const_id],
288        );
289        let one_uint_const_id = self.writer.get_constant_scalar(crate::Literal::U32(1));
290        let max_uint_const_id = self
291            .writer
292            .get_constant_scalar(crate::Literal::U32(u32::MAX));
293        let max_uint2_const_id = self.writer.get_constant_composite(
294            LookupType::Local(LocalType::Numeric(NumericType::Vector {
295                size: crate::VectorSize::Bi,
296                scalar: crate::Scalar::U32,
297            })),
298            &[max_uint_const_id, max_uint_const_id],
299        );
300
301        let loop_counter_var_id = self.gen_id();
302        if self.writer.flags.contains(WriterFlags::DEBUG) {
303            self.writer
304                .debugs
305                .push(Instruction::name(loop_counter_var_id, "loop_bound"));
306        }
307        let var = super::LocalVariable {
308            id: loop_counter_var_id,
309            instruction: Instruction::variable(
310                uint2_ptr_type_id,
311                loop_counter_var_id,
312                spirv::StorageClass::Function,
313                Some(max_uint2_const_id),
314            ),
315        };
316        self.function.force_loop_bounding_vars.push(var);
317
318        let break_if_block = self.gen_id();
319
320        self.function
321            .consume(block, Instruction::branch(break_if_block));
322        block = Block::new(break_if_block);
323
324        // Load the current loop counter value from its variable. We use a vec2<u32> to
325        // simulate a 64-bit counter.
326        let load_id = self.gen_id();
327        block.body.push(Instruction::load(
328            uint2_type_id,
329            load_id,
330            loop_counter_var_id,
331            None,
332        ));
333
334        // If both the high and low u32s have reached 0 then break. ie
335        // if (all(eq(loop_counter, vec2(0)))) { break; }
336        let eq_id = self.gen_id();
337        block.body.push(Instruction::binary(
338            spirv::Op::IEqual,
339            bool2_type_id,
340            eq_id,
341            zero_uint2_const_id,
342            load_id,
343        ));
344        let all_eq_id = self.gen_id();
345        block.body.push(Instruction::relational(
346            spirv::Op::All,
347            bool_type_id,
348            all_eq_id,
349            eq_id,
350        ));
351
352        let inc_counter_block_id = self.gen_id();
353        block.body.push(Instruction::selection_merge(
354            inc_counter_block_id,
355            spirv::SelectionControl::empty(),
356        ));
357        self.function.consume(
358            block,
359            Instruction::branch_conditional(all_eq_id, merge_id, inc_counter_block_id),
360        );
361        block = Block::new(inc_counter_block_id);
362
363        // To simulate a 64-bit counter we always decrement the low u32, and decrement
364        // the high u32 when the low u32 overflows. ie
365        // counter -= vec2(select(0u, 1u, counter.y == 0), 1u);
366        // Count down from u32::MAX rather than up from 0 to avoid hang on
367        // certain Intel drivers. See <https://github.com/gfx-rs/wgpu/issues/7319>.
368        let low_id = self.gen_id();
369        block.body.push(Instruction::composite_extract(
370            uint_type_id,
371            low_id,
372            load_id,
373            &[1],
374        ));
375        let low_overflow_id = self.gen_id();
376        block.body.push(Instruction::binary(
377            spirv::Op::IEqual,
378            bool_type_id,
379            low_overflow_id,
380            low_id,
381            zero_uint_const_id,
382        ));
383        let carry_bit_id = self.gen_id();
384        block.body.push(Instruction::select(
385            uint_type_id,
386            carry_bit_id,
387            low_overflow_id,
388            one_uint_const_id,
389            zero_uint_const_id,
390        ));
391        let decrement_id = self.gen_id();
392        block.body.push(Instruction::composite_construct(
393            uint2_type_id,
394            decrement_id,
395            &[carry_bit_id, one_uint_const_id],
396        ));
397        let result_id = self.gen_id();
398        block.body.push(Instruction::binary(
399            spirv::Op::ISub,
400            uint2_type_id,
401            result_id,
402            load_id,
403            decrement_id,
404        ));
405        block
406            .body
407            .push(Instruction::store(loop_counter_var_id, result_id, None));
408
409        block
410    }
411
412    /// Cache an expression for a value.
413    pub(super) fn cache_expression_value(
414        &mut self,
415        expr_handle: Handle<crate::Expression>,
416        block: &mut Block,
417    ) -> Result<(), Error> {
418        let is_named_expression = self
419            .ir_function
420            .named_expressions
421            .contains_key(&expr_handle);
422
423        if self.fun_info[expr_handle].ref_count == 0 && !is_named_expression {
424            return Ok(());
425        }
426
427        let result_type_id = self.get_expression_type_id(&self.fun_info[expr_handle].ty);
428        let id = match self.ir_function.expressions[expr_handle] {
429            crate::Expression::Literal(literal) => self.writer.get_constant_scalar(literal),
430            crate::Expression::Constant(handle) => {
431                let init = self.ir_module.constants[handle].init;
432                self.writer.constant_ids[init]
433            }
434            crate::Expression::Override(_) => return Err(Error::Override),
435            crate::Expression::ZeroValue(_) => self.writer.get_constant_null(result_type_id),
436            crate::Expression::Compose { ty, ref components } => {
437                self.temp_list.clear();
438                if self.expression_constness.is_const(expr_handle) {
439                    self.temp_list.extend(
440                        crate::proc::flatten_compose(
441                            ty,
442                            components,
443                            &self.ir_function.expressions,
444                            &self.ir_module.types,
445                        )
446                        .map(|component| self.cached[component]),
447                    );
448                    self.writer
449                        .get_constant_composite(LookupType::Handle(ty), &self.temp_list)
450                } else {
451                    self.temp_list
452                        .extend(components.iter().map(|&component| self.cached[component]));
453
454                    let id = self.gen_id();
455                    block.body.push(Instruction::composite_construct(
456                        result_type_id,
457                        id,
458                        &self.temp_list,
459                    ));
460                    id
461                }
462            }
463            crate::Expression::Splat { size, value } => {
464                let value_id = self.cached[value];
465                let components = &[value_id; 4][..size as usize];
466
467                if self.expression_constness.is_const(expr_handle) {
468                    let ty = self
469                        .writer
470                        .get_expression_lookup_type(&self.fun_info[expr_handle].ty);
471                    self.writer.get_constant_composite(ty, components)
472                } else {
473                    let id = self.gen_id();
474                    block.body.push(Instruction::composite_construct(
475                        result_type_id,
476                        id,
477                        components,
478                    ));
479                    id
480                }
481            }
482            crate::Expression::Access { base, index } => {
483                let base_ty_inner = self.fun_info[base].ty.inner_with(&self.ir_module.types);
484                match *base_ty_inner {
485                    crate::TypeInner::Pointer { .. } | crate::TypeInner::ValuePointer { .. } => {
486                        // When we have a chain of `Access` and `AccessIndex` expressions
487                        // operating on pointers, we want to generate a single
488                        // `OpAccessChain` instruction for the whole chain. Put off
489                        // generating any code for this until we find the `Expression`
490                        // that actually dereferences the pointer.
491                        0
492                    }
493                    _ if self.function.spilled_accesses.contains(base) => {
494                        // As far as Naga IR is concerned, this expression does not yield
495                        // a pointer (we just checked, above), but this backend spilled it
496                        // to a temporary variable, so SPIR-V thinks we're accessing it
497                        // via a pointer.
498
499                        // Since the base expression was spilled, mark this access to it
500                        // as spilled, too.
501                        self.function.spilled_accesses.insert(expr_handle);
502                        self.maybe_access_spilled_composite(expr_handle, block, result_type_id)?
503                    }
504                    crate::TypeInner::Vector { .. } => {
505                        self.write_vector_access(expr_handle, base, index, block)?
506                    }
507                    crate::TypeInner::Array { .. } | crate::TypeInner::Matrix { .. } => {
508                        // See if `index` is known at compile time.
509                        match GuardedIndex::from_expression(
510                            index,
511                            &self.ir_function.expressions,
512                            self.ir_module,
513                        ) {
514                            GuardedIndex::Known(value) => {
515                                // If `index` is known and in bounds, we can just use
516                                // `OpCompositeExtract`.
517                                //
518                                // At the moment, validation rejects programs if this
519                                // index is out of bounds, so we don't need bounds checks.
520                                // However, that rejection is incorrect, since WGSL says
521                                // that `let` bindings are not constant expressions
522                                // (#6396). So eventually we will need to emulate bounds
523                                // checks here.
524                                let id = self.gen_id();
525                                let base_id = self.cached[base];
526                                block.body.push(Instruction::composite_extract(
527                                    result_type_id,
528                                    id,
529                                    base_id,
530                                    &[value],
531                                ));
532                                id
533                            }
534                            GuardedIndex::Expression(_) => {
535                                // We are subscripting an array or matrix that is not
536                                // behind a pointer, using an index computed at runtime.
537                                // SPIR-V has no instructions that do this, so the best we
538                                // can do is spill the value to a new temporary variable,
539                                // at which point we can get a pointer to that and just
540                                // use `OpAccessChain` in the usual way.
541                                self.spill_to_internal_variable(base, block);
542
543                                // Since the base was spilled, mark this access to it as
544                                // spilled, too.
545                                self.function.spilled_accesses.insert(expr_handle);
546                                self.maybe_access_spilled_composite(
547                                    expr_handle,
548                                    block,
549                                    result_type_id,
550                                )?
551                            }
552                        }
553                    }
554                    crate::TypeInner::BindingArray {
555                        base: binding_type, ..
556                    } => {
557                        // Only binding arrays in the `Handle` address space will take
558                        // this path, since we handled the `Pointer` case above.
559                        let result_id = match self.write_access_chain(
560                            expr_handle,
561                            block,
562                            AccessTypeAdjustment::IntroducePointer(
563                                spirv::StorageClass::UniformConstant,
564                            ),
565                        )? {
566                            ExpressionPointer::Ready { pointer_id } => pointer_id,
567                            ExpressionPointer::Conditional { .. } => {
568                                return Err(Error::FeatureNotImplemented(
569                                    "Texture array out-of-bounds handling",
570                                ));
571                            }
572                        };
573
574                        let binding_type_id = self.get_handle_type_id(binding_type);
575
576                        let load_id = self.gen_id();
577                        block.body.push(Instruction::load(
578                            binding_type_id,
579                            load_id,
580                            result_id,
581                            None,
582                        ));
583
584                        // Subsequent image operations require the image/sampler to be decorated as NonUniform
585                        // if the image/sampler binding array was accessed with a non-uniform index
586                        // see VUID-RuntimeSpirv-NonUniform-06274
587                        if self.fun_info[index].uniformity.non_uniform_result.is_some() {
588                            self.writer
589                                .decorate_non_uniform_binding_array_access(load_id)?;
590                        }
591
592                        load_id
593                    }
594                    ref other => {
595                        log::error!(
596                            "Unable to access base {:?} of type {:?}",
597                            self.ir_function.expressions[base],
598                            other
599                        );
600                        return Err(Error::Validation(
601                            "only vectors and arrays may be dynamically indexed by value",
602                        ));
603                    }
604                }
605            }
606            crate::Expression::AccessIndex { base, index } => {
607                match *self.fun_info[base].ty.inner_with(&self.ir_module.types) {
608                    crate::TypeInner::Pointer { .. } | crate::TypeInner::ValuePointer { .. } => {
609                        // When we have a chain of `Access` and `AccessIndex` expressions
610                        // operating on pointers, we want to generate a single
611                        // `OpAccessChain` instruction for the whole chain. Put off
612                        // generating any code for this until we find the `Expression`
613                        // that actually dereferences the pointer.
614                        0
615                    }
616                    _ if self.function.spilled_accesses.contains(base) => {
617                        // As far as Naga IR is concerned, this expression does not yield
618                        // a pointer (we just checked, above), but this backend spilled it
619                        // to a temporary variable, so SPIR-V thinks we're accessing it
620                        // via a pointer.
621
622                        // Since the base expression was spilled, mark this access to it
623                        // as spilled, too.
624                        self.function.spilled_accesses.insert(expr_handle);
625                        self.maybe_access_spilled_composite(expr_handle, block, result_type_id)?
626                    }
627                    crate::TypeInner::Vector { .. }
628                    | crate::TypeInner::Matrix { .. }
629                    | crate::TypeInner::Array { .. }
630                    | crate::TypeInner::Struct { .. } => {
631                        // We never need bounds checks here: dynamically sized arrays can
632                        // only appear behind pointers, and are thus handled by the
633                        // `is_intermediate` case above. Everything else's size is
634                        // statically known and checked in validation.
635                        let id = self.gen_id();
636                        let base_id = self.cached[base];
637                        block.body.push(Instruction::composite_extract(
638                            result_type_id,
639                            id,
640                            base_id,
641                            &[index],
642                        ));
643                        id
644                    }
645                    crate::TypeInner::BindingArray {
646                        base: binding_type, ..
647                    } => {
648                        // Only binding arrays in the `Handle` address space will take
649                        // this path, since we handled the `Pointer` case above.
650                        let result_id = match self.write_access_chain(
651                            expr_handle,
652                            block,
653                            AccessTypeAdjustment::IntroducePointer(
654                                spirv::StorageClass::UniformConstant,
655                            ),
656                        )? {
657                            ExpressionPointer::Ready { pointer_id } => pointer_id,
658                            ExpressionPointer::Conditional { .. } => {
659                                return Err(Error::FeatureNotImplemented(
660                                    "Texture array out-of-bounds handling",
661                                ));
662                            }
663                        };
664
665                        let binding_type_id = self.get_handle_type_id(binding_type);
666
667                        let load_id = self.gen_id();
668                        block.body.push(Instruction::load(
669                            binding_type_id,
670                            load_id,
671                            result_id,
672                            None,
673                        ));
674
675                        load_id
676                    }
677                    ref other => {
678                        log::error!("Unable to access index of {other:?}");
679                        return Err(Error::FeatureNotImplemented("access index for type"));
680                    }
681                }
682            }
683            crate::Expression::GlobalVariable(handle) => {
684                self.writer.global_variables[handle].access_id
685            }
686            crate::Expression::Swizzle {
687                size,
688                vector,
689                pattern,
690            } => {
691                let vector_id = self.cached[vector];
692                self.temp_list.clear();
693                for &sc in pattern[..size as usize].iter() {
694                    self.temp_list.push(sc as Word);
695                }
696                let id = self.gen_id();
697                block.body.push(Instruction::vector_shuffle(
698                    result_type_id,
699                    id,
700                    vector_id,
701                    vector_id,
702                    &self.temp_list,
703                ));
704                id
705            }
706            crate::Expression::Unary { op, expr } => {
707                let id = self.gen_id();
708                let expr_id = self.cached[expr];
709                let expr_ty_inner = self.fun_info[expr].ty.inner_with(&self.ir_module.types);
710
711                let spirv_op = match op {
712                    crate::UnaryOperator::Negate => match expr_ty_inner.scalar_kind() {
713                        Some(crate::ScalarKind::Float) => spirv::Op::FNegate,
714                        Some(crate::ScalarKind::Sint) => spirv::Op::SNegate,
715                        _ => return Err(Error::Validation("Unexpected kind for negation")),
716                    },
717                    crate::UnaryOperator::LogicalNot => spirv::Op::LogicalNot,
718                    crate::UnaryOperator::BitwiseNot => spirv::Op::Not,
719                };
720
721                block
722                    .body
723                    .push(Instruction::unary(spirv_op, result_type_id, id, expr_id));
724                id
725            }
726            crate::Expression::Binary { op, left, right } => {
727                let id = self.gen_id();
728                let left_id = self.cached[left];
729                let right_id = self.cached[right];
730                let left_type_id = self.get_expression_type_id(&self.fun_info[left].ty);
731                let right_type_id = self.get_expression_type_id(&self.fun_info[right].ty);
732
733                if let Some(function_id) =
734                    self.writer
735                        .wrapped_functions
736                        .get(&WrappedFunction::BinaryOp {
737                            op,
738                            left_type_id,
739                            right_type_id,
740                        })
741                {
742                    block.body.push(Instruction::function_call(
743                        result_type_id,
744                        id,
745                        *function_id,
746                        &[left_id, right_id],
747                    ));
748                } else {
749                    let left_ty_inner = self.fun_info[left].ty.inner_with(&self.ir_module.types);
750                    let right_ty_inner = self.fun_info[right].ty.inner_with(&self.ir_module.types);
751
752                    let left_dimension = get_dimension(left_ty_inner);
753                    let right_dimension = get_dimension(right_ty_inner);
754
755                    let mut reverse_operands = false;
756
757                    let spirv_op = match op {
758                        crate::BinaryOperator::Add => match *left_ty_inner {
759                            crate::TypeInner::Scalar(scalar)
760                            | crate::TypeInner::Vector { scalar, .. } => match scalar.kind {
761                                crate::ScalarKind::Float => spirv::Op::FAdd,
762                                _ => spirv::Op::IAdd,
763                            },
764                            crate::TypeInner::Matrix {
765                                columns,
766                                rows,
767                                scalar,
768                            } => {
769                                self.write_matrix_matrix_column_op(
770                                    block,
771                                    id,
772                                    result_type_id,
773                                    left_id,
774                                    right_id,
775                                    columns,
776                                    rows,
777                                    scalar.width,
778                                    spirv::Op::FAdd,
779                                );
780
781                                self.cached[expr_handle] = id;
782                                return Ok(());
783                            }
784                            _ => unimplemented!(),
785                        },
786                        crate::BinaryOperator::Subtract => match *left_ty_inner {
787                            crate::TypeInner::Scalar(scalar)
788                            | crate::TypeInner::Vector { scalar, .. } => match scalar.kind {
789                                crate::ScalarKind::Float => spirv::Op::FSub,
790                                _ => spirv::Op::ISub,
791                            },
792                            crate::TypeInner::Matrix {
793                                columns,
794                                rows,
795                                scalar,
796                            } => {
797                                self.write_matrix_matrix_column_op(
798                                    block,
799                                    id,
800                                    result_type_id,
801                                    left_id,
802                                    right_id,
803                                    columns,
804                                    rows,
805                                    scalar.width,
806                                    spirv::Op::FSub,
807                                );
808
809                                self.cached[expr_handle] = id;
810                                return Ok(());
811                            }
812                            _ => unimplemented!(),
813                        },
814                        crate::BinaryOperator::Multiply => {
815                            match (left_dimension, right_dimension) {
816                                (Dimension::Scalar, Dimension::Vector) => {
817                                    self.write_vector_scalar_mult(
818                                        block,
819                                        id,
820                                        result_type_id,
821                                        right_id,
822                                        left_id,
823                                        right_ty_inner,
824                                    );
825
826                                    self.cached[expr_handle] = id;
827                                    return Ok(());
828                                }
829                                (Dimension::Vector, Dimension::Scalar) => {
830                                    self.write_vector_scalar_mult(
831                                        block,
832                                        id,
833                                        result_type_id,
834                                        left_id,
835                                        right_id,
836                                        left_ty_inner,
837                                    );
838
839                                    self.cached[expr_handle] = id;
840                                    return Ok(());
841                                }
842                                (Dimension::Vector, Dimension::Matrix) => {
843                                    spirv::Op::VectorTimesMatrix
844                                }
845                                (Dimension::Matrix, Dimension::Scalar) => {
846                                    spirv::Op::MatrixTimesScalar
847                                }
848                                (Dimension::Scalar, Dimension::Matrix) => {
849                                    reverse_operands = true;
850                                    spirv::Op::MatrixTimesScalar
851                                }
852                                (Dimension::Matrix, Dimension::Vector) => {
853                                    spirv::Op::MatrixTimesVector
854                                }
855                                (Dimension::Matrix, Dimension::Matrix) => {
856                                    spirv::Op::MatrixTimesMatrix
857                                }
858                                (Dimension::Vector, Dimension::Vector)
859                                | (Dimension::Scalar, Dimension::Scalar)
860                                    if left_ty_inner.scalar_kind()
861                                        == Some(crate::ScalarKind::Float) =>
862                                {
863                                    spirv::Op::FMul
864                                }
865                                (Dimension::Vector, Dimension::Vector)
866                                | (Dimension::Scalar, Dimension::Scalar) => spirv::Op::IMul,
867                            }
868                        }
869                        crate::BinaryOperator::Divide => match left_ty_inner.scalar_kind() {
870                            Some(crate::ScalarKind::Sint) => spirv::Op::SDiv,
871                            Some(crate::ScalarKind::Uint) => spirv::Op::UDiv,
872                            Some(crate::ScalarKind::Float) => spirv::Op::FDiv,
873                            _ => unimplemented!(),
874                        },
875                        crate::BinaryOperator::Modulo => match left_ty_inner.scalar_kind() {
876                            // TODO: handle undefined behavior
877                            // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798
878                            Some(crate::ScalarKind::Float) => spirv::Op::FRem,
879                            Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => {
880                                unreachable!("Should have been handled by wrapped function")
881                            }
882                            _ => unimplemented!(),
883                        },
884                        crate::BinaryOperator::Equal => match left_ty_inner.scalar_kind() {
885                            Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => {
886                                spirv::Op::IEqual
887                            }
888                            Some(crate::ScalarKind::Float) => spirv::Op::FOrdEqual,
889                            Some(crate::ScalarKind::Bool) => spirv::Op::LogicalEqual,
890                            _ => unimplemented!(),
891                        },
892                        crate::BinaryOperator::NotEqual => match left_ty_inner.scalar_kind() {
893                            Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => {
894                                spirv::Op::INotEqual
895                            }
896                            Some(crate::ScalarKind::Float) => spirv::Op::FOrdNotEqual,
897                            Some(crate::ScalarKind::Bool) => spirv::Op::LogicalNotEqual,
898                            _ => unimplemented!(),
899                        },
900                        crate::BinaryOperator::Less => match left_ty_inner.scalar_kind() {
901                            Some(crate::ScalarKind::Sint) => spirv::Op::SLessThan,
902                            Some(crate::ScalarKind::Uint) => spirv::Op::ULessThan,
903                            Some(crate::ScalarKind::Float) => spirv::Op::FOrdLessThan,
904                            _ => unimplemented!(),
905                        },
906                        crate::BinaryOperator::LessEqual => match left_ty_inner.scalar_kind() {
907                            Some(crate::ScalarKind::Sint) => spirv::Op::SLessThanEqual,
908                            Some(crate::ScalarKind::Uint) => spirv::Op::ULessThanEqual,
909                            Some(crate::ScalarKind::Float) => spirv::Op::FOrdLessThanEqual,
910                            _ => unimplemented!(),
911                        },
912                        crate::BinaryOperator::Greater => match left_ty_inner.scalar_kind() {
913                            Some(crate::ScalarKind::Sint) => spirv::Op::SGreaterThan,
914                            Some(crate::ScalarKind::Uint) => spirv::Op::UGreaterThan,
915                            Some(crate::ScalarKind::Float) => spirv::Op::FOrdGreaterThan,
916                            _ => unimplemented!(),
917                        },
918                        crate::BinaryOperator::GreaterEqual => match left_ty_inner.scalar_kind() {
919                            Some(crate::ScalarKind::Sint) => spirv::Op::SGreaterThanEqual,
920                            Some(crate::ScalarKind::Uint) => spirv::Op::UGreaterThanEqual,
921                            Some(crate::ScalarKind::Float) => spirv::Op::FOrdGreaterThanEqual,
922                            _ => unimplemented!(),
923                        },
924                        crate::BinaryOperator::And => match left_ty_inner.scalar_kind() {
925                            Some(crate::ScalarKind::Bool) => spirv::Op::LogicalAnd,
926                            _ => spirv::Op::BitwiseAnd,
927                        },
928                        crate::BinaryOperator::ExclusiveOr => spirv::Op::BitwiseXor,
929                        crate::BinaryOperator::InclusiveOr => match left_ty_inner.scalar_kind() {
930                            Some(crate::ScalarKind::Bool) => spirv::Op::LogicalOr,
931                            _ => spirv::Op::BitwiseOr,
932                        },
933                        crate::BinaryOperator::LogicalAnd => spirv::Op::LogicalAnd,
934                        crate::BinaryOperator::LogicalOr => spirv::Op::LogicalOr,
935                        crate::BinaryOperator::ShiftLeft => spirv::Op::ShiftLeftLogical,
936                        crate::BinaryOperator::ShiftRight => match left_ty_inner.scalar_kind() {
937                            Some(crate::ScalarKind::Sint) => spirv::Op::ShiftRightArithmetic,
938                            Some(crate::ScalarKind::Uint) => spirv::Op::ShiftRightLogical,
939                            _ => unimplemented!(),
940                        },
941                    };
942
943                    block.body.push(Instruction::binary(
944                        spirv_op,
945                        result_type_id,
946                        id,
947                        if reverse_operands { right_id } else { left_id },
948                        if reverse_operands { left_id } else { right_id },
949                    ));
950                }
951                id
952            }
953            crate::Expression::Math {
954                fun,
955                arg,
956                arg1,
957                arg2,
958                arg3,
959            } => {
960                use crate::MathFunction as Mf;
961                enum MathOp {
962                    Ext(spirv::GLOp),
963                    Custom(Instruction),
964                }
965
966                let arg0_id = self.cached[arg];
967                let arg_ty = self.fun_info[arg].ty.inner_with(&self.ir_module.types);
968                let arg_scalar_kind = arg_ty.scalar_kind();
969                let arg1_id = match arg1 {
970                    Some(handle) => self.cached[handle],
971                    None => 0,
972                };
973                let arg2_id = match arg2 {
974                    Some(handle) => self.cached[handle],
975                    None => 0,
976                };
977                let arg3_id = match arg3 {
978                    Some(handle) => self.cached[handle],
979                    None => 0,
980                };
981
982                let id = self.gen_id();
983                let math_op = match fun {
984                    // comparison
985                    Mf::Abs => {
986                        match arg_scalar_kind {
987                            Some(crate::ScalarKind::Float) => MathOp::Ext(spirv::GLOp::FAbs),
988                            Some(crate::ScalarKind::Sint) => MathOp::Ext(spirv::GLOp::SAbs),
989                            Some(crate::ScalarKind::Uint) => {
990                                MathOp::Custom(Instruction::unary(
991                                    spirv::Op::CopyObject, // do nothing
992                                    result_type_id,
993                                    id,
994                                    arg0_id,
995                                ))
996                            }
997                            other => unimplemented!("Unexpected abs({:?})", other),
998                        }
999                    }
1000                    Mf::Min => MathOp::Ext(match arg_scalar_kind {
1001                        Some(crate::ScalarKind::Float) => spirv::GLOp::FMin,
1002                        Some(crate::ScalarKind::Sint) => spirv::GLOp::SMin,
1003                        Some(crate::ScalarKind::Uint) => spirv::GLOp::UMin,
1004                        other => unimplemented!("Unexpected min({:?})", other),
1005                    }),
1006                    Mf::Max => MathOp::Ext(match arg_scalar_kind {
1007                        Some(crate::ScalarKind::Float) => spirv::GLOp::FMax,
1008                        Some(crate::ScalarKind::Sint) => spirv::GLOp::SMax,
1009                        Some(crate::ScalarKind::Uint) => spirv::GLOp::UMax,
1010                        other => unimplemented!("Unexpected max({:?})", other),
1011                    }),
1012                    Mf::Clamp => match arg_scalar_kind {
1013                        // Clamp is undefined if min > max. In practice this means it can use a median-of-three
1014                        // instruction to determine the value. This is fine according to the WGSL spec for float
1015                        // clamp, but integer clamp _must_ use min-max. As such we write out min/max.
1016                        Some(crate::ScalarKind::Float) => MathOp::Ext(spirv::GLOp::FClamp),
1017                        Some(_) => {
1018                            let (min_op, max_op) = match arg_scalar_kind {
1019                                Some(crate::ScalarKind::Sint) => {
1020                                    (spirv::GLOp::SMin, spirv::GLOp::SMax)
1021                                }
1022                                Some(crate::ScalarKind::Uint) => {
1023                                    (spirv::GLOp::UMin, spirv::GLOp::UMax)
1024                                }
1025                                _ => unreachable!(),
1026                            };
1027
1028                            let max_id = self.gen_id();
1029                            block.body.push(Instruction::ext_inst_gl_op(
1030                                self.writer.gl450_ext_inst_id,
1031                                max_op,
1032                                result_type_id,
1033                                max_id,
1034                                &[arg0_id, arg1_id],
1035                            ));
1036
1037                            MathOp::Custom(Instruction::ext_inst_gl_op(
1038                                self.writer.gl450_ext_inst_id,
1039                                min_op,
1040                                result_type_id,
1041                                id,
1042                                &[max_id, arg2_id],
1043                            ))
1044                        }
1045                        other => unimplemented!("Unexpected max({:?})", other),
1046                    },
1047                    Mf::Saturate => {
1048                        let (maybe_size, scalar) = match *arg_ty {
1049                            crate::TypeInner::Vector { size, scalar } => (Some(size), scalar),
1050                            crate::TypeInner::Scalar(scalar) => (None, scalar),
1051                            ref other => unimplemented!("Unexpected saturate({:?})", other),
1052                        };
1053                        let scalar = crate::Scalar::float(scalar.width);
1054                        let mut arg1_id = self.writer.get_constant_scalar_with(0, scalar)?;
1055                        let mut arg2_id = self.writer.get_constant_scalar_with(1, scalar)?;
1056
1057                        if let Some(size) = maybe_size {
1058                            let ty =
1059                                LocalType::Numeric(NumericType::Vector { size, scalar }).into();
1060
1061                            self.temp_list.clear();
1062                            self.temp_list.resize(size as _, arg1_id);
1063
1064                            arg1_id = self.writer.get_constant_composite(ty, &self.temp_list);
1065
1066                            self.temp_list.fill(arg2_id);
1067
1068                            arg2_id = self.writer.get_constant_composite(ty, &self.temp_list);
1069                        }
1070
1071                        MathOp::Custom(Instruction::ext_inst_gl_op(
1072                            self.writer.gl450_ext_inst_id,
1073                            spirv::GLOp::FClamp,
1074                            result_type_id,
1075                            id,
1076                            &[arg0_id, arg1_id, arg2_id],
1077                        ))
1078                    }
1079                    // trigonometry
1080                    Mf::Sin => MathOp::Ext(spirv::GLOp::Sin),
1081                    Mf::Sinh => MathOp::Ext(spirv::GLOp::Sinh),
1082                    Mf::Asin => MathOp::Ext(spirv::GLOp::Asin),
1083                    Mf::Cos => MathOp::Ext(spirv::GLOp::Cos),
1084                    Mf::Cosh => MathOp::Ext(spirv::GLOp::Cosh),
1085                    Mf::Acos => MathOp::Ext(spirv::GLOp::Acos),
1086                    Mf::Tan => MathOp::Ext(spirv::GLOp::Tan),
1087                    Mf::Tanh => MathOp::Ext(spirv::GLOp::Tanh),
1088                    Mf::Atan => MathOp::Ext(spirv::GLOp::Atan),
1089                    Mf::Atan2 => MathOp::Ext(spirv::GLOp::Atan2),
1090                    Mf::Asinh => MathOp::Ext(spirv::GLOp::Asinh),
1091                    Mf::Acosh => MathOp::Ext(spirv::GLOp::Acosh),
1092                    Mf::Atanh => MathOp::Ext(spirv::GLOp::Atanh),
1093                    Mf::Radians => MathOp::Ext(spirv::GLOp::Radians),
1094                    Mf::Degrees => MathOp::Ext(spirv::GLOp::Degrees),
1095                    // decomposition
1096                    Mf::Ceil => MathOp::Ext(spirv::GLOp::Ceil),
1097                    Mf::Round => MathOp::Ext(spirv::GLOp::RoundEven),
1098                    Mf::Floor => MathOp::Ext(spirv::GLOp::Floor),
1099                    Mf::Fract => MathOp::Ext(spirv::GLOp::Fract),
1100                    Mf::Trunc => MathOp::Ext(spirv::GLOp::Trunc),
1101                    Mf::Modf => MathOp::Ext(spirv::GLOp::ModfStruct),
1102                    Mf::Frexp => MathOp::Ext(spirv::GLOp::FrexpStruct),
1103                    Mf::Ldexp => MathOp::Ext(spirv::GLOp::Ldexp),
1104                    // geometry
1105                    Mf::Dot => match *self.fun_info[arg].ty.inner_with(&self.ir_module.types) {
1106                        crate::TypeInner::Vector {
1107                            scalar:
1108                                crate::Scalar {
1109                                    kind: crate::ScalarKind::Float,
1110                                    ..
1111                                },
1112                            ..
1113                        } => MathOp::Custom(Instruction::binary(
1114                            spirv::Op::Dot,
1115                            result_type_id,
1116                            id,
1117                            arg0_id,
1118                            arg1_id,
1119                        )),
1120                        // TODO: consider using integer dot product if VK_KHR_shader_integer_dot_product is available
1121                        crate::TypeInner::Vector { size, .. } => {
1122                            self.write_dot_product(
1123                                id,
1124                                result_type_id,
1125                                arg0_id,
1126                                arg1_id,
1127                                size as u32,
1128                                block,
1129                                |result_id, composite_id, index| {
1130                                    Instruction::composite_extract(
1131                                        result_type_id,
1132                                        result_id,
1133                                        composite_id,
1134                                        &[index],
1135                                    )
1136                                },
1137                            );
1138                            self.cached[expr_handle] = id;
1139                            return Ok(());
1140                        }
1141                        _ => unreachable!(
1142                            "Correct TypeInner for dot product should be already validated"
1143                        ),
1144                    },
1145                    fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
1146                        if self
1147                            .writer
1148                            .require_all(&[
1149                                spirv::Capability::DotProduct,
1150                                spirv::Capability::DotProductInput4x8BitPacked,
1151                            ])
1152                            .is_ok()
1153                        {
1154                            // Write optimized code using `PackedVectorFormat4x8Bit`.
1155                            if self.writer.lang_version() < (1, 6) {
1156                                // SPIR-V 1.6 supports the required capabilities natively, so the extension
1157                                // is only required for earlier versions. See right column of
1158                                // <https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpSDot>.
1159                                self.writer.use_extension("SPV_KHR_integer_dot_product");
1160                            }
1161
1162                            let op = match fun {
1163                                Mf::Dot4I8Packed => spirv::Op::SDot,
1164                                Mf::Dot4U8Packed => spirv::Op::UDot,
1165                                _ => unreachable!(),
1166                            };
1167
1168                            block.body.push(Instruction::ternary(
1169                                op,
1170                                result_type_id,
1171                                id,
1172                                arg0_id,
1173                                arg1_id,
1174                                spirv::PackedVectorFormat::PackedVectorFormat4x8Bit as Word,
1175                            ));
1176                        } else {
1177                            // Fall back to a polyfill since `PackedVectorFormat4x8Bit` is not available.
1178                            let (extract_op, arg0_id, arg1_id) = match fun {
1179                                Mf::Dot4U8Packed => (spirv::Op::BitFieldUExtract, arg0_id, arg1_id),
1180                                Mf::Dot4I8Packed => {
1181                                    // Convert both packed arguments to signed integers so that we can apply the
1182                                    // `BitFieldSExtract` operation on them in `write_dot_product` below.
1183                                    let new_arg0_id = self.gen_id();
1184                                    block.body.push(Instruction::unary(
1185                                        spirv::Op::Bitcast,
1186                                        result_type_id,
1187                                        new_arg0_id,
1188                                        arg0_id,
1189                                    ));
1190
1191                                    let new_arg1_id = self.gen_id();
1192                                    block.body.push(Instruction::unary(
1193                                        spirv::Op::Bitcast,
1194                                        result_type_id,
1195                                        new_arg1_id,
1196                                        arg1_id,
1197                                    ));
1198
1199                                    (spirv::Op::BitFieldSExtract, new_arg0_id, new_arg1_id)
1200                                }
1201                                _ => unreachable!(),
1202                            };
1203
1204                            let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
1205
1206                            const VEC_LENGTH: u8 = 4;
1207                            let bit_shifts: [_; VEC_LENGTH as usize] =
1208                                core::array::from_fn(|index| {
1209                                    self.writer
1210                                        .get_constant_scalar(crate::Literal::U32(index as u32 * 8))
1211                                });
1212
1213                            self.write_dot_product(
1214                                id,
1215                                result_type_id,
1216                                arg0_id,
1217                                arg1_id,
1218                                VEC_LENGTH as Word,
1219                                block,
1220                                |result_id, composite_id, index| {
1221                                    Instruction::ternary(
1222                                        extract_op,
1223                                        result_type_id,
1224                                        result_id,
1225                                        composite_id,
1226                                        bit_shifts[index as usize],
1227                                        eight,
1228                                    )
1229                                },
1230                            );
1231                        }
1232
1233                        self.cached[expr_handle] = id;
1234                        return Ok(());
1235                    }
1236                    Mf::Outer => MathOp::Custom(Instruction::binary(
1237                        spirv::Op::OuterProduct,
1238                        result_type_id,
1239                        id,
1240                        arg0_id,
1241                        arg1_id,
1242                    )),
1243                    Mf::Cross => MathOp::Ext(spirv::GLOp::Cross),
1244                    Mf::Distance => MathOp::Ext(spirv::GLOp::Distance),
1245                    Mf::Length => MathOp::Ext(spirv::GLOp::Length),
1246                    Mf::Normalize => MathOp::Ext(spirv::GLOp::Normalize),
1247                    Mf::FaceForward => MathOp::Ext(spirv::GLOp::FaceForward),
1248                    Mf::Reflect => MathOp::Ext(spirv::GLOp::Reflect),
1249                    Mf::Refract => MathOp::Ext(spirv::GLOp::Refract),
1250                    // exponent
1251                    Mf::Exp => MathOp::Ext(spirv::GLOp::Exp),
1252                    Mf::Exp2 => MathOp::Ext(spirv::GLOp::Exp2),
1253                    Mf::Log => MathOp::Ext(spirv::GLOp::Log),
1254                    Mf::Log2 => MathOp::Ext(spirv::GLOp::Log2),
1255                    Mf::Pow => MathOp::Ext(spirv::GLOp::Pow),
1256                    // computational
1257                    Mf::Sign => MathOp::Ext(match arg_scalar_kind {
1258                        Some(crate::ScalarKind::Float) => spirv::GLOp::FSign,
1259                        Some(crate::ScalarKind::Sint) => spirv::GLOp::SSign,
1260                        other => unimplemented!("Unexpected sign({:?})", other),
1261                    }),
1262                    Mf::Fma => MathOp::Ext(spirv::GLOp::Fma),
1263                    Mf::Mix => {
1264                        let selector = arg2.unwrap();
1265                        let selector_ty =
1266                            self.fun_info[selector].ty.inner_with(&self.ir_module.types);
1267                        match (arg_ty, selector_ty) {
1268                            // if the selector is a scalar, we need to splat it
1269                            (
1270                                &crate::TypeInner::Vector { size, .. },
1271                                &crate::TypeInner::Scalar(scalar),
1272                            ) => {
1273                                let selector_type_id =
1274                                    self.get_numeric_type_id(NumericType::Vector { size, scalar });
1275                                self.temp_list.clear();
1276                                self.temp_list.resize(size as usize, arg2_id);
1277
1278                                let selector_id = self.gen_id();
1279                                block.body.push(Instruction::composite_construct(
1280                                    selector_type_id,
1281                                    selector_id,
1282                                    &self.temp_list,
1283                                ));
1284
1285                                MathOp::Custom(Instruction::ext_inst_gl_op(
1286                                    self.writer.gl450_ext_inst_id,
1287                                    spirv::GLOp::FMix,
1288                                    result_type_id,
1289                                    id,
1290                                    &[arg0_id, arg1_id, selector_id],
1291                                ))
1292                            }
1293                            _ => MathOp::Ext(spirv::GLOp::FMix),
1294                        }
1295                    }
1296                    Mf::Step => MathOp::Ext(spirv::GLOp::Step),
1297                    Mf::SmoothStep => MathOp::Ext(spirv::GLOp::SmoothStep),
1298                    Mf::Sqrt => MathOp::Ext(spirv::GLOp::Sqrt),
1299                    Mf::InverseSqrt => MathOp::Ext(spirv::GLOp::InverseSqrt),
1300                    Mf::Inverse => MathOp::Ext(spirv::GLOp::MatrixInverse),
1301                    Mf::Transpose => MathOp::Custom(Instruction::unary(
1302                        spirv::Op::Transpose,
1303                        result_type_id,
1304                        id,
1305                        arg0_id,
1306                    )),
1307                    Mf::Determinant => MathOp::Ext(spirv::GLOp::Determinant),
1308                    Mf::QuantizeToF16 => MathOp::Custom(Instruction::unary(
1309                        spirv::Op::QuantizeToF16,
1310                        result_type_id,
1311                        id,
1312                        arg0_id,
1313                    )),
1314                    Mf::ReverseBits => MathOp::Custom(Instruction::unary(
1315                        spirv::Op::BitReverse,
1316                        result_type_id,
1317                        id,
1318                        arg0_id,
1319                    )),
1320                    Mf::CountTrailingZeros => {
1321                        let uint_id = match *arg_ty {
1322                            crate::TypeInner::Vector { size, scalar } => {
1323                                let ty =
1324                                    LocalType::Numeric(NumericType::Vector { size, scalar }).into();
1325
1326                                self.temp_list.clear();
1327                                self.temp_list.resize(
1328                                    size as _,
1329                                    self.writer
1330                                        .get_constant_scalar_with(scalar.width * 8, scalar)?,
1331                                );
1332
1333                                self.writer.get_constant_composite(ty, &self.temp_list)
1334                            }
1335                            crate::TypeInner::Scalar(scalar) => self
1336                                .writer
1337                                .get_constant_scalar_with(scalar.width * 8, scalar)?,
1338                            _ => unreachable!(),
1339                        };
1340
1341                        let lsb_id = self.gen_id();
1342                        block.body.push(Instruction::ext_inst_gl_op(
1343                            self.writer.gl450_ext_inst_id,
1344                            spirv::GLOp::FindILsb,
1345                            result_type_id,
1346                            lsb_id,
1347                            &[arg0_id],
1348                        ));
1349
1350                        MathOp::Custom(Instruction::ext_inst_gl_op(
1351                            self.writer.gl450_ext_inst_id,
1352                            spirv::GLOp::UMin,
1353                            result_type_id,
1354                            id,
1355                            &[uint_id, lsb_id],
1356                        ))
1357                    }
1358                    Mf::CountLeadingZeros => {
1359                        let (int_type_id, int_id, width) = match *arg_ty {
1360                            crate::TypeInner::Vector { size, scalar } => {
1361                                let ty =
1362                                    LocalType::Numeric(NumericType::Vector { size, scalar }).into();
1363
1364                                self.temp_list.clear();
1365                                self.temp_list.resize(
1366                                    size as _,
1367                                    self.writer
1368                                        .get_constant_scalar_with(scalar.width * 8 - 1, scalar)?,
1369                                );
1370
1371                                (
1372                                    self.get_type_id(ty),
1373                                    self.writer.get_constant_composite(ty, &self.temp_list),
1374                                    scalar.width,
1375                                )
1376                            }
1377                            crate::TypeInner::Scalar(scalar) => (
1378                                self.get_numeric_type_id(NumericType::Scalar(scalar)),
1379                                self.writer
1380                                    .get_constant_scalar_with(scalar.width * 8 - 1, scalar)?,
1381                                scalar.width,
1382                            ),
1383                            _ => unreachable!(),
1384                        };
1385
1386                        if width != 4 {
1387                            unreachable!("This is validated out until a polyfill is implemented. https://github.com/gfx-rs/wgpu/issues/5276");
1388                        };
1389
1390                        let msb_id = self.gen_id();
1391                        block.body.push(Instruction::ext_inst_gl_op(
1392                            self.writer.gl450_ext_inst_id,
1393                            if width != 4 {
1394                                spirv::GLOp::FindILsb
1395                            } else {
1396                                spirv::GLOp::FindUMsb
1397                            },
1398                            int_type_id,
1399                            msb_id,
1400                            &[arg0_id],
1401                        ));
1402
1403                        MathOp::Custom(Instruction::binary(
1404                            spirv::Op::ISub,
1405                            result_type_id,
1406                            id,
1407                            int_id,
1408                            msb_id,
1409                        ))
1410                    }
1411                    Mf::CountOneBits => MathOp::Custom(Instruction::unary(
1412                        spirv::Op::BitCount,
1413                        result_type_id,
1414                        id,
1415                        arg0_id,
1416                    )),
1417                    Mf::ExtractBits => {
1418                        let op = match arg_scalar_kind {
1419                            Some(crate::ScalarKind::Uint) => spirv::Op::BitFieldUExtract,
1420                            Some(crate::ScalarKind::Sint) => spirv::Op::BitFieldSExtract,
1421                            other => unimplemented!("Unexpected sign({:?})", other),
1422                        };
1423
1424                        // The behavior of ExtractBits is undefined when offset + count > bit_width. We need
1425                        // to first sanitize the offset and count first. If we don't do this, AMD and Intel
1426                        // will return out-of-spec values if the extracted range is not within the bit width.
1427                        //
1428                        // This encodes the exact formula specified by the wgsl spec:
1429                        // https://gpuweb.github.io/gpuweb/wgsl/#extractBits-unsigned-builtin
1430                        //
1431                        // w = sizeof(x) * 8
1432                        // o = min(offset, w)
1433                        // tmp = w - o
1434                        // c = min(count, tmp)
1435                        //
1436                        // bitfieldExtract(x, o, c)
1437
1438                        let bit_width = arg_ty.scalar_width().unwrap() * 8;
1439                        let width_constant = self
1440                            .writer
1441                            .get_constant_scalar(crate::Literal::U32(bit_width as u32));
1442
1443                        let u32_type =
1444                            self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
1445
1446                        // o = min(offset, w)
1447                        let offset_id = self.gen_id();
1448                        block.body.push(Instruction::ext_inst_gl_op(
1449                            self.writer.gl450_ext_inst_id,
1450                            spirv::GLOp::UMin,
1451                            u32_type,
1452                            offset_id,
1453                            &[arg1_id, width_constant],
1454                        ));
1455
1456                        // tmp = w - o
1457                        let max_count_id = self.gen_id();
1458                        block.body.push(Instruction::binary(
1459                            spirv::Op::ISub,
1460                            u32_type,
1461                            max_count_id,
1462                            width_constant,
1463                            offset_id,
1464                        ));
1465
1466                        // c = min(count, tmp)
1467                        let count_id = self.gen_id();
1468                        block.body.push(Instruction::ext_inst_gl_op(
1469                            self.writer.gl450_ext_inst_id,
1470                            spirv::GLOp::UMin,
1471                            u32_type,
1472                            count_id,
1473                            &[arg2_id, max_count_id],
1474                        ));
1475
1476                        MathOp::Custom(Instruction::ternary(
1477                            op,
1478                            result_type_id,
1479                            id,
1480                            arg0_id,
1481                            offset_id,
1482                            count_id,
1483                        ))
1484                    }
1485                    Mf::InsertBits => {
1486                        // The behavior of InsertBits has the same undefined behavior as ExtractBits.
1487
1488                        let bit_width = arg_ty.scalar_width().unwrap() * 8;
1489                        let width_constant = self
1490                            .writer
1491                            .get_constant_scalar(crate::Literal::U32(bit_width as u32));
1492
1493                        let u32_type =
1494                            self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
1495
1496                        // o = min(offset, w)
1497                        let offset_id = self.gen_id();
1498                        block.body.push(Instruction::ext_inst_gl_op(
1499                            self.writer.gl450_ext_inst_id,
1500                            spirv::GLOp::UMin,
1501                            u32_type,
1502                            offset_id,
1503                            &[arg2_id, width_constant],
1504                        ));
1505
1506                        // tmp = w - o
1507                        let max_count_id = self.gen_id();
1508                        block.body.push(Instruction::binary(
1509                            spirv::Op::ISub,
1510                            u32_type,
1511                            max_count_id,
1512                            width_constant,
1513                            offset_id,
1514                        ));
1515
1516                        // c = min(count, tmp)
1517                        let count_id = self.gen_id();
1518                        block.body.push(Instruction::ext_inst_gl_op(
1519                            self.writer.gl450_ext_inst_id,
1520                            spirv::GLOp::UMin,
1521                            u32_type,
1522                            count_id,
1523                            &[arg3_id, max_count_id],
1524                        ));
1525
1526                        MathOp::Custom(Instruction::quaternary(
1527                            spirv::Op::BitFieldInsert,
1528                            result_type_id,
1529                            id,
1530                            arg0_id,
1531                            arg1_id,
1532                            offset_id,
1533                            count_id,
1534                        ))
1535                    }
1536                    Mf::FirstTrailingBit => MathOp::Ext(spirv::GLOp::FindILsb),
1537                    Mf::FirstLeadingBit => {
1538                        if arg_ty.scalar_width() == Some(4) {
1539                            let thing = match arg_scalar_kind {
1540                                Some(crate::ScalarKind::Uint) => spirv::GLOp::FindUMsb,
1541                                Some(crate::ScalarKind::Sint) => spirv::GLOp::FindSMsb,
1542                                other => unimplemented!("Unexpected firstLeadingBit({:?})", other),
1543                            };
1544                            MathOp::Ext(thing)
1545                        } else {
1546                            unreachable!("This is validated out until a polyfill is implemented. https://github.com/gfx-rs/wgpu/issues/5276");
1547                        }
1548                    }
1549                    Mf::Pack4x8unorm => MathOp::Ext(spirv::GLOp::PackUnorm4x8),
1550                    Mf::Pack4x8snorm => MathOp::Ext(spirv::GLOp::PackSnorm4x8),
1551                    Mf::Pack2x16float => MathOp::Ext(spirv::GLOp::PackHalf2x16),
1552                    Mf::Pack2x16unorm => MathOp::Ext(spirv::GLOp::PackUnorm2x16),
1553                    Mf::Pack2x16snorm => MathOp::Ext(spirv::GLOp::PackSnorm2x16),
1554                    fun @ (Mf::Pack4xI8 | Mf::Pack4xU8 | Mf::Pack4xI8Clamp | Mf::Pack4xU8Clamp) => {
1555                        let is_signed = matches!(fun, Mf::Pack4xI8 | Mf::Pack4xI8Clamp);
1556                        let should_clamp = matches!(fun, Mf::Pack4xI8Clamp | Mf::Pack4xU8Clamp);
1557
1558                        let last_instruction =
1559                            if self.writer.require_all(&[spirv::Capability::Int8]).is_ok() {
1560                                self.write_pack4x8_optimized(
1561                                    block,
1562                                    result_type_id,
1563                                    arg0_id,
1564                                    id,
1565                                    is_signed,
1566                                    should_clamp,
1567                                )
1568                            } else {
1569                                self.write_pack4x8_polyfill(
1570                                    block,
1571                                    result_type_id,
1572                                    arg0_id,
1573                                    id,
1574                                    is_signed,
1575                                    should_clamp,
1576                                )
1577                            };
1578
1579                        MathOp::Custom(last_instruction)
1580                    }
1581                    Mf::Unpack4x8unorm => MathOp::Ext(spirv::GLOp::UnpackUnorm4x8),
1582                    Mf::Unpack4x8snorm => MathOp::Ext(spirv::GLOp::UnpackSnorm4x8),
1583                    Mf::Unpack2x16float => MathOp::Ext(spirv::GLOp::UnpackHalf2x16),
1584                    Mf::Unpack2x16unorm => MathOp::Ext(spirv::GLOp::UnpackUnorm2x16),
1585                    Mf::Unpack2x16snorm => MathOp::Ext(spirv::GLOp::UnpackSnorm2x16),
1586                    fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => {
1587                        let is_signed = matches!(fun, Mf::Unpack4xI8);
1588
1589                        let last_instruction =
1590                            if self.writer.require_all(&[spirv::Capability::Int8]).is_ok() {
1591                                self.write_unpack4x8_optimized(
1592                                    block,
1593                                    result_type_id,
1594                                    arg0_id,
1595                                    id,
1596                                    is_signed,
1597                                )
1598                            } else {
1599                                self.write_unpack4x8_polyfill(
1600                                    block,
1601                                    result_type_id,
1602                                    arg0_id,
1603                                    id,
1604                                    is_signed,
1605                                )
1606                            };
1607
1608                        MathOp::Custom(last_instruction)
1609                    }
1610                };
1611
1612                block.body.push(match math_op {
1613                    MathOp::Ext(op) => Instruction::ext_inst_gl_op(
1614                        self.writer.gl450_ext_inst_id,
1615                        op,
1616                        result_type_id,
1617                        id,
1618                        &[arg0_id, arg1_id, arg2_id, arg3_id][..fun.argument_count()],
1619                    ),
1620                    MathOp::Custom(inst) => inst,
1621                });
1622                id
1623            }
1624            crate::Expression::LocalVariable(variable) => {
1625                if let Some(rq_tracker) = self
1626                    .function
1627                    .ray_query_initialization_tracker_variables
1628                    .get(&variable)
1629                {
1630                    self.ray_query_tracker_expr.insert(
1631                        expr_handle,
1632                        super::RayQueryTrackers {
1633                            initialized_tracker: rq_tracker.id,
1634                            t_max_tracker: self
1635                                .function
1636                                .ray_query_t_max_tracker_variables
1637                                .get(&variable)
1638                                .expect("Both trackers are set at the same time.")
1639                                .id,
1640                        },
1641                    );
1642                }
1643                self.function.variables[&variable].id
1644            }
1645            crate::Expression::Load { pointer } => {
1646                self.write_checked_load(pointer, block, AccessTypeAdjustment::None, result_type_id)?
1647            }
1648            crate::Expression::FunctionArgument(index) => self.function.parameter_id(index),
1649            crate::Expression::CallResult(_)
1650            | crate::Expression::AtomicResult { .. }
1651            | crate::Expression::WorkGroupUniformLoadResult { .. }
1652            | crate::Expression::RayQueryProceedResult
1653            | crate::Expression::SubgroupBallotResult
1654            | crate::Expression::SubgroupOperationResult { .. } => self.cached[expr_handle],
1655            crate::Expression::As {
1656                expr,
1657                kind,
1658                convert,
1659            } => self.write_as_expression(expr, convert, kind, block, result_type_id)?,
1660            crate::Expression::ImageLoad {
1661                image,
1662                coordinate,
1663                array_index,
1664                sample,
1665                level,
1666            } => self.write_image_load(
1667                result_type_id,
1668                image,
1669                coordinate,
1670                array_index,
1671                level,
1672                sample,
1673                block,
1674            )?,
1675            crate::Expression::ImageSample {
1676                image,
1677                sampler,
1678                gather,
1679                coordinate,
1680                array_index,
1681                offset,
1682                level,
1683                depth_ref,
1684                clamp_to_edge,
1685            } => self.write_image_sample(
1686                result_type_id,
1687                image,
1688                sampler,
1689                gather,
1690                coordinate,
1691                array_index,
1692                offset,
1693                level,
1694                depth_ref,
1695                clamp_to_edge,
1696                block,
1697            )?,
1698            crate::Expression::Select {
1699                condition,
1700                accept,
1701                reject,
1702            } => {
1703                let id = self.gen_id();
1704                let mut condition_id = self.cached[condition];
1705                let accept_id = self.cached[accept];
1706                let reject_id = self.cached[reject];
1707
1708                let condition_ty = self.fun_info[condition]
1709                    .ty
1710                    .inner_with(&self.ir_module.types);
1711                let object_ty = self.fun_info[accept].ty.inner_with(&self.ir_module.types);
1712
1713                if let (
1714                    &crate::TypeInner::Scalar(
1715                        condition_scalar @ crate::Scalar {
1716                            kind: crate::ScalarKind::Bool,
1717                            ..
1718                        },
1719                    ),
1720                    &crate::TypeInner::Vector { size, .. },
1721                ) = (condition_ty, object_ty)
1722                {
1723                    self.temp_list.clear();
1724                    self.temp_list.resize(size as usize, condition_id);
1725
1726                    let bool_vector_type_id = self.get_numeric_type_id(NumericType::Vector {
1727                        size,
1728                        scalar: condition_scalar,
1729                    });
1730
1731                    let id = self.gen_id();
1732                    block.body.push(Instruction::composite_construct(
1733                        bool_vector_type_id,
1734                        id,
1735                        &self.temp_list,
1736                    ));
1737                    condition_id = id
1738                }
1739
1740                let instruction =
1741                    Instruction::select(result_type_id, id, condition_id, accept_id, reject_id);
1742                block.body.push(instruction);
1743                id
1744            }
1745            crate::Expression::Derivative { axis, ctrl, expr } => {
1746                use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
1747                match ctrl {
1748                    Ctrl::Coarse | Ctrl::Fine => {
1749                        self.writer.require_any(
1750                            "DerivativeControl",
1751                            &[spirv::Capability::DerivativeControl],
1752                        )?;
1753                    }
1754                    Ctrl::None => {}
1755                }
1756                let id = self.gen_id();
1757                let expr_id = self.cached[expr];
1758                let op = match (axis, ctrl) {
1759                    (Axis::X, Ctrl::Coarse) => spirv::Op::DPdxCoarse,
1760                    (Axis::X, Ctrl::Fine) => spirv::Op::DPdxFine,
1761                    (Axis::X, Ctrl::None) => spirv::Op::DPdx,
1762                    (Axis::Y, Ctrl::Coarse) => spirv::Op::DPdyCoarse,
1763                    (Axis::Y, Ctrl::Fine) => spirv::Op::DPdyFine,
1764                    (Axis::Y, Ctrl::None) => spirv::Op::DPdy,
1765                    (Axis::Width, Ctrl::Coarse) => spirv::Op::FwidthCoarse,
1766                    (Axis::Width, Ctrl::Fine) => spirv::Op::FwidthFine,
1767                    (Axis::Width, Ctrl::None) => spirv::Op::Fwidth,
1768                };
1769                block
1770                    .body
1771                    .push(Instruction::derivative(op, result_type_id, id, expr_id));
1772                id
1773            }
1774            crate::Expression::ImageQuery { image, query } => {
1775                self.write_image_query(result_type_id, image, query, block)?
1776            }
1777            crate::Expression::Relational { fun, argument } => {
1778                use crate::RelationalFunction as Rf;
1779                let arg_id = self.cached[argument];
1780                let op = match fun {
1781                    Rf::All => spirv::Op::All,
1782                    Rf::Any => spirv::Op::Any,
1783                    Rf::IsNan => spirv::Op::IsNan,
1784                    Rf::IsInf => spirv::Op::IsInf,
1785                };
1786                let id = self.gen_id();
1787                block
1788                    .body
1789                    .push(Instruction::relational(op, result_type_id, id, arg_id));
1790                id
1791            }
1792            crate::Expression::ArrayLength(expr) => self.write_runtime_array_length(expr, block)?,
1793            crate::Expression::RayQueryGetIntersection { query, committed } => {
1794                let query_id = self.cached[query];
1795                let init_tracker_id = *self
1796                    .ray_query_tracker_expr
1797                    .get(&query)
1798                    .expect("not a cached ray query");
1799                let func_id = self
1800                    .writer
1801                    .write_ray_query_get_intersection_function(committed, self.ir_module);
1802                let ray_intersection = self.ir_module.special_types.ray_intersection.unwrap();
1803                let intersection_type_id = self.get_handle_type_id(ray_intersection);
1804                let id = self.gen_id();
1805                block.body.push(Instruction::function_call(
1806                    intersection_type_id,
1807                    id,
1808                    func_id,
1809                    &[query_id, init_tracker_id.initialized_tracker],
1810                ));
1811                id
1812            }
1813            crate::Expression::RayQueryVertexPositions { query, committed } => {
1814                self.writer.require_any(
1815                    "RayQueryVertexPositions",
1816                    &[spirv::Capability::RayQueryPositionFetchKHR],
1817                )?;
1818                self.write_ray_query_return_vertex_position(query, block, committed)
1819            }
1820        };
1821
1822        self.cached[expr_handle] = id;
1823        Ok(())
1824    }
1825
1826    /// Helper which focuses on generating the `As` expressions and the various conversions
1827    /// that need to happen because of that.
1828    fn write_as_expression(
1829        &mut self,
1830        expr: Handle<crate::Expression>,
1831        convert: Option<u8>,
1832        kind: crate::ScalarKind,
1833
1834        block: &mut Block,
1835        result_type_id: u32,
1836    ) -> Result<u32, Error> {
1837        use crate::ScalarKind as Sk;
1838        let expr_id = self.cached[expr];
1839        let ty = self.fun_info[expr].ty.inner_with(&self.ir_module.types);
1840
1841        // Matrix casts needs special treatment in SPIR-V, as the cast functions
1842        // can take vectors or scalars, but not matrices. In order to cast a matrix
1843        // we need to cast each column of the matrix individually and construct a new
1844        // matrix from the converted columns.
1845        if let crate::TypeInner::Matrix {
1846            columns,
1847            rows,
1848            scalar,
1849        } = *ty
1850        {
1851            let Some(convert) = convert else {
1852                // No conversion needs to be done, passes through.
1853                return Ok(expr_id);
1854            };
1855
1856            if convert == scalar.width {
1857                // No conversion needs to be done, passes through.
1858                return Ok(expr_id);
1859            }
1860
1861            if kind != Sk::Float {
1862                // Only float conversions are supported for matrices.
1863                return Err(Error::Validation("Matrices must be floats"));
1864            }
1865
1866            // Type of each extracted column
1867            let column_src_ty =
1868                self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector {
1869                    size: rows,
1870                    scalar,
1871                })));
1872
1873            // Type of the column after conversion
1874            let column_dst_ty =
1875                self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector {
1876                    size: rows,
1877                    scalar: crate::Scalar {
1878                        kind,
1879                        width: convert,
1880                    },
1881                })));
1882
1883            let mut components = ArrayVec::<Word, 4>::new();
1884
1885            for column in 0..columns as usize {
1886                let column_id = self.gen_id();
1887                block.body.push(Instruction::composite_extract(
1888                    column_src_ty,
1889                    column_id,
1890                    expr_id,
1891                    &[column as u32],
1892                ));
1893
1894                let column_conv_id = self.gen_id();
1895                block.body.push(Instruction::unary(
1896                    spirv::Op::FConvert,
1897                    column_dst_ty,
1898                    column_conv_id,
1899                    column_id,
1900                ));
1901
1902                components.push(column_conv_id);
1903            }
1904
1905            let construct_id = self.gen_id();
1906
1907            block.body.push(Instruction::composite_construct(
1908                result_type_id,
1909                construct_id,
1910                &components,
1911            ));
1912
1913            return Ok(construct_id);
1914        }
1915
1916        let (src_scalar, src_size) = match *ty {
1917            crate::TypeInner::Scalar(scalar) => (scalar, None),
1918            crate::TypeInner::Vector { scalar, size } => (scalar, Some(size)),
1919            ref other => {
1920                log::error!("As source {other:?}");
1921                return Err(Error::Validation("Unexpected Expression::As source"));
1922            }
1923        };
1924
1925        enum Cast {
1926            Identity(Word),
1927            Unary(spirv::Op, Word),
1928            Binary(spirv::Op, Word, Word),
1929            Ternary(spirv::Op, Word, Word, Word),
1930        }
1931        let cast = match (src_scalar.kind, kind, convert) {
1932            // Filter out identity casts. Some Adreno drivers are
1933            // confused by no-op OpBitCast instructions.
1934            (src_kind, kind, convert)
1935                if src_kind == kind
1936                    && convert.filter(|&width| width != src_scalar.width).is_none() =>
1937            {
1938                Cast::Identity(expr_id)
1939            }
1940            (Sk::Bool, Sk::Bool, _) => Cast::Unary(spirv::Op::CopyObject, expr_id),
1941            (_, _, None) => Cast::Unary(spirv::Op::Bitcast, expr_id),
1942            // casting to a bool - generate `OpXxxNotEqual`
1943            (_, Sk::Bool, Some(_)) => {
1944                let op = match src_scalar.kind {
1945                    Sk::Sint | Sk::Uint => spirv::Op::INotEqual,
1946                    Sk::Float => spirv::Op::FUnordNotEqual,
1947                    Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => unreachable!(),
1948                };
1949                let zero_scalar_id = self.writer.get_constant_scalar_with(0, src_scalar)?;
1950                let zero_id = match src_size {
1951                    Some(size) => {
1952                        let ty = LocalType::Numeric(NumericType::Vector {
1953                            size,
1954                            scalar: src_scalar,
1955                        })
1956                        .into();
1957
1958                        self.temp_list.clear();
1959                        self.temp_list.resize(size as _, zero_scalar_id);
1960
1961                        self.writer.get_constant_composite(ty, &self.temp_list)
1962                    }
1963                    None => zero_scalar_id,
1964                };
1965
1966                Cast::Binary(op, expr_id, zero_id)
1967            }
1968            // casting from a bool - generate `OpSelect`
1969            (Sk::Bool, _, Some(dst_width)) => {
1970                let dst_scalar = crate::Scalar {
1971                    kind,
1972                    width: dst_width,
1973                };
1974                let zero_scalar_id = self.writer.get_constant_scalar_with(0, dst_scalar)?;
1975                let one_scalar_id = self.writer.get_constant_scalar_with(1, dst_scalar)?;
1976                let (accept_id, reject_id) = match src_size {
1977                    Some(size) => {
1978                        let ty = LocalType::Numeric(NumericType::Vector {
1979                            size,
1980                            scalar: dst_scalar,
1981                        })
1982                        .into();
1983
1984                        self.temp_list.clear();
1985                        self.temp_list.resize(size as _, zero_scalar_id);
1986
1987                        let vec0_id = self.writer.get_constant_composite(ty, &self.temp_list);
1988
1989                        self.temp_list.fill(one_scalar_id);
1990
1991                        let vec1_id = self.writer.get_constant_composite(ty, &self.temp_list);
1992
1993                        (vec1_id, vec0_id)
1994                    }
1995                    None => (one_scalar_id, zero_scalar_id),
1996                };
1997
1998                Cast::Ternary(spirv::Op::Select, expr_id, accept_id, reject_id)
1999            }
2000            // Avoid undefined behaviour when casting from a float to integer
2001            // when the value is out of range for the target type. Additionally
2002            // ensure we clamp to the correct value as per the WGSL spec.
2003            //
2004            // https://www.w3.org/TR/WGSL/#floating-point-conversion:
2005            // * If X is exactly representable in the target type T, then the
2006            //   result is that value.
2007            // * Otherwise, the result is the value in T closest to
2008            //   truncate(X) and also exactly representable in the original
2009            //   floating point type.
2010            (Sk::Float, Sk::Sint | Sk::Uint, Some(width)) => {
2011                let dst_scalar = crate::Scalar { kind, width };
2012                let (min, max) =
2013                    crate::proc::min_max_float_representable_by(src_scalar, dst_scalar);
2014                let expr_type_id = self.get_expression_type_id(&self.fun_info[expr].ty);
2015
2016                let maybe_splat_const = |writer: &mut Writer, const_id| match src_size {
2017                    None => const_id,
2018                    Some(size) => {
2019                        let constituent_ids = [const_id; crate::VectorSize::MAX];
2020                        writer.get_constant_composite(
2021                            LookupType::Local(LocalType::Numeric(NumericType::Vector {
2022                                size,
2023                                scalar: src_scalar,
2024                            })),
2025                            &constituent_ids[..size as usize],
2026                        )
2027                    }
2028                };
2029                let min_const_id = self.writer.get_constant_scalar(min);
2030                let min_const_id = maybe_splat_const(self.writer, min_const_id);
2031                let max_const_id = self.writer.get_constant_scalar(max);
2032                let max_const_id = maybe_splat_const(self.writer, max_const_id);
2033
2034                let clamp_id = self.gen_id();
2035                block.body.push(Instruction::ext_inst_gl_op(
2036                    self.writer.gl450_ext_inst_id,
2037                    spirv::GLOp::FClamp,
2038                    expr_type_id,
2039                    clamp_id,
2040                    &[expr_id, min_const_id, max_const_id],
2041                ));
2042
2043                let op = match dst_scalar.kind {
2044                    crate::ScalarKind::Sint => spirv::Op::ConvertFToS,
2045                    crate::ScalarKind::Uint => spirv::Op::ConvertFToU,
2046                    _ => unreachable!(),
2047                };
2048                Cast::Unary(op, clamp_id)
2049            }
2050            (Sk::Float, Sk::Float, Some(dst_width)) if src_scalar.width != dst_width => {
2051                Cast::Unary(spirv::Op::FConvert, expr_id)
2052            }
2053            (Sk::Sint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertSToF, expr_id),
2054            (Sk::Sint, Sk::Sint, Some(dst_width)) if src_scalar.width != dst_width => {
2055                Cast::Unary(spirv::Op::SConvert, expr_id)
2056            }
2057            (Sk::Uint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertUToF, expr_id),
2058            (Sk::Uint, Sk::Uint, Some(dst_width)) if src_scalar.width != dst_width => {
2059                Cast::Unary(spirv::Op::UConvert, expr_id)
2060            }
2061            (Sk::Uint, Sk::Sint, Some(dst_width)) if src_scalar.width != dst_width => {
2062                Cast::Unary(spirv::Op::SConvert, expr_id)
2063            }
2064            (Sk::Sint, Sk::Uint, Some(dst_width)) if src_scalar.width != dst_width => {
2065                Cast::Unary(spirv::Op::UConvert, expr_id)
2066            }
2067            // We assume it's either an identity cast, or int-uint.
2068            _ => Cast::Unary(spirv::Op::Bitcast, expr_id),
2069        };
2070        Ok(match cast {
2071            Cast::Identity(expr) => expr,
2072            Cast::Unary(op, op1) => {
2073                let id = self.gen_id();
2074                block
2075                    .body
2076                    .push(Instruction::unary(op, result_type_id, id, op1));
2077                id
2078            }
2079            Cast::Binary(op, op1, op2) => {
2080                let id = self.gen_id();
2081                block
2082                    .body
2083                    .push(Instruction::binary(op, result_type_id, id, op1, op2));
2084                id
2085            }
2086            Cast::Ternary(op, op1, op2, op3) => {
2087                let id = self.gen_id();
2088                block
2089                    .body
2090                    .push(Instruction::ternary(op, result_type_id, id, op1, op2, op3));
2091                id
2092            }
2093        })
2094    }
2095
2096    /// Build an `OpAccessChain` instruction.
2097    ///
2098    /// Emit any needed bounds-checking expressions to `block`.
2099    ///
2100    /// Give the `OpAccessChain` a result type based on `expr_handle`, adjusted
2101    /// according to `type_adjustment`; see the documentation for
2102    /// [`AccessTypeAdjustment`] for details.
2103    ///
2104    /// On success, the return value is an [`ExpressionPointer`] value; see the
2105    /// documentation for that type.
2106    fn write_access_chain(
2107        &mut self,
2108        mut expr_handle: Handle<crate::Expression>,
2109        block: &mut Block,
2110        type_adjustment: AccessTypeAdjustment,
2111    ) -> Result<ExpressionPointer, Error> {
2112        let result_type_id = {
2113            let resolution = &self.fun_info[expr_handle].ty;
2114            match type_adjustment {
2115                AccessTypeAdjustment::None => self.writer.get_expression_type_id(resolution),
2116                AccessTypeAdjustment::IntroducePointer(class) => {
2117                    self.writer.get_resolution_pointer_id(resolution, class)
2118                }
2119            }
2120        };
2121
2122        // The id of the boolean `and` of all dynamic bounds checks up to this point.
2123        //
2124        // See `extend_bounds_check_condition_chain` for a full explanation.
2125        let mut accumulated_checks = None;
2126
2127        // Is true if we are accessing into a binding array with a non-uniform index.
2128        let mut is_non_uniform_binding_array = false;
2129
2130        self.temp_list.clear();
2131        let root_id = loop {
2132            // If `expr_handle` was spilled, then the temporary variable has exactly
2133            // the value we want to start from.
2134            if let Some(spilled) = self.function.spilled_composites.get(&expr_handle) {
2135                // The root id of the `OpAccessChain` instruction is the temporary
2136                // variable we spilled the composite to.
2137                break spilled.id;
2138            }
2139
2140            expr_handle = match self.ir_function.expressions[expr_handle] {
2141                crate::Expression::Access { base, index } => {
2142                    is_non_uniform_binding_array |=
2143                        self.is_nonuniform_binding_array_access(base, index);
2144
2145                    let index = GuardedIndex::Expression(index);
2146                    let index_id =
2147                        self.write_access_chain_index(base, index, &mut accumulated_checks, block)?;
2148                    self.temp_list.push(index_id);
2149
2150                    base
2151                }
2152                crate::Expression::AccessIndex { base, index } => {
2153                    // Decide whether we're indexing a struct (bounds checks
2154                    // forbidden) or anything else (bounds checks required).
2155                    let mut base_ty = self.fun_info[base].ty.inner_with(&self.ir_module.types);
2156                    if let crate::TypeInner::Pointer { base, .. } = *base_ty {
2157                        base_ty = &self.ir_module.types[base].inner;
2158                    }
2159                    let index_id = if let crate::TypeInner::Struct { .. } = *base_ty {
2160                        self.get_index_constant(index)
2161                    } else {
2162                        // `index` is constant, so this can't possibly require
2163                        // setting `is_nonuniform_binding_array_access`.
2164
2165                        // Even though the index value is statically known, `base`
2166                        // may be a runtime-sized array, so we still need to go
2167                        // through the bounds check process.
2168                        self.write_access_chain_index(
2169                            base,
2170                            GuardedIndex::Known(index),
2171                            &mut accumulated_checks,
2172                            block,
2173                        )?
2174                    };
2175
2176                    self.temp_list.push(index_id);
2177                    base
2178                }
2179                crate::Expression::GlobalVariable(handle) => {
2180                    let gv = &self.writer.global_variables[handle];
2181                    break gv.access_id;
2182                }
2183                crate::Expression::LocalVariable(variable) => {
2184                    let local_var = &self.function.variables[&variable];
2185                    break local_var.id;
2186                }
2187                crate::Expression::FunctionArgument(index) => {
2188                    break self.function.parameter_id(index);
2189                }
2190                ref other => unimplemented!("Unexpected pointer expression {:?}", other),
2191            }
2192        };
2193
2194        let (pointer_id, expr_pointer) = if self.temp_list.is_empty() {
2195            (
2196                root_id,
2197                ExpressionPointer::Ready {
2198                    pointer_id: root_id,
2199                },
2200            )
2201        } else {
2202            self.temp_list.reverse();
2203            let pointer_id = self.gen_id();
2204            let access =
2205                Instruction::access_chain(result_type_id, pointer_id, root_id, &self.temp_list);
2206
2207            // If we generated some bounds checks, we need to leave it to our
2208            // caller to generate the branch, the access, the load or store, and
2209            // the zero value (for loads). Otherwise, we can emit the access
2210            // ourselves, and just hand them the id of the pointer.
2211            let expr_pointer = match accumulated_checks {
2212                Some(condition) => ExpressionPointer::Conditional { condition, access },
2213                None => {
2214                    block.body.push(access);
2215                    ExpressionPointer::Ready { pointer_id }
2216                }
2217            };
2218            (pointer_id, expr_pointer)
2219        };
2220        // Subsequent load, store and atomic operations require the pointer to be decorated as NonUniform
2221        // if the binding array was accessed with a non-uniform index
2222        // see VUID-RuntimeSpirv-NonUniform-06274
2223        if is_non_uniform_binding_array {
2224            self.writer
2225                .decorate_non_uniform_binding_array_access(pointer_id)?;
2226        }
2227
2228        Ok(expr_pointer)
2229    }
2230
2231    fn is_nonuniform_binding_array_access(
2232        &mut self,
2233        base: Handle<crate::Expression>,
2234        index: Handle<crate::Expression>,
2235    ) -> bool {
2236        let crate::Expression::GlobalVariable(var_handle) = self.ir_function.expressions[base]
2237        else {
2238            return false;
2239        };
2240
2241        // The access chain needs to be decorated as NonUniform
2242        // see VUID-RuntimeSpirv-NonUniform-06274
2243        let gvar = &self.ir_module.global_variables[var_handle];
2244        let crate::TypeInner::BindingArray { .. } = self.ir_module.types[gvar.ty].inner else {
2245            return false;
2246        };
2247
2248        self.fun_info[index].uniformity.non_uniform_result.is_some()
2249    }
2250
2251    /// Compute a single index operand to an `OpAccessChain` instruction.
2252    ///
2253    /// Given that we are indexing `base` with `index`, apply the appropriate
2254    /// bounds check policies, emitting code to `block` to clamp `index` or
2255    /// determine whether it's in bounds. Return the SPIR-V instruction id of
2256    /// the index value we should actually use.
2257    ///
2258    /// Extend `accumulated_checks` to include the results of any needed bounds
2259    /// checks. See [`BlockContext::extend_bounds_check_condition_chain`].
2260    fn write_access_chain_index(
2261        &mut self,
2262        base: Handle<crate::Expression>,
2263        index: GuardedIndex,
2264        accumulated_checks: &mut Option<Word>,
2265        block: &mut Block,
2266    ) -> Result<Word, Error> {
2267        match self.write_bounds_check(base, index, block)? {
2268            BoundsCheckResult::KnownInBounds(known_index) => {
2269                // Even if the index is known, `OpAccessChain`
2270                // requires expression operands, not literals.
2271                let scalar = crate::Literal::U32(known_index);
2272                Ok(self.writer.get_constant_scalar(scalar))
2273            }
2274            BoundsCheckResult::Computed(computed_index_id) => Ok(computed_index_id),
2275            BoundsCheckResult::Conditional {
2276                condition_id: condition,
2277                index_id: index,
2278            } => {
2279                self.extend_bounds_check_condition_chain(accumulated_checks, condition, block);
2280
2281                // Use the index from the `Access` expression unchanged.
2282                Ok(index)
2283            }
2284        }
2285    }
2286
2287    /// Add a condition to a chain of bounds checks.
2288    ///
2289    /// As we build an `OpAccessChain` instruction govered by
2290    /// [`BoundsCheckPolicy::ReadZeroSkipWrite`], we accumulate a chain of
2291    /// dynamic bounds checks, one for each index in the chain, which must all
2292    /// be true for that `OpAccessChain`'s execution to be well-defined. This
2293    /// function adds the boolean instruction id `comparison_id` to `chain`.
2294    ///
2295    /// If `chain` is `None`, that means there are no bounds checks in the chain
2296    /// yet. If chain is `Some(id)`, then `id` is the conjunction of all the
2297    /// bounds checks in the chain.
2298    ///
2299    /// When we have multiple bounds checks, we combine them with
2300    /// `OpLogicalAnd`, not a short-circuit branch. This means we might do
2301    /// comparisons we don't need to, but we expect these checks to almost
2302    /// always succeed, and keeping branches to a minimum is essential.
2303    ///
2304    /// [`BoundsCheckPolicy::ReadZeroSkipWrite`]: crate::proc::BoundsCheckPolicy
2305    fn extend_bounds_check_condition_chain(
2306        &mut self,
2307        chain: &mut Option<Word>,
2308        comparison_id: Word,
2309        block: &mut Block,
2310    ) {
2311        match *chain {
2312            Some(ref mut prior_checks) => {
2313                let combined = self.gen_id();
2314                block.body.push(Instruction::binary(
2315                    spirv::Op::LogicalAnd,
2316                    self.writer.get_bool_type_id(),
2317                    combined,
2318                    *prior_checks,
2319                    comparison_id,
2320                ));
2321                *prior_checks = combined;
2322            }
2323            None => {
2324                // Start a fresh chain of checks.
2325                *chain = Some(comparison_id);
2326            }
2327        }
2328    }
2329
2330    fn write_checked_load(
2331        &mut self,
2332        pointer: Handle<crate::Expression>,
2333        block: &mut Block,
2334        access_type_adjustment: AccessTypeAdjustment,
2335        result_type_id: Word,
2336    ) -> Result<Word, Error> {
2337        match self.write_access_chain(pointer, block, access_type_adjustment)? {
2338            ExpressionPointer::Ready { pointer_id } => {
2339                let id = self.gen_id();
2340                let atomic_space =
2341                    match *self.fun_info[pointer].ty.inner_with(&self.ir_module.types) {
2342                        crate::TypeInner::Pointer { base, space } => {
2343                            match self.ir_module.types[base].inner {
2344                                crate::TypeInner::Atomic { .. } => Some(space),
2345                                _ => None,
2346                            }
2347                        }
2348                        _ => None,
2349                    };
2350                let instruction = if let Some(space) = atomic_space {
2351                    let (semantics, scope) = space.to_spirv_semantics_and_scope();
2352                    let scope_constant_id = self.get_scope_constant(scope as u32);
2353                    let semantics_id = self.get_index_constant(semantics.bits());
2354                    Instruction::atomic_load(
2355                        result_type_id,
2356                        id,
2357                        pointer_id,
2358                        scope_constant_id,
2359                        semantics_id,
2360                    )
2361                } else {
2362                    Instruction::load(result_type_id, id, pointer_id, None)
2363                };
2364                block.body.push(instruction);
2365                Ok(id)
2366            }
2367            ExpressionPointer::Conditional { condition, access } => {
2368                //TODO: support atomics?
2369                let value = self.write_conditional_indexed_load(
2370                    result_type_id,
2371                    condition,
2372                    block,
2373                    move |id_gen, block| {
2374                        // The in-bounds path. Perform the access and the load.
2375                        let pointer_id = access.result_id.unwrap();
2376                        let value_id = id_gen.next();
2377                        block.body.push(access);
2378                        block.body.push(Instruction::load(
2379                            result_type_id,
2380                            value_id,
2381                            pointer_id,
2382                            None,
2383                        ));
2384                        value_id
2385                    },
2386                );
2387                Ok(value)
2388            }
2389        }
2390    }
2391
2392    fn spill_to_internal_variable(&mut self, base: Handle<crate::Expression>, block: &mut Block) {
2393        use indexmap::map::Entry;
2394
2395        // Make sure we have an internal variable to spill `base` to.
2396        let spill_variable_id = match self.function.spilled_composites.entry(base) {
2397            Entry::Occupied(preexisting) => preexisting.get().id,
2398            Entry::Vacant(vacant) => {
2399                // Generate a new internal variable of the appropriate
2400                // type for `base`.
2401                let pointer_type_id = self.writer.get_resolution_pointer_id(
2402                    &self.fun_info[base].ty,
2403                    spirv::StorageClass::Function,
2404                );
2405                let id = self.writer.id_gen.next();
2406                vacant.insert(super::LocalVariable {
2407                    id,
2408                    instruction: Instruction::variable(
2409                        pointer_type_id,
2410                        id,
2411                        spirv::StorageClass::Function,
2412                        None,
2413                    ),
2414                });
2415                id
2416            }
2417        };
2418
2419        // Perform the store even if we already had a spill variable for `base`.
2420        // Consider this code:
2421        //
2422        // var x = ...;
2423        // var y = ...;
2424        // var z = ...;
2425        // for (i = 0; i<2; i++) {
2426        //     let a = array(i, i, i);
2427        //     if (i == 0) {
2428        //         x += a[y];
2429        //     } else [
2430        //         x += a[z];
2431        //     }
2432        // }
2433        //
2434        // The value of `a` needs to be spilled so we can subscript it with `y` and `z`.
2435        //
2436        // When we generate SPIR-V for `a[y]`, we will create the spill
2437        // variable, and store `a`'s value in it.
2438        //
2439        // When we generate SPIR-V for `a[z]`, we will notice that the spill
2440        // variable for `a` has already been declared, but it is still essential
2441        // that we store `a` into it, so that `a[z]` sees this iteration's value
2442        // of `a`.
2443        let base_id = self.cached[base];
2444        block
2445            .body
2446            .push(Instruction::store(spill_variable_id, base_id, None));
2447    }
2448
2449    /// Generate an access to a spilled temporary, if necessary.
2450    ///
2451    /// Given `access`, an [`Access`] or [`AccessIndex`] expression that refers
2452    /// to a component of a composite value that has been spilled to a temporary
2453    /// variable, determine whether other expressions are going to use
2454    /// `access`'s value:
2455    ///
2456    /// - If so, perform the access and cache that as the value of `access`.
2457    ///
2458    /// - Otherwise, generate no code and cache no value for `access`.
2459    ///
2460    /// Return `Ok(0)` if no value was fetched, or `Ok(id)` if we loaded it into
2461    /// the instruction given by `id`.
2462    ///
2463    /// [`Access`]: crate::Expression::Access
2464    /// [`AccessIndex`]: crate::Expression::AccessIndex
2465    fn maybe_access_spilled_composite(
2466        &mut self,
2467        access: Handle<crate::Expression>,
2468        block: &mut Block,
2469        result_type_id: Word,
2470    ) -> Result<Word, Error> {
2471        let access_uses = self.function.access_uses.get(&access).map_or(0, |r| *r);
2472        if access_uses == self.fun_info[access].ref_count {
2473            // This expression is only used by other `Access` and
2474            // `AccessIndex` expressions, so we don't need to cache a
2475            // value for it yet.
2476            Ok(0)
2477        } else {
2478            // There are other expressions that are going to expect this
2479            // expression's value to be cached, not just other `Access` or
2480            // `AccessIndex` expressions. We must actually perform the
2481            // access on the spill variable now.
2482            self.write_checked_load(
2483                access,
2484                block,
2485                AccessTypeAdjustment::IntroducePointer(spirv::StorageClass::Function),
2486                result_type_id,
2487            )
2488        }
2489    }
2490
2491    /// Build the instructions for matrix - matrix column operations
2492    #[allow(clippy::too_many_arguments)]
2493    fn write_matrix_matrix_column_op(
2494        &mut self,
2495        block: &mut Block,
2496        result_id: Word,
2497        result_type_id: Word,
2498        left_id: Word,
2499        right_id: Word,
2500        columns: crate::VectorSize,
2501        rows: crate::VectorSize,
2502        width: u8,
2503        op: spirv::Op,
2504    ) {
2505        self.temp_list.clear();
2506
2507        let vector_type_id = self.get_numeric_type_id(NumericType::Vector {
2508            size: rows,
2509            scalar: crate::Scalar::float(width),
2510        });
2511
2512        for index in 0..columns as u32 {
2513            let column_id_left = self.gen_id();
2514            let column_id_right = self.gen_id();
2515            let column_id_res = self.gen_id();
2516
2517            block.body.push(Instruction::composite_extract(
2518                vector_type_id,
2519                column_id_left,
2520                left_id,
2521                &[index],
2522            ));
2523            block.body.push(Instruction::composite_extract(
2524                vector_type_id,
2525                column_id_right,
2526                right_id,
2527                &[index],
2528            ));
2529            block.body.push(Instruction::binary(
2530                op,
2531                vector_type_id,
2532                column_id_res,
2533                column_id_left,
2534                column_id_right,
2535            ));
2536
2537            self.temp_list.push(column_id_res);
2538        }
2539
2540        block.body.push(Instruction::composite_construct(
2541            result_type_id,
2542            result_id,
2543            &self.temp_list,
2544        ));
2545    }
2546
2547    /// Build the instructions for vector - scalar multiplication
2548    fn write_vector_scalar_mult(
2549        &mut self,
2550        block: &mut Block,
2551        result_id: Word,
2552        result_type_id: Word,
2553        vector_id: Word,
2554        scalar_id: Word,
2555        vector: &crate::TypeInner,
2556    ) {
2557        let (size, kind) = match *vector {
2558            crate::TypeInner::Vector {
2559                size,
2560                scalar: crate::Scalar { kind, .. },
2561            } => (size, kind),
2562            _ => unreachable!(),
2563        };
2564
2565        let (op, operand_id) = match kind {
2566            crate::ScalarKind::Float => (spirv::Op::VectorTimesScalar, scalar_id),
2567            _ => {
2568                let operand_id = self.gen_id();
2569                self.temp_list.clear();
2570                self.temp_list.resize(size as usize, scalar_id);
2571                block.body.push(Instruction::composite_construct(
2572                    result_type_id,
2573                    operand_id,
2574                    &self.temp_list,
2575                ));
2576                (spirv::Op::IMul, operand_id)
2577            }
2578        };
2579
2580        block.body.push(Instruction::binary(
2581            op,
2582            result_type_id,
2583            result_id,
2584            vector_id,
2585            operand_id,
2586        ));
2587    }
2588
2589    /// Build the instructions for the arithmetic expression of a dot product
2590    ///
2591    /// The argument `extractor` is a function that maps `(result_id,
2592    /// composite_id, index)` to an instruction that extracts the `index`th
2593    /// entry of the value with ID `composite_id` and assigns it to the slot
2594    /// with id `result_id` (which must have type `result_type_id`).
2595    #[expect(clippy::too_many_arguments)]
2596    fn write_dot_product(
2597        &mut self,
2598        result_id: Word,
2599        result_type_id: Word,
2600        arg0_id: Word,
2601        arg1_id: Word,
2602        size: u32,
2603        block: &mut Block,
2604        extractor: impl Fn(Word, Word, Word) -> Instruction,
2605    ) {
2606        let mut partial_sum = self.writer.get_constant_null(result_type_id);
2607        let last_component = size - 1;
2608        for index in 0..=last_component {
2609            // compute the product of the current components
2610            let a_id = self.gen_id();
2611            block.body.push(extractor(a_id, arg0_id, index));
2612            let b_id = self.gen_id();
2613            block.body.push(extractor(b_id, arg1_id, index));
2614            let prod_id = self.gen_id();
2615            block.body.push(Instruction::binary(
2616                spirv::Op::IMul,
2617                result_type_id,
2618                prod_id,
2619                a_id,
2620                b_id,
2621            ));
2622
2623            // choose the id for the next sum, depending on current index
2624            let id = if index == last_component {
2625                result_id
2626            } else {
2627                self.gen_id()
2628            };
2629
2630            // sum the computed product with the partial sum
2631            block.body.push(Instruction::binary(
2632                spirv::Op::IAdd,
2633                result_type_id,
2634                id,
2635                partial_sum,
2636                prod_id,
2637            ));
2638            // set the id of the result as the previous partial sum
2639            partial_sum = id;
2640        }
2641    }
2642
2643    /// Emit code for `pack4x{I,U}8[Clamp]` if capability "Int8" is available.
2644    fn write_pack4x8_optimized(
2645        &mut self,
2646        block: &mut Block,
2647        result_type_id: u32,
2648        arg0_id: u32,
2649        id: u32,
2650        is_signed: bool,
2651        should_clamp: bool,
2652    ) -> Instruction {
2653        let int_type = if is_signed {
2654            crate::ScalarKind::Sint
2655        } else {
2656            crate::ScalarKind::Uint
2657        };
2658        let wide_vector_type = NumericType::Vector {
2659            size: crate::VectorSize::Quad,
2660            scalar: crate::Scalar {
2661                kind: int_type,
2662                width: 4,
2663            },
2664        };
2665        let wide_vector_type_id = self.get_numeric_type_id(wide_vector_type);
2666        let packed_vector_type_id = self.get_numeric_type_id(NumericType::Vector {
2667            size: crate::VectorSize::Quad,
2668            scalar: crate::Scalar {
2669                kind: crate::ScalarKind::Uint,
2670                width: 1,
2671            },
2672        });
2673
2674        let mut wide_vector = arg0_id;
2675        if should_clamp {
2676            let (min, max, clamp_op) = if is_signed {
2677                (
2678                    crate::Literal::I32(-128),
2679                    crate::Literal::I32(127),
2680                    spirv::GLOp::SClamp,
2681                )
2682            } else {
2683                (
2684                    crate::Literal::U32(0),
2685                    crate::Literal::U32(255),
2686                    spirv::GLOp::UClamp,
2687                )
2688            };
2689            let [min, max] = [min, max].map(|lit| {
2690                let scalar = self.writer.get_constant_scalar(lit);
2691                self.writer.get_constant_composite(
2692                    LookupType::Local(LocalType::Numeric(wide_vector_type)),
2693                    &[scalar; 4],
2694                )
2695            });
2696
2697            let clamp_id = self.gen_id();
2698            block.body.push(Instruction::ext_inst_gl_op(
2699                self.writer.gl450_ext_inst_id,
2700                clamp_op,
2701                wide_vector_type_id,
2702                clamp_id,
2703                &[wide_vector, min, max],
2704            ));
2705
2706            wide_vector = clamp_id;
2707        }
2708
2709        let packed_vector = self.gen_id();
2710        block.body.push(Instruction::unary(
2711            spirv::Op::UConvert, // We truncate, so `UConvert` and `SConvert` behave identically.
2712            packed_vector_type_id,
2713            packed_vector,
2714            wide_vector,
2715        ));
2716
2717        // The SPIR-V spec [1] defines the bit order for bit casting between a vector
2718        // and a scalar precisely as required by the WGSL spec [2].
2719        // [1]: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpBitcast
2720        // [2]: https://www.w3.org/TR/WGSL/#pack4xI8-builtin
2721        Instruction::unary(spirv::Op::Bitcast, result_type_id, id, packed_vector)
2722    }
2723
2724    /// Emit code for `pack4x{I,U}8[Clamp]` if capability "Int8" is not available.
2725    fn write_pack4x8_polyfill(
2726        &mut self,
2727        block: &mut Block,
2728        result_type_id: u32,
2729        arg0_id: u32,
2730        id: u32,
2731        is_signed: bool,
2732        should_clamp: bool,
2733    ) -> Instruction {
2734        let int_type = if is_signed {
2735            crate::ScalarKind::Sint
2736        } else {
2737            crate::ScalarKind::Uint
2738        };
2739        let uint_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
2740        let int_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar {
2741            kind: int_type,
2742            width: 4,
2743        }));
2744
2745        let mut last_instruction = Instruction::new(spirv::Op::Nop);
2746
2747        let zero = self.writer.get_constant_scalar(crate::Literal::U32(0));
2748        let mut preresult = zero;
2749        block
2750            .body
2751            .reserve(usize::from(VEC_LENGTH) * (2 + usize::from(is_signed)));
2752
2753        let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
2754        const VEC_LENGTH: u8 = 4;
2755        for i in 0..u32::from(VEC_LENGTH) {
2756            let offset = self.writer.get_constant_scalar(crate::Literal::U32(i * 8));
2757            let mut extracted = self.gen_id();
2758            block.body.push(Instruction::binary(
2759                spirv::Op::CompositeExtract,
2760                int_type_id,
2761                extracted,
2762                arg0_id,
2763                i,
2764            ));
2765            if is_signed {
2766                let casted = self.gen_id();
2767                block.body.push(Instruction::unary(
2768                    spirv::Op::Bitcast,
2769                    uint_type_id,
2770                    casted,
2771                    extracted,
2772                ));
2773                extracted = casted;
2774            }
2775            if should_clamp {
2776                let (min, max, clamp_op) = if is_signed {
2777                    (
2778                        crate::Literal::I32(-128),
2779                        crate::Literal::I32(127),
2780                        spirv::GLOp::SClamp,
2781                    )
2782                } else {
2783                    (
2784                        crate::Literal::U32(0),
2785                        crate::Literal::U32(255),
2786                        spirv::GLOp::UClamp,
2787                    )
2788                };
2789                let [min, max] = [min, max].map(|lit| self.writer.get_constant_scalar(lit));
2790
2791                let clamp_id = self.gen_id();
2792                block.body.push(Instruction::ext_inst_gl_op(
2793                    self.writer.gl450_ext_inst_id,
2794                    clamp_op,
2795                    result_type_id,
2796                    clamp_id,
2797                    &[extracted, min, max],
2798                ));
2799
2800                extracted = clamp_id;
2801            }
2802            let is_last = i == u32::from(VEC_LENGTH - 1);
2803            if is_last {
2804                last_instruction = Instruction::quaternary(
2805                    spirv::Op::BitFieldInsert,
2806                    result_type_id,
2807                    id,
2808                    preresult,
2809                    extracted,
2810                    offset,
2811                    eight,
2812                )
2813            } else {
2814                let new_preresult = self.gen_id();
2815                block.body.push(Instruction::quaternary(
2816                    spirv::Op::BitFieldInsert,
2817                    result_type_id,
2818                    new_preresult,
2819                    preresult,
2820                    extracted,
2821                    offset,
2822                    eight,
2823                ));
2824                preresult = new_preresult;
2825            }
2826        }
2827        last_instruction
2828    }
2829
2830    /// Emit code for `unpack4x{I,U}8` if capability "Int8" is available.
2831    fn write_unpack4x8_optimized(
2832        &mut self,
2833        block: &mut Block,
2834        result_type_id: u32,
2835        arg0_id: u32,
2836        id: u32,
2837        is_signed: bool,
2838    ) -> Instruction {
2839        let (int_type, convert_op) = if is_signed {
2840            (crate::ScalarKind::Sint, spirv::Op::SConvert)
2841        } else {
2842            (crate::ScalarKind::Uint, spirv::Op::UConvert)
2843        };
2844
2845        let packed_vector_type_id = self.get_numeric_type_id(NumericType::Vector {
2846            size: crate::VectorSize::Quad,
2847            scalar: crate::Scalar {
2848                kind: int_type,
2849                width: 1,
2850            },
2851        });
2852
2853        // The SPIR-V spec [1] defines the bit order for bit casting between a vector
2854        // and a scalar precisely as required by the WGSL spec [2].
2855        // [1]: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpBitcast
2856        // [2]: https://www.w3.org/TR/WGSL/#pack4xI8-builtin
2857        let packed_vector = self.gen_id();
2858        block.body.push(Instruction::unary(
2859            spirv::Op::Bitcast,
2860            packed_vector_type_id,
2861            packed_vector,
2862            arg0_id,
2863        ));
2864
2865        Instruction::unary(convert_op, result_type_id, id, packed_vector)
2866    }
2867
2868    /// Emit code for `unpack4x{I,U}8` if capability "Int8" is not available.
2869    fn write_unpack4x8_polyfill(
2870        &mut self,
2871        block: &mut Block,
2872        result_type_id: u32,
2873        arg0_id: u32,
2874        id: u32,
2875        is_signed: bool,
2876    ) -> Instruction {
2877        let (int_type, extract_op) = if is_signed {
2878            (crate::ScalarKind::Sint, spirv::Op::BitFieldSExtract)
2879        } else {
2880            (crate::ScalarKind::Uint, spirv::Op::BitFieldUExtract)
2881        };
2882
2883        let sint_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::I32));
2884
2885        let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
2886        let int_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar {
2887            kind: int_type,
2888            width: 4,
2889        }));
2890        block
2891            .body
2892            .reserve(usize::from(VEC_LENGTH) * 2 + usize::from(is_signed));
2893        let arg_id = if is_signed {
2894            let new_arg_id = self.gen_id();
2895            block.body.push(Instruction::unary(
2896                spirv::Op::Bitcast,
2897                sint_type_id,
2898                new_arg_id,
2899                arg0_id,
2900            ));
2901            new_arg_id
2902        } else {
2903            arg0_id
2904        };
2905
2906        const VEC_LENGTH: u8 = 4;
2907        let parts: [_; VEC_LENGTH as usize] = core::array::from_fn(|_| self.gen_id());
2908        for (i, part_id) in parts.into_iter().enumerate() {
2909            let index = self
2910                .writer
2911                .get_constant_scalar(crate::Literal::U32(i as u32 * 8));
2912            block.body.push(Instruction::ternary(
2913                extract_op,
2914                int_type_id,
2915                part_id,
2916                arg_id,
2917                index,
2918                eight,
2919            ));
2920        }
2921
2922        Instruction::composite_construct(result_type_id, id, &parts)
2923    }
2924
2925    /// Generate one or more SPIR-V blocks for `naga_block`.
2926    ///
2927    /// Use `label_id` as the label for the SPIR-V entry point block.
2928    ///
2929    /// If control reaches the end of the SPIR-V block, terminate it according
2930    /// to `exit`. This function's return value indicates whether it acted on
2931    /// this parameter or not; see [`BlockExitDisposition`].
2932    ///
2933    /// If the block contains [`Break`] or [`Continue`] statements,
2934    /// `loop_context` supplies the labels of the SPIR-V blocks to jump to. If
2935    /// either of these labels are `None`, then it should have been a Naga
2936    /// validation error for the corresponding statement to occur in this
2937    /// context.
2938    ///
2939    /// [`Break`]: Statement::Break
2940    /// [`Continue`]: Statement::Continue
2941    fn write_block(
2942        &mut self,
2943        label_id: Word,
2944        naga_block: &crate::Block,
2945        exit: BlockExit,
2946        loop_context: LoopContext,
2947        debug_info: Option<&DebugInfoInner>,
2948    ) -> Result<BlockExitDisposition, Error> {
2949        let mut block = Block::new(label_id);
2950        for (statement, span) in naga_block.span_iter() {
2951            if let (Some(debug_info), false) = (
2952                debug_info,
2953                matches!(
2954                    statement,
2955                    &(Statement::Block(..)
2956                        | Statement::Break
2957                        | Statement::Continue
2958                        | Statement::Kill
2959                        | Statement::Return { .. }
2960                        | Statement::Loop { .. })
2961                ),
2962            ) {
2963                let loc: crate::SourceLocation = span.location(debug_info.source_code);
2964                block.body.push(Instruction::line(
2965                    debug_info.source_file_id,
2966                    loc.line_number,
2967                    loc.line_position,
2968                ));
2969            };
2970            match *statement {
2971                Statement::Emit(ref range) => {
2972                    for handle in range.clone() {
2973                        // omit const expressions as we've already cached those
2974                        if !self.expression_constness.is_const(handle) {
2975                            self.cache_expression_value(handle, &mut block)?;
2976                        }
2977                    }
2978                }
2979                Statement::Block(ref block_statements) => {
2980                    let scope_id = self.gen_id();
2981                    self.function.consume(block, Instruction::branch(scope_id));
2982
2983                    let merge_id = self.gen_id();
2984                    let merge_used = self.write_block(
2985                        scope_id,
2986                        block_statements,
2987                        BlockExit::Branch { target: merge_id },
2988                        loop_context,
2989                        debug_info,
2990                    )?;
2991
2992                    match merge_used {
2993                        BlockExitDisposition::Used => {
2994                            block = Block::new(merge_id);
2995                        }
2996                        BlockExitDisposition::Discarded => {
2997                            return Ok(BlockExitDisposition::Discarded);
2998                        }
2999                    }
3000                }
3001                Statement::If {
3002                    condition,
3003                    ref accept,
3004                    ref reject,
3005                } => {
3006                    // In spirv 1.6, in a conditional branch the two block ids
3007                    // of the branches can't have the same label. If `accept`
3008                    // and `reject` are both empty (e.g. in `if (condition) {}`)
3009                    // merge id will be both labels. Because both branches are
3010                    // empty, we can skip the if statement.
3011                    if !(accept.is_empty() && reject.is_empty()) {
3012                        let condition_id = self.cached[condition];
3013
3014                        let merge_id = self.gen_id();
3015                        block.body.push(Instruction::selection_merge(
3016                            merge_id,
3017                            spirv::SelectionControl::NONE,
3018                        ));
3019
3020                        let accept_id = if accept.is_empty() {
3021                            None
3022                        } else {
3023                            Some(self.gen_id())
3024                        };
3025                        let reject_id = if reject.is_empty() {
3026                            None
3027                        } else {
3028                            Some(self.gen_id())
3029                        };
3030
3031                        self.function.consume(
3032                            block,
3033                            Instruction::branch_conditional(
3034                                condition_id,
3035                                accept_id.unwrap_or(merge_id),
3036                                reject_id.unwrap_or(merge_id),
3037                            ),
3038                        );
3039
3040                        if let Some(block_id) = accept_id {
3041                            // We can ignore the `BlockExitDisposition` returned here because,
3042                            // even if `merge_id` is not actually reachable, it is always
3043                            // referred to by the `OpSelectionMerge` instruction we emitted
3044                            // earlier.
3045                            let _ = self.write_block(
3046                                block_id,
3047                                accept,
3048                                BlockExit::Branch { target: merge_id },
3049                                loop_context,
3050                                debug_info,
3051                            )?;
3052                        }
3053                        if let Some(block_id) = reject_id {
3054                            // We can ignore the `BlockExitDisposition` returned here because,
3055                            // even if `merge_id` is not actually reachable, it is always
3056                            // referred to by the `OpSelectionMerge` instruction we emitted
3057                            // earlier.
3058                            let _ = self.write_block(
3059                                block_id,
3060                                reject,
3061                                BlockExit::Branch { target: merge_id },
3062                                loop_context,
3063                                debug_info,
3064                            )?;
3065                        }
3066
3067                        block = Block::new(merge_id);
3068                    }
3069                }
3070                Statement::Switch {
3071                    selector,
3072                    ref cases,
3073                } => {
3074                    let selector_id = self.cached[selector];
3075
3076                    let merge_id = self.gen_id();
3077                    block.body.push(Instruction::selection_merge(
3078                        merge_id,
3079                        spirv::SelectionControl::NONE,
3080                    ));
3081
3082                    let mut default_id = None;
3083                    // id of previous empty fall-through case
3084                    let mut last_id = None;
3085
3086                    let mut raw_cases = Vec::with_capacity(cases.len());
3087                    let mut case_ids = Vec::with_capacity(cases.len());
3088                    for case in cases.iter() {
3089                        // take id of previous empty fall-through case or generate a new one
3090                        let label_id = last_id.take().unwrap_or_else(|| self.gen_id());
3091
3092                        if case.fall_through && case.body.is_empty() {
3093                            last_id = Some(label_id);
3094                        }
3095
3096                        case_ids.push(label_id);
3097
3098                        match case.value {
3099                            crate::SwitchValue::I32(value) => {
3100                                raw_cases.push(super::instructions::Case {
3101                                    value: value as Word,
3102                                    label_id,
3103                                });
3104                            }
3105                            crate::SwitchValue::U32(value) => {
3106                                raw_cases.push(super::instructions::Case { value, label_id });
3107                            }
3108                            crate::SwitchValue::Default => {
3109                                default_id = Some(label_id);
3110                            }
3111                        }
3112                    }
3113
3114                    let default_id = default_id.unwrap();
3115
3116                    self.function.consume(
3117                        block,
3118                        Instruction::switch(selector_id, default_id, &raw_cases),
3119                    );
3120
3121                    let inner_context = LoopContext {
3122                        break_id: Some(merge_id),
3123                        ..loop_context
3124                    };
3125
3126                    for (i, (case, label_id)) in cases
3127                        .iter()
3128                        .zip(case_ids.iter())
3129                        .filter(|&(case, _)| !(case.fall_through && case.body.is_empty()))
3130                        .enumerate()
3131                    {
3132                        let case_finish_id = if case.fall_through {
3133                            case_ids[i + 1]
3134                        } else {
3135                            merge_id
3136                        };
3137                        // We can ignore the `BlockExitDisposition` returned here because
3138                        // `case_finish_id` is always referred to by either:
3139                        //
3140                        // - the `OpSwitch`, if it's the next case's label for a
3141                        //   fall-through, or
3142                        //
3143                        // - the `OpSelectionMerge`, if it's the switch's overall merge
3144                        //   block because there's no fall-through.
3145                        let _ = self.write_block(
3146                            *label_id,
3147                            &case.body,
3148                            BlockExit::Branch {
3149                                target: case_finish_id,
3150                            },
3151                            inner_context,
3152                            debug_info,
3153                        )?;
3154                    }
3155
3156                    block = Block::new(merge_id);
3157                }
3158                Statement::Loop {
3159                    ref body,
3160                    ref continuing,
3161                    break_if,
3162                } => {
3163                    let preamble_id = self.gen_id();
3164                    self.function
3165                        .consume(block, Instruction::branch(preamble_id));
3166
3167                    let merge_id = self.gen_id();
3168                    let body_id = self.gen_id();
3169                    let continuing_id = self.gen_id();
3170
3171                    // SPIR-V requires the continuing to the `OpLoopMerge`,
3172                    // so we have to start a new block with it.
3173                    block = Block::new(preamble_id);
3174                    // HACK the loop statement is begin with branch instruction,
3175                    // so we need to put `OpLine` debug info before merge instruction
3176                    if let Some(debug_info) = debug_info {
3177                        let loc: crate::SourceLocation = span.location(debug_info.source_code);
3178                        block.body.push(Instruction::line(
3179                            debug_info.source_file_id,
3180                            loc.line_number,
3181                            loc.line_position,
3182                        ))
3183                    }
3184                    block.body.push(Instruction::loop_merge(
3185                        merge_id,
3186                        continuing_id,
3187                        spirv::SelectionControl::NONE,
3188                    ));
3189
3190                    if self.force_loop_bounding {
3191                        block = self.write_force_bounded_loop_instructions(block, merge_id);
3192                    }
3193                    self.function.consume(block, Instruction::branch(body_id));
3194
3195                    // We can ignore the `BlockExitDisposition` returned here because,
3196                    // even if `continuing_id` is not actually reachable, it is always
3197                    // referred to by the `OpLoopMerge` instruction we emitted earlier.
3198                    let _ = self.write_block(
3199                        body_id,
3200                        body,
3201                        BlockExit::Branch {
3202                            target: continuing_id,
3203                        },
3204                        LoopContext {
3205                            continuing_id: Some(continuing_id),
3206                            break_id: Some(merge_id),
3207                        },
3208                        debug_info,
3209                    )?;
3210
3211                    let exit = match break_if {
3212                        Some(condition) => BlockExit::BreakIf {
3213                            condition,
3214                            preamble_id,
3215                        },
3216                        None => BlockExit::Branch {
3217                            target: preamble_id,
3218                        },
3219                    };
3220
3221                    // We can ignore the `BlockExitDisposition` returned here because,
3222                    // even if `merge_id` is not actually reachable, it is always referred
3223                    // to by the `OpLoopMerge` instruction we emitted earlier.
3224                    let _ = self.write_block(
3225                        continuing_id,
3226                        continuing,
3227                        exit,
3228                        LoopContext {
3229                            continuing_id: None,
3230                            break_id: Some(merge_id),
3231                        },
3232                        debug_info,
3233                    )?;
3234
3235                    block = Block::new(merge_id);
3236                }
3237                Statement::Break => {
3238                    self.function
3239                        .consume(block, Instruction::branch(loop_context.break_id.unwrap()));
3240                    return Ok(BlockExitDisposition::Discarded);
3241                }
3242                Statement::Continue => {
3243                    self.function.consume(
3244                        block,
3245                        Instruction::branch(loop_context.continuing_id.unwrap()),
3246                    );
3247                    return Ok(BlockExitDisposition::Discarded);
3248                }
3249                Statement::Return { value: Some(value) } => {
3250                    let value_id = self.cached[value];
3251                    let instruction = match self.function.entry_point_context {
3252                        // If this is an entry point, and we need to return anything,
3253                        // let's instead store the output variables and return `void`.
3254                        Some(ref context) => {
3255                            self.writer.write_entry_point_return(
3256                                value_id,
3257                                self.ir_function.result.as_ref().unwrap(),
3258                                &context.results,
3259                                &mut block.body,
3260                            )?;
3261                            Instruction::return_void()
3262                        }
3263                        None => Instruction::return_value(value_id),
3264                    };
3265                    self.function.consume(block, instruction);
3266                    return Ok(BlockExitDisposition::Discarded);
3267                }
3268                Statement::Return { value: None } => {
3269                    self.function.consume(block, Instruction::return_void());
3270                    return Ok(BlockExitDisposition::Discarded);
3271                }
3272                Statement::Kill => {
3273                    self.function.consume(block, Instruction::kill());
3274                    return Ok(BlockExitDisposition::Discarded);
3275                }
3276                Statement::ControlBarrier(flags) => {
3277                    self.writer.write_control_barrier(flags, &mut block);
3278                }
3279                Statement::MemoryBarrier(flags) => {
3280                    self.writer.write_memory_barrier(flags, &mut block);
3281                }
3282                Statement::Store { pointer, value } => {
3283                    let value_id = self.cached[value];
3284                    match self.write_access_chain(
3285                        pointer,
3286                        &mut block,
3287                        AccessTypeAdjustment::None,
3288                    )? {
3289                        ExpressionPointer::Ready { pointer_id } => {
3290                            let atomic_space = match *self.fun_info[pointer]
3291                                .ty
3292                                .inner_with(&self.ir_module.types)
3293                            {
3294                                crate::TypeInner::Pointer { base, space } => {
3295                                    match self.ir_module.types[base].inner {
3296                                        crate::TypeInner::Atomic { .. } => Some(space),
3297                                        _ => None,
3298                                    }
3299                                }
3300                                _ => None,
3301                            };
3302                            let instruction = if let Some(space) = atomic_space {
3303                                let (semantics, scope) = space.to_spirv_semantics_and_scope();
3304                                let scope_constant_id = self.get_scope_constant(scope as u32);
3305                                let semantics_id = self.get_index_constant(semantics.bits());
3306                                Instruction::atomic_store(
3307                                    pointer_id,
3308                                    scope_constant_id,
3309                                    semantics_id,
3310                                    value_id,
3311                                )
3312                            } else {
3313                                Instruction::store(pointer_id, value_id, None)
3314                            };
3315                            block.body.push(instruction);
3316                        }
3317                        ExpressionPointer::Conditional { condition, access } => {
3318                            let mut selection = Selection::start(&mut block, ());
3319                            selection.if_true(self, condition, ());
3320
3321                            // The in-bounds path. Perform the access and the store.
3322                            let pointer_id = access.result_id.unwrap();
3323                            selection.block().body.push(access);
3324                            selection
3325                                .block()
3326                                .body
3327                                .push(Instruction::store(pointer_id, value_id, None));
3328
3329                            // Finish the in-bounds block and start the merge block. This
3330                            // is the block we'll leave current on return.
3331                            selection.finish(self, ());
3332                        }
3333                    };
3334                }
3335                Statement::ImageStore {
3336                    image,
3337                    coordinate,
3338                    array_index,
3339                    value,
3340                } => self.write_image_store(image, coordinate, array_index, value, &mut block)?,
3341                Statement::Call {
3342                    function: local_function,
3343                    ref arguments,
3344                    result,
3345                } => {
3346                    let id = self.gen_id();
3347                    self.temp_list.clear();
3348                    for &argument in arguments {
3349                        self.temp_list.push(self.cached[argument]);
3350                    }
3351
3352                    let type_id = match result {
3353                        Some(expr) => {
3354                            self.cached[expr] = id;
3355                            self.get_expression_type_id(&self.fun_info[expr].ty)
3356                        }
3357                        None => self.writer.void_type,
3358                    };
3359
3360                    block.body.push(Instruction::function_call(
3361                        type_id,
3362                        id,
3363                        self.writer.lookup_function[&local_function],
3364                        &self.temp_list,
3365                    ));
3366                }
3367                Statement::Atomic {
3368                    pointer,
3369                    ref fun,
3370                    value,
3371                    result,
3372                } => {
3373                    let id = self.gen_id();
3374                    // Compare-and-exchange operations produce a struct result,
3375                    // so use `result`'s type if it is available. For no-result
3376                    // operations, fall back to `value`'s type.
3377                    let result_type_id =
3378                        self.get_expression_type_id(&self.fun_info[result.unwrap_or(value)].ty);
3379
3380                    if let Some(result) = result {
3381                        self.cached[result] = id;
3382                    }
3383
3384                    let pointer_id = match self.write_access_chain(
3385                        pointer,
3386                        &mut block,
3387                        AccessTypeAdjustment::None,
3388                    )? {
3389                        ExpressionPointer::Ready { pointer_id } => pointer_id,
3390                        ExpressionPointer::Conditional { .. } => {
3391                            return Err(Error::FeatureNotImplemented(
3392                                "Atomics out-of-bounds handling",
3393                            ));
3394                        }
3395                    };
3396
3397                    let space = self.fun_info[pointer]
3398                        .ty
3399                        .inner_with(&self.ir_module.types)
3400                        .pointer_space()
3401                        .unwrap();
3402                    let (semantics, scope) = space.to_spirv_semantics_and_scope();
3403                    let scope_constant_id = self.get_scope_constant(scope as u32);
3404                    let semantics_id = self.get_index_constant(semantics.bits());
3405                    let value_id = self.cached[value];
3406                    let value_inner = self.fun_info[value].ty.inner_with(&self.ir_module.types);
3407
3408                    let crate::TypeInner::Scalar(scalar) = *value_inner else {
3409                        return Err(Error::FeatureNotImplemented(
3410                            "Atomics with non-scalar values",
3411                        ));
3412                    };
3413
3414                    let instruction = match *fun {
3415                        crate::AtomicFunction::Add => {
3416                            let spirv_op = match scalar.kind {
3417                                crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3418                                    spirv::Op::AtomicIAdd
3419                                }
3420                                crate::ScalarKind::Float => spirv::Op::AtomicFAddEXT,
3421                                _ => unimplemented!(),
3422                            };
3423                            Instruction::atomic_binary(
3424                                spirv_op,
3425                                result_type_id,
3426                                id,
3427                                pointer_id,
3428                                scope_constant_id,
3429                                semantics_id,
3430                                value_id,
3431                            )
3432                        }
3433                        crate::AtomicFunction::Subtract => {
3434                            let (spirv_op, value_id) = match scalar.kind {
3435                                crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3436                                    (spirv::Op::AtomicISub, value_id)
3437                                }
3438                                crate::ScalarKind::Float => {
3439                                    // HACK: SPIR-V doesn't have a atomic subtraction,
3440                                    // so we add the negated value instead.
3441                                    let neg_result_id = self.gen_id();
3442                                    block.body.push(Instruction::unary(
3443                                        spirv::Op::FNegate,
3444                                        result_type_id,
3445                                        neg_result_id,
3446                                        value_id,
3447                                    ));
3448                                    (spirv::Op::AtomicFAddEXT, neg_result_id)
3449                                }
3450                                _ => unimplemented!(),
3451                            };
3452                            Instruction::atomic_binary(
3453                                spirv_op,
3454                                result_type_id,
3455                                id,
3456                                pointer_id,
3457                                scope_constant_id,
3458                                semantics_id,
3459                                value_id,
3460                            )
3461                        }
3462                        crate::AtomicFunction::And => {
3463                            let spirv_op = match scalar.kind {
3464                                crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3465                                    spirv::Op::AtomicAnd
3466                                }
3467                                _ => unimplemented!(),
3468                            };
3469                            Instruction::atomic_binary(
3470                                spirv_op,
3471                                result_type_id,
3472                                id,
3473                                pointer_id,
3474                                scope_constant_id,
3475                                semantics_id,
3476                                value_id,
3477                            )
3478                        }
3479                        crate::AtomicFunction::InclusiveOr => {
3480                            let spirv_op = match scalar.kind {
3481                                crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3482                                    spirv::Op::AtomicOr
3483                                }
3484                                _ => unimplemented!(),
3485                            };
3486                            Instruction::atomic_binary(
3487                                spirv_op,
3488                                result_type_id,
3489                                id,
3490                                pointer_id,
3491                                scope_constant_id,
3492                                semantics_id,
3493                                value_id,
3494                            )
3495                        }
3496                        crate::AtomicFunction::ExclusiveOr => {
3497                            let spirv_op = match scalar.kind {
3498                                crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3499                                    spirv::Op::AtomicXor
3500                                }
3501                                _ => unimplemented!(),
3502                            };
3503                            Instruction::atomic_binary(
3504                                spirv_op,
3505                                result_type_id,
3506                                id,
3507                                pointer_id,
3508                                scope_constant_id,
3509                                semantics_id,
3510                                value_id,
3511                            )
3512                        }
3513                        crate::AtomicFunction::Min => {
3514                            let spirv_op = match scalar.kind {
3515                                crate::ScalarKind::Sint => spirv::Op::AtomicSMin,
3516                                crate::ScalarKind::Uint => spirv::Op::AtomicUMin,
3517                                _ => unimplemented!(),
3518                            };
3519                            Instruction::atomic_binary(
3520                                spirv_op,
3521                                result_type_id,
3522                                id,
3523                                pointer_id,
3524                                scope_constant_id,
3525                                semantics_id,
3526                                value_id,
3527                            )
3528                        }
3529                        crate::AtomicFunction::Max => {
3530                            let spirv_op = match scalar.kind {
3531                                crate::ScalarKind::Sint => spirv::Op::AtomicSMax,
3532                                crate::ScalarKind::Uint => spirv::Op::AtomicUMax,
3533                                _ => unimplemented!(),
3534                            };
3535                            Instruction::atomic_binary(
3536                                spirv_op,
3537                                result_type_id,
3538                                id,
3539                                pointer_id,
3540                                scope_constant_id,
3541                                semantics_id,
3542                                value_id,
3543                            )
3544                        }
3545                        crate::AtomicFunction::Exchange { compare: None } => {
3546                            Instruction::atomic_binary(
3547                                spirv::Op::AtomicExchange,
3548                                result_type_id,
3549                                id,
3550                                pointer_id,
3551                                scope_constant_id,
3552                                semantics_id,
3553                                value_id,
3554                            )
3555                        }
3556                        crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
3557                            let scalar_type_id =
3558                                self.get_numeric_type_id(NumericType::Scalar(scalar));
3559                            let bool_type_id =
3560                                self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::BOOL));
3561
3562                            let cas_result_id = self.gen_id();
3563                            let equality_result_id = self.gen_id();
3564                            let equality_operator = match scalar.kind {
3565                                crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3566                                    spirv::Op::IEqual
3567                                }
3568                                _ => unimplemented!(),
3569                            };
3570
3571                            let mut cas_instr = Instruction::new(spirv::Op::AtomicCompareExchange);
3572                            cas_instr.set_type(scalar_type_id);
3573                            cas_instr.set_result(cas_result_id);
3574                            cas_instr.add_operand(pointer_id);
3575                            cas_instr.add_operand(scope_constant_id);
3576                            cas_instr.add_operand(semantics_id); // semantics if equal
3577                            cas_instr.add_operand(semantics_id); // semantics if not equal
3578                            cas_instr.add_operand(value_id);
3579                            cas_instr.add_operand(self.cached[cmp]);
3580                            block.body.push(cas_instr);
3581                            block.body.push(Instruction::binary(
3582                                equality_operator,
3583                                bool_type_id,
3584                                equality_result_id,
3585                                cas_result_id,
3586                                self.cached[cmp],
3587                            ));
3588                            Instruction::composite_construct(
3589                                result_type_id,
3590                                id,
3591                                &[cas_result_id, equality_result_id],
3592                            )
3593                        }
3594                    };
3595
3596                    block.body.push(instruction);
3597                }
3598                Statement::ImageAtomic {
3599                    image,
3600                    coordinate,
3601                    array_index,
3602                    fun,
3603                    value,
3604                } => {
3605                    self.write_image_atomic(
3606                        image,
3607                        coordinate,
3608                        array_index,
3609                        fun,
3610                        value,
3611                        &mut block,
3612                    )?;
3613                }
3614                Statement::WorkGroupUniformLoad { pointer, result } => {
3615                    self.writer
3616                        .write_control_barrier(crate::Barrier::WORK_GROUP, &mut block);
3617                    let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty);
3618                    // Embed the body of
3619                    match self.write_access_chain(
3620                        pointer,
3621                        &mut block,
3622                        AccessTypeAdjustment::None,
3623                    )? {
3624                        ExpressionPointer::Ready { pointer_id } => {
3625                            let id = self.gen_id();
3626                            block.body.push(Instruction::load(
3627                                result_type_id,
3628                                id,
3629                                pointer_id,
3630                                None,
3631                            ));
3632                            self.cached[result] = id;
3633                        }
3634                        ExpressionPointer::Conditional { condition, access } => {
3635                            self.cached[result] = self.write_conditional_indexed_load(
3636                                result_type_id,
3637                                condition,
3638                                &mut block,
3639                                move |id_gen, block| {
3640                                    // The in-bounds path. Perform the access and the load.
3641                                    let pointer_id = access.result_id.unwrap();
3642                                    let value_id = id_gen.next();
3643                                    block.body.push(access);
3644                                    block.body.push(Instruction::load(
3645                                        result_type_id,
3646                                        value_id,
3647                                        pointer_id,
3648                                        None,
3649                                    ));
3650                                    value_id
3651                                },
3652                            )
3653                        }
3654                    }
3655                    self.writer
3656                        .write_control_barrier(crate::Barrier::WORK_GROUP, &mut block);
3657                }
3658                Statement::RayQuery { query, ref fun } => {
3659                    self.write_ray_query_function(query, fun, &mut block);
3660                }
3661                Statement::SubgroupBallot {
3662                    result,
3663                    ref predicate,
3664                } => {
3665                    self.write_subgroup_ballot(predicate, result, &mut block)?;
3666                }
3667                Statement::SubgroupCollectiveOperation {
3668                    ref op,
3669                    ref collective_op,
3670                    argument,
3671                    result,
3672                } => {
3673                    self.write_subgroup_operation(op, collective_op, argument, result, &mut block)?;
3674                }
3675                Statement::SubgroupGather {
3676                    ref mode,
3677                    argument,
3678                    result,
3679                } => {
3680                    self.write_subgroup_gather(mode, argument, result, &mut block)?;
3681                }
3682            }
3683        }
3684
3685        let termination = match exit {
3686            // We're generating code for the top-level Block of the function, so we
3687            // need to end it with some kind of return instruction.
3688            BlockExit::Return => match self.ir_function.result {
3689                Some(ref result) if self.function.entry_point_context.is_none() => {
3690                    let type_id = self.get_handle_type_id(result.ty);
3691                    let null_id = self.writer.get_constant_null(type_id);
3692                    Instruction::return_value(null_id)
3693                }
3694                _ => Instruction::return_void(),
3695            },
3696            BlockExit::Branch { target } => Instruction::branch(target),
3697            BlockExit::BreakIf {
3698                condition,
3699                preamble_id,
3700            } => {
3701                let condition_id = self.cached[condition];
3702
3703                Instruction::branch_conditional(
3704                    condition_id,
3705                    loop_context.break_id.unwrap(),
3706                    preamble_id,
3707                )
3708            }
3709        };
3710
3711        self.function.consume(block, termination);
3712        Ok(BlockExitDisposition::Used)
3713    }
3714
3715    pub(super) fn write_function_body(
3716        &mut self,
3717        entry_id: Word,
3718        debug_info: Option<&DebugInfoInner>,
3719    ) -> Result<(), Error> {
3720        // We can ignore the `BlockExitDisposition` returned here because
3721        // `BlockExit::Return` doesn't refer to a block.
3722        let _ = self.write_block(
3723            entry_id,
3724            &self.ir_function.body,
3725            BlockExit::Return,
3726            LoopContext::default(),
3727            debug_info,
3728        )?;
3729
3730        Ok(())
3731    }
3732}