naga/back/spv/
writer.rs

1use alloc::{format, string::String, vec, vec::Vec};
2
3use arrayvec::ArrayVec;
4use hashbrown::hash_map::Entry;
5use spirv::Word;
6
7use super::{
8    block::DebugInfoInner,
9    helpers::{contains_builtin, global_needs_wrapper, map_storage_class},
10    Block, BlockContext, CachedConstant, CachedExpressions, CooperativeType, DebugInfo,
11    EntryPointContext, Error, Function, FunctionArgument, GlobalVariable, IdGenerator, Instruction,
12    LocalImageType, LocalType, LocalVariable, LogicalLayout, LookupFunctionType, LookupType,
13    NumericType, Options, PhysicalLayout, PipelineOptions, ResultMember, Writer, WriterFlags,
14    BITS_PER_BYTE,
15};
16use crate::{
17    arena::{Handle, HandleVec, UniqueArena},
18    back::spv::{
19        helpers::{is_uniform_matcx2_struct_member_access, BindingDecorations},
20        BindingInfo, Std140CompatTypeInfo, WrappedFunction,
21    },
22    common::ForDebugWithTypes as _,
23    proc::{Alignment, TypeResolution},
24    valid::{FunctionInfo, ModuleInfo},
25};
26
27pub struct FunctionInterface<'a> {
28    pub varying_ids: &'a mut Vec<Word>,
29    pub stage: crate::ShaderStage,
30    pub task_payload: Option<Handle<crate::GlobalVariable>>,
31    pub mesh_info: Option<crate::MeshStageInfo>,
32    pub workgroup_size: [u32; 3],
33}
34
35impl Function {
36    pub(super) fn to_words(&self, sink: &mut impl Extend<Word>) {
37        self.signature.as_ref().unwrap().to_words(sink);
38        for argument in self.parameters.iter() {
39            argument.instruction.to_words(sink);
40        }
41        for (index, block) in self.blocks.iter().enumerate() {
42            Instruction::label(block.label_id).to_words(sink);
43            if index == 0 {
44                for local_var in self.variables.values() {
45                    local_var.instruction.to_words(sink);
46                }
47                for local_var in self.ray_query_initialization_tracker_variables.values() {
48                    local_var.instruction.to_words(sink);
49                }
50                for local_var in self.ray_query_t_max_tracker_variables.values() {
51                    local_var.instruction.to_words(sink);
52                }
53                for local_var in self.force_loop_bounding_vars.iter() {
54                    local_var.instruction.to_words(sink);
55                }
56                for internal_var in self.spilled_composites.values() {
57                    internal_var.instruction.to_words(sink);
58                }
59            }
60            for instruction in block.body.iter() {
61                instruction.to_words(sink);
62            }
63        }
64        Instruction::function_end().to_words(sink);
65    }
66}
67
68impl Writer {
69    pub fn new(options: &Options) -> Result<Self, Error> {
70        let (major, minor) = options.lang_version;
71        if major != 1 {
72            return Err(Error::UnsupportedVersion(major, minor));
73        }
74
75        let mut capabilities_used = crate::FastIndexSet::default();
76        capabilities_used.insert(spirv::Capability::Shader);
77
78        let mut id_gen = IdGenerator::default();
79        let gl450_ext_inst_id = id_gen.next();
80        let void_type = id_gen.next();
81
82        Ok(Writer {
83            physical_layout: PhysicalLayout::new(major, minor),
84            logical_layout: LogicalLayout::default(),
85            id_gen,
86            capabilities_available: options.capabilities.clone(),
87            capabilities_used,
88            extensions_used: crate::FastIndexSet::default(),
89            debug_strings: vec![],
90            debugs: vec![],
91            annotations: vec![],
92            flags: options.flags,
93            bounds_check_policies: options.bounds_check_policies,
94            zero_initialize_workgroup_memory: options.zero_initialize_workgroup_memory,
95            force_loop_bounding: options.force_loop_bounding,
96            ray_query_initialization_tracking: options.ray_query_initialization_tracking,
97            use_storage_input_output_16: options.use_storage_input_output_16,
98            void_type,
99            lookup_type: crate::FastHashMap::default(),
100            lookup_function: crate::FastHashMap::default(),
101            lookup_function_type: crate::FastHashMap::default(),
102            wrapped_functions: crate::FastHashMap::default(),
103            constant_ids: HandleVec::new(),
104            cached_constants: crate::FastHashMap::default(),
105            global_variables: HandleVec::new(),
106            std140_compat_uniform_types: crate::FastHashMap::default(),
107            fake_missing_bindings: options.fake_missing_bindings,
108            binding_map: options.binding_map.clone(),
109            saved_cached: CachedExpressions::default(),
110            gl450_ext_inst_id,
111            temp_list: Vec::new(),
112            ray_query_functions: crate::FastHashMap::default(),
113            io_f16_polyfills: super::f16_polyfill::F16IoPolyfill::new(
114                options.use_storage_input_output_16,
115            ),
116            debug_printf: None,
117        })
118    }
119
120    pub fn set_options(&mut self, options: &Options) -> Result<(), Error> {
121        let (major, minor) = options.lang_version;
122        if major != 1 {
123            return Err(Error::UnsupportedVersion(major, minor));
124        }
125        self.physical_layout = PhysicalLayout::new(major, minor);
126        self.capabilities_available = options.capabilities.clone();
127        self.flags = options.flags;
128        self.bounds_check_policies = options.bounds_check_policies;
129        self.zero_initialize_workgroup_memory = options.zero_initialize_workgroup_memory;
130        self.force_loop_bounding = options.force_loop_bounding;
131        self.use_storage_input_output_16 = options.use_storage_input_output_16;
132        self.binding_map = options.binding_map.clone();
133        self.io_f16_polyfills =
134            super::f16_polyfill::F16IoPolyfill::new(options.use_storage_input_output_16);
135        Ok(())
136    }
137
138    /// Returns `(major, minor)` of the SPIR-V language version.
139    pub const fn lang_version(&self) -> (u8, u8) {
140        self.physical_layout.lang_version()
141    }
142
143    /// Reset `Writer` to its initial state, retaining any allocations.
144    ///
145    /// Why not just implement `Recyclable` for `Writer`? By design,
146    /// `Recyclable::recycle` requires ownership of the value, not just
147    /// `&mut`; see the trait documentation. But we need to use this method
148    /// from functions like `Writer::write`, which only have `&mut Writer`.
149    /// Workarounds include unsafe code (`core::ptr::read`, then `write`, ugh)
150    /// or something like a `Default` impl that returns an oddly-initialized
151    /// `Writer`, which is worse.
152    fn reset(&mut self) {
153        use super::recyclable::Recyclable;
154        use core::mem::take;
155
156        let mut id_gen = IdGenerator::default();
157        let gl450_ext_inst_id = id_gen.next();
158        let void_type = id_gen.next();
159
160        // Every field of the old writer that is not determined by the `Options`
161        // passed to `Writer::new` should be reset somehow.
162        let fresh = Writer {
163            // Copied from the old Writer:
164            flags: self.flags,
165            bounds_check_policies: self.bounds_check_policies,
166            zero_initialize_workgroup_memory: self.zero_initialize_workgroup_memory,
167            force_loop_bounding: self.force_loop_bounding,
168            ray_query_initialization_tracking: self.ray_query_initialization_tracking,
169            use_storage_input_output_16: self.use_storage_input_output_16,
170            capabilities_available: take(&mut self.capabilities_available),
171            fake_missing_bindings: self.fake_missing_bindings,
172            binding_map: take(&mut self.binding_map),
173
174            // Initialized afresh:
175            id_gen,
176            void_type,
177            gl450_ext_inst_id,
178
179            // Recycled:
180            capabilities_used: take(&mut self.capabilities_used).recycle(),
181            extensions_used: take(&mut self.extensions_used).recycle(),
182            physical_layout: self.physical_layout.clone().recycle(),
183            logical_layout: take(&mut self.logical_layout).recycle(),
184            debug_strings: take(&mut self.debug_strings).recycle(),
185            debugs: take(&mut self.debugs).recycle(),
186            annotations: take(&mut self.annotations).recycle(),
187            lookup_type: take(&mut self.lookup_type).recycle(),
188            lookup_function: take(&mut self.lookup_function).recycle(),
189            lookup_function_type: take(&mut self.lookup_function_type).recycle(),
190            wrapped_functions: take(&mut self.wrapped_functions).recycle(),
191            constant_ids: take(&mut self.constant_ids).recycle(),
192            cached_constants: take(&mut self.cached_constants).recycle(),
193            global_variables: take(&mut self.global_variables).recycle(),
194            std140_compat_uniform_types: take(&mut self.std140_compat_uniform_types).recycle(),
195            saved_cached: take(&mut self.saved_cached).recycle(),
196            temp_list: take(&mut self.temp_list).recycle(),
197            ray_query_functions: take(&mut self.ray_query_functions).recycle(),
198            io_f16_polyfills: take(&mut self.io_f16_polyfills).recycle(),
199            debug_printf: None,
200        };
201
202        *self = fresh;
203
204        self.capabilities_used.insert(spirv::Capability::Shader);
205    }
206
207    /// Indicate that the code requires any one of the listed capabilities.
208    ///
209    /// If nothing in `capabilities` appears in the available capabilities
210    /// specified in the [`Options`] from which this `Writer` was created,
211    /// return an error. The `what` string is used in the error message to
212    /// explain what provoked the requirement. (If no available capabilities were
213    /// given, assume everything is available.)
214    ///
215    /// The first acceptable capability will be added to this `Writer`'s
216    /// [`capabilities_used`] table, and an `OpCapability` emitted for it in the
217    /// result. For this reason, more specific capabilities should be listed
218    /// before more general.
219    ///
220    /// [`capabilities_used`]: Writer::capabilities_used
221    pub(super) fn require_any(
222        &mut self,
223        what: &'static str,
224        capabilities: &[spirv::Capability],
225    ) -> Result<(), Error> {
226        match *capabilities {
227            [] => Ok(()),
228            [first, ..] => {
229                // Find the first acceptable capability, or return an error if
230                // there is none.
231                let selected = match self.capabilities_available {
232                    None => first,
233                    Some(ref available) => {
234                        match capabilities
235                            .iter()
236                            // need explicit type for hashbrown::HashSet::contains fn call to keep rustc happy
237                            .find(|cap| available.contains::<spirv::Capability>(cap))
238                        {
239                            Some(&cap) => cap,
240                            None => {
241                                return Err(Error::MissingCapabilities(what, capabilities.to_vec()))
242                            }
243                        }
244                    }
245                };
246                self.capabilities_used.insert(selected);
247                Ok(())
248            }
249        }
250    }
251
252    /// Indicate that the code requires all of the listed capabilities.
253    ///
254    /// If all entries of `capabilities` appear in the available capabilities
255    /// specified in the [`Options`] from which this `Writer` was created
256    /// (including the case where [`Options::capabilities`] is `None`), add
257    /// them all to this `Writer`'s [`capabilities_used`] table, and return
258    /// `Ok(())`. If at least one of the listed capabilities is not available,
259    /// do not add anything to the `capabilities_used` table, and return the
260    /// first unavailable requested capability, wrapped in `Err()`.
261    ///
262    /// This method is does not return an [`enum@Error`] in case of failure
263    /// because it may be used in cases where the caller can recover (e.g.,
264    /// with a polyfill) if the requested capabilities are not available. In
265    /// this case, it would be unnecessary work to find *all* the unavailable
266    /// requested capabilities, and to allocate a `Vec` for them, just so we
267    /// could return an [`Error::MissingCapabilities`]).
268    ///
269    /// [`capabilities_used`]: Writer::capabilities_used
270    pub(super) fn require_all(
271        &mut self,
272        capabilities: &[spirv::Capability],
273    ) -> Result<(), spirv::Capability> {
274        if let Some(ref available) = self.capabilities_available {
275            for requested in capabilities {
276                if !available.contains(requested) {
277                    return Err(*requested);
278                }
279            }
280        }
281
282        for requested in capabilities {
283            self.capabilities_used.insert(*requested);
284        }
285
286        Ok(())
287    }
288
289    /// Indicate that the code uses the given extension.
290    pub(super) fn use_extension(&mut self, extension: &'static str) {
291        self.extensions_used.insert(extension);
292    }
293
294    pub(super) fn get_type_id(&mut self, lookup_ty: LookupType) -> Word {
295        match self.lookup_type.entry(lookup_ty) {
296            Entry::Occupied(e) => *e.get(),
297            Entry::Vacant(e) => {
298                let local = match lookup_ty {
299                    LookupType::Handle(_handle) => unreachable!("Handles are populated at start"),
300                    LookupType::Local(local) => local,
301                };
302
303                let id = self.id_gen.next();
304                e.insert(id);
305                self.write_type_declaration_local(id, local);
306                id
307            }
308        }
309    }
310
311    pub(super) fn get_handle_type_id(&mut self, handle: Handle<crate::Type>) -> Word {
312        self.get_type_id(LookupType::Handle(handle))
313    }
314
315    pub(super) fn get_expression_lookup_type(&mut self, tr: &TypeResolution) -> LookupType {
316        match *tr {
317            TypeResolution::Handle(ty_handle) => LookupType::Handle(ty_handle),
318            TypeResolution::Value(ref inner) => {
319                let inner_local_type = self.localtype_from_inner(inner).unwrap();
320                LookupType::Local(inner_local_type)
321            }
322        }
323    }
324
325    pub(super) fn get_expression_type_id(&mut self, tr: &TypeResolution) -> Word {
326        let lookup_ty = self.get_expression_lookup_type(tr);
327        self.get_type_id(lookup_ty)
328    }
329
330    pub(super) fn get_localtype_id(&mut self, local: LocalType) -> Word {
331        self.get_type_id(LookupType::Local(local))
332    }
333
334    pub(super) fn get_pointer_type_id(&mut self, base: Word, class: spirv::StorageClass) -> Word {
335        self.get_type_id(LookupType::Local(LocalType::Pointer { base, class }))
336    }
337
338    pub(super) fn get_handle_pointer_type_id(
339        &mut self,
340        base: Handle<crate::Type>,
341        class: spirv::StorageClass,
342    ) -> Word {
343        let base_id = self.get_handle_type_id(base);
344        self.get_pointer_type_id(base_id, class)
345    }
346
347    pub(super) fn get_ray_query_pointer_id(&mut self) -> Word {
348        let rq_id = self.get_type_id(LookupType::Local(LocalType::RayQuery));
349        self.get_pointer_type_id(rq_id, spirv::StorageClass::Function)
350    }
351
352    /// Return a SPIR-V type for a pointer to `resolution`.
353    ///
354    /// The given `resolution` must be one that we can represent
355    /// either as a `LocalType::Pointer` or `LocalType::LocalPointer`.
356    pub(super) fn get_resolution_pointer_id(
357        &mut self,
358        resolution: &TypeResolution,
359        class: spirv::StorageClass,
360    ) -> Word {
361        let resolution_type_id = self.get_expression_type_id(resolution);
362        self.get_pointer_type_id(resolution_type_id, class)
363    }
364
365    pub(super) fn get_numeric_type_id(&mut self, numeric: NumericType) -> Word {
366        self.get_type_id(LocalType::Numeric(numeric).into())
367    }
368
369    pub(super) fn get_u32_type_id(&mut self) -> Word {
370        self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32))
371    }
372
373    pub(super) fn get_f32_type_id(&mut self) -> Word {
374        self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::F32))
375    }
376
377    pub(super) fn get_vec2u_type_id(&mut self) -> Word {
378        self.get_numeric_type_id(NumericType::Vector {
379            size: crate::VectorSize::Bi,
380            scalar: crate::Scalar::U32,
381        })
382    }
383
384    pub(super) fn get_vec2f_type_id(&mut self) -> Word {
385        self.get_numeric_type_id(NumericType::Vector {
386            size: crate::VectorSize::Bi,
387            scalar: crate::Scalar::F32,
388        })
389    }
390
391    pub(super) fn get_vec3u_type_id(&mut self) -> Word {
392        self.get_numeric_type_id(NumericType::Vector {
393            size: crate::VectorSize::Tri,
394            scalar: crate::Scalar::U32,
395        })
396    }
397
398    pub(super) fn get_f32_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word {
399        let f32_id = self.get_f32_type_id();
400        self.get_pointer_type_id(f32_id, class)
401    }
402
403    pub(super) fn get_vec2u_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word {
404        let vec2u_id = self.get_numeric_type_id(NumericType::Vector {
405            size: crate::VectorSize::Bi,
406            scalar: crate::Scalar::U32,
407        });
408        self.get_pointer_type_id(vec2u_id, class)
409    }
410
411    pub(super) fn get_vec3u_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word {
412        let vec3u_id = self.get_numeric_type_id(NumericType::Vector {
413            size: crate::VectorSize::Tri,
414            scalar: crate::Scalar::U32,
415        });
416        self.get_pointer_type_id(vec3u_id, class)
417    }
418
419    pub(super) fn get_bool_type_id(&mut self) -> Word {
420        self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::BOOL))
421    }
422
423    pub(super) fn get_vec2_bool_type_id(&mut self) -> Word {
424        self.get_numeric_type_id(NumericType::Vector {
425            size: crate::VectorSize::Bi,
426            scalar: crate::Scalar::BOOL,
427        })
428    }
429
430    pub(super) fn get_vec3_bool_type_id(&mut self) -> Word {
431        self.get_numeric_type_id(NumericType::Vector {
432            size: crate::VectorSize::Tri,
433            scalar: crate::Scalar::BOOL,
434        })
435    }
436
437    pub(super) fn decorate(&mut self, id: Word, decoration: spirv::Decoration, operands: &[Word]) {
438        self.annotations
439            .push(Instruction::decorate(id, decoration, operands));
440    }
441
442    /// Return `inner` as a `LocalType`, if that's possible.
443    ///
444    /// If `inner` can be represented as a `LocalType`, return
445    /// `Some(local_type)`.
446    ///
447    /// Otherwise, return `None`. In this case, the type must always be looked
448    /// up using a `LookupType::Handle`.
449    fn localtype_from_inner(&mut self, inner: &crate::TypeInner) -> Option<LocalType> {
450        Some(match *inner {
451            crate::TypeInner::Scalar(_)
452            | crate::TypeInner::Atomic(_)
453            | crate::TypeInner::Vector { .. }
454            | crate::TypeInner::Matrix { .. } => {
455                // We expect `NumericType::from_inner` to handle all
456                // these cases, so unwrap.
457                LocalType::Numeric(NumericType::from_inner(inner).unwrap())
458            }
459            crate::TypeInner::CooperativeMatrix { .. } => {
460                LocalType::Cooperative(CooperativeType::from_inner(inner).unwrap())
461            }
462            crate::TypeInner::Pointer { base, space } => {
463                let base_type_id = self.get_handle_type_id(base);
464                LocalType::Pointer {
465                    base: base_type_id,
466                    class: map_storage_class(space),
467                }
468            }
469            crate::TypeInner::ValuePointer {
470                size,
471                scalar,
472                space,
473            } => {
474                let base_numeric_type = match size {
475                    Some(size) => NumericType::Vector { size, scalar },
476                    None => NumericType::Scalar(scalar),
477                };
478                LocalType::Pointer {
479                    base: self.get_numeric_type_id(base_numeric_type),
480                    class: map_storage_class(space),
481                }
482            }
483            crate::TypeInner::Image {
484                dim,
485                arrayed,
486                class,
487            } => LocalType::Image(LocalImageType::from_inner(dim, arrayed, class)),
488            crate::TypeInner::Sampler { comparison: _ } => LocalType::Sampler,
489            crate::TypeInner::AccelerationStructure { .. } => LocalType::AccelerationStructure,
490            crate::TypeInner::RayQuery { .. } => LocalType::RayQuery,
491            crate::TypeInner::Array { .. }
492            | crate::TypeInner::Struct { .. }
493            | crate::TypeInner::BindingArray { .. } => return None,
494        })
495    }
496
497    /// Resolve the [`BindingInfo`] for a [`crate::ResourceBinding`] from the
498    /// provided [`Writer::binding_map`].
499    ///
500    /// If the specified resource is not present in the binding map this will
501    /// return an error, unless [`Writer::fake_missing_bindings`] is set.
502    fn resolve_resource_binding(
503        &self,
504        res_binding: &crate::ResourceBinding,
505    ) -> Result<BindingInfo, Error> {
506        match self.binding_map.get(res_binding) {
507            Some(target) => Ok(*target),
508            None if self.fake_missing_bindings => Ok(BindingInfo {
509                descriptor_set: res_binding.group,
510                binding: res_binding.binding,
511                binding_array_size: None,
512            }),
513            None => Err(Error::MissingBinding(*res_binding)),
514        }
515    }
516
517    /// Emits code for any wrapper functions required by the expressions in ir_function.
518    /// The IDs of any emitted functions will be stored in [`Self::wrapped_functions`].
519    fn write_wrapped_functions(
520        &mut self,
521        ir_function: &crate::Function,
522        info: &FunctionInfo,
523        ir_module: &crate::Module,
524    ) -> Result<(), Error> {
525        log::trace!("Generating wrapped functions for {:?}", ir_function.name);
526
527        for (expr_handle, expr) in ir_function.expressions.iter() {
528            match *expr {
529                crate::Expression::Binary { op, left, right } => {
530                    let expr_ty_inner = info[expr_handle].ty.inner_with(&ir_module.types);
531                    if let Some(expr_ty) = NumericType::from_inner(expr_ty_inner) {
532                        match (op, expr_ty.scalar().kind) {
533                            // Division and modulo are undefined behaviour when the
534                            // dividend is the minimum representable value and the divisor
535                            // is negative one, or when the divisor is zero. These wrapped
536                            // functions override the divisor to one in these cases,
537                            // matching the WGSL spec.
538                            (
539                                crate::BinaryOperator::Divide | crate::BinaryOperator::Modulo,
540                                crate::ScalarKind::Sint | crate::ScalarKind::Uint,
541                            ) => {
542                                self.write_wrapped_binary_op(
543                                    op,
544                                    expr_ty,
545                                    &info[left].ty,
546                                    &info[right].ty,
547                                )?;
548                            }
549                            _ => {}
550                        }
551                    }
552                }
553                crate::Expression::Load { pointer } => {
554                    if let crate::TypeInner::Pointer {
555                        base: pointer_type,
556                        space: crate::AddressSpace::Uniform,
557                    } = *info[pointer].ty.inner_with(&ir_module.types)
558                    {
559                        if self.std140_compat_uniform_types.contains_key(&pointer_type) {
560                            // Loading a std140 compat type requires the wrapper function
561                            // to convert to the regular type.
562                            self.write_wrapped_convert_from_std140_compat_type(
563                                ir_module,
564                                pointer_type,
565                            )?;
566                        }
567                    }
568                }
569                crate::Expression::Access { base, .. } => {
570                    if let crate::TypeInner::Pointer {
571                        base: base_type,
572                        space: crate::AddressSpace::Uniform,
573                    } = *info[base].ty.inner_with(&ir_module.types)
574                    {
575                        // Dynamic accesses of a two-row matrix's columns require a
576                        // wrapper function.
577                        if let crate::TypeInner::Matrix {
578                            rows: crate::VectorSize::Bi,
579                            ..
580                        } = ir_module.types[base_type].inner
581                        {
582                            self.write_wrapped_matcx2_get_column(ir_module, base_type)?;
583                            // If the matrix is *not* directly a member of a struct, then
584                            // we additionally require a wrapper function to convert from
585                            // the std140 compat type to the regular type.
586                            if !is_uniform_matcx2_struct_member_access(
587                                ir_function,
588                                info,
589                                ir_module,
590                                base,
591                            ) {
592                                self.write_wrapped_convert_from_std140_compat_type(
593                                    ir_module, base_type,
594                                )?;
595                            }
596                        }
597                    }
598                }
599                _ => {}
600            }
601        }
602
603        Ok(())
604    }
605
606    /// Write a SPIR-V function that performs the operator `op` with Naga IR semantics.
607    ///
608    /// Define a function that performs an integer division or modulo operation,
609    /// except that using a divisor of zero or causing signed overflow with a
610    /// divisor of -1 returns the numerator unchanged, rather than exhibiting
611    /// undefined behavior.
612    ///
613    /// Store the generated function's id in the [`wrapped_functions`] table.
614    ///
615    /// The operator `op` must be either [`Divide`] or [`Modulo`].
616    ///
617    /// # Panics
618    ///
619    /// The `return_type`, `left_type` or `right_type` arguments must all be
620    /// integer scalars or vectors. If not, this function panics.
621    ///
622    /// [`wrapped_functions`]: Writer::wrapped_functions
623    /// [`Divide`]: crate::BinaryOperator::Divide
624    /// [`Modulo`]: crate::BinaryOperator::Modulo
625    fn write_wrapped_binary_op(
626        &mut self,
627        op: crate::BinaryOperator,
628        return_type: NumericType,
629        left_type: &TypeResolution,
630        right_type: &TypeResolution,
631    ) -> Result<(), Error> {
632        let return_type_id = self.get_localtype_id(LocalType::Numeric(return_type));
633        let left_type_id = self.get_expression_type_id(left_type);
634        let right_type_id = self.get_expression_type_id(right_type);
635
636        // Check if we've already emitted this function.
637        let wrapped = WrappedFunction::BinaryOp {
638            op,
639            left_type_id,
640            right_type_id,
641        };
642        let function_id = match self.wrapped_functions.entry(wrapped) {
643            Entry::Occupied(_) => return Ok(()),
644            Entry::Vacant(e) => *e.insert(self.id_gen.next()),
645        };
646
647        let scalar = return_type.scalar();
648
649        if self.flags.contains(WriterFlags::DEBUG) {
650            let function_name = match op {
651                crate::BinaryOperator::Divide => "naga_div",
652                crate::BinaryOperator::Modulo => "naga_mod",
653                _ => unreachable!(),
654            };
655            self.debugs
656                .push(Instruction::name(function_id, function_name));
657        }
658        let mut function = Function::default();
659
660        let function_type_id = self.get_function_type(LookupFunctionType {
661            parameter_type_ids: vec![left_type_id, right_type_id],
662            return_type_id,
663        });
664        function.signature = Some(Instruction::function(
665            return_type_id,
666            function_id,
667            spirv::FunctionControl::empty(),
668            function_type_id,
669        ));
670
671        let lhs_id = self.id_gen.next();
672        let rhs_id = self.id_gen.next();
673        if self.flags.contains(WriterFlags::DEBUG) {
674            self.debugs.push(Instruction::name(lhs_id, "lhs"));
675            self.debugs.push(Instruction::name(rhs_id, "rhs"));
676        }
677        let left_par = Instruction::function_parameter(left_type_id, lhs_id);
678        let right_par = Instruction::function_parameter(right_type_id, rhs_id);
679        for instruction in [left_par, right_par] {
680            function.parameters.push(FunctionArgument {
681                instruction,
682                handle_id: 0,
683            });
684        }
685
686        let label_id = self.id_gen.next();
687        let mut block = Block::new(label_id);
688
689        let bool_type = return_type.with_scalar(crate::Scalar::BOOL);
690        let bool_type_id = self.get_numeric_type_id(bool_type);
691
692        let maybe_splat_const = |writer: &mut Self, const_id| match return_type {
693            NumericType::Scalar(_) => const_id,
694            NumericType::Vector { size, .. } => {
695                let constituent_ids = [const_id; crate::VectorSize::MAX];
696                writer.get_constant_composite(
697                    LookupType::Local(LocalType::Numeric(return_type)),
698                    &constituent_ids[..size as usize],
699                )
700            }
701            NumericType::Matrix { .. } => unreachable!(),
702        };
703
704        let const_zero_id = self.get_constant_scalar_with(0, scalar)?;
705        let composite_zero_id = maybe_splat_const(self, const_zero_id);
706        let rhs_eq_zero_id = self.id_gen.next();
707        block.body.push(Instruction::binary(
708            spirv::Op::IEqual,
709            bool_type_id,
710            rhs_eq_zero_id,
711            rhs_id,
712            composite_zero_id,
713        ));
714        let divisor_selector_id = match scalar.kind {
715            crate::ScalarKind::Sint => {
716                let (const_min_id, const_neg_one_id) = match scalar.width {
717                    4 => Ok((
718                        self.get_constant_scalar(crate::Literal::I32(i32::MIN)),
719                        self.get_constant_scalar(crate::Literal::I32(-1i32)),
720                    )),
721                    8 => Ok((
722                        self.get_constant_scalar(crate::Literal::I64(i64::MIN)),
723                        self.get_constant_scalar(crate::Literal::I64(-1i64)),
724                    )),
725                    _ => Err(Error::Validation("Unexpected scalar width")),
726                }?;
727                let composite_min_id = maybe_splat_const(self, const_min_id);
728                let composite_neg_one_id = maybe_splat_const(self, const_neg_one_id);
729
730                let lhs_eq_int_min_id = self.id_gen.next();
731                block.body.push(Instruction::binary(
732                    spirv::Op::IEqual,
733                    bool_type_id,
734                    lhs_eq_int_min_id,
735                    lhs_id,
736                    composite_min_id,
737                ));
738                let rhs_eq_neg_one_id = self.id_gen.next();
739                block.body.push(Instruction::binary(
740                    spirv::Op::IEqual,
741                    bool_type_id,
742                    rhs_eq_neg_one_id,
743                    rhs_id,
744                    composite_neg_one_id,
745                ));
746                let lhs_eq_int_min_and_rhs_eq_neg_one_id = self.id_gen.next();
747                block.body.push(Instruction::binary(
748                    spirv::Op::LogicalAnd,
749                    bool_type_id,
750                    lhs_eq_int_min_and_rhs_eq_neg_one_id,
751                    lhs_eq_int_min_id,
752                    rhs_eq_neg_one_id,
753                ));
754                let rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id = self.id_gen.next();
755                block.body.push(Instruction::binary(
756                    spirv::Op::LogicalOr,
757                    bool_type_id,
758                    rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id,
759                    rhs_eq_zero_id,
760                    lhs_eq_int_min_and_rhs_eq_neg_one_id,
761                ));
762                rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id
763            }
764            crate::ScalarKind::Uint => rhs_eq_zero_id,
765            _ => unreachable!(),
766        };
767
768        let const_one_id = self.get_constant_scalar_with(1, scalar)?;
769        let composite_one_id = maybe_splat_const(self, const_one_id);
770        let divisor_id = self.id_gen.next();
771        block.body.push(Instruction::select(
772            right_type_id,
773            divisor_id,
774            divisor_selector_id,
775            composite_one_id,
776            rhs_id,
777        ));
778        let op = match (op, scalar.kind) {
779            (crate::BinaryOperator::Divide, crate::ScalarKind::Sint) => spirv::Op::SDiv,
780            (crate::BinaryOperator::Divide, crate::ScalarKind::Uint) => spirv::Op::UDiv,
781            (crate::BinaryOperator::Modulo, crate::ScalarKind::Sint) => spirv::Op::SRem,
782            (crate::BinaryOperator::Modulo, crate::ScalarKind::Uint) => spirv::Op::UMod,
783            _ => unreachable!(),
784        };
785        let return_id = self.id_gen.next();
786        block.body.push(Instruction::binary(
787            op,
788            return_type_id,
789            return_id,
790            lhs_id,
791            divisor_id,
792        ));
793
794        function.consume(block, Instruction::return_value(return_id));
795        function.to_words(&mut self.logical_layout.function_definitions);
796        Ok(())
797    }
798
799    /// Writes a wrapper function to convert from a std140 compat type to its
800    /// corresponding regular type.
801    ///
802    /// See [`Self::write_std140_compat_type_declaration`] for more details.
803    fn write_wrapped_convert_from_std140_compat_type(
804        &mut self,
805        ir_module: &crate::Module,
806        r#type: Handle<crate::Type>,
807    ) -> Result<(), Error> {
808        // Check if we've already emitted this function.
809        let wrapped = WrappedFunction::ConvertFromStd140CompatType { r#type };
810        let function_id = match self.wrapped_functions.entry(wrapped) {
811            Entry::Occupied(_) => return Ok(()),
812            Entry::Vacant(e) => *e.insert(self.id_gen.next()),
813        };
814        if self.flags.contains(WriterFlags::DEBUG) {
815            self.debugs.push(Instruction::name(
816                function_id,
817                &format!("{:?}_from_std140", r#type.for_debug(&ir_module.types)),
818            ));
819        }
820        let param_type_id = self.std140_compat_uniform_types[&r#type].type_id;
821        let return_type_id = self.get_handle_type_id(r#type);
822
823        let mut function = Function::default();
824        let function_type_id = self.get_function_type(LookupFunctionType {
825            parameter_type_ids: vec![param_type_id],
826            return_type_id,
827        });
828        function.signature = Some(Instruction::function(
829            return_type_id,
830            function_id,
831            spirv::FunctionControl::empty(),
832            function_type_id,
833        ));
834        let param_id = self.id_gen.next();
835        function.parameters.push(FunctionArgument {
836            instruction: Instruction::function_parameter(param_type_id, param_id),
837            handle_id: 0,
838        });
839
840        let label_id = self.id_gen.next();
841        let mut block = Block::new(label_id);
842
843        let result_id = match ir_module.types[r#type].inner {
844            // Param is struct containing a vector member for each of the
845            // matrix's columns. Extract each column from the struct then
846            // composite into a matrix.
847            crate::TypeInner::Matrix {
848                columns,
849                rows: rows @ crate::VectorSize::Bi,
850                scalar,
851            } => {
852                let column_type_id =
853                    self.get_numeric_type_id(NumericType::Vector { size: rows, scalar });
854
855                let mut column_ids: ArrayVec<Word, 4> = ArrayVec::new();
856                for column in 0..columns as u32 {
857                    let column_id = self.id_gen.next();
858                    block.body.push(Instruction::composite_extract(
859                        column_type_id,
860                        column_id,
861                        param_id,
862                        &[column],
863                    ));
864                    column_ids.push(column_id);
865                }
866                let result_id = self.id_gen.next();
867                block.body.push(Instruction::composite_construct(
868                    return_type_id,
869                    result_id,
870                    &column_ids,
871                ));
872                result_id
873            }
874            // Param is an array where the base type is the std140 compatible
875            // type corresponding to `base`. Iterate through each element and
876            // call its conversion function, then composite into a new array.
877            crate::TypeInner::Array { base, size, .. } => {
878                // Ensure the conversion function for the array's base type is
879                // declared.
880                self.write_wrapped_convert_from_std140_compat_type(ir_module, base)?;
881
882                let element_type_id = self.get_handle_type_id(base);
883                let std140_element_type_id = self.std140_compat_uniform_types[&base].type_id;
884                let element_conversion_function_id = self.wrapped_functions
885                    [&WrappedFunction::ConvertFromStd140CompatType { r#type: base }];
886                let mut element_ids = Vec::new();
887                let size = match size.resolve(ir_module.to_ctx())? {
888                    crate::proc::IndexableLength::Known(size) => size,
889                    crate::proc::IndexableLength::Dynamic => {
890                        return Err(Error::Validation(
891                            "Uniform buffers cannot contain dynamic arrays",
892                        ))
893                    }
894                };
895                for i in 0..size {
896                    let std140_element_id = self.id_gen.next();
897                    block.body.push(Instruction::composite_extract(
898                        std140_element_type_id,
899                        std140_element_id,
900                        param_id,
901                        &[i],
902                    ));
903                    let element_id = self.id_gen.next();
904                    block.body.push(Instruction::function_call(
905                        element_type_id,
906                        element_id,
907                        element_conversion_function_id,
908                        &[std140_element_id],
909                    ));
910                    element_ids.push(element_id);
911                }
912                let result_id = self.id_gen.next();
913                block.body.push(Instruction::composite_construct(
914                    return_type_id,
915                    result_id,
916                    &element_ids,
917                ));
918                result_id
919            }
920            // Param is a struct where each two-row matrix member has been
921            // decomposed in to separate vector members for each column.
922            // Other members use their std140 compatible type if one exists, or
923            // else their regular type. Iterate through each member, converting
924            // or composing any matrices if required, then finally compose into
925            // the struct.
926            crate::TypeInner::Struct { ref members, .. } => {
927                let mut member_ids = Vec::new();
928                let mut next_index = 0;
929                for member in members {
930                    let member_id = self.id_gen.next();
931                    let member_type_id = self.get_handle_type_id(member.ty);
932                    match ir_module.types[member.ty].inner {
933                        crate::TypeInner::Matrix {
934                            columns,
935                            rows: rows @ crate::VectorSize::Bi,
936                            scalar,
937                        } => {
938                            let mut column_ids: ArrayVec<Word, 4> = ArrayVec::new();
939                            let column_type_id = self
940                                .get_numeric_type_id(NumericType::Vector { size: rows, scalar });
941                            for _ in 0..columns as u32 {
942                                let column_id = self.id_gen.next();
943                                block.body.push(Instruction::composite_extract(
944                                    column_type_id,
945                                    column_id,
946                                    param_id,
947                                    &[next_index],
948                                ));
949                                column_ids.push(column_id);
950                                next_index += 1;
951                            }
952                            block.body.push(Instruction::composite_construct(
953                                member_type_id,
954                                member_id,
955                                &column_ids,
956                            ));
957                        }
958                        _ => {
959                            // Ensure the conversion function for the member's
960                            // type is declared.
961                            self.write_wrapped_convert_from_std140_compat_type(
962                                ir_module, member.ty,
963                            )?;
964                            match self.std140_compat_uniform_types.get(&member.ty) {
965                                Some(std140_type_info) => {
966                                    let std140_member_id = self.id_gen.next();
967                                    block.body.push(Instruction::composite_extract(
968                                        std140_type_info.type_id,
969                                        std140_member_id,
970                                        param_id,
971                                        &[next_index],
972                                    ));
973                                    let function_id = self.wrapped_functions
974                                        [&WrappedFunction::ConvertFromStd140CompatType {
975                                            r#type: member.ty,
976                                        }];
977                                    block.body.push(Instruction::function_call(
978                                        member_type_id,
979                                        member_id,
980                                        function_id,
981                                        &[std140_member_id],
982                                    ));
983                                    next_index += 1;
984                                }
985                                None => {
986                                    let member_id = self.id_gen.next();
987                                    block.body.push(Instruction::composite_extract(
988                                        member_type_id,
989                                        member_id,
990                                        param_id,
991                                        &[next_index],
992                                    ));
993                                    next_index += 1;
994                                }
995                            }
996                        }
997                    }
998                    member_ids.push(member_id);
999                }
1000                let result_id = self.id_gen.next();
1001                block.body.push(Instruction::composite_construct(
1002                    return_type_id,
1003                    result_id,
1004                    &member_ids,
1005                ));
1006                result_id
1007            }
1008            _ => unreachable!(),
1009        };
1010
1011        function.consume(block, Instruction::return_value(result_id));
1012        function.to_words(&mut self.logical_layout.function_definitions);
1013        Ok(())
1014    }
1015
1016    /// Writes a wrapper function to get an `OpTypeVector` column from an
1017    /// `OpTypeMatrix` with a dynamic index.
1018    ///
1019    /// This is used when accessing a column of a [`TypeInner::Matrix`] through
1020    /// a [`Uniform`] address space pointer. In such cases, the matrix will have
1021    /// been declared in SPIR-V using an alternative type where each column is a
1022    /// member of a containing struct. SPIR-V is unable to dynamically access
1023    /// struct members, so instead we load the matrix then call this function to
1024    /// access a column from the loaded value.
1025    ///
1026    /// [`TypeInner::Matrix`]: crate::TypeInner::Matrix
1027    /// [`Uniform`]: crate::AddressSpace::Uniform
1028    fn write_wrapped_matcx2_get_column(
1029        &mut self,
1030        ir_module: &crate::Module,
1031        r#type: Handle<crate::Type>,
1032    ) -> Result<(), Error> {
1033        let wrapped = WrappedFunction::MatCx2GetColumn { r#type };
1034        let function_id = match self.wrapped_functions.entry(wrapped) {
1035            Entry::Occupied(_) => return Ok(()),
1036            Entry::Vacant(e) => *e.insert(self.id_gen.next()),
1037        };
1038        if self.flags.contains(WriterFlags::DEBUG) {
1039            self.debugs.push(Instruction::name(
1040                function_id,
1041                &format!("{:?}_get_column", r#type.for_debug(&ir_module.types)),
1042            ));
1043        }
1044
1045        let crate::TypeInner::Matrix {
1046            columns,
1047            rows: rows @ crate::VectorSize::Bi,
1048            scalar,
1049        } = ir_module.types[r#type].inner
1050        else {
1051            unreachable!();
1052        };
1053
1054        let mut function = Function::default();
1055        let matrix_type_id = self.get_handle_type_id(r#type);
1056        let column_index_type_id = self.get_u32_type_id();
1057        let column_type_id = self.get_numeric_type_id(NumericType::Vector { size: rows, scalar });
1058        let matrix_param_id = self.id_gen.next();
1059        let column_index_param_id = self.id_gen.next();
1060        function.parameters.push(FunctionArgument {
1061            instruction: Instruction::function_parameter(matrix_type_id, matrix_param_id),
1062            handle_id: 0,
1063        });
1064        function.parameters.push(FunctionArgument {
1065            instruction: Instruction::function_parameter(
1066                column_index_type_id,
1067                column_index_param_id,
1068            ),
1069            handle_id: 0,
1070        });
1071        let function_type_id = self.get_function_type(LookupFunctionType {
1072            parameter_type_ids: vec![matrix_type_id, column_index_type_id],
1073            return_type_id: column_type_id,
1074        });
1075        function.signature = Some(Instruction::function(
1076            column_type_id,
1077            function_id,
1078            spirv::FunctionControl::empty(),
1079            function_type_id,
1080        ));
1081
1082        let label_id = self.id_gen.next();
1083        let mut block = Block::new(label_id);
1084
1085        // Create a switch case for each column in the matrix, where each case
1086        // extracts its column from the matrix. Finally we use OpPhi to return
1087        // the correct column.
1088        let merge_id = self.id_gen.next();
1089        block.body.push(Instruction::selection_merge(
1090            merge_id,
1091            spirv::SelectionControl::NONE,
1092        ));
1093        let cases = (0..columns as u32)
1094            .map(|i| super::instructions::Case {
1095                value: i,
1096                label_id: self.id_gen.next(),
1097            })
1098            .collect::<ArrayVec<_, 4>>();
1099
1100        // Which label we branch to in the default (column index out-of-bounds)
1101        // case depends on our bounds check policy.
1102        let default_id = match self.bounds_check_policies.index {
1103            // For `Restrict`, treat the same as the final column.
1104            crate::proc::BoundsCheckPolicy::Restrict => cases.last().unwrap().label_id,
1105            // For `ReadZeroSkipWrite`, branch directly to the merge block. This
1106            // will be handled in the `OpPhi` below to produce a zero value.
1107            crate::proc::BoundsCheckPolicy::ReadZeroSkipWrite => merge_id,
1108            // For `Unchecked` we create a new block containing an
1109            // `OpUnreachable`.
1110            crate::proc::BoundsCheckPolicy::Unchecked => self.id_gen.next(),
1111        };
1112        function.consume(
1113            block,
1114            Instruction::switch(column_index_param_id, default_id, &cases),
1115        );
1116
1117        // Emit a block for each case, and produce a list of variable and parent
1118        // block IDs that will be used in an `OpPhi` below to select the right
1119        // value.
1120        let mut var_parent_pairs = cases
1121            .into_iter()
1122            .map(|case| {
1123                let mut block = Block::new(case.label_id);
1124                let column_id = self.id_gen.next();
1125                block.body.push(Instruction::composite_extract(
1126                    column_type_id,
1127                    column_id,
1128                    matrix_param_id,
1129                    &[case.value],
1130                ));
1131                function.consume(block, Instruction::branch(merge_id));
1132                (column_id, case.label_id)
1133            })
1134            // Need capacity for up to 4 columns plus possibly a default case.
1135            .collect::<ArrayVec<_, 5>>();
1136
1137        // Emit a block or append the variable and parent `OpPhi` pair for the
1138        // column index out-of-bounds case, if required.
1139        match self.bounds_check_policies.index {
1140            // Don't need to do anything for `Restrict` as we have branched from
1141            // the final column case's block.
1142            crate::proc::BoundsCheckPolicy::Restrict => {}
1143            // For `ReadZeroSkipWrite` we have branched directly from the block
1144            // containing the `OpSwitch`. The `OpPhi` should produce a zero
1145            // value.
1146            crate::proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
1147                var_parent_pairs.push((self.get_constant_null(column_type_id), label_id));
1148            }
1149            // For `Unchecked` create a new block containing `OpUnreachable`.
1150            // This does not need to be handled by the `OpPhi`.
1151            crate::proc::BoundsCheckPolicy::Unchecked => {
1152                function.consume(
1153                    Block::new(default_id),
1154                    Instruction::new(spirv::Op::Unreachable),
1155                );
1156            }
1157        }
1158
1159        let mut block = Block::new(merge_id);
1160        let result_id = self.id_gen.next();
1161        block.body.push(Instruction::phi(
1162            column_type_id,
1163            result_id,
1164            &var_parent_pairs,
1165        ));
1166
1167        function.consume(block, Instruction::return_value(result_id));
1168        function.to_words(&mut self.logical_layout.function_definitions);
1169        Ok(())
1170    }
1171
1172    fn write_function(
1173        &mut self,
1174        ir_function: &crate::Function,
1175        info: &FunctionInfo,
1176        ir_module: &crate::Module,
1177        mut interface: Option<FunctionInterface>,
1178        debug_info: &Option<DebugInfoInner>,
1179    ) -> Result<Word, Error> {
1180        self.write_wrapped_functions(ir_function, info, ir_module)?;
1181
1182        log::trace!("Generating code for {:?}", ir_function.name);
1183        let mut function = Function::default();
1184
1185        let prelude_id = self.id_gen.next();
1186        let mut prelude = Block::new(prelude_id);
1187        let mut ep_context = EntryPointContext {
1188            argument_ids: Vec::new(),
1189            results: Vec::new(),
1190            task_payload_variable_id: if let Some(ref i) = interface {
1191                i.task_payload.map(|a| self.global_variables[a].var_id)
1192            } else {
1193                None
1194            },
1195            mesh_state: None,
1196        };
1197
1198        let mut local_invocation_id = None;
1199
1200        let mut parameter_type_ids = Vec::with_capacity(ir_function.arguments.len());
1201
1202        let mut local_invocation_index_id = None;
1203
1204        for argument in ir_function.arguments.iter() {
1205            let class = spirv::StorageClass::Input;
1206            let handle_ty = ir_module.types[argument.ty].inner.is_handle();
1207            let argument_type_id = if handle_ty {
1208                self.get_handle_pointer_type_id(argument.ty, spirv::StorageClass::UniformConstant)
1209            } else {
1210                self.get_handle_type_id(argument.ty)
1211            };
1212
1213            if let Some(ref mut iface) = interface {
1214                let id = if let Some(ref binding) = argument.binding {
1215                    let name = argument.name.as_deref();
1216
1217                    let varying_id = self.write_varying(
1218                        ir_module,
1219                        iface.stage,
1220                        class,
1221                        name,
1222                        argument.ty,
1223                        binding,
1224                    )?;
1225                    iface.varying_ids.push(varying_id);
1226                    let id = self.load_io_with_f16_polyfill(
1227                        &mut prelude.body,
1228                        varying_id,
1229                        argument_type_id,
1230                    );
1231
1232                    if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationId) {
1233                        local_invocation_id = Some(id);
1234                    } else if binding
1235                        == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationIndex)
1236                    {
1237                        local_invocation_index_id = Some(id);
1238                    }
1239
1240                    id
1241                } else if let crate::TypeInner::Struct { ref members, .. } =
1242                    ir_module.types[argument.ty].inner
1243                {
1244                    let struct_id = self.id_gen.next();
1245                    let mut constituent_ids = Vec::with_capacity(members.len());
1246                    for member in members {
1247                        let type_id = self.get_handle_type_id(member.ty);
1248                        let name = member.name.as_deref();
1249                        let binding = member.binding.as_ref().unwrap();
1250                        let varying_id = self.write_varying(
1251                            ir_module,
1252                            iface.stage,
1253                            class,
1254                            name,
1255                            member.ty,
1256                            binding,
1257                        )?;
1258                        iface.varying_ids.push(varying_id);
1259                        let id =
1260                            self.load_io_with_f16_polyfill(&mut prelude.body, varying_id, type_id);
1261                        constituent_ids.push(id);
1262
1263                        if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationId) {
1264                            local_invocation_id = Some(id);
1265                        } else if binding
1266                            == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationIndex)
1267                        {
1268                            local_invocation_index_id = Some(id);
1269                        }
1270                    }
1271                    prelude.body.push(Instruction::composite_construct(
1272                        argument_type_id,
1273                        struct_id,
1274                        &constituent_ids,
1275                    ));
1276                    struct_id
1277                } else {
1278                    unreachable!("Missing argument binding on an entry point");
1279                };
1280                ep_context.argument_ids.push(id);
1281            } else {
1282                let argument_id = self.id_gen.next();
1283                let instruction = Instruction::function_parameter(argument_type_id, argument_id);
1284                if self.flags.contains(WriterFlags::DEBUG) {
1285                    if let Some(ref name) = argument.name {
1286                        self.debugs.push(Instruction::name(argument_id, name));
1287                    }
1288                }
1289                function.parameters.push(FunctionArgument {
1290                    instruction,
1291                    handle_id: if handle_ty {
1292                        let id = self.id_gen.next();
1293                        prelude.body.push(Instruction::load(
1294                            self.get_handle_type_id(argument.ty),
1295                            id,
1296                            argument_id,
1297                            None,
1298                        ));
1299                        id
1300                    } else {
1301                        0
1302                    },
1303                });
1304                parameter_type_ids.push(argument_type_id);
1305            };
1306        }
1307
1308        let return_type_id = match ir_function.result {
1309            Some(ref result) => {
1310                if let Some(ref mut iface) = interface {
1311                    let mut has_point_size = false;
1312                    let class = spirv::StorageClass::Output;
1313                    if let Some(ref binding) = result.binding {
1314                        has_point_size |=
1315                            *binding == crate::Binding::BuiltIn(crate::BuiltIn::PointSize);
1316                        let type_id = self.get_handle_type_id(result.ty);
1317                        let varying_id =
1318                            if *binding == crate::Binding::BuiltIn(crate::BuiltIn::MeshTaskSize) {
1319                                0
1320                            } else {
1321                                let varying_id = self.write_varying(
1322                                    ir_module,
1323                                    iface.stage,
1324                                    class,
1325                                    None,
1326                                    result.ty,
1327                                    binding,
1328                                )?;
1329                                iface.varying_ids.push(varying_id);
1330                                varying_id
1331                            };
1332                        ep_context.results.push(ResultMember {
1333                            id: varying_id,
1334                            type_id,
1335                            built_in: binding.to_built_in(),
1336                        });
1337                    } else if let crate::TypeInner::Struct { ref members, .. } =
1338                        ir_module.types[result.ty].inner
1339                    {
1340                        for member in members {
1341                            let type_id = self.get_handle_type_id(member.ty);
1342                            let name = member.name.as_deref();
1343                            let binding = member.binding.as_ref().unwrap();
1344                            has_point_size |=
1345                                *binding == crate::Binding::BuiltIn(crate::BuiltIn::PointSize);
1346                            // This isn't an actual builtin in SPIR-V. It can only appear as the
1347                            // output of a task shader and the output is used when writing the
1348                            // entry point return, in which case the id is ignored anyway.
1349                            let varying_id = if *binding
1350                                == crate::Binding::BuiltIn(crate::BuiltIn::MeshTaskSize)
1351                            {
1352                                0
1353                            } else {
1354                                let varying_id = self.write_varying(
1355                                    ir_module,
1356                                    iface.stage,
1357                                    class,
1358                                    name,
1359                                    member.ty,
1360                                    binding,
1361                                )?;
1362                                iface.varying_ids.push(varying_id);
1363                                varying_id
1364                            };
1365                            ep_context.results.push(ResultMember {
1366                                id: varying_id,
1367                                type_id,
1368                                built_in: binding.to_built_in(),
1369                            });
1370                        }
1371                    } else {
1372                        unreachable!("Missing result binding on an entry point");
1373                    }
1374
1375                    if self.flags.contains(WriterFlags::FORCE_POINT_SIZE)
1376                        && iface.stage == crate::ShaderStage::Vertex
1377                        && !has_point_size
1378                    {
1379                        // add point size artificially
1380                        let varying_id = self.id_gen.next();
1381                        let pointer_type_id = self.get_f32_pointer_type_id(class);
1382                        Instruction::variable(pointer_type_id, varying_id, class, None)
1383                            .to_words(&mut self.logical_layout.declarations);
1384                        self.decorate(
1385                            varying_id,
1386                            spirv::Decoration::BuiltIn,
1387                            &[spirv::BuiltIn::PointSize as u32],
1388                        );
1389                        iface.varying_ids.push(varying_id);
1390
1391                        let default_value_id = self.get_constant_scalar(crate::Literal::F32(1.0));
1392                        prelude
1393                            .body
1394                            .push(Instruction::store(varying_id, default_value_id, None));
1395                    }
1396                    self.void_type
1397                } else {
1398                    self.get_handle_type_id(result.ty)
1399                }
1400            }
1401            None => self.void_type,
1402        };
1403
1404        if let Some(ref mut iface) = interface {
1405            if let Some(task_payload) = iface.task_payload {
1406                iface
1407                    .varying_ids
1408                    .push(self.global_variables[task_payload].var_id);
1409            }
1410            self.write_entry_point_mesh_shader_info(
1411                iface,
1412                local_invocation_index_id,
1413                ir_module,
1414                &mut prelude,
1415                &mut ep_context,
1416            )?;
1417        }
1418
1419        let lookup_function_type = LookupFunctionType {
1420            parameter_type_ids,
1421            return_type_id,
1422        };
1423
1424        let function_id = self.id_gen.next();
1425        if self.flags.contains(WriterFlags::DEBUG) {
1426            if let Some(ref name) = ir_function.name {
1427                self.debugs.push(Instruction::name(function_id, name));
1428            }
1429        }
1430
1431        let function_type = self.get_function_type(lookup_function_type);
1432        function.signature = Some(Instruction::function(
1433            return_type_id,
1434            function_id,
1435            spirv::FunctionControl::empty(),
1436            function_type,
1437        ));
1438
1439        if interface.is_some() {
1440            function.entry_point_context = Some(ep_context);
1441        }
1442
1443        // fill up the `GlobalVariable::access_id`
1444        for gv in self.global_variables.iter_mut() {
1445            gv.reset_for_function();
1446        }
1447        for (handle, var) in ir_module.global_variables.iter() {
1448            if info[handle].is_empty() {
1449                continue;
1450            }
1451
1452            let mut gv = self.global_variables[handle].clone();
1453            if let Some(ref mut iface) = interface {
1454                // Have to include global variables in the interface
1455                if self.physical_layout.version >= 0x10400 && iface.task_payload != Some(handle) {
1456                    iface.varying_ids.push(gv.var_id);
1457                }
1458            }
1459
1460            match ir_module.types[var.ty].inner {
1461                // Any that are binding arrays we skip as we cannot load the array, we must load the result after indexing.
1462                crate::TypeInner::BindingArray { .. } => {
1463                    gv.access_id = gv.var_id;
1464                }
1465                _ => {
1466                    // Handle globals are pre-emitted and should be loaded automatically.
1467                    if var.space == crate::AddressSpace::Handle {
1468                        let var_type_id = self.get_handle_type_id(var.ty);
1469                        let id = self.id_gen.next();
1470                        prelude
1471                            .body
1472                            .push(Instruction::load(var_type_id, id, gv.var_id, None));
1473                        gv.access_id = gv.var_id;
1474                        gv.handle_id = id;
1475                    } else if global_needs_wrapper(ir_module, var) {
1476                        let class = map_storage_class(var.space);
1477                        let pointer_type_id = match self.std140_compat_uniform_types.get(&var.ty) {
1478                            Some(std140_type_info) if var.space == crate::AddressSpace::Uniform => {
1479                                self.get_pointer_type_id(std140_type_info.type_id, class)
1480                            }
1481                            _ => self.get_handle_pointer_type_id(var.ty, class),
1482                        };
1483                        let index_id = self.get_index_constant(0);
1484                        let id = self.id_gen.next();
1485                        prelude.body.push(Instruction::access_chain(
1486                            pointer_type_id,
1487                            id,
1488                            gv.var_id,
1489                            &[index_id],
1490                        ));
1491                        gv.access_id = id;
1492                    } else {
1493                        // by default, the variable ID is accessed as is
1494                        gv.access_id = gv.var_id;
1495                    };
1496                }
1497            }
1498
1499            // work around borrow checking in the presence of `self.xxx()` calls
1500            self.global_variables[handle] = gv;
1501        }
1502
1503        // Create a `BlockContext` for generating SPIR-V for the function's
1504        // body.
1505        let mut context = BlockContext {
1506            ir_module,
1507            ir_function,
1508            fun_info: info,
1509            function: &mut function,
1510            // Re-use the cached expression table from prior functions.
1511            cached: core::mem::take(&mut self.saved_cached),
1512
1513            // Steal the Writer's temp list for a bit.
1514            temp_list: core::mem::take(&mut self.temp_list),
1515            force_loop_bounding: self.force_loop_bounding,
1516            writer: self,
1517            expression_constness: super::ExpressionConstnessTracker::from_arena(
1518                &ir_function.expressions,
1519            ),
1520            ray_query_tracker_expr: crate::FastHashMap::default(),
1521        };
1522
1523        // fill up the pre-emitted and const expressions
1524        context.cached.reset(ir_function.expressions.len());
1525        for (handle, expr) in ir_function.expressions.iter() {
1526            if (expr.needs_pre_emit() && !matches!(*expr, crate::Expression::LocalVariable(_)))
1527                || context.expression_constness.is_const(handle)
1528            {
1529                context.cache_expression_value(handle, &mut prelude)?;
1530            }
1531        }
1532
1533        for (handle, variable) in ir_function.local_variables.iter() {
1534            let id = context.gen_id();
1535
1536            if context.writer.flags.contains(WriterFlags::DEBUG) {
1537                if let Some(ref name) = variable.name {
1538                    context.writer.debugs.push(Instruction::name(id, name));
1539                }
1540            }
1541
1542            let init_word = variable.init.map(|constant| context.cached[constant]);
1543            let pointer_type_id = context
1544                .writer
1545                .get_handle_pointer_type_id(variable.ty, spirv::StorageClass::Function);
1546            let instruction = Instruction::variable(
1547                pointer_type_id,
1548                id,
1549                spirv::StorageClass::Function,
1550                init_word.or_else(|| match ir_module.types[variable.ty].inner {
1551                    crate::TypeInner::RayQuery { .. } => None,
1552                    _ => {
1553                        let type_id = context.get_handle_type_id(variable.ty);
1554                        Some(context.writer.write_constant_null(type_id))
1555                    }
1556                }),
1557            );
1558
1559            context
1560                .function
1561                .variables
1562                .insert(handle, LocalVariable { id, instruction });
1563
1564            if let crate::TypeInner::RayQuery { .. } = ir_module.types[variable.ty].inner {
1565                // Don't refactor this into a struct: Although spirv itself allows opaque types in structs,
1566                // the vulkan environment for spirv does not. Putting ray queries into structs can cause
1567                // confusing bugs.
1568                let u32_type_id = context.writer.get_u32_type_id();
1569                let ptr_u32_type_id = context
1570                    .writer
1571                    .get_pointer_type_id(u32_type_id, spirv::StorageClass::Function);
1572                let tracker_id = context.gen_id();
1573                let tracker_init_id = context.writer.get_constant_scalar(crate::Literal::U32(
1574                    crate::back::RayQueryPoint::empty().bits(),
1575                ));
1576                let tracker_instruction = Instruction::variable(
1577                    ptr_u32_type_id,
1578                    tracker_id,
1579                    spirv::StorageClass::Function,
1580                    Some(tracker_init_id),
1581                );
1582
1583                context
1584                    .function
1585                    .ray_query_initialization_tracker_variables
1586                    .insert(
1587                        handle,
1588                        LocalVariable {
1589                            id: tracker_id,
1590                            instruction: tracker_instruction,
1591                        },
1592                    );
1593                let f32_type_id = context.writer.get_f32_type_id();
1594                let ptr_f32_type_id = context
1595                    .writer
1596                    .get_pointer_type_id(f32_type_id, spirv::StorageClass::Function);
1597                let t_max_tracker_id = context.gen_id();
1598                let t_max_tracker_init_id =
1599                    context.writer.get_constant_scalar(crate::Literal::F32(0.0));
1600                let t_max_tracker_instruction = Instruction::variable(
1601                    ptr_f32_type_id,
1602                    t_max_tracker_id,
1603                    spirv::StorageClass::Function,
1604                    Some(t_max_tracker_init_id),
1605                );
1606
1607                context.function.ray_query_t_max_tracker_variables.insert(
1608                    handle,
1609                    LocalVariable {
1610                        id: t_max_tracker_id,
1611                        instruction: t_max_tracker_instruction,
1612                    },
1613                );
1614            }
1615        }
1616
1617        for (handle, expr) in ir_function.expressions.iter() {
1618            match *expr {
1619                crate::Expression::LocalVariable(_) => {
1620                    // Cache the `OpVariable` instruction we generated above as
1621                    // the value of this expression.
1622                    context.cache_expression_value(handle, &mut prelude)?;
1623                }
1624                crate::Expression::Access { base, .. }
1625                | crate::Expression::AccessIndex { base, .. } => {
1626                    // Count references to `base` by `Access` and `AccessIndex`
1627                    // instructions. See `access_uses` for details.
1628                    *context.function.access_uses.entry(base).or_insert(0) += 1;
1629                }
1630                _ => {}
1631            }
1632        }
1633
1634        let next_id = context.gen_id();
1635
1636        context
1637            .function
1638            .consume(prelude, Instruction::branch(next_id));
1639
1640        let workgroup_vars_init_exit_block_id =
1641            match (context.writer.zero_initialize_workgroup_memory, interface) {
1642                (
1643                    super::ZeroInitializeWorkgroupMemoryMode::Polyfill,
1644                    Some(
1645                        ref mut interface @ FunctionInterface {
1646                            stage:
1647                                crate::ShaderStage::Compute
1648                                | crate::ShaderStage::Mesh
1649                                | crate::ShaderStage::Task,
1650                            ..
1651                        },
1652                    ),
1653                ) => context.writer.generate_workgroup_vars_init_block(
1654                    next_id,
1655                    ir_module,
1656                    info,
1657                    local_invocation_id,
1658                    interface,
1659                    context.function,
1660                ),
1661                _ => None,
1662            };
1663
1664        let main_id = if let Some(exit_id) = workgroup_vars_init_exit_block_id {
1665            exit_id
1666        } else {
1667            next_id
1668        };
1669
1670        context.write_function_body(main_id, debug_info.as_ref())?;
1671
1672        // Consume the `BlockContext`, ending its borrows and letting the
1673        // `Writer` steal back its cached expression table and temp_list.
1674        let BlockContext {
1675            cached, temp_list, ..
1676        } = context;
1677        self.saved_cached = cached;
1678        self.temp_list = temp_list;
1679
1680        function.to_words(&mut self.logical_layout.function_definitions);
1681
1682        Ok(function_id)
1683    }
1684
1685    fn write_execution_mode(
1686        &mut self,
1687        function_id: Word,
1688        mode: spirv::ExecutionMode,
1689    ) -> Result<(), Error> {
1690        //self.check(mode.required_capabilities())?;
1691        Instruction::execution_mode(function_id, mode, &[])
1692            .to_words(&mut self.logical_layout.execution_modes);
1693        Ok(())
1694    }
1695
1696    // TODO Move to instructions module
1697    fn write_entry_point(
1698        &mut self,
1699        entry_point: &crate::EntryPoint,
1700        info: &FunctionInfo,
1701        ir_module: &crate::Module,
1702        debug_info: &Option<DebugInfoInner>,
1703    ) -> Result<Instruction, Error> {
1704        let mut interface_ids = Vec::new();
1705        let function_id = self.write_function(
1706            &entry_point.function,
1707            info,
1708            ir_module,
1709            Some(FunctionInterface {
1710                varying_ids: &mut interface_ids,
1711                stage: entry_point.stage,
1712                task_payload: entry_point.task_payload,
1713                mesh_info: entry_point.mesh_info.clone(),
1714                workgroup_size: entry_point.workgroup_size,
1715            }),
1716            debug_info,
1717        )?;
1718
1719        let exec_model = match entry_point.stage {
1720            crate::ShaderStage::Vertex => spirv::ExecutionModel::Vertex,
1721            crate::ShaderStage::Fragment => {
1722                self.write_execution_mode(function_id, spirv::ExecutionMode::OriginUpperLeft)?;
1723                match entry_point.early_depth_test {
1724                    Some(crate::EarlyDepthTest::Force) => {
1725                        self.write_execution_mode(
1726                            function_id,
1727                            spirv::ExecutionMode::EarlyFragmentTests,
1728                        )?;
1729                    }
1730                    Some(crate::EarlyDepthTest::Allow { conservative }) => {
1731                        // TODO: Consider emitting EarlyAndLateFragmentTestsAMD here, if available.
1732                        // https://github.khronos.org/SPIRV-Registry/extensions/AMD/SPV_AMD_shader_early_and_late_fragment_tests.html
1733                        // This permits early depth tests even if the shader writes to a storage
1734                        // binding
1735                        match conservative {
1736                            crate::ConservativeDepth::GreaterEqual => self.write_execution_mode(
1737                                function_id,
1738                                spirv::ExecutionMode::DepthGreater,
1739                            )?,
1740                            crate::ConservativeDepth::LessEqual => self.write_execution_mode(
1741                                function_id,
1742                                spirv::ExecutionMode::DepthLess,
1743                            )?,
1744                            crate::ConservativeDepth::Unchanged => self.write_execution_mode(
1745                                function_id,
1746                                spirv::ExecutionMode::DepthUnchanged,
1747                            )?,
1748                        }
1749                    }
1750                    None => {}
1751                }
1752                if let Some(ref result) = entry_point.function.result {
1753                    if contains_builtin(
1754                        result.binding.as_ref(),
1755                        result.ty,
1756                        &ir_module.types,
1757                        crate::BuiltIn::FragDepth,
1758                    ) {
1759                        self.write_execution_mode(
1760                            function_id,
1761                            spirv::ExecutionMode::DepthReplacing,
1762                        )?;
1763                    }
1764                }
1765                spirv::ExecutionModel::Fragment
1766            }
1767            crate::ShaderStage::Compute => {
1768                let execution_mode = spirv::ExecutionMode::LocalSize;
1769                Instruction::execution_mode(
1770                    function_id,
1771                    execution_mode,
1772                    &entry_point.workgroup_size,
1773                )
1774                .to_words(&mut self.logical_layout.execution_modes);
1775                spirv::ExecutionModel::GLCompute
1776            }
1777            crate::ShaderStage::Task => {
1778                let execution_mode = spirv::ExecutionMode::LocalSize;
1779                Instruction::execution_mode(
1780                    function_id,
1781                    execution_mode,
1782                    &entry_point.workgroup_size,
1783                )
1784                .to_words(&mut self.logical_layout.execution_modes);
1785                spirv::ExecutionModel::TaskEXT
1786            }
1787            crate::ShaderStage::Mesh => {
1788                let execution_mode = spirv::ExecutionMode::LocalSize;
1789                Instruction::execution_mode(
1790                    function_id,
1791                    execution_mode,
1792                    &entry_point.workgroup_size,
1793                )
1794                .to_words(&mut self.logical_layout.execution_modes);
1795                let mesh_info = entry_point.mesh_info.as_ref().unwrap();
1796                Instruction::execution_mode(
1797                    function_id,
1798                    match mesh_info.topology {
1799                        crate::MeshOutputTopology::Points => spirv::ExecutionMode::OutputPoints,
1800                        crate::MeshOutputTopology::Lines => spirv::ExecutionMode::OutputLinesEXT,
1801                        crate::MeshOutputTopology::Triangles => {
1802                            spirv::ExecutionMode::OutputTrianglesEXT
1803                        }
1804                    },
1805                    &[],
1806                )
1807                .to_words(&mut self.logical_layout.execution_modes);
1808                Instruction::execution_mode(
1809                    function_id,
1810                    spirv::ExecutionMode::OutputVertices,
1811                    core::slice::from_ref(&mesh_info.max_vertices),
1812                )
1813                .to_words(&mut self.logical_layout.execution_modes);
1814                Instruction::execution_mode(
1815                    function_id,
1816                    spirv::ExecutionMode::OutputPrimitivesEXT,
1817                    core::slice::from_ref(&mesh_info.max_primitives),
1818                )
1819                .to_words(&mut self.logical_layout.execution_modes);
1820                spirv::ExecutionModel::MeshEXT
1821            }
1822        };
1823        //self.check(exec_model.required_capabilities())?;
1824
1825        Ok(Instruction::entry_point(
1826            exec_model,
1827            function_id,
1828            &entry_point.name,
1829            interface_ids.as_slice(),
1830        ))
1831    }
1832
1833    fn make_scalar(&mut self, id: Word, scalar: crate::Scalar) -> Instruction {
1834        use crate::ScalarKind as Sk;
1835
1836        let bits = (scalar.width * BITS_PER_BYTE) as u32;
1837        match scalar.kind {
1838            Sk::Sint | Sk::Uint => {
1839                let signedness = if scalar.kind == Sk::Sint {
1840                    super::instructions::Signedness::Signed
1841                } else {
1842                    super::instructions::Signedness::Unsigned
1843                };
1844                let cap = match bits {
1845                    8 => Some(spirv::Capability::Int8),
1846                    16 => Some(spirv::Capability::Int16),
1847                    64 => Some(spirv::Capability::Int64),
1848                    _ => None,
1849                };
1850                if let Some(cap) = cap {
1851                    self.capabilities_used.insert(cap);
1852                }
1853                Instruction::type_int(id, bits, signedness)
1854            }
1855            Sk::Float => {
1856                if bits == 64 {
1857                    self.capabilities_used.insert(spirv::Capability::Float64);
1858                }
1859                if bits == 16 {
1860                    self.capabilities_used.insert(spirv::Capability::Float16);
1861                    self.capabilities_used
1862                        .insert(spirv::Capability::StorageBuffer16BitAccess);
1863                    self.capabilities_used
1864                        .insert(spirv::Capability::UniformAndStorageBuffer16BitAccess);
1865                    if self.use_storage_input_output_16 {
1866                        self.capabilities_used
1867                            .insert(spirv::Capability::StorageInputOutput16);
1868                    }
1869                }
1870                Instruction::type_float(id, bits)
1871            }
1872            Sk::Bool => Instruction::type_bool(id),
1873            Sk::AbstractInt | Sk::AbstractFloat => {
1874                unreachable!("abstract types should never reach the backend");
1875            }
1876        }
1877    }
1878
1879    fn request_type_capabilities(&mut self, inner: &crate::TypeInner) -> Result<(), Error> {
1880        match *inner {
1881            crate::TypeInner::Image {
1882                dim,
1883                arrayed,
1884                class,
1885            } => {
1886                let sampled = match class {
1887                    crate::ImageClass::Sampled { .. } => true,
1888                    crate::ImageClass::Depth { .. } => true,
1889                    crate::ImageClass::Storage { format, .. } => {
1890                        self.request_image_format_capabilities(format.into())?;
1891                        false
1892                    }
1893                    crate::ImageClass::External => unimplemented!(),
1894                };
1895
1896                match dim {
1897                    crate::ImageDimension::D1 => {
1898                        if sampled {
1899                            self.require_any("sampled 1D images", &[spirv::Capability::Sampled1D])?;
1900                        } else {
1901                            self.require_any("1D storage images", &[spirv::Capability::Image1D])?;
1902                        }
1903                    }
1904                    crate::ImageDimension::Cube if arrayed => {
1905                        if sampled {
1906                            self.require_any(
1907                                "sampled cube array images",
1908                                &[spirv::Capability::SampledCubeArray],
1909                            )?;
1910                        } else {
1911                            self.require_any(
1912                                "cube array storage images",
1913                                &[spirv::Capability::ImageCubeArray],
1914                            )?;
1915                        }
1916                    }
1917                    _ => {}
1918                }
1919            }
1920            crate::TypeInner::AccelerationStructure { .. } => {
1921                self.require_any("Acceleration Structure", &[spirv::Capability::RayQueryKHR])?;
1922            }
1923            crate::TypeInner::RayQuery { .. } => {
1924                self.require_any("Ray Query", &[spirv::Capability::RayQueryKHR])?;
1925            }
1926            crate::TypeInner::Atomic(crate::Scalar { width: 8, kind: _ }) => {
1927                self.require_any("64 bit integer atomics", &[spirv::Capability::Int64Atomics])?;
1928            }
1929            crate::TypeInner::Atomic(crate::Scalar {
1930                width: 4,
1931                kind: crate::ScalarKind::Float,
1932            }) => {
1933                self.require_any(
1934                    "32 bit floating-point atomics",
1935                    &[spirv::Capability::AtomicFloat32AddEXT],
1936                )?;
1937                self.use_extension("SPV_EXT_shader_atomic_float_add");
1938            }
1939            // 16 bit floating-point support requires Float16 capability
1940            crate::TypeInner::Matrix {
1941                scalar: crate::Scalar::F16,
1942                ..
1943            }
1944            | crate::TypeInner::Vector {
1945                scalar: crate::Scalar::F16,
1946                ..
1947            }
1948            | crate::TypeInner::Scalar(crate::Scalar::F16) => {
1949                self.require_any("16 bit floating-point", &[spirv::Capability::Float16])?;
1950                self.use_extension("SPV_KHR_16bit_storage");
1951            }
1952            // Cooperative types and ops
1953            crate::TypeInner::CooperativeMatrix { .. } => {
1954                self.require_any(
1955                    "cooperative matrix",
1956                    &[spirv::Capability::CooperativeMatrixKHR],
1957                )?;
1958                self.require_any("memory model", &[spirv::Capability::VulkanMemoryModel])?;
1959                self.use_extension("SPV_KHR_cooperative_matrix");
1960                self.use_extension("SPV_KHR_vulkan_memory_model");
1961            }
1962            _ => {}
1963        }
1964        Ok(())
1965    }
1966
1967    fn write_numeric_type_declaration_local(&mut self, id: Word, numeric: NumericType) {
1968        let instruction = match numeric {
1969            NumericType::Scalar(scalar) => self.make_scalar(id, scalar),
1970            NumericType::Vector { size, scalar } => {
1971                let scalar_id = self.get_numeric_type_id(NumericType::Scalar(scalar));
1972                Instruction::type_vector(id, scalar_id, size)
1973            }
1974            NumericType::Matrix {
1975                columns,
1976                rows,
1977                scalar,
1978            } => {
1979                let column_id =
1980                    self.get_numeric_type_id(NumericType::Vector { size: rows, scalar });
1981                Instruction::type_matrix(id, column_id, columns)
1982            }
1983        };
1984
1985        instruction.to_words(&mut self.logical_layout.declarations);
1986    }
1987
1988    fn write_cooperative_type_declaration_local(&mut self, id: Word, coop: CooperativeType) {
1989        let instruction = match coop {
1990            CooperativeType::Matrix {
1991                columns,
1992                rows,
1993                scalar,
1994                role,
1995            } => {
1996                let scalar_id =
1997                    self.get_localtype_id(LocalType::Numeric(NumericType::Scalar(scalar)));
1998                let scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32);
1999                let columns_id = self.get_index_constant(columns as u32);
2000                let rows_id = self.get_index_constant(rows as u32);
2001                let role_id =
2002                    self.get_index_constant(spirv::CooperativeMatrixUse::from(role) as u32);
2003                Instruction::type_coop_matrix(id, scalar_id, scope_id, rows_id, columns_id, role_id)
2004            }
2005        };
2006
2007        instruction.to_words(&mut self.logical_layout.declarations);
2008    }
2009
2010    fn write_type_declaration_local(&mut self, id: Word, local_ty: LocalType) {
2011        let instruction = match local_ty {
2012            LocalType::Numeric(numeric) => {
2013                self.write_numeric_type_declaration_local(id, numeric);
2014                return;
2015            }
2016            LocalType::Cooperative(coop) => {
2017                self.write_cooperative_type_declaration_local(id, coop);
2018                return;
2019            }
2020            LocalType::Pointer { base, class } => Instruction::type_pointer(id, class, base),
2021            LocalType::Image(image) => {
2022                let local_type = LocalType::Numeric(NumericType::Scalar(image.sampled_type));
2023                let type_id = self.get_localtype_id(local_type);
2024                Instruction::type_image(id, type_id, image.dim, image.flags, image.image_format)
2025            }
2026            LocalType::Sampler => Instruction::type_sampler(id),
2027            LocalType::SampledImage { image_type_id } => {
2028                Instruction::type_sampled_image(id, image_type_id)
2029            }
2030            LocalType::BindingArray { base, size } => {
2031                let inner_ty = self.get_handle_type_id(base);
2032                let scalar_id = self.get_constant_scalar(crate::Literal::U32(size));
2033                Instruction::type_array(id, inner_ty, scalar_id)
2034            }
2035            LocalType::AccelerationStructure => Instruction::type_acceleration_structure(id),
2036            LocalType::RayQuery => Instruction::type_ray_query(id),
2037        };
2038
2039        instruction.to_words(&mut self.logical_layout.declarations);
2040    }
2041
2042    fn write_type_declaration_arena(
2043        &mut self,
2044        module: &crate::Module,
2045        handle: Handle<crate::Type>,
2046    ) -> Result<Word, Error> {
2047        let ty = &module.types[handle];
2048        // If it's a type that needs SPIR-V capabilities, request them now.
2049        // This needs to happen regardless of the LocalType lookup succeeding,
2050        // because some types which map to the same LocalType have different
2051        // capability requirements. See https://github.com/gfx-rs/wgpu/issues/5569
2052        self.request_type_capabilities(&ty.inner)?;
2053        let id = if let Some(local) = self.localtype_from_inner(&ty.inner) {
2054            // This type can be represented as a `LocalType`, so check if we've
2055            // already written an instruction for it. If not, do so now, with
2056            // `write_type_declaration_local`.
2057            match self.lookup_type.entry(LookupType::Local(local)) {
2058                // We already have an id for this `LocalType`.
2059                Entry::Occupied(e) => *e.get(),
2060
2061                // It's a type we haven't seen before.
2062                Entry::Vacant(e) => {
2063                    let id = self.id_gen.next();
2064                    e.insert(id);
2065
2066                    self.write_type_declaration_local(id, local);
2067
2068                    id
2069                }
2070            }
2071        } else {
2072            use spirv::Decoration;
2073
2074            let id = self.id_gen.next();
2075            let instruction = match ty.inner {
2076                crate::TypeInner::Array { base, size, stride } => {
2077                    self.decorate(id, Decoration::ArrayStride, &[stride]);
2078
2079                    let type_id = self.get_handle_type_id(base);
2080                    match size.resolve(module.to_ctx())? {
2081                        crate::proc::IndexableLength::Known(length) => {
2082                            let length_id = self.get_index_constant(length);
2083                            Instruction::type_array(id, type_id, length_id)
2084                        }
2085                        crate::proc::IndexableLength::Dynamic => {
2086                            Instruction::type_runtime_array(id, type_id)
2087                        }
2088                    }
2089                }
2090                crate::TypeInner::BindingArray { base, size } => {
2091                    let type_id = self.get_handle_type_id(base);
2092                    match size.resolve(module.to_ctx())? {
2093                        crate::proc::IndexableLength::Known(length) => {
2094                            let length_id = self.get_index_constant(length);
2095                            Instruction::type_array(id, type_id, length_id)
2096                        }
2097                        crate::proc::IndexableLength::Dynamic => {
2098                            Instruction::type_runtime_array(id, type_id)
2099                        }
2100                    }
2101                }
2102                crate::TypeInner::Struct {
2103                    ref members,
2104                    span: _,
2105                } => {
2106                    let mut has_runtime_array = false;
2107                    let mut member_ids = Vec::with_capacity(members.len());
2108                    for (index, member) in members.iter().enumerate() {
2109                        let member_ty = &module.types[member.ty];
2110                        match member_ty.inner {
2111                            crate::TypeInner::Array {
2112                                base: _,
2113                                size: crate::ArraySize::Dynamic,
2114                                stride: _,
2115                            } => {
2116                                has_runtime_array = true;
2117                            }
2118                            _ => (),
2119                        }
2120                        self.decorate_struct_member(id, index, member, &module.types)?;
2121                        let member_id = self.get_handle_type_id(member.ty);
2122                        member_ids.push(member_id);
2123                    }
2124                    if has_runtime_array {
2125                        self.decorate(id, Decoration::Block, &[]);
2126                    }
2127                    Instruction::type_struct(id, member_ids.as_slice())
2128                }
2129
2130                // These all have TypeLocal representations, so they should have been
2131                // handled by `write_type_declaration_local` above.
2132                crate::TypeInner::Scalar(_)
2133                | crate::TypeInner::Atomic(_)
2134                | crate::TypeInner::Vector { .. }
2135                | crate::TypeInner::Matrix { .. }
2136                | crate::TypeInner::CooperativeMatrix { .. }
2137                | crate::TypeInner::Pointer { .. }
2138                | crate::TypeInner::ValuePointer { .. }
2139                | crate::TypeInner::Image { .. }
2140                | crate::TypeInner::Sampler { .. }
2141                | crate::TypeInner::AccelerationStructure { .. }
2142                | crate::TypeInner::RayQuery { .. } => unreachable!(),
2143            };
2144
2145            instruction.to_words(&mut self.logical_layout.declarations);
2146            id
2147        };
2148
2149        // Add this handle as a new alias for that type.
2150        self.lookup_type.insert(LookupType::Handle(handle), id);
2151
2152        if self.flags.contains(WriterFlags::DEBUG) {
2153            if let Some(ref name) = ty.name {
2154                self.debugs.push(Instruction::name(id, name));
2155            }
2156        }
2157
2158        Ok(id)
2159    }
2160
2161    /// Writes a std140 layout compatible type declaration for a type. Returns
2162    /// the ID of the declared type, or None if no declaration is required.
2163    ///
2164    /// This should be called for any type for which there exists a
2165    /// [`GlobalVariable`] in the [`Uniform`] address space. If the type already
2166    /// adheres to std140 layout rules it will return without declaring any
2167    /// types. If the type contains another type which requires a std140
2168    /// compatible type declaration, it will recursively call itself.
2169    ///
2170    /// When `handle` refers to a [`TypeInner::Matrix`] with 2 rows, the
2171    /// declared type will be an `OpTypeStruct` containing an `OpVector` for
2172    /// each of the matrix's columns.
2173    ///
2174    /// When `handle` refers to a [`TypeInner::Array`] whose base type is a
2175    /// matrix with 2 rows, this will declare an `OpTypeArray` whose element
2176    /// type is the matrix's corresponding std140 compatible type.
2177    ///
2178    /// When `handle` refers to a [`TypeInner::Struct`] and any of its members
2179    /// require a std140 compatible type declaration, this will declare a new
2180    /// struct with the following rules:
2181    /// * Struct or array members will be declared with their std140 compatible
2182    ///   type declaration, if one is required.
2183    /// * Two-row matrix members will have each of their columns hoisted
2184    ///   directly into the struct as 2-component vector members.
2185    /// * All other members will be declared with their normal type.
2186    ///
2187    /// Note that this means the Naga IR index of a struct member may not match
2188    /// the index in the generated SPIR-V. The mapping can be obtained via
2189    /// `Std140TypeInfo::member_indices`.
2190    ///
2191    /// [`GlobalVariable`]: crate::GlobalVariable
2192    /// [`Uniform`]: crate::AddressSpace::Uniform
2193    /// [`TypeInner::Matrix`]: crate::TypeInner::Matrix
2194    /// [`TypeInner::Array`]: crate::TypeInner::Array
2195    /// [`TypeInner::Struct`]: crate::TypeInner::Struct
2196    fn write_std140_compat_type_declaration(
2197        &mut self,
2198        module: &crate::Module,
2199        handle: Handle<crate::Type>,
2200    ) -> Result<Option<Word>, Error> {
2201        if let Some(std140_type_info) = self.std140_compat_uniform_types.get(&handle) {
2202            return Ok(Some(std140_type_info.type_id));
2203        }
2204
2205        let type_inner = &module.types[handle].inner;
2206        let std140_type_id = match *type_inner {
2207            crate::TypeInner::Matrix {
2208                columns,
2209                rows: rows @ crate::VectorSize::Bi,
2210                scalar,
2211            } => {
2212                let std140_type_id = self.id_gen.next();
2213                let mut member_type_ids: ArrayVec<Word, 4> = ArrayVec::new();
2214                let column_type_id =
2215                    self.get_numeric_type_id(NumericType::Vector { size: rows, scalar });
2216                for column in 0..columns as u32 {
2217                    member_type_ids.push(column_type_id);
2218                    self.annotations.push(Instruction::member_decorate(
2219                        std140_type_id,
2220                        column,
2221                        spirv::Decoration::Offset,
2222                        &[column * rows as u32 * scalar.width as u32],
2223                    ));
2224                    if self.flags.contains(WriterFlags::DEBUG) {
2225                        self.debugs.push(Instruction::member_name(
2226                            std140_type_id,
2227                            column,
2228                            &format!("col{column}"),
2229                        ));
2230                    }
2231                }
2232                Instruction::type_struct(std140_type_id, &member_type_ids)
2233                    .to_words(&mut self.logical_layout.declarations);
2234                self.std140_compat_uniform_types.insert(
2235                    handle,
2236                    Std140CompatTypeInfo {
2237                        type_id: std140_type_id,
2238                        member_indices: Vec::new(),
2239                    },
2240                );
2241                Some(std140_type_id)
2242            }
2243            crate::TypeInner::Array { base, size, stride } => {
2244                match self.write_std140_compat_type_declaration(module, base)? {
2245                    Some(std140_base_type_id) => {
2246                        let std140_type_id = self.id_gen.next();
2247                        self.decorate(std140_type_id, spirv::Decoration::ArrayStride, &[stride]);
2248                        let instruction = match size.resolve(module.to_ctx())? {
2249                            crate::proc::IndexableLength::Known(length) => {
2250                                let length_id = self.get_index_constant(length);
2251                                Instruction::type_array(
2252                                    std140_type_id,
2253                                    std140_base_type_id,
2254                                    length_id,
2255                                )
2256                            }
2257                            crate::proc::IndexableLength::Dynamic => {
2258                                unreachable!()
2259                            }
2260                        };
2261                        instruction.to_words(&mut self.logical_layout.declarations);
2262                        self.std140_compat_uniform_types.insert(
2263                            handle,
2264                            Std140CompatTypeInfo {
2265                                type_id: std140_type_id,
2266                                member_indices: Vec::new(),
2267                            },
2268                        );
2269                        Some(std140_type_id)
2270                    }
2271                    None => None,
2272                }
2273            }
2274            crate::TypeInner::Struct { ref members, .. } => {
2275                let mut needs_std140_type = false;
2276                for member in members {
2277                    match module.types[member.ty].inner {
2278                        // We don't need to write a std140 type for the matrix itself as
2279                        // it will be decomposed into the parent struct. As a result, the
2280                        // struct does need a std140 type, however.
2281                        crate::TypeInner::Matrix {
2282                            rows: crate::VectorSize::Bi,
2283                            ..
2284                        } => needs_std140_type = true,
2285                        // If an array member needs a std140 type, because it is an array
2286                        // (of an array, etc) of `matCx2`s, then the struct also needs
2287                        // a std140 type which uses the std140 type for this member.
2288                        crate::TypeInner::Array { .. }
2289                            if self
2290                                .write_std140_compat_type_declaration(module, member.ty)?
2291                                .is_some() =>
2292                        {
2293                            needs_std140_type = true;
2294                        }
2295                        _ => {}
2296                    }
2297                }
2298
2299                if needs_std140_type {
2300                    let std140_type_id = self.id_gen.next();
2301                    let mut member_ids = Vec::new();
2302                    let mut member_indices = Vec::new();
2303                    let mut next_index = 0;
2304
2305                    for member in members {
2306                        member_indices.push(next_index);
2307                        match module.types[member.ty].inner {
2308                            crate::TypeInner::Matrix {
2309                                columns,
2310                                rows: rows @ crate::VectorSize::Bi,
2311                                scalar,
2312                            } => {
2313                                let vector_type_id =
2314                                    self.get_numeric_type_id(NumericType::Vector {
2315                                        size: rows,
2316                                        scalar,
2317                                    });
2318                                for column in 0..columns as u32 {
2319                                    self.annotations.push(Instruction::member_decorate(
2320                                        std140_type_id,
2321                                        next_index,
2322                                        spirv::Decoration::Offset,
2323                                        &[member.offset
2324                                            + column * rows as u32 * scalar.width as u32],
2325                                    ));
2326                                    if self.flags.contains(WriterFlags::DEBUG) {
2327                                        if let Some(ref name) = member.name {
2328                                            self.debugs.push(Instruction::member_name(
2329                                                std140_type_id,
2330                                                next_index,
2331                                                &format!("{name}_col{column}"),
2332                                            ));
2333                                        }
2334                                    }
2335                                    member_ids.push(vector_type_id);
2336                                    next_index += 1;
2337                                }
2338                            }
2339                            _ => {
2340                                let member_id =
2341                                    match self.std140_compat_uniform_types.get(&member.ty) {
2342                                        Some(std140_member_type_info) => {
2343                                            self.annotations.push(Instruction::member_decorate(
2344                                                std140_type_id,
2345                                                next_index,
2346                                                spirv::Decoration::Offset,
2347                                                &[member.offset],
2348                                            ));
2349                                            if self.flags.contains(WriterFlags::DEBUG) {
2350                                                if let Some(ref name) = member.name {
2351                                                    self.debugs.push(Instruction::member_name(
2352                                                        std140_type_id,
2353                                                        next_index,
2354                                                        name,
2355                                                    ));
2356                                                }
2357                                            }
2358                                            std140_member_type_info.type_id
2359                                        }
2360                                        None => {
2361                                            self.decorate_struct_member(
2362                                                std140_type_id,
2363                                                next_index as usize,
2364                                                member,
2365                                                &module.types,
2366                                            )?;
2367                                            self.get_handle_type_id(member.ty)
2368                                        }
2369                                    };
2370                                member_ids.push(member_id);
2371                                next_index += 1;
2372                            }
2373                        }
2374                    }
2375
2376                    Instruction::type_struct(std140_type_id, &member_ids)
2377                        .to_words(&mut self.logical_layout.declarations);
2378                    self.std140_compat_uniform_types.insert(
2379                        handle,
2380                        Std140CompatTypeInfo {
2381                            type_id: std140_type_id,
2382                            member_indices,
2383                        },
2384                    );
2385                    Some(std140_type_id)
2386                } else {
2387                    None
2388                }
2389            }
2390            _ => None,
2391        };
2392
2393        if let Some(std140_type_id) = std140_type_id {
2394            if self.flags.contains(WriterFlags::DEBUG) {
2395                let name = format!("std140_{:?}", handle.for_debug(&module.types));
2396                self.debugs.push(Instruction::name(std140_type_id, &name));
2397            }
2398        }
2399        Ok(std140_type_id)
2400    }
2401
2402    fn request_image_format_capabilities(
2403        &mut self,
2404        format: spirv::ImageFormat,
2405    ) -> Result<(), Error> {
2406        use spirv::ImageFormat as If;
2407        match format {
2408            If::Rg32f
2409            | If::Rg16f
2410            | If::R11fG11fB10f
2411            | If::R16f
2412            | If::Rgba16
2413            | If::Rgb10A2
2414            | If::Rg16
2415            | If::Rg8
2416            | If::R16
2417            | If::R8
2418            | If::Rgba16Snorm
2419            | If::Rg16Snorm
2420            | If::Rg8Snorm
2421            | If::R16Snorm
2422            | If::R8Snorm
2423            | If::Rg32i
2424            | If::Rg16i
2425            | If::Rg8i
2426            | If::R16i
2427            | If::R8i
2428            | If::Rgb10a2ui
2429            | If::Rg32ui
2430            | If::Rg16ui
2431            | If::Rg8ui
2432            | If::R16ui
2433            | If::R8ui => self.require_any(
2434                "storage image format",
2435                &[spirv::Capability::StorageImageExtendedFormats],
2436            ),
2437            If::R64ui | If::R64i => {
2438                self.use_extension("SPV_EXT_shader_image_int64");
2439                self.require_any(
2440                    "64-bit integer storage image format",
2441                    &[spirv::Capability::Int64ImageEXT],
2442                )
2443            }
2444            If::Unknown
2445            | If::Rgba32f
2446            | If::Rgba16f
2447            | If::R32f
2448            | If::Rgba8
2449            | If::Rgba8Snorm
2450            | If::Rgba32i
2451            | If::Rgba16i
2452            | If::Rgba8i
2453            | If::R32i
2454            | If::Rgba32ui
2455            | If::Rgba16ui
2456            | If::Rgba8ui
2457            | If::R32ui => Ok(()),
2458        }
2459    }
2460
2461    pub(super) fn get_index_constant(&mut self, index: Word) -> Word {
2462        self.get_constant_scalar(crate::Literal::U32(index))
2463    }
2464
2465    pub(super) fn get_constant_scalar_with(
2466        &mut self,
2467        value: u8,
2468        scalar: crate::Scalar,
2469    ) -> Result<Word, Error> {
2470        Ok(
2471            self.get_constant_scalar(crate::Literal::new(value, scalar).ok_or(
2472                Error::Validation("Unexpected kind and/or width for Literal"),
2473            )?),
2474        )
2475    }
2476
2477    pub(super) fn get_constant_scalar(&mut self, value: crate::Literal) -> Word {
2478        let scalar = CachedConstant::Literal(value.into());
2479        if let Some(&id) = self.cached_constants.get(&scalar) {
2480            return id;
2481        }
2482        let id = self.id_gen.next();
2483        self.write_constant_scalar(id, &value, None);
2484        self.cached_constants.insert(scalar, id);
2485        id
2486    }
2487
2488    fn write_constant_scalar(
2489        &mut self,
2490        id: Word,
2491        value: &crate::Literal,
2492        debug_name: Option<&String>,
2493    ) {
2494        if self.flags.contains(WriterFlags::DEBUG) {
2495            if let Some(name) = debug_name {
2496                self.debugs.push(Instruction::name(id, name));
2497            }
2498        }
2499        let type_id = self.get_numeric_type_id(NumericType::Scalar(value.scalar()));
2500        let instruction = match *value {
2501            crate::Literal::F64(value) => {
2502                let bits = value.to_bits();
2503                Instruction::constant_64bit(type_id, id, bits as u32, (bits >> 32) as u32)
2504            }
2505            crate::Literal::F32(value) => Instruction::constant_32bit(type_id, id, value.to_bits()),
2506            crate::Literal::F16(value) => {
2507                let low = value.to_bits();
2508                Instruction::constant_16bit(type_id, id, low as u32)
2509            }
2510            crate::Literal::U32(value) => Instruction::constant_32bit(type_id, id, value),
2511            crate::Literal::I32(value) => Instruction::constant_32bit(type_id, id, value as u32),
2512            crate::Literal::U64(value) => {
2513                Instruction::constant_64bit(type_id, id, value as u32, (value >> 32) as u32)
2514            }
2515            crate::Literal::I64(value) => {
2516                Instruction::constant_64bit(type_id, id, value as u32, (value >> 32) as u32)
2517            }
2518            crate::Literal::Bool(true) => Instruction::constant_true(type_id, id),
2519            crate::Literal::Bool(false) => Instruction::constant_false(type_id, id),
2520            crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
2521                unreachable!("Abstract types should not appear in IR presented to backends");
2522            }
2523        };
2524
2525        instruction.to_words(&mut self.logical_layout.declarations);
2526    }
2527
2528    pub(super) fn get_constant_composite(
2529        &mut self,
2530        ty: LookupType,
2531        constituent_ids: &[Word],
2532    ) -> Word {
2533        let composite = CachedConstant::Composite {
2534            ty,
2535            constituent_ids: constituent_ids.to_vec(),
2536        };
2537        if let Some(&id) = self.cached_constants.get(&composite) {
2538            return id;
2539        }
2540        let id = self.id_gen.next();
2541        self.write_constant_composite(id, ty, constituent_ids, None);
2542        self.cached_constants.insert(composite, id);
2543        id
2544    }
2545
2546    fn write_constant_composite(
2547        &mut self,
2548        id: Word,
2549        ty: LookupType,
2550        constituent_ids: &[Word],
2551        debug_name: Option<&String>,
2552    ) {
2553        if self.flags.contains(WriterFlags::DEBUG) {
2554            if let Some(name) = debug_name {
2555                self.debugs.push(Instruction::name(id, name));
2556            }
2557        }
2558        let type_id = self.get_type_id(ty);
2559        Instruction::constant_composite(type_id, id, constituent_ids)
2560            .to_words(&mut self.logical_layout.declarations);
2561    }
2562
2563    pub(super) fn get_constant_null(&mut self, type_id: Word) -> Word {
2564        let null = CachedConstant::ZeroValue(type_id);
2565        if let Some(&id) = self.cached_constants.get(&null) {
2566            return id;
2567        }
2568        let id = self.write_constant_null(type_id);
2569        self.cached_constants.insert(null, id);
2570        id
2571    }
2572
2573    pub(super) fn write_constant_null(&mut self, type_id: Word) -> Word {
2574        let null_id = self.id_gen.next();
2575        Instruction::constant_null(type_id, null_id)
2576            .to_words(&mut self.logical_layout.declarations);
2577        null_id
2578    }
2579
2580    fn write_constant_expr(
2581        &mut self,
2582        handle: Handle<crate::Expression>,
2583        ir_module: &crate::Module,
2584        mod_info: &ModuleInfo,
2585    ) -> Result<Word, Error> {
2586        let id = match ir_module.global_expressions[handle] {
2587            crate::Expression::Literal(literal) => self.get_constant_scalar(literal),
2588            crate::Expression::Constant(constant) => {
2589                let constant = &ir_module.constants[constant];
2590                self.constant_ids[constant.init]
2591            }
2592            crate::Expression::ZeroValue(ty) => {
2593                let type_id = self.get_handle_type_id(ty);
2594                self.get_constant_null(type_id)
2595            }
2596            crate::Expression::Compose { ty, ref components } => {
2597                let component_ids: Vec<_> = crate::proc::flatten_compose(
2598                    ty,
2599                    components,
2600                    &ir_module.global_expressions,
2601                    &ir_module.types,
2602                )
2603                .map(|component| self.constant_ids[component])
2604                .collect();
2605                self.get_constant_composite(LookupType::Handle(ty), component_ids.as_slice())
2606            }
2607            crate::Expression::Splat { size, value } => {
2608                let value_id = self.constant_ids[value];
2609                let component_ids = &[value_id; 4][..size as usize];
2610
2611                let ty = self.get_expression_lookup_type(&mod_info[handle]);
2612
2613                self.get_constant_composite(ty, component_ids)
2614            }
2615            _ => {
2616                return Err(Error::Override);
2617            }
2618        };
2619
2620        self.constant_ids[handle] = id;
2621
2622        Ok(id)
2623    }
2624
2625    pub(super) fn write_control_barrier(
2626        &mut self,
2627        flags: crate::Barrier,
2628        body: &mut Vec<Instruction>,
2629    ) {
2630        let memory_scope = if flags.contains(crate::Barrier::STORAGE) {
2631            spirv::Scope::Device
2632        } else if flags.contains(crate::Barrier::SUB_GROUP) {
2633            spirv::Scope::Subgroup
2634        } else {
2635            spirv::Scope::Workgroup
2636        };
2637        let mut semantics = spirv::MemorySemantics::ACQUIRE_RELEASE;
2638        semantics.set(
2639            spirv::MemorySemantics::UNIFORM_MEMORY,
2640            flags.contains(crate::Barrier::STORAGE),
2641        );
2642        semantics.set(
2643            spirv::MemorySemantics::WORKGROUP_MEMORY,
2644            flags.contains(crate::Barrier::WORK_GROUP),
2645        );
2646        semantics.set(
2647            spirv::MemorySemantics::SUBGROUP_MEMORY,
2648            flags.contains(crate::Barrier::SUB_GROUP),
2649        );
2650        semantics.set(
2651            spirv::MemorySemantics::IMAGE_MEMORY,
2652            flags.contains(crate::Barrier::TEXTURE),
2653        );
2654        let exec_scope_id = if flags.contains(crate::Barrier::SUB_GROUP) {
2655            self.get_index_constant(spirv::Scope::Subgroup as u32)
2656        } else {
2657            self.get_index_constant(spirv::Scope::Workgroup as u32)
2658        };
2659        let mem_scope_id = self.get_index_constant(memory_scope as u32);
2660        let semantics_id = self.get_index_constant(semantics.bits());
2661        body.push(Instruction::control_barrier(
2662            exec_scope_id,
2663            mem_scope_id,
2664            semantics_id,
2665        ));
2666    }
2667
2668    pub(super) fn write_memory_barrier(&mut self, flags: crate::Barrier, block: &mut Block) {
2669        let mut semantics = spirv::MemorySemantics::ACQUIRE_RELEASE;
2670        semantics.set(
2671            spirv::MemorySemantics::UNIFORM_MEMORY,
2672            flags.contains(crate::Barrier::STORAGE),
2673        );
2674        semantics.set(
2675            spirv::MemorySemantics::WORKGROUP_MEMORY,
2676            flags.contains(crate::Barrier::WORK_GROUP),
2677        );
2678        semantics.set(
2679            spirv::MemorySemantics::SUBGROUP_MEMORY,
2680            flags.contains(crate::Barrier::SUB_GROUP),
2681        );
2682        semantics.set(
2683            spirv::MemorySemantics::IMAGE_MEMORY,
2684            flags.contains(crate::Barrier::TEXTURE),
2685        );
2686        let mem_scope_id = if flags.contains(crate::Barrier::STORAGE) {
2687            self.get_index_constant(spirv::Scope::Device as u32)
2688        } else if flags.contains(crate::Barrier::SUB_GROUP) {
2689            self.get_index_constant(spirv::Scope::Subgroup as u32)
2690        } else {
2691            self.get_index_constant(spirv::Scope::Workgroup as u32)
2692        };
2693        let semantics_id = self.get_index_constant(semantics.bits());
2694        block
2695            .body
2696            .push(Instruction::memory_barrier(mem_scope_id, semantics_id));
2697    }
2698
2699    fn generate_workgroup_vars_init_block(
2700        &mut self,
2701        entry_id: Word,
2702        ir_module: &crate::Module,
2703        info: &FunctionInfo,
2704        local_invocation_id: Option<Word>,
2705        interface: &mut FunctionInterface,
2706        function: &mut Function,
2707    ) -> Option<Word> {
2708        let body = ir_module
2709            .global_variables
2710            .iter()
2711            .filter(|&(handle, var)| {
2712                !info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
2713            })
2714            .map(|(handle, var)| {
2715                // It's safe to use `var_id` here, not `access_id`, because only
2716                // variables in the `Uniform` and `StorageBuffer` address spaces
2717                // get wrapped, and we're initializing `WorkGroup` variables.
2718                let var_id = self.global_variables[handle].var_id;
2719                let var_type_id = self.get_handle_type_id(var.ty);
2720                let init_word = self.get_constant_null(var_type_id);
2721                Instruction::store(var_id, init_word, None)
2722            })
2723            .collect::<Vec<_>>();
2724
2725        if body.is_empty() {
2726            return None;
2727        }
2728
2729        let uint3_type_id = self.get_vec3u_type_id();
2730
2731        let mut pre_if_block = Block::new(entry_id);
2732
2733        let local_invocation_id = if let Some(local_invocation_id) = local_invocation_id {
2734            local_invocation_id
2735        } else {
2736            let varying_id = self.id_gen.next();
2737            let class = spirv::StorageClass::Input;
2738            let pointer_type_id = self.get_vec3u_pointer_type_id(class);
2739
2740            Instruction::variable(pointer_type_id, varying_id, class, None)
2741                .to_words(&mut self.logical_layout.declarations);
2742
2743            self.decorate(
2744                varying_id,
2745                spirv::Decoration::BuiltIn,
2746                &[spirv::BuiltIn::LocalInvocationId as u32],
2747            );
2748
2749            interface.varying_ids.push(varying_id);
2750            let id = self.id_gen.next();
2751            pre_if_block
2752                .body
2753                .push(Instruction::load(uint3_type_id, id, varying_id, None));
2754
2755            id
2756        };
2757
2758        let zero_id = self.get_constant_null(uint3_type_id);
2759        let bool3_type_id = self.get_vec3_bool_type_id();
2760
2761        let eq_id = self.id_gen.next();
2762        pre_if_block.body.push(Instruction::binary(
2763            spirv::Op::IEqual,
2764            bool3_type_id,
2765            eq_id,
2766            local_invocation_id,
2767            zero_id,
2768        ));
2769
2770        let condition_id = self.id_gen.next();
2771        let bool_type_id = self.get_bool_type_id();
2772        pre_if_block.body.push(Instruction::relational(
2773            spirv::Op::All,
2774            bool_type_id,
2775            condition_id,
2776            eq_id,
2777        ));
2778
2779        let merge_id = self.id_gen.next();
2780        pre_if_block.body.push(Instruction::selection_merge(
2781            merge_id,
2782            spirv::SelectionControl::NONE,
2783        ));
2784
2785        let accept_id = self.id_gen.next();
2786        function.consume(
2787            pre_if_block,
2788            Instruction::branch_conditional(condition_id, accept_id, merge_id),
2789        );
2790
2791        let accept_block = Block {
2792            label_id: accept_id,
2793            body,
2794        };
2795        function.consume(accept_block, Instruction::branch(merge_id));
2796
2797        let mut post_if_block = Block::new(merge_id);
2798
2799        self.write_control_barrier(crate::Barrier::WORK_GROUP, &mut post_if_block.body);
2800
2801        let next_id = self.id_gen.next();
2802        function.consume(post_if_block, Instruction::branch(next_id));
2803        Some(next_id)
2804    }
2805
2806    /// Generate an `OpVariable` for one value in an [`EntryPoint`]'s IO interface.
2807    ///
2808    /// The [`Binding`]s of the arguments and result of an [`EntryPoint`]'s
2809    /// [`Function`] describe a SPIR-V shader interface. In SPIR-V, the
2810    /// interface is represented by global variables in the `Input` and `Output`
2811    /// storage classes, with decorations indicating which builtin or location
2812    /// each variable corresponds to.
2813    ///
2814    /// This function emits a single global `OpVariable` for a single value from
2815    /// the interface, and adds appropriate decorations to indicate which
2816    /// builtin or location it represents, how it should be interpolated, and so
2817    /// on. The `class` argument gives the variable's SPIR-V storage class,
2818    /// which should be either [`Input`] or [`Output`].
2819    ///
2820    /// [`Binding`]: crate::Binding
2821    /// [`Function`]: crate::Function
2822    /// [`EntryPoint`]: crate::EntryPoint
2823    /// [`Input`]: spirv::StorageClass::Input
2824    /// [`Output`]: spirv::StorageClass::Output
2825    fn write_varying(
2826        &mut self,
2827        ir_module: &crate::Module,
2828        stage: crate::ShaderStage,
2829        class: spirv::StorageClass,
2830        debug_name: Option<&str>,
2831        ty: Handle<crate::Type>,
2832        binding: &crate::Binding,
2833    ) -> Result<Word, Error> {
2834        let id = self.id_gen.next();
2835        let ty_inner = &ir_module.types[ty].inner;
2836        let needs_polyfill = self.needs_f16_polyfill(ty_inner);
2837
2838        let pointer_type_id = if needs_polyfill {
2839            let f32_value_local =
2840                super::f16_polyfill::F16IoPolyfill::create_polyfill_type(ty_inner)
2841                    .expect("needs_polyfill returned true but create_polyfill_type returned None");
2842
2843            let f32_type_id = self.get_localtype_id(f32_value_local);
2844            let ptr_id = self.get_pointer_type_id(f32_type_id, class);
2845            self.io_f16_polyfills.register_io_var(id, f32_type_id);
2846
2847            ptr_id
2848        } else {
2849            self.get_handle_pointer_type_id(ty, class)
2850        };
2851
2852        Instruction::variable(pointer_type_id, id, class, None)
2853            .to_words(&mut self.logical_layout.declarations);
2854
2855        if self
2856            .flags
2857            .contains(WriterFlags::DEBUG | WriterFlags::LABEL_VARYINGS)
2858        {
2859            if let Some(name) = debug_name {
2860                self.debugs.push(Instruction::name(id, name));
2861            }
2862        }
2863
2864        let binding = self.map_binding(ir_module, stage, class, ty, binding)?;
2865        self.write_binding(id, binding);
2866
2867        Ok(id)
2868    }
2869
2870    pub fn write_binding(&mut self, id: Word, binding: BindingDecorations) {
2871        match binding {
2872            BindingDecorations::None => (),
2873            BindingDecorations::BuiltIn(bi, others) => {
2874                self.decorate(id, spirv::Decoration::BuiltIn, &[bi as u32]);
2875                for other in others {
2876                    self.decorate(id, other, &[]);
2877                }
2878            }
2879            BindingDecorations::Location {
2880                location,
2881                others,
2882                blend_src,
2883            } => {
2884                self.decorate(id, spirv::Decoration::Location, &[location]);
2885                for other in others {
2886                    self.decorate(id, other, &[]);
2887                }
2888                if let Some(blend_src) = blend_src {
2889                    self.decorate(id, spirv::Decoration::Index, &[blend_src]);
2890                }
2891            }
2892        }
2893    }
2894
2895    pub fn write_binding_struct_member(
2896        &mut self,
2897        struct_id: Word,
2898        member_idx: Word,
2899        binding_info: BindingDecorations,
2900    ) {
2901        match binding_info {
2902            BindingDecorations::None => (),
2903            BindingDecorations::BuiltIn(bi, others) => {
2904                self.annotations.push(Instruction::member_decorate(
2905                    struct_id,
2906                    member_idx,
2907                    spirv::Decoration::BuiltIn,
2908                    &[bi as Word],
2909                ));
2910                for other in others {
2911                    self.annotations.push(Instruction::member_decorate(
2912                        struct_id,
2913                        member_idx,
2914                        other,
2915                        &[],
2916                    ));
2917                }
2918            }
2919            BindingDecorations::Location {
2920                location,
2921                others,
2922                blend_src,
2923            } => {
2924                self.annotations.push(Instruction::member_decorate(
2925                    struct_id,
2926                    member_idx,
2927                    spirv::Decoration::Location,
2928                    &[location],
2929                ));
2930                for other in others {
2931                    self.annotations.push(Instruction::member_decorate(
2932                        struct_id,
2933                        member_idx,
2934                        other,
2935                        &[],
2936                    ));
2937                }
2938                if let Some(blend_src) = blend_src {
2939                    self.annotations.push(Instruction::member_decorate(
2940                        struct_id,
2941                        member_idx,
2942                        spirv::Decoration::Index,
2943                        &[blend_src],
2944                    ));
2945                }
2946            }
2947        }
2948    }
2949
2950    pub fn map_binding(
2951        &mut self,
2952        ir_module: &crate::Module,
2953        stage: crate::ShaderStage,
2954        class: spirv::StorageClass,
2955        ty: Handle<crate::Type>,
2956        binding: &crate::Binding,
2957    ) -> Result<BindingDecorations, Error> {
2958        use spirv::BuiltIn;
2959        use spirv::Decoration;
2960        match *binding {
2961            crate::Binding::Location {
2962                location,
2963                interpolation,
2964                sampling,
2965                blend_src,
2966                per_primitive,
2967            } => {
2968                let mut others = ArrayVec::new();
2969
2970                let no_decorations =
2971                    // VUID-StandaloneSpirv-Flat-06202
2972                    // > The Flat, NoPerspective, Sample, and Centroid decorations
2973                    // > must not be used on variables with the Input storage class in a vertex shader
2974                    (class == spirv::StorageClass::Input && stage == crate::ShaderStage::Vertex) ||
2975                    // VUID-StandaloneSpirv-Flat-06201
2976                    // > The Flat, NoPerspective, Sample, and Centroid decorations
2977                    // > must not be used on variables with the Output storage class in a fragment shader
2978                    (class == spirv::StorageClass::Output && stage == crate::ShaderStage::Fragment);
2979
2980                if !no_decorations {
2981                    match interpolation {
2982                        // Perspective-correct interpolation is the default in SPIR-V.
2983                        None | Some(crate::Interpolation::Perspective) => (),
2984                        Some(crate::Interpolation::Flat) => {
2985                            others.push(Decoration::Flat);
2986                        }
2987                        Some(crate::Interpolation::Linear) => {
2988                            others.push(Decoration::NoPerspective);
2989                        }
2990                        Some(crate::Interpolation::PerVertex) => {
2991                            others.push(Decoration::PerVertexKHR);
2992                            self.require_any(
2993                                "`per_vertex` interpolation",
2994                                &[spirv::Capability::FragmentBarycentricKHR],
2995                            )?;
2996                            self.use_extension("SPV_KHR_fragment_shader_barycentric");
2997                        }
2998                    }
2999                    match sampling {
3000                        // Center sampling is the default in SPIR-V.
3001                        None
3002                        | Some(
3003                            crate::Sampling::Center
3004                            | crate::Sampling::First
3005                            | crate::Sampling::Either,
3006                        ) => (),
3007                        Some(crate::Sampling::Centroid) => {
3008                            others.push(Decoration::Centroid);
3009                        }
3010                        Some(crate::Sampling::Sample) => {
3011                            self.require_any(
3012                                "per-sample interpolation",
3013                                &[spirv::Capability::SampleRateShading],
3014                            )?;
3015                            others.push(Decoration::Sample);
3016                        }
3017                    }
3018                }
3019                if per_primitive && stage == crate::ShaderStage::Fragment {
3020                    others.push(Decoration::PerPrimitiveEXT);
3021                    self.require_mesh_shaders()?;
3022                }
3023                Ok(BindingDecorations::Location {
3024                    location,
3025                    others,
3026                    blend_src,
3027                })
3028            }
3029            crate::Binding::BuiltIn(built_in) => {
3030                use crate::BuiltIn as Bi;
3031                let mut others = ArrayVec::new();
3032
3033                if matches!(
3034                    built_in,
3035                    Bi::CullPrimitive | Bi::PointIndex | Bi::LineIndices | Bi::TriangleIndices
3036                ) {
3037                    self.require_mesh_shaders()?;
3038                }
3039
3040                let built_in = match built_in {
3041                    Bi::Position { invariant } => {
3042                        if invariant {
3043                            others.push(Decoration::Invariant);
3044                        }
3045
3046                        if class == spirv::StorageClass::Output {
3047                            BuiltIn::Position
3048                        } else {
3049                            BuiltIn::FragCoord
3050                        }
3051                    }
3052                    Bi::ViewIndex => {
3053                        self.require_any("`view_index` built-in", &[spirv::Capability::MultiView])?;
3054                        BuiltIn::ViewIndex
3055                    }
3056                    // vertex
3057                    Bi::BaseInstance => BuiltIn::BaseInstance,
3058                    Bi::BaseVertex => BuiltIn::BaseVertex,
3059                    Bi::ClipDistance => {
3060                        self.require_any(
3061                            "`clip_distance` built-in",
3062                            &[spirv::Capability::ClipDistance],
3063                        )?;
3064                        BuiltIn::ClipDistance
3065                    }
3066                    Bi::CullDistance => {
3067                        self.require_any(
3068                            "`cull_distance` built-in",
3069                            &[spirv::Capability::CullDistance],
3070                        )?;
3071                        BuiltIn::CullDistance
3072                    }
3073                    Bi::InstanceIndex => BuiltIn::InstanceIndex,
3074                    Bi::PointSize => BuiltIn::PointSize,
3075                    Bi::VertexIndex => BuiltIn::VertexIndex,
3076                    Bi::DrawID => BuiltIn::DrawIndex,
3077                    // fragment
3078                    Bi::FragDepth => BuiltIn::FragDepth,
3079                    Bi::PointCoord => BuiltIn::PointCoord,
3080                    Bi::FrontFacing => BuiltIn::FrontFacing,
3081                    Bi::PrimitiveIndex => {
3082                        self.require_any(
3083                            "`primitive_index` built-in",
3084                            &[spirv::Capability::Geometry],
3085                        )?;
3086                        if stage == crate::ShaderStage::Mesh {
3087                            others.push(Decoration::PerPrimitiveEXT);
3088                        }
3089                        BuiltIn::PrimitiveId
3090                    }
3091                    Bi::Barycentric { perspective } => {
3092                        self.require_any(
3093                            "`barycentric` built-in",
3094                            &[spirv::Capability::FragmentBarycentricKHR],
3095                        )?;
3096                        self.use_extension("SPV_KHR_fragment_shader_barycentric");
3097                        if perspective {
3098                            BuiltIn::BaryCoordKHR
3099                        } else {
3100                            BuiltIn::BaryCoordNoPerspKHR
3101                        }
3102                    }
3103                    Bi::SampleIndex => {
3104                        self.require_any(
3105                            "`sample_index` built-in",
3106                            &[spirv::Capability::SampleRateShading],
3107                        )?;
3108
3109                        BuiltIn::SampleId
3110                    }
3111                    Bi::SampleMask => BuiltIn::SampleMask,
3112                    // compute
3113                    Bi::GlobalInvocationId => BuiltIn::GlobalInvocationId,
3114                    Bi::LocalInvocationId => BuiltIn::LocalInvocationId,
3115                    Bi::LocalInvocationIndex => BuiltIn::LocalInvocationIndex,
3116                    Bi::WorkGroupId => BuiltIn::WorkgroupId,
3117                    Bi::WorkGroupSize => BuiltIn::WorkgroupSize,
3118                    Bi::NumWorkGroups => BuiltIn::NumWorkgroups,
3119                    // Subgroup
3120                    Bi::NumSubgroups => {
3121                        self.require_any(
3122                            "`num_subgroups` built-in",
3123                            &[spirv::Capability::GroupNonUniform],
3124                        )?;
3125                        BuiltIn::NumSubgroups
3126                    }
3127                    Bi::SubgroupId => {
3128                        self.require_any(
3129                            "`subgroup_id` built-in",
3130                            &[spirv::Capability::GroupNonUniform],
3131                        )?;
3132                        BuiltIn::SubgroupId
3133                    }
3134                    Bi::SubgroupSize => {
3135                        self.require_any(
3136                            "`subgroup_size` built-in",
3137                            &[
3138                                spirv::Capability::GroupNonUniform,
3139                                spirv::Capability::SubgroupBallotKHR,
3140                            ],
3141                        )?;
3142                        BuiltIn::SubgroupSize
3143                    }
3144                    Bi::SubgroupInvocationId => {
3145                        self.require_any(
3146                            "`subgroup_invocation_id` built-in",
3147                            &[
3148                                spirv::Capability::GroupNonUniform,
3149                                spirv::Capability::SubgroupBallotKHR,
3150                            ],
3151                        )?;
3152                        BuiltIn::SubgroupLocalInvocationId
3153                    }
3154                    Bi::CullPrimitive => {
3155                        self.require_mesh_shaders()?;
3156                        others.push(Decoration::PerPrimitiveEXT);
3157                        BuiltIn::CullPrimitiveEXT
3158                    }
3159                    Bi::PointIndex => {
3160                        self.require_mesh_shaders()?;
3161                        BuiltIn::PrimitivePointIndicesEXT
3162                    }
3163                    Bi::LineIndices => {
3164                        self.require_mesh_shaders()?;
3165                        BuiltIn::PrimitiveLineIndicesEXT
3166                    }
3167                    Bi::TriangleIndices => {
3168                        self.require_mesh_shaders()?;
3169                        BuiltIn::PrimitiveTriangleIndicesEXT
3170                    }
3171                    // No decoration, this EmitMeshTasksEXT is called at function return
3172                    Bi::MeshTaskSize => return Ok(BindingDecorations::None),
3173                    // These aren't normal builtins and don't occur in function output
3174                    Bi::VertexCount | Bi::Vertices | Bi::PrimitiveCount | Bi::Primitives => {
3175                        unreachable!()
3176                    }
3177                };
3178
3179                use crate::ScalarKind as Sk;
3180
3181                // Per the Vulkan spec, `VUID-StandaloneSpirv-Flat-04744`:
3182                //
3183                // > Any variable with integer or double-precision floating-
3184                // > point type and with Input storage class in a fragment
3185                // > shader, must be decorated Flat
3186                if class == spirv::StorageClass::Input && stage == crate::ShaderStage::Fragment {
3187                    let is_flat = match ir_module.types[ty].inner {
3188                        crate::TypeInner::Scalar(scalar)
3189                        | crate::TypeInner::Vector { scalar, .. } => match scalar.kind {
3190                            Sk::Uint | Sk::Sint | Sk::Bool => true,
3191                            Sk::Float => false,
3192                            Sk::AbstractInt | Sk::AbstractFloat => {
3193                                return Err(Error::Validation(
3194                                    "Abstract types should not appear in IR presented to backends",
3195                                ))
3196                            }
3197                        },
3198                        _ => false,
3199                    };
3200
3201                    if is_flat {
3202                        others.push(Decoration::Flat);
3203                    }
3204                }
3205                Ok(BindingDecorations::BuiltIn(built_in, others))
3206            }
3207        }
3208    }
3209
3210    /// Load an IO variable, converting from `f32` to `f16` if polyfill is active.
3211    /// Returns the id of the loaded value matching `target_type_id`.
3212    pub(super) fn load_io_with_f16_polyfill(
3213        &mut self,
3214        body: &mut Vec<Instruction>,
3215        varying_id: Word,
3216        target_type_id: Word,
3217    ) -> Word {
3218        let tmp = self.id_gen.next();
3219        if let Some(f32_ty) = self.io_f16_polyfills.get_f32_io_type(varying_id) {
3220            body.push(Instruction::load(f32_ty, tmp, varying_id, None));
3221            let converted = self.id_gen.next();
3222            super::f16_polyfill::F16IoPolyfill::emit_f32_to_f16_conversion(
3223                tmp,
3224                target_type_id,
3225                converted,
3226                body,
3227            );
3228            converted
3229        } else {
3230            body.push(Instruction::load(target_type_id, tmp, varying_id, None));
3231            tmp
3232        }
3233    }
3234
3235    /// Store an IO variable, converting from `f16` to `f32` if polyfill is active.
3236    pub(super) fn store_io_with_f16_polyfill(
3237        &mut self,
3238        body: &mut Vec<Instruction>,
3239        varying_id: Word,
3240        value_id: Word,
3241    ) {
3242        if let Some(f32_ty) = self.io_f16_polyfills.get_f32_io_type(varying_id) {
3243            let converted = self.id_gen.next();
3244            super::f16_polyfill::F16IoPolyfill::emit_f16_to_f32_conversion(
3245                value_id, f32_ty, converted, body,
3246            );
3247            body.push(Instruction::store(varying_id, converted, None));
3248        } else {
3249            body.push(Instruction::store(varying_id, value_id, None));
3250        }
3251    }
3252
3253    fn write_global_variable(
3254        &mut self,
3255        ir_module: &crate::Module,
3256        global_variable: &crate::GlobalVariable,
3257    ) -> Result<Word, Error> {
3258        use spirv::Decoration;
3259
3260        let id = self.id_gen.next();
3261        let class = map_storage_class(global_variable.space);
3262
3263        //self.check(class.required_capabilities())?;
3264
3265        if self.flags.contains(WriterFlags::DEBUG) {
3266            if let Some(ref name) = global_variable.name {
3267                self.debugs.push(Instruction::name(id, name));
3268            }
3269        }
3270
3271        let storage_access = match global_variable.space {
3272            crate::AddressSpace::Storage { access } => Some(access),
3273            _ => match ir_module.types[global_variable.ty].inner {
3274                crate::TypeInner::Image {
3275                    class: crate::ImageClass::Storage { access, .. },
3276                    ..
3277                } => Some(access),
3278                _ => None,
3279            },
3280        };
3281        if let Some(storage_access) = storage_access {
3282            if !storage_access.contains(crate::StorageAccess::LOAD) {
3283                self.decorate(id, Decoration::NonReadable, &[]);
3284            }
3285            if !storage_access.contains(crate::StorageAccess::STORE) {
3286                self.decorate(id, Decoration::NonWritable, &[]);
3287            }
3288        }
3289
3290        // Note: we should be able to substitute `binding_array<Foo, 0>`,
3291        // but there is still code that tries to register the pre-substituted type,
3292        // and it is failing on 0.
3293        let mut substitute_inner_type_lookup = None;
3294        if let Some(ref res_binding) = global_variable.binding {
3295            let bind_target = self.resolve_resource_binding(res_binding)?;
3296            self.decorate(id, Decoration::DescriptorSet, &[bind_target.descriptor_set]);
3297            self.decorate(id, Decoration::Binding, &[bind_target.binding]);
3298
3299            if let Some(remapped_binding_array_size) = bind_target.binding_array_size {
3300                if let crate::TypeInner::BindingArray { base, .. } =
3301                    ir_module.types[global_variable.ty].inner
3302                {
3303                    let binding_array_type_id =
3304                        self.get_type_id(LookupType::Local(LocalType::BindingArray {
3305                            base,
3306                            size: remapped_binding_array_size,
3307                        }));
3308                    substitute_inner_type_lookup = Some(LookupType::Local(LocalType::Pointer {
3309                        base: binding_array_type_id,
3310                        class,
3311                    }));
3312                }
3313            }
3314        };
3315
3316        let init_word = global_variable
3317            .init
3318            .map(|constant| self.constant_ids[constant]);
3319        let inner_type_id = self.get_type_id(
3320            substitute_inner_type_lookup.unwrap_or(LookupType::Handle(global_variable.ty)),
3321        );
3322
3323        // generate the wrapping structure if needed
3324        let pointer_type_id = if global_needs_wrapper(ir_module, global_variable) {
3325            let wrapper_type_id = self.id_gen.next();
3326
3327            self.decorate(wrapper_type_id, Decoration::Block, &[]);
3328
3329            match self.std140_compat_uniform_types.get(&global_variable.ty) {
3330                Some(std140_type_info) if global_variable.space == crate::AddressSpace::Uniform => {
3331                    self.annotations.push(Instruction::member_decorate(
3332                        wrapper_type_id,
3333                        0,
3334                        Decoration::Offset,
3335                        &[0],
3336                    ));
3337                    Instruction::type_struct(wrapper_type_id, &[std140_type_info.type_id])
3338                        .to_words(&mut self.logical_layout.declarations);
3339                }
3340                _ => {
3341                    let member = crate::StructMember {
3342                        name: None,
3343                        ty: global_variable.ty,
3344                        binding: None,
3345                        offset: 0,
3346                    };
3347                    self.decorate_struct_member(wrapper_type_id, 0, &member, &ir_module.types)?;
3348
3349                    Instruction::type_struct(wrapper_type_id, &[inner_type_id])
3350                        .to_words(&mut self.logical_layout.declarations);
3351                }
3352            }
3353
3354            let pointer_type_id = self.id_gen.next();
3355            Instruction::type_pointer(pointer_type_id, class, wrapper_type_id)
3356                .to_words(&mut self.logical_layout.declarations);
3357
3358            pointer_type_id
3359        } else {
3360            // This is a global variable in the Storage address space. The only
3361            // way it could have `global_needs_wrapper() == false` is if it has
3362            // a runtime-sized or binding array.
3363            // Runtime-sized arrays were decorated when iterating through struct content.
3364            // Now binding arrays require Block decorating.
3365            if let crate::AddressSpace::Storage { .. } = global_variable.space {
3366                match ir_module.types[global_variable.ty].inner {
3367                    crate::TypeInner::BindingArray { base, .. } => {
3368                        let ty = &ir_module.types[base];
3369                        let mut should_decorate = true;
3370                        // Check if the type has a runtime array.
3371                        // A normal runtime array gets validated out,
3372                        // so only structs can be with runtime arrays
3373                        if let crate::TypeInner::Struct { ref members, .. } = ty.inner {
3374                            // only the last member in a struct can be dynamically sized
3375                            if let Some(last_member) = members.last() {
3376                                if let &crate::TypeInner::Array {
3377                                    size: crate::ArraySize::Dynamic,
3378                                    ..
3379                                } = &ir_module.types[last_member.ty].inner
3380                                {
3381                                    should_decorate = false;
3382                                }
3383                            }
3384                        }
3385                        if should_decorate {
3386                            let decorated_id = self.get_handle_type_id(base);
3387                            self.decorate(decorated_id, Decoration::Block, &[]);
3388                        }
3389                    }
3390                    _ => (),
3391                };
3392            }
3393            if substitute_inner_type_lookup.is_some() {
3394                inner_type_id
3395            } else {
3396                self.get_handle_pointer_type_id(global_variable.ty, class)
3397            }
3398        };
3399
3400        let init_word = match (global_variable.space, self.zero_initialize_workgroup_memory) {
3401            (crate::AddressSpace::Private, _)
3402            | (crate::AddressSpace::WorkGroup, super::ZeroInitializeWorkgroupMemoryMode::Native) => {
3403                init_word.or_else(|| Some(self.get_constant_null(inner_type_id)))
3404            }
3405            _ => init_word,
3406        };
3407
3408        Instruction::variable(pointer_type_id, id, class, init_word)
3409            .to_words(&mut self.logical_layout.declarations);
3410        Ok(id)
3411    }
3412
3413    /// Write the necessary decorations for a struct member.
3414    ///
3415    /// Emit decorations for the `index`'th member of the struct type
3416    /// designated by `struct_id`, described by `member`.
3417    fn decorate_struct_member(
3418        &mut self,
3419        struct_id: Word,
3420        index: usize,
3421        member: &crate::StructMember,
3422        arena: &UniqueArena<crate::Type>,
3423    ) -> Result<(), Error> {
3424        use spirv::Decoration;
3425
3426        self.annotations.push(Instruction::member_decorate(
3427            struct_id,
3428            index as u32,
3429            Decoration::Offset,
3430            &[member.offset],
3431        ));
3432
3433        if self.flags.contains(WriterFlags::DEBUG) {
3434            if let Some(ref name) = member.name {
3435                self.debugs
3436                    .push(Instruction::member_name(struct_id, index as u32, name));
3437            }
3438        }
3439
3440        // Matrices and (potentially nested) arrays of matrices both require decorations,
3441        // so "see through" any arrays to determine if they're needed.
3442        let mut member_array_subty_inner = &arena[member.ty].inner;
3443        while let crate::TypeInner::Array { base, .. } = *member_array_subty_inner {
3444            member_array_subty_inner = &arena[base].inner;
3445        }
3446
3447        if let crate::TypeInner::Matrix {
3448            columns: _,
3449            rows,
3450            scalar,
3451        } = *member_array_subty_inner
3452        {
3453            let byte_stride = Alignment::from(rows) * scalar.width as u32;
3454            self.annotations.push(Instruction::member_decorate(
3455                struct_id,
3456                index as u32,
3457                Decoration::ColMajor,
3458                &[],
3459            ));
3460            self.annotations.push(Instruction::member_decorate(
3461                struct_id,
3462                index as u32,
3463                Decoration::MatrixStride,
3464                &[byte_stride],
3465            ));
3466        }
3467
3468        Ok(())
3469    }
3470
3471    pub(super) fn get_function_type(&mut self, lookup_function_type: LookupFunctionType) -> Word {
3472        match self
3473            .lookup_function_type
3474            .entry(lookup_function_type.clone())
3475        {
3476            Entry::Occupied(e) => *e.get(),
3477            Entry::Vacant(_) => {
3478                let id = self.id_gen.next();
3479                let instruction = Instruction::type_function(
3480                    id,
3481                    lookup_function_type.return_type_id,
3482                    &lookup_function_type.parameter_type_ids,
3483                );
3484                instruction.to_words(&mut self.logical_layout.declarations);
3485                self.lookup_function_type.insert(lookup_function_type, id);
3486                id
3487            }
3488        }
3489    }
3490
3491    fn write_physical_layout(&mut self) {
3492        self.physical_layout.bound = self.id_gen.0 + 1;
3493    }
3494
3495    fn write_logical_layout(
3496        &mut self,
3497        ir_module: &crate::Module,
3498        mod_info: &ModuleInfo,
3499        ep_index: Option<usize>,
3500        debug_info: &Option<DebugInfo>,
3501    ) -> Result<(), Error> {
3502        fn has_view_index_check(
3503            ir_module: &crate::Module,
3504            binding: Option<&crate::Binding>,
3505            ty: Handle<crate::Type>,
3506        ) -> bool {
3507            match ir_module.types[ty].inner {
3508                crate::TypeInner::Struct { ref members, .. } => members.iter().any(|member| {
3509                    has_view_index_check(ir_module, member.binding.as_ref(), member.ty)
3510                }),
3511                _ => binding == Some(&crate::Binding::BuiltIn(crate::BuiltIn::ViewIndex)),
3512            }
3513        }
3514
3515        let has_storage_buffers =
3516            ir_module
3517                .global_variables
3518                .iter()
3519                .any(|(_, var)| match var.space {
3520                    crate::AddressSpace::Storage { .. } => true,
3521                    _ => false,
3522                });
3523        let has_view_index = ir_module
3524            .entry_points
3525            .iter()
3526            .flat_map(|entry| entry.function.arguments.iter())
3527            .any(|arg| has_view_index_check(ir_module, arg.binding.as_ref(), arg.ty));
3528        let mut has_ray_query = ir_module.special_types.ray_desc.is_some()
3529            | ir_module.special_types.ray_intersection.is_some();
3530        let has_vertex_return = ir_module.special_types.ray_vertex_return.is_some();
3531
3532        // Ways mesh shaders are required:
3533        // * Mesh entry point used - checked for
3534        // * Mesh function like setVertex used outside mesh entry point, this is handled when those are written
3535        // * Fragment shader with per primitive data - handled in `map_binding`
3536        let has_mesh_shaders = ir_module.entry_points.iter().any(|entry| {
3537            entry.stage == crate::ShaderStage::Mesh || entry.stage == crate::ShaderStage::Task
3538        }) || ir_module
3539            .global_variables
3540            .iter()
3541            .any(|gvar| gvar.1.space == crate::AddressSpace::TaskPayload);
3542
3543        for (_, &crate::Type { ref inner, .. }) in ir_module.types.iter() {
3544            // spirv does not know whether these have vertex return - that is done by us
3545            if let &crate::TypeInner::AccelerationStructure { .. }
3546            | &crate::TypeInner::RayQuery { .. } = inner
3547            {
3548                has_ray_query = true
3549            }
3550        }
3551
3552        if self.physical_layout.version < 0x10300 && has_storage_buffers {
3553            // enable the storage buffer class on < SPV-1.3
3554            Instruction::extension("SPV_KHR_storage_buffer_storage_class")
3555                .to_words(&mut self.logical_layout.extensions);
3556        }
3557        if has_view_index {
3558            Instruction::extension("SPV_KHR_multiview")
3559                .to_words(&mut self.logical_layout.extensions)
3560        }
3561        if has_ray_query {
3562            Instruction::extension("SPV_KHR_ray_query")
3563                .to_words(&mut self.logical_layout.extensions)
3564        }
3565        if has_vertex_return {
3566            Instruction::extension("SPV_KHR_ray_tracing_position_fetch")
3567                .to_words(&mut self.logical_layout.extensions);
3568        }
3569        if has_mesh_shaders {
3570            self.require_mesh_shaders()?;
3571        }
3572        Instruction::type_void(self.void_type).to_words(&mut self.logical_layout.declarations);
3573        Instruction::ext_inst_import(self.gl450_ext_inst_id, "GLSL.std.450")
3574            .to_words(&mut self.logical_layout.ext_inst_imports);
3575
3576        let mut debug_info_inner = None;
3577        if self.flags.contains(WriterFlags::DEBUG) {
3578            if let Some(debug_info) = debug_info.as_ref() {
3579                let source_file_id = self.id_gen.next();
3580                self.debugs
3581                    .push(Instruction::string(debug_info.file_name, source_file_id));
3582
3583                debug_info_inner = Some(DebugInfoInner {
3584                    source_code: debug_info.source_code,
3585                    source_file_id,
3586                });
3587                self.debugs.append(&mut Instruction::source_auto_continued(
3588                    debug_info.language,
3589                    0,
3590                    &debug_info_inner,
3591                ));
3592            }
3593        }
3594
3595        // write all types
3596        for (handle, _) in ir_module.types.iter() {
3597            self.write_type_declaration_arena(ir_module, handle)?;
3598        }
3599
3600        // write std140 layout compatible types required by uniforms
3601        for (_, var) in ir_module.global_variables.iter() {
3602            if var.space == crate::AddressSpace::Uniform {
3603                self.write_std140_compat_type_declaration(ir_module, var.ty)?;
3604            }
3605        }
3606
3607        // write all const-expressions as constants
3608        self.constant_ids
3609            .resize(ir_module.global_expressions.len(), 0);
3610        for (handle, _) in ir_module.global_expressions.iter() {
3611            self.write_constant_expr(handle, ir_module, mod_info)?;
3612        }
3613        debug_assert!(self.constant_ids.iter().all(|&id| id != 0));
3614
3615        // write the name of constants on their respective const-expression initializer
3616        if self.flags.contains(WriterFlags::DEBUG) {
3617            for (_, constant) in ir_module.constants.iter() {
3618                if let Some(ref name) = constant.name {
3619                    let id = self.constant_ids[constant.init];
3620                    self.debugs.push(Instruction::name(id, name));
3621                }
3622            }
3623        }
3624
3625        // write all global variables
3626        for (handle, var) in ir_module.global_variables.iter() {
3627            // If a single entry point was specified, only write `OpVariable` instructions
3628            // for the globals it actually uses. Emit dummies for the others,
3629            // to preserve the indices in `global_variables`.
3630            let gvar = match ep_index {
3631                Some(index) if mod_info.get_entry_point(index)[handle].is_empty() => {
3632                    GlobalVariable::dummy()
3633                }
3634                _ => {
3635                    let id = self.write_global_variable(ir_module, var)?;
3636                    GlobalVariable::new(id)
3637                }
3638            };
3639            self.global_variables.insert(handle, gvar);
3640        }
3641
3642        // write all functions
3643        for (handle, ir_function) in ir_module.functions.iter() {
3644            let info = &mod_info[handle];
3645            if let Some(index) = ep_index {
3646                let ep_info = mod_info.get_entry_point(index);
3647                // If this function uses globals that we omitted from the SPIR-V
3648                // because the entry point and its callees didn't use them,
3649                // then we must skip it.
3650                if !ep_info.dominates_global_use(info) {
3651                    log::debug!("Skip function {:?}", ir_function.name);
3652                    continue;
3653                }
3654
3655                // Skip functions that that are not compatible with this entry point's stage.
3656                //
3657                // When validation is enabled, it rejects modules whose entry points try to call
3658                // incompatible functions, so if we got this far, then any functions incompatible
3659                // with our selected entry point must not be used.
3660                //
3661                // When validation is disabled, `fun_info.available_stages` is always just
3662                // `ShaderStages::all()`, so this will write all functions in the module, and
3663                // the downstream GLSL compiler will catch any problems.
3664                if !info.available_stages.contains(ep_info.available_stages) {
3665                    continue;
3666                }
3667            }
3668            let id = self.write_function(ir_function, info, ir_module, None, &debug_info_inner)?;
3669            self.lookup_function.insert(handle, id);
3670        }
3671
3672        // write all or one entry points
3673        for (index, ir_ep) in ir_module.entry_points.iter().enumerate() {
3674            if ep_index.is_some() && ep_index != Some(index) {
3675                continue;
3676            }
3677            let info = mod_info.get_entry_point(index);
3678            let ep_instruction =
3679                self.write_entry_point(ir_ep, info, ir_module, &debug_info_inner)?;
3680            ep_instruction.to_words(&mut self.logical_layout.entry_points);
3681        }
3682
3683        for capability in self.capabilities_used.iter() {
3684            Instruction::capability(*capability).to_words(&mut self.logical_layout.capabilities);
3685        }
3686        for extension in self.extensions_used.iter() {
3687            Instruction::extension(extension).to_words(&mut self.logical_layout.extensions);
3688        }
3689        if ir_module.entry_points.is_empty() {
3690            // SPIR-V doesn't like modules without entry points
3691            Instruction::capability(spirv::Capability::Linkage)
3692                .to_words(&mut self.logical_layout.capabilities);
3693        }
3694
3695        let addressing_model = spirv::AddressingModel::Logical;
3696        let memory_model = if self
3697            .capabilities_used
3698            .contains(&spirv::Capability::VulkanMemoryModel)
3699        {
3700            spirv::MemoryModel::Vulkan
3701        } else {
3702            spirv::MemoryModel::GLSL450
3703        };
3704        //self.check(addressing_model.required_capabilities())?;
3705        //self.check(memory_model.required_capabilities())?;
3706
3707        Instruction::memory_model(addressing_model, memory_model)
3708            .to_words(&mut self.logical_layout.memory_model);
3709
3710        for debug_string in self.debug_strings.iter() {
3711            debug_string.to_words(&mut self.logical_layout.debugs);
3712        }
3713
3714        if self.flags.contains(WriterFlags::DEBUG) {
3715            for debug in self.debugs.iter() {
3716                debug.to_words(&mut self.logical_layout.debugs);
3717            }
3718        }
3719
3720        for annotation in self.annotations.iter() {
3721            annotation.to_words(&mut self.logical_layout.annotations);
3722        }
3723
3724        Ok(())
3725    }
3726
3727    pub fn write(
3728        &mut self,
3729        ir_module: &crate::Module,
3730        info: &ModuleInfo,
3731        pipeline_options: Option<&PipelineOptions>,
3732        debug_info: &Option<DebugInfo>,
3733        words: &mut Vec<Word>,
3734    ) -> Result<(), Error> {
3735        self.reset();
3736
3737        // Try to find the entry point and corresponding index
3738        let ep_index = match pipeline_options {
3739            Some(po) => {
3740                let index = ir_module
3741                    .entry_points
3742                    .iter()
3743                    .position(|ep| po.shader_stage == ep.stage && po.entry_point == ep.name)
3744                    .ok_or(Error::EntryPointNotFound)?;
3745                Some(index)
3746            }
3747            None => None,
3748        };
3749
3750        self.write_logical_layout(ir_module, info, ep_index, debug_info)?;
3751        self.write_physical_layout();
3752
3753        self.physical_layout.in_words(words);
3754        self.logical_layout.in_words(words);
3755        Ok(())
3756    }
3757
3758    /// Return the set of capabilities the last module written used.
3759    pub const fn get_capabilities_used(&self) -> &crate::FastIndexSet<spirv::Capability> {
3760        &self.capabilities_used
3761    }
3762
3763    pub fn decorate_non_uniform_binding_array_access(&mut self, id: Word) -> Result<(), Error> {
3764        self.require_any("NonUniformEXT", &[spirv::Capability::ShaderNonUniform])?;
3765        self.use_extension("SPV_EXT_descriptor_indexing");
3766        self.decorate(id, spirv::Decoration::NonUniform, &[]);
3767        Ok(())
3768    }
3769
3770    pub(super) fn needs_f16_polyfill(&self, ty_inner: &crate::TypeInner) -> bool {
3771        self.io_f16_polyfills.needs_polyfill(ty_inner)
3772    }
3773
3774    pub(super) fn write_debug_printf(
3775        &mut self,
3776        block: &mut Block,
3777        string: &str,
3778        format_params: &[Word],
3779    ) {
3780        if self.debug_printf.is_none() {
3781            self.use_extension("SPV_KHR_non_semantic_info");
3782            let import_id = self.id_gen.next();
3783            Instruction::ext_inst_import(import_id, "NonSemantic.DebugPrintf")
3784                .to_words(&mut self.logical_layout.ext_inst_imports);
3785            self.debug_printf = Some(import_id)
3786        }
3787
3788        let import_id = self.debug_printf.unwrap();
3789
3790        let string_id = self.id_gen.next();
3791        self.debug_strings
3792            .push(Instruction::string(string, string_id));
3793
3794        let mut operands = Vec::with_capacity(1 + format_params.len());
3795        operands.push(string_id);
3796        operands.extend(format_params.iter());
3797
3798        let print_id = self.id_gen.next();
3799        block.body.push(Instruction::ext_inst(
3800            import_id,
3801            1,
3802            self.void_type,
3803            print_id,
3804            &operands,
3805        ));
3806    }
3807}
3808
3809#[test]
3810fn test_write_physical_layout() {
3811    let mut writer = Writer::new(&Options::default()).unwrap();
3812    assert_eq!(writer.physical_layout.bound, 0);
3813    writer.write_physical_layout();
3814    assert_eq!(writer.physical_layout.bound, 3);
3815}