naga/back/spv/
writer.rs

1use alloc::{format, string::String, vec, vec::Vec};
2
3use arrayvec::ArrayVec;
4use hashbrown::hash_map::Entry;
5use spirv::Word;
6
7use super::{
8    block::DebugInfoInner,
9    helpers::{contains_builtin, global_needs_wrapper, map_storage_class},
10    Block, BlockContext, CachedConstant, CachedExpressions, CooperativeType, DebugInfo,
11    EntryPointContext, Error, Function, FunctionArgument, GlobalVariable, IdGenerator, Instruction,
12    LocalImageType, LocalType, LocalVariable, LogicalLayout, LookupFunctionType, LookupType,
13    NumericType, Options, PhysicalLayout, PipelineOptions, ResultMember, Writer, WriterFlags,
14    BITS_PER_BYTE,
15};
16use crate::{
17    arena::{Handle, HandleVec, UniqueArena},
18    back::spv::{
19        helpers::{is_uniform_matcx2_struct_member_access, BindingDecorations},
20        BindingInfo, Std140CompatTypeInfo, WrappedFunction,
21    },
22    common::ForDebugWithTypes as _,
23    proc::{Alignment, TypeResolution},
24    valid::{FunctionInfo, ModuleInfo},
25};
26
27pub struct FunctionInterface<'a> {
28    pub varying_ids: &'a mut Vec<Word>,
29    pub stage: crate::ShaderStage,
30    pub task_payload: Option<Handle<crate::GlobalVariable>>,
31    pub mesh_info: Option<crate::MeshStageInfo>,
32    pub workgroup_size: [u32; 3],
33}
34
35impl Function {
36    pub(super) fn to_words(&self, sink: &mut impl Extend<Word>) {
37        self.signature.as_ref().unwrap().to_words(sink);
38        for argument in self.parameters.iter() {
39            argument.instruction.to_words(sink);
40        }
41        for (index, block) in self.blocks.iter().enumerate() {
42            Instruction::label(block.label_id).to_words(sink);
43            if index == 0 {
44                for local_var in self.variables.values() {
45                    local_var.instruction.to_words(sink);
46                }
47                for local_var in self.ray_query_initialization_tracker_variables.values() {
48                    local_var.instruction.to_words(sink);
49                }
50                for local_var in self.ray_query_t_max_tracker_variables.values() {
51                    local_var.instruction.to_words(sink);
52                }
53                for local_var in self.force_loop_bounding_vars.iter() {
54                    local_var.instruction.to_words(sink);
55                }
56                for internal_var in self.spilled_composites.values() {
57                    internal_var.instruction.to_words(sink);
58                }
59            }
60            for instruction in block.body.iter() {
61                instruction.to_words(sink);
62            }
63        }
64        Instruction::function_end().to_words(sink);
65    }
66}
67
68impl Writer {
69    pub fn new(options: &Options) -> Result<Self, Error> {
70        let (major, minor) = options.lang_version;
71        if major != 1 {
72            return Err(Error::UnsupportedVersion(major, minor));
73        }
74
75        let mut capabilities_used = crate::FastIndexSet::default();
76        capabilities_used.insert(spirv::Capability::Shader);
77
78        let mut id_gen = IdGenerator::default();
79        let gl450_ext_inst_id = id_gen.next();
80        let void_type = id_gen.next();
81
82        Ok(Writer {
83            physical_layout: PhysicalLayout::new(major, minor),
84            logical_layout: LogicalLayout::default(),
85            id_gen,
86            capabilities_available: options.capabilities.clone(),
87            capabilities_used,
88            extensions_used: crate::FastIndexSet::default(),
89            debug_strings: vec![],
90            debugs: vec![],
91            annotations: vec![],
92            flags: options.flags,
93            bounds_check_policies: options.bounds_check_policies,
94            zero_initialize_workgroup_memory: options.zero_initialize_workgroup_memory,
95            force_loop_bounding: options.force_loop_bounding,
96            ray_query_initialization_tracking: options.ray_query_initialization_tracking,
97            use_storage_input_output_16: options.use_storage_input_output_16,
98            void_type,
99            tuple_of_u32s_ty_id: None,
100            lookup_type: crate::FastHashMap::default(),
101            lookup_function: crate::FastHashMap::default(),
102            lookup_function_type: crate::FastHashMap::default(),
103            wrapped_functions: crate::FastHashMap::default(),
104            constant_ids: HandleVec::new(),
105            cached_constants: crate::FastHashMap::default(),
106            global_variables: HandleVec::new(),
107            std140_compat_uniform_types: crate::FastHashMap::default(),
108            fake_missing_bindings: options.fake_missing_bindings,
109            binding_map: options.binding_map.clone(),
110            saved_cached: CachedExpressions::default(),
111            gl450_ext_inst_id,
112            temp_list: Vec::new(),
113            ray_query_functions: crate::FastHashMap::default(),
114            io_f16_polyfills: super::f16_polyfill::F16IoPolyfill::new(
115                options.use_storage_input_output_16,
116            ),
117            debug_printf: None,
118            task_dispatch_limits: options.task_dispatch_limits,
119            mesh_shader_primitive_indices_clamp: options.mesh_shader_primitive_indices_clamp,
120        })
121    }
122
123    pub fn set_options(&mut self, options: &Options) -> Result<(), Error> {
124        let (major, minor) = options.lang_version;
125        if major != 1 {
126            return Err(Error::UnsupportedVersion(major, minor));
127        }
128        self.physical_layout = PhysicalLayout::new(major, minor);
129        self.capabilities_available = options.capabilities.clone();
130        self.flags = options.flags;
131        self.bounds_check_policies = options.bounds_check_policies;
132        self.zero_initialize_workgroup_memory = options.zero_initialize_workgroup_memory;
133        self.force_loop_bounding = options.force_loop_bounding;
134        self.use_storage_input_output_16 = options.use_storage_input_output_16;
135        self.binding_map = options.binding_map.clone();
136        self.io_f16_polyfills =
137            super::f16_polyfill::F16IoPolyfill::new(options.use_storage_input_output_16);
138        self.task_dispatch_limits = options.task_dispatch_limits;
139        self.mesh_shader_primitive_indices_clamp = options.mesh_shader_primitive_indices_clamp;
140        Ok(())
141    }
142
143    /// Returns `(major, minor)` of the SPIR-V language version.
144    pub const fn lang_version(&self) -> (u8, u8) {
145        self.physical_layout.lang_version()
146    }
147
148    /// Reset `Writer` to its initial state, retaining any allocations.
149    ///
150    /// Why not just implement `Reclaimable` for `Writer`? By design,
151    /// `Reclaimable::reclaim` requires ownership of the value, not just
152    /// `&mut`; see the trait documentation. But we need to use this method
153    /// from functions like `Writer::write`, which only have `&mut Writer`.
154    /// Workarounds include unsafe code (`core::ptr::read`, then `write`, ugh)
155    /// or something like a `Default` impl that returns an oddly-initialized
156    /// `Writer`, which is worse.
157    fn reset(&mut self) {
158        use super::reclaimable::Reclaimable;
159        use core::mem::take;
160
161        let mut id_gen = IdGenerator::default();
162        let gl450_ext_inst_id = id_gen.next();
163        let void_type = id_gen.next();
164
165        // Every field of the old writer that is not determined by the `Options`
166        // passed to `Writer::new` should be reset somehow.
167        let fresh = Writer {
168            // Copied from the old Writer:
169            flags: self.flags,
170            bounds_check_policies: self.bounds_check_policies,
171            zero_initialize_workgroup_memory: self.zero_initialize_workgroup_memory,
172            force_loop_bounding: self.force_loop_bounding,
173            ray_query_initialization_tracking: self.ray_query_initialization_tracking,
174            use_storage_input_output_16: self.use_storage_input_output_16,
175            capabilities_available: take(&mut self.capabilities_available),
176            fake_missing_bindings: self.fake_missing_bindings,
177            binding_map: take(&mut self.binding_map),
178            task_dispatch_limits: self.task_dispatch_limits,
179            mesh_shader_primitive_indices_clamp: self.mesh_shader_primitive_indices_clamp,
180
181            // Initialized afresh:
182            id_gen,
183            void_type,
184            tuple_of_u32s_ty_id: None,
185            gl450_ext_inst_id,
186
187            // Reclaimed:
188            capabilities_used: take(&mut self.capabilities_used).reclaim(),
189            extensions_used: take(&mut self.extensions_used).reclaim(),
190            physical_layout: self.physical_layout.clone().reclaim(),
191            logical_layout: take(&mut self.logical_layout).reclaim(),
192            debug_strings: take(&mut self.debug_strings).reclaim(),
193            debugs: take(&mut self.debugs).reclaim(),
194            annotations: take(&mut self.annotations).reclaim(),
195            lookup_type: take(&mut self.lookup_type).reclaim(),
196            lookup_function: take(&mut self.lookup_function).reclaim(),
197            lookup_function_type: take(&mut self.lookup_function_type).reclaim(),
198            wrapped_functions: take(&mut self.wrapped_functions).reclaim(),
199            constant_ids: take(&mut self.constant_ids).reclaim(),
200            cached_constants: take(&mut self.cached_constants).reclaim(),
201            global_variables: take(&mut self.global_variables).reclaim(),
202            std140_compat_uniform_types: take(&mut self.std140_compat_uniform_types).reclaim(),
203            saved_cached: take(&mut self.saved_cached).reclaim(),
204            temp_list: take(&mut self.temp_list).reclaim(),
205            ray_query_functions: take(&mut self.ray_query_functions).reclaim(),
206            io_f16_polyfills: take(&mut self.io_f16_polyfills).reclaim(),
207            debug_printf: None,
208        };
209
210        *self = fresh;
211
212        self.capabilities_used.insert(spirv::Capability::Shader);
213    }
214
215    /// Indicate that the code requires any one of the listed capabilities.
216    ///
217    /// If nothing in `capabilities` appears in the available capabilities
218    /// specified in the [`Options`] from which this `Writer` was created,
219    /// return an error. The `what` string is used in the error message to
220    /// explain what provoked the requirement. (If no available capabilities were
221    /// given, assume everything is available.)
222    ///
223    /// The first acceptable capability will be added to this `Writer`'s
224    /// [`capabilities_used`] table, and an `OpCapability` emitted for it in the
225    /// result. For this reason, more specific capabilities should be listed
226    /// before more general.
227    ///
228    /// [`capabilities_used`]: Writer::capabilities_used
229    pub(super) fn require_any(
230        &mut self,
231        what: &'static str,
232        capabilities: &[spirv::Capability],
233    ) -> Result<(), Error> {
234        match *capabilities {
235            [] => Ok(()),
236            [first, ..] => {
237                // Find the first acceptable capability, or return an error if
238                // there is none.
239                let selected = match self.capabilities_available {
240                    None => first,
241                    Some(ref available) => {
242                        match capabilities
243                            .iter()
244                            // need explicit type for hashbrown::HashSet::contains fn call to keep rustc happy
245                            .find(|cap| available.contains::<spirv::Capability>(cap))
246                        {
247                            Some(&cap) => cap,
248                            None => {
249                                return Err(Error::MissingCapabilities(what, capabilities.to_vec()))
250                            }
251                        }
252                    }
253                };
254                self.capabilities_used.insert(selected);
255                Ok(())
256            }
257        }
258    }
259
260    /// Indicate that the code requires all of the listed capabilities.
261    ///
262    /// If all entries of `capabilities` appear in the available capabilities
263    /// specified in the [`Options`] from which this `Writer` was created
264    /// (including the case where [`Options::capabilities`] is `None`), add
265    /// them all to this `Writer`'s [`capabilities_used`] table, and return
266    /// `Ok(())`. If at least one of the listed capabilities is not available,
267    /// do not add anything to the `capabilities_used` table, and return the
268    /// first unavailable requested capability, wrapped in `Err()`.
269    ///
270    /// This method is does not return an [`enum@Error`] in case of failure
271    /// because it may be used in cases where the caller can recover (e.g.,
272    /// with a polyfill) if the requested capabilities are not available. In
273    /// this case, it would be unnecessary work to find *all* the unavailable
274    /// requested capabilities, and to allocate a `Vec` for them, just so we
275    /// could return an [`Error::MissingCapabilities`]).
276    ///
277    /// [`capabilities_used`]: Writer::capabilities_used
278    pub(super) fn require_all(
279        &mut self,
280        capabilities: &[spirv::Capability],
281    ) -> Result<(), spirv::Capability> {
282        if let Some(ref available) = self.capabilities_available {
283            for requested in capabilities {
284                if !available.contains(requested) {
285                    return Err(*requested);
286                }
287            }
288        }
289
290        for requested in capabilities {
291            self.capabilities_used.insert(*requested);
292        }
293
294        Ok(())
295    }
296
297    /// Indicate that the code uses the given extension.
298    pub(super) fn use_extension(&mut self, extension: &'static str) {
299        self.extensions_used.insert(extension);
300    }
301
302    pub(super) fn get_type_id(&mut self, lookup_ty: LookupType) -> Word {
303        match self.lookup_type.entry(lookup_ty) {
304            Entry::Occupied(e) => *e.get(),
305            Entry::Vacant(e) => {
306                let local = match lookup_ty {
307                    LookupType::Handle(_handle) => unreachable!("Handles are populated at start"),
308                    LookupType::Local(local) => local,
309                };
310
311                let id = self.id_gen.next();
312                e.insert(id);
313                self.write_type_declaration_local(id, local);
314                id
315            }
316        }
317    }
318
319    pub(super) fn get_handle_type_id(&mut self, handle: Handle<crate::Type>) -> Word {
320        self.get_type_id(LookupType::Handle(handle))
321    }
322
323    pub(super) fn get_expression_lookup_type(&mut self, tr: &TypeResolution) -> LookupType {
324        match *tr {
325            TypeResolution::Handle(ty_handle) => LookupType::Handle(ty_handle),
326            TypeResolution::Value(ref inner) => {
327                let inner_local_type = self.localtype_from_inner(inner).unwrap();
328                LookupType::Local(inner_local_type)
329            }
330        }
331    }
332
333    pub(super) fn get_expression_type_id(&mut self, tr: &TypeResolution) -> Word {
334        let lookup_ty = self.get_expression_lookup_type(tr);
335        self.get_type_id(lookup_ty)
336    }
337
338    pub(super) fn get_localtype_id(&mut self, local: LocalType) -> Word {
339        self.get_type_id(LookupType::Local(local))
340    }
341
342    pub(super) fn get_pointer_type_id(&mut self, base: Word, class: spirv::StorageClass) -> Word {
343        self.get_type_id(LookupType::Local(LocalType::Pointer { base, class }))
344    }
345
346    pub(super) fn get_handle_pointer_type_id(
347        &mut self,
348        base: Handle<crate::Type>,
349        class: spirv::StorageClass,
350    ) -> Word {
351        let base_id = self.get_handle_type_id(base);
352        self.get_pointer_type_id(base_id, class)
353    }
354
355    pub(super) fn get_ray_query_pointer_id(&mut self) -> Word {
356        let rq_id = self.get_type_id(LookupType::Local(LocalType::RayQuery));
357        self.get_pointer_type_id(rq_id, spirv::StorageClass::Function)
358    }
359
360    /// Return a SPIR-V type for a pointer to `resolution`.
361    ///
362    /// The given `resolution` must be one that we can represent
363    /// either as a `LocalType::Pointer` or `LocalType::LocalPointer`.
364    pub(super) fn get_resolution_pointer_id(
365        &mut self,
366        resolution: &TypeResolution,
367        class: spirv::StorageClass,
368    ) -> Word {
369        let resolution_type_id = self.get_expression_type_id(resolution);
370        self.get_pointer_type_id(resolution_type_id, class)
371    }
372
373    pub(super) fn get_numeric_type_id(&mut self, numeric: NumericType) -> Word {
374        self.get_type_id(LocalType::Numeric(numeric).into())
375    }
376
377    pub(super) fn get_u32_type_id(&mut self) -> Word {
378        self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32))
379    }
380
381    pub(super) fn get_f32_type_id(&mut self) -> Word {
382        self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::F32))
383    }
384
385    pub(super) fn get_vec2u_type_id(&mut self) -> Word {
386        self.get_numeric_type_id(NumericType::Vector {
387            size: crate::VectorSize::Bi,
388            scalar: crate::Scalar::U32,
389        })
390    }
391
392    pub(super) fn get_vec2f_type_id(&mut self) -> Word {
393        self.get_numeric_type_id(NumericType::Vector {
394            size: crate::VectorSize::Bi,
395            scalar: crate::Scalar::F32,
396        })
397    }
398
399    pub(super) fn get_vec3u_type_id(&mut self) -> Word {
400        self.get_numeric_type_id(NumericType::Vector {
401            size: crate::VectorSize::Tri,
402            scalar: crate::Scalar::U32,
403        })
404    }
405
406    pub(super) fn get_f32_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word {
407        let f32_id = self.get_f32_type_id();
408        self.get_pointer_type_id(f32_id, class)
409    }
410
411    pub(super) fn get_vec2u_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word {
412        let vec2u_id = self.get_numeric_type_id(NumericType::Vector {
413            size: crate::VectorSize::Bi,
414            scalar: crate::Scalar::U32,
415        });
416        self.get_pointer_type_id(vec2u_id, class)
417    }
418
419    pub(super) fn get_bool_type_id(&mut self) -> Word {
420        self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::BOOL))
421    }
422
423    pub(super) fn get_vec2_bool_type_id(&mut self) -> Word {
424        self.get_numeric_type_id(NumericType::Vector {
425            size: crate::VectorSize::Bi,
426            scalar: crate::Scalar::BOOL,
427        })
428    }
429
430    pub(super) fn get_vec3_bool_type_id(&mut self) -> Word {
431        self.get_numeric_type_id(NumericType::Vector {
432            size: crate::VectorSize::Tri,
433            scalar: crate::Scalar::BOOL,
434        })
435    }
436
437    /// Used for "mulhi" to get the upper bits of multiplication.
438    ///
439    /// More specifically, `OpUMulExtended` multiplies 2 numbers and returns the lower and upper bits of the result
440    /// as a user-defined struct type with 2 u32s. This defines that struct.
441    pub(super) fn get_tuple_of_u32s_ty_id(&mut self) -> Word {
442        if let Some(val) = self.tuple_of_u32s_ty_id {
443            val
444        } else {
445            let id = self.id_gen.next();
446            let u32_id = self.get_u32_type_id();
447            let ins = Instruction::type_struct(id, &[u32_id, u32_id]);
448            ins.to_words(&mut self.logical_layout.declarations);
449            self.tuple_of_u32s_ty_id = Some(id);
450            id
451        }
452    }
453
454    pub(super) fn decorate(&mut self, id: Word, decoration: spirv::Decoration, operands: &[Word]) {
455        self.annotations
456            .push(Instruction::decorate(id, decoration, operands));
457    }
458
459    /// Return `inner` as a `LocalType`, if that's possible.
460    ///
461    /// If `inner` can be represented as a `LocalType`, return
462    /// `Some(local_type)`.
463    ///
464    /// Otherwise, return `None`. In this case, the type must always be looked
465    /// up using a `LookupType::Handle`.
466    fn localtype_from_inner(&mut self, inner: &crate::TypeInner) -> Option<LocalType> {
467        Some(match *inner {
468            crate::TypeInner::Scalar(_)
469            | crate::TypeInner::Atomic(_)
470            | crate::TypeInner::Vector { .. }
471            | crate::TypeInner::Matrix { .. } => {
472                // We expect `NumericType::from_inner` to handle all
473                // these cases, so unwrap.
474                LocalType::Numeric(NumericType::from_inner(inner).unwrap())
475            }
476            crate::TypeInner::CooperativeMatrix { .. } => {
477                LocalType::Cooperative(CooperativeType::from_inner(inner).unwrap())
478            }
479            crate::TypeInner::Pointer { base, space } => {
480                let base_type_id = self.get_handle_type_id(base);
481                LocalType::Pointer {
482                    base: base_type_id,
483                    class: map_storage_class(space),
484                }
485            }
486            crate::TypeInner::ValuePointer {
487                size,
488                scalar,
489                space,
490            } => {
491                let base_numeric_type = match size {
492                    Some(size) => NumericType::Vector { size, scalar },
493                    None => NumericType::Scalar(scalar),
494                };
495                LocalType::Pointer {
496                    base: self.get_numeric_type_id(base_numeric_type),
497                    class: map_storage_class(space),
498                }
499            }
500            crate::TypeInner::Image {
501                dim,
502                arrayed,
503                class,
504            } => LocalType::Image(LocalImageType::from_inner(dim, arrayed, class)),
505            crate::TypeInner::Sampler { comparison: _ } => LocalType::Sampler,
506            crate::TypeInner::AccelerationStructure { .. } => LocalType::AccelerationStructure,
507            crate::TypeInner::RayQuery { .. } => LocalType::RayQuery,
508            crate::TypeInner::Array { .. }
509            | crate::TypeInner::Struct { .. }
510            | crate::TypeInner::BindingArray { .. } => return None,
511        })
512    }
513
514    /// Resolve the [`BindingInfo`] for a [`crate::ResourceBinding`] from the
515    /// provided [`Writer::binding_map`].
516    ///
517    /// If the specified resource is not present in the binding map this will
518    /// return an error, unless [`Writer::fake_missing_bindings`] is set.
519    fn resolve_resource_binding(
520        &self,
521        res_binding: &crate::ResourceBinding,
522    ) -> Result<BindingInfo, Error> {
523        match self.binding_map.get(res_binding) {
524            Some(target) => Ok(*target),
525            None if self.fake_missing_bindings => Ok(BindingInfo {
526                descriptor_set: res_binding.group,
527                binding: res_binding.binding,
528                binding_array_size: None,
529            }),
530            None => Err(Error::MissingBinding(*res_binding)),
531        }
532    }
533
534    /// Emits code for any wrapper functions required by the expressions in ir_function.
535    /// The IDs of any emitted functions will be stored in [`Self::wrapped_functions`].
536    fn write_wrapped_functions(
537        &mut self,
538        ir_function: &crate::Function,
539        info: &FunctionInfo,
540        ir_module: &crate::Module,
541    ) -> Result<(), Error> {
542        log::trace!("Generating wrapped functions for {:?}", ir_function.name);
543
544        for (expr_handle, expr) in ir_function.expressions.iter() {
545            match *expr {
546                crate::Expression::Binary { op, left, right } => {
547                    let expr_ty_inner = info[expr_handle].ty.inner_with(&ir_module.types);
548                    if let Some(expr_ty) = NumericType::from_inner(expr_ty_inner) {
549                        match (op, expr_ty.scalar().kind) {
550                            // Division and modulo are undefined behaviour when the
551                            // dividend is the minimum representable value and the divisor
552                            // is negative one, or when the divisor is zero. These wrapped
553                            // functions override the divisor to one in these cases,
554                            // matching the WGSL spec.
555                            (
556                                crate::BinaryOperator::Divide | crate::BinaryOperator::Modulo,
557                                crate::ScalarKind::Sint | crate::ScalarKind::Uint,
558                            ) => {
559                                self.write_wrapped_binary_op(
560                                    op,
561                                    expr_ty,
562                                    &info[left].ty,
563                                    &info[right].ty,
564                                )?;
565                            }
566                            _ => {}
567                        }
568                    }
569                }
570                crate::Expression::Load { pointer } => {
571                    if let crate::TypeInner::Pointer {
572                        base: pointer_type,
573                        space: crate::AddressSpace::Uniform,
574                    } = *info[pointer].ty.inner_with(&ir_module.types)
575                    {
576                        if self.std140_compat_uniform_types.contains_key(&pointer_type) {
577                            // Loading a std140 compat type requires the wrapper function
578                            // to convert to the regular type.
579                            self.write_wrapped_convert_from_std140_compat_type(
580                                ir_module,
581                                pointer_type,
582                            )?;
583                        }
584                    }
585                }
586                crate::Expression::Access { base, .. } => {
587                    if let crate::TypeInner::Pointer {
588                        base: base_type,
589                        space: crate::AddressSpace::Uniform,
590                    } = *info[base].ty.inner_with(&ir_module.types)
591                    {
592                        // Dynamic accesses of a two-row matrix's columns require a
593                        // wrapper function.
594                        if let crate::TypeInner::Matrix {
595                            rows: crate::VectorSize::Bi,
596                            ..
597                        } = ir_module.types[base_type].inner
598                        {
599                            self.write_wrapped_matcx2_get_column(ir_module, base_type)?;
600                            // If the matrix is *not* directly a member of a struct, then
601                            // we additionally require a wrapper function to convert from
602                            // the std140 compat type to the regular type.
603                            if !is_uniform_matcx2_struct_member_access(
604                                ir_function,
605                                info,
606                                ir_module,
607                                base,
608                            ) {
609                                self.write_wrapped_convert_from_std140_compat_type(
610                                    ir_module, base_type,
611                                )?;
612                            }
613                        }
614                    }
615                }
616                _ => {}
617            }
618        }
619
620        Ok(())
621    }
622
623    /// Write a SPIR-V function that performs the operator `op` with Naga IR semantics.
624    ///
625    /// Define a function that performs an integer division or modulo operation,
626    /// except that using a divisor of zero or causing signed overflow with a
627    /// divisor of -1 returns the numerator unchanged, rather than exhibiting
628    /// undefined behavior.
629    ///
630    /// Store the generated function's id in the [`wrapped_functions`] table.
631    ///
632    /// The operator `op` must be either [`Divide`] or [`Modulo`].
633    ///
634    /// # Panics
635    ///
636    /// The `return_type`, `left_type` or `right_type` arguments must all be
637    /// integer scalars or vectors. If not, this function panics.
638    ///
639    /// [`wrapped_functions`]: Writer::wrapped_functions
640    /// [`Divide`]: crate::BinaryOperator::Divide
641    /// [`Modulo`]: crate::BinaryOperator::Modulo
642    fn write_wrapped_binary_op(
643        &mut self,
644        op: crate::BinaryOperator,
645        return_type: NumericType,
646        left_type: &TypeResolution,
647        right_type: &TypeResolution,
648    ) -> Result<(), Error> {
649        let return_type_id = self.get_localtype_id(LocalType::Numeric(return_type));
650        let left_type_id = self.get_expression_type_id(left_type);
651        let right_type_id = self.get_expression_type_id(right_type);
652
653        // Check if we've already emitted this function.
654        let wrapped = WrappedFunction::BinaryOp {
655            op,
656            left_type_id,
657            right_type_id,
658        };
659        let function_id = match self.wrapped_functions.entry(wrapped) {
660            Entry::Occupied(_) => return Ok(()),
661            Entry::Vacant(e) => *e.insert(self.id_gen.next()),
662        };
663
664        let scalar = return_type.scalar();
665
666        if self.flags.contains(WriterFlags::DEBUG) {
667            let function_name = match op {
668                crate::BinaryOperator::Divide => "naga_div",
669                crate::BinaryOperator::Modulo => "naga_mod",
670                _ => unreachable!(),
671            };
672            self.debugs
673                .push(Instruction::name(function_id, function_name));
674        }
675        let mut function = Function::default();
676
677        let function_type_id = self.get_function_type(LookupFunctionType {
678            parameter_type_ids: vec![left_type_id, right_type_id],
679            return_type_id,
680        });
681        function.signature = Some(Instruction::function(
682            return_type_id,
683            function_id,
684            spirv::FunctionControl::empty(),
685            function_type_id,
686        ));
687
688        let lhs_id = self.id_gen.next();
689        let rhs_id = self.id_gen.next();
690        if self.flags.contains(WriterFlags::DEBUG) {
691            self.debugs.push(Instruction::name(lhs_id, "lhs"));
692            self.debugs.push(Instruction::name(rhs_id, "rhs"));
693        }
694        let left_par = Instruction::function_parameter(left_type_id, lhs_id);
695        let right_par = Instruction::function_parameter(right_type_id, rhs_id);
696        for instruction in [left_par, right_par] {
697            function.parameters.push(FunctionArgument {
698                instruction,
699                handle_id: 0,
700            });
701        }
702
703        let label_id = self.id_gen.next();
704        let mut block = Block::new(label_id);
705
706        let bool_type = return_type.with_scalar(crate::Scalar::BOOL);
707        let bool_type_id = self.get_numeric_type_id(bool_type);
708
709        let maybe_splat_const = |writer: &mut Self, const_id| match return_type {
710            NumericType::Scalar(_) => const_id,
711            NumericType::Vector { size, .. } => {
712                let constituent_ids = [const_id; crate::VectorSize::MAX];
713                writer.get_constant_composite(
714                    LookupType::Local(LocalType::Numeric(return_type)),
715                    &constituent_ids[..size as usize],
716                )
717            }
718            NumericType::Matrix { .. } => unreachable!(),
719        };
720
721        let const_zero_id = self.get_constant_scalar_with(0, scalar)?;
722        let composite_zero_id = maybe_splat_const(self, const_zero_id);
723        let rhs_eq_zero_id = self.id_gen.next();
724        block.body.push(Instruction::binary(
725            spirv::Op::IEqual,
726            bool_type_id,
727            rhs_eq_zero_id,
728            rhs_id,
729            composite_zero_id,
730        ));
731        let divisor_selector_id = match scalar.kind {
732            crate::ScalarKind::Sint => {
733                let (const_min_id, const_neg_one_id) = match scalar.width {
734                    4 => Ok((
735                        self.get_constant_scalar(crate::Literal::I32(i32::MIN)),
736                        self.get_constant_scalar(crate::Literal::I32(-1i32)),
737                    )),
738                    8 => Ok((
739                        self.get_constant_scalar(crate::Literal::I64(i64::MIN)),
740                        self.get_constant_scalar(crate::Literal::I64(-1i64)),
741                    )),
742                    _ => Err(Error::Validation("Unexpected scalar width")),
743                }?;
744                let composite_min_id = maybe_splat_const(self, const_min_id);
745                let composite_neg_one_id = maybe_splat_const(self, const_neg_one_id);
746
747                let lhs_eq_int_min_id = self.id_gen.next();
748                block.body.push(Instruction::binary(
749                    spirv::Op::IEqual,
750                    bool_type_id,
751                    lhs_eq_int_min_id,
752                    lhs_id,
753                    composite_min_id,
754                ));
755                let rhs_eq_neg_one_id = self.id_gen.next();
756                block.body.push(Instruction::binary(
757                    spirv::Op::IEqual,
758                    bool_type_id,
759                    rhs_eq_neg_one_id,
760                    rhs_id,
761                    composite_neg_one_id,
762                ));
763                let lhs_eq_int_min_and_rhs_eq_neg_one_id = self.id_gen.next();
764                block.body.push(Instruction::binary(
765                    spirv::Op::LogicalAnd,
766                    bool_type_id,
767                    lhs_eq_int_min_and_rhs_eq_neg_one_id,
768                    lhs_eq_int_min_id,
769                    rhs_eq_neg_one_id,
770                ));
771                let rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id = self.id_gen.next();
772                block.body.push(Instruction::binary(
773                    spirv::Op::LogicalOr,
774                    bool_type_id,
775                    rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id,
776                    rhs_eq_zero_id,
777                    lhs_eq_int_min_and_rhs_eq_neg_one_id,
778                ));
779                rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id
780            }
781            crate::ScalarKind::Uint => rhs_eq_zero_id,
782            _ => unreachable!(),
783        };
784
785        let const_one_id = self.get_constant_scalar_with(1, scalar)?;
786        let composite_one_id = maybe_splat_const(self, const_one_id);
787        let divisor_id = self.id_gen.next();
788        block.body.push(Instruction::select(
789            right_type_id,
790            divisor_id,
791            divisor_selector_id,
792            composite_one_id,
793            rhs_id,
794        ));
795        let op = match (op, scalar.kind) {
796            (crate::BinaryOperator::Divide, crate::ScalarKind::Sint) => spirv::Op::SDiv,
797            (crate::BinaryOperator::Divide, crate::ScalarKind::Uint) => spirv::Op::UDiv,
798            (crate::BinaryOperator::Modulo, crate::ScalarKind::Sint) => spirv::Op::SRem,
799            (crate::BinaryOperator::Modulo, crate::ScalarKind::Uint) => spirv::Op::UMod,
800            _ => unreachable!(),
801        };
802        let return_id = self.id_gen.next();
803        block.body.push(Instruction::binary(
804            op,
805            return_type_id,
806            return_id,
807            lhs_id,
808            divisor_id,
809        ));
810
811        function.consume(block, Instruction::return_value(return_id));
812        function.to_words(&mut self.logical_layout.function_definitions);
813        Ok(())
814    }
815
816    /// Writes a wrapper function to convert from a std140 compat type to its
817    /// corresponding regular type.
818    ///
819    /// See [`Self::write_std140_compat_type_declaration`] for more details.
820    fn write_wrapped_convert_from_std140_compat_type(
821        &mut self,
822        ir_module: &crate::Module,
823        r#type: Handle<crate::Type>,
824    ) -> Result<(), Error> {
825        // Check if we've already emitted this function.
826        let wrapped = WrappedFunction::ConvertFromStd140CompatType { r#type };
827        let function_id = match self.wrapped_functions.entry(wrapped) {
828            Entry::Occupied(_) => return Ok(()),
829            Entry::Vacant(e) => *e.insert(self.id_gen.next()),
830        };
831        if self.flags.contains(WriterFlags::DEBUG) {
832            self.debugs.push(Instruction::name(
833                function_id,
834                &format!("{:?}_from_std140", r#type.for_debug(&ir_module.types)),
835            ));
836        }
837        let param_type_id = self.std140_compat_uniform_types[&r#type].type_id;
838        let return_type_id = self.get_handle_type_id(r#type);
839
840        let mut function = Function::default();
841        let function_type_id = self.get_function_type(LookupFunctionType {
842            parameter_type_ids: vec![param_type_id],
843            return_type_id,
844        });
845        function.signature = Some(Instruction::function(
846            return_type_id,
847            function_id,
848            spirv::FunctionControl::empty(),
849            function_type_id,
850        ));
851        let param_id = self.id_gen.next();
852        function.parameters.push(FunctionArgument {
853            instruction: Instruction::function_parameter(param_type_id, param_id),
854            handle_id: 0,
855        });
856
857        let label_id = self.id_gen.next();
858        let mut block = Block::new(label_id);
859
860        let result_id = match ir_module.types[r#type].inner {
861            // Param is struct containing a vector member for each of the
862            // matrix's columns. Extract each column from the struct then
863            // composite into a matrix.
864            crate::TypeInner::Matrix {
865                columns,
866                rows: rows @ crate::VectorSize::Bi,
867                scalar,
868            } => {
869                let column_type_id =
870                    self.get_numeric_type_id(NumericType::Vector { size: rows, scalar });
871
872                let mut column_ids: ArrayVec<Word, 4> = ArrayVec::new();
873                for column in 0..columns as u32 {
874                    let column_id = self.id_gen.next();
875                    block.body.push(Instruction::composite_extract(
876                        column_type_id,
877                        column_id,
878                        param_id,
879                        &[column],
880                    ));
881                    column_ids.push(column_id);
882                }
883                let result_id = self.id_gen.next();
884                block.body.push(Instruction::composite_construct(
885                    return_type_id,
886                    result_id,
887                    &column_ids,
888                ));
889                result_id
890            }
891            // Param is an array where the base type is the std140 compatible
892            // type corresponding to `base`. Iterate through each element and
893            // call its conversion function, then composite into a new array.
894            crate::TypeInner::Array { base, size, .. } => {
895                // Ensure the conversion function for the array's base type is
896                // declared.
897                self.write_wrapped_convert_from_std140_compat_type(ir_module, base)?;
898
899                let element_type_id = self.get_handle_type_id(base);
900                let std140_element_type_id = self.std140_compat_uniform_types[&base].type_id;
901                let element_conversion_function_id = self.wrapped_functions
902                    [&WrappedFunction::ConvertFromStd140CompatType { r#type: base }];
903                let mut element_ids = Vec::new();
904                let size = match size.resolve(ir_module.to_ctx())? {
905                    crate::proc::IndexableLength::Known(size) => size,
906                    crate::proc::IndexableLength::Dynamic => {
907                        return Err(Error::Validation(
908                            "Uniform buffers cannot contain dynamic arrays",
909                        ))
910                    }
911                };
912                for i in 0..size {
913                    let std140_element_id = self.id_gen.next();
914                    block.body.push(Instruction::composite_extract(
915                        std140_element_type_id,
916                        std140_element_id,
917                        param_id,
918                        &[i],
919                    ));
920                    let element_id = self.id_gen.next();
921                    block.body.push(Instruction::function_call(
922                        element_type_id,
923                        element_id,
924                        element_conversion_function_id,
925                        &[std140_element_id],
926                    ));
927                    element_ids.push(element_id);
928                }
929                let result_id = self.id_gen.next();
930                block.body.push(Instruction::composite_construct(
931                    return_type_id,
932                    result_id,
933                    &element_ids,
934                ));
935                result_id
936            }
937            // Param is a struct where each two-row matrix member has been
938            // decomposed in to separate vector members for each column.
939            // Other members use their std140 compatible type if one exists, or
940            // else their regular type. Iterate through each member, converting
941            // or composing any matrices if required, then finally compose into
942            // the struct.
943            crate::TypeInner::Struct { ref members, .. } => {
944                let mut member_ids = Vec::new();
945                let mut next_index = 0;
946                for member in members {
947                    let member_id = self.id_gen.next();
948                    let member_type_id = self.get_handle_type_id(member.ty);
949                    match ir_module.types[member.ty].inner {
950                        crate::TypeInner::Matrix {
951                            columns,
952                            rows: rows @ crate::VectorSize::Bi,
953                            scalar,
954                        } => {
955                            let mut column_ids: ArrayVec<Word, 4> = ArrayVec::new();
956                            let column_type_id = self
957                                .get_numeric_type_id(NumericType::Vector { size: rows, scalar });
958                            for _ in 0..columns as u32 {
959                                let column_id = self.id_gen.next();
960                                block.body.push(Instruction::composite_extract(
961                                    column_type_id,
962                                    column_id,
963                                    param_id,
964                                    &[next_index],
965                                ));
966                                column_ids.push(column_id);
967                                next_index += 1;
968                            }
969                            block.body.push(Instruction::composite_construct(
970                                member_type_id,
971                                member_id,
972                                &column_ids,
973                            ));
974                        }
975                        _ => {
976                            // Ensure the conversion function for the member's
977                            // type is declared.
978                            self.write_wrapped_convert_from_std140_compat_type(
979                                ir_module, member.ty,
980                            )?;
981                            match self.std140_compat_uniform_types.get(&member.ty) {
982                                Some(std140_type_info) => {
983                                    let std140_member_id = self.id_gen.next();
984                                    block.body.push(Instruction::composite_extract(
985                                        std140_type_info.type_id,
986                                        std140_member_id,
987                                        param_id,
988                                        &[next_index],
989                                    ));
990                                    let function_id = self.wrapped_functions
991                                        [&WrappedFunction::ConvertFromStd140CompatType {
992                                            r#type: member.ty,
993                                        }];
994                                    block.body.push(Instruction::function_call(
995                                        member_type_id,
996                                        member_id,
997                                        function_id,
998                                        &[std140_member_id],
999                                    ));
1000                                    next_index += 1;
1001                                }
1002                                None => {
1003                                    let member_id = self.id_gen.next();
1004                                    block.body.push(Instruction::composite_extract(
1005                                        member_type_id,
1006                                        member_id,
1007                                        param_id,
1008                                        &[next_index],
1009                                    ));
1010                                    next_index += 1;
1011                                }
1012                            }
1013                        }
1014                    }
1015                    member_ids.push(member_id);
1016                }
1017                let result_id = self.id_gen.next();
1018                block.body.push(Instruction::composite_construct(
1019                    return_type_id,
1020                    result_id,
1021                    &member_ids,
1022                ));
1023                result_id
1024            }
1025            _ => unreachable!(),
1026        };
1027
1028        function.consume(block, Instruction::return_value(result_id));
1029        function.to_words(&mut self.logical_layout.function_definitions);
1030        Ok(())
1031    }
1032
1033    /// Writes a wrapper function to get an `OpTypeVector` column from an
1034    /// `OpTypeMatrix` with a dynamic index.
1035    ///
1036    /// This is used when accessing a column of a [`TypeInner::Matrix`] through
1037    /// a [`Uniform`] address space pointer. In such cases, the matrix will have
1038    /// been declared in SPIR-V using an alternative type where each column is a
1039    /// member of a containing struct. SPIR-V is unable to dynamically access
1040    /// struct members, so instead we load the matrix then call this function to
1041    /// access a column from the loaded value.
1042    ///
1043    /// [`TypeInner::Matrix`]: crate::TypeInner::Matrix
1044    /// [`Uniform`]: crate::AddressSpace::Uniform
1045    fn write_wrapped_matcx2_get_column(
1046        &mut self,
1047        ir_module: &crate::Module,
1048        r#type: Handle<crate::Type>,
1049    ) -> Result<(), Error> {
1050        let wrapped = WrappedFunction::MatCx2GetColumn { r#type };
1051        let function_id = match self.wrapped_functions.entry(wrapped) {
1052            Entry::Occupied(_) => return Ok(()),
1053            Entry::Vacant(e) => *e.insert(self.id_gen.next()),
1054        };
1055        if self.flags.contains(WriterFlags::DEBUG) {
1056            self.debugs.push(Instruction::name(
1057                function_id,
1058                &format!("{:?}_get_column", r#type.for_debug(&ir_module.types)),
1059            ));
1060        }
1061
1062        let crate::TypeInner::Matrix {
1063            columns,
1064            rows: rows @ crate::VectorSize::Bi,
1065            scalar,
1066        } = ir_module.types[r#type].inner
1067        else {
1068            unreachable!();
1069        };
1070
1071        let mut function = Function::default();
1072        let matrix_type_id = self.get_handle_type_id(r#type);
1073        let column_index_type_id = self.get_u32_type_id();
1074        let column_type_id = self.get_numeric_type_id(NumericType::Vector { size: rows, scalar });
1075        let matrix_param_id = self.id_gen.next();
1076        let column_index_param_id = self.id_gen.next();
1077        function.parameters.push(FunctionArgument {
1078            instruction: Instruction::function_parameter(matrix_type_id, matrix_param_id),
1079            handle_id: 0,
1080        });
1081        function.parameters.push(FunctionArgument {
1082            instruction: Instruction::function_parameter(
1083                column_index_type_id,
1084                column_index_param_id,
1085            ),
1086            handle_id: 0,
1087        });
1088        let function_type_id = self.get_function_type(LookupFunctionType {
1089            parameter_type_ids: vec![matrix_type_id, column_index_type_id],
1090            return_type_id: column_type_id,
1091        });
1092        function.signature = Some(Instruction::function(
1093            column_type_id,
1094            function_id,
1095            spirv::FunctionControl::empty(),
1096            function_type_id,
1097        ));
1098
1099        let label_id = self.id_gen.next();
1100        let mut block = Block::new(label_id);
1101
1102        // Create a switch case for each column in the matrix, where each case
1103        // extracts its column from the matrix. Finally we use OpPhi to return
1104        // the correct column.
1105        let merge_id = self.id_gen.next();
1106        block.body.push(Instruction::selection_merge(
1107            merge_id,
1108            spirv::SelectionControl::NONE,
1109        ));
1110        let cases = (0..columns as u32)
1111            .map(|i| super::instructions::Case {
1112                value: i,
1113                label_id: self.id_gen.next(),
1114            })
1115            .collect::<ArrayVec<_, 4>>();
1116
1117        // Which label we branch to in the default (column index out-of-bounds)
1118        // case depends on our bounds check policy.
1119        let default_id = match self.bounds_check_policies.index {
1120            // For `Restrict`, treat the same as the final column.
1121            crate::proc::BoundsCheckPolicy::Restrict => cases.last().unwrap().label_id,
1122            // For `ReadZeroSkipWrite`, branch directly to the merge block. This
1123            // will be handled in the `OpPhi` below to produce a zero value.
1124            crate::proc::BoundsCheckPolicy::ReadZeroSkipWrite => merge_id,
1125            // For `Unchecked` we create a new block containing an
1126            // `OpUnreachable`.
1127            crate::proc::BoundsCheckPolicy::Unchecked => self.id_gen.next(),
1128        };
1129        function.consume(
1130            block,
1131            Instruction::switch(column_index_param_id, default_id, &cases),
1132        );
1133
1134        // Emit a block for each case, and produce a list of variable and parent
1135        // block IDs that will be used in an `OpPhi` below to select the right
1136        // value.
1137        let mut var_parent_pairs = cases
1138            .into_iter()
1139            .map(|case| {
1140                let mut block = Block::new(case.label_id);
1141                let column_id = self.id_gen.next();
1142                block.body.push(Instruction::composite_extract(
1143                    column_type_id,
1144                    column_id,
1145                    matrix_param_id,
1146                    &[case.value],
1147                ));
1148                function.consume(block, Instruction::branch(merge_id));
1149                (column_id, case.label_id)
1150            })
1151            // Need capacity for up to 4 columns plus possibly a default case.
1152            .collect::<ArrayVec<_, 5>>();
1153
1154        // Emit a block or append the variable and parent `OpPhi` pair for the
1155        // column index out-of-bounds case, if required.
1156        match self.bounds_check_policies.index {
1157            // Don't need to do anything for `Restrict` as we have branched from
1158            // the final column case's block.
1159            crate::proc::BoundsCheckPolicy::Restrict => {}
1160            // For `ReadZeroSkipWrite` we have branched directly from the block
1161            // containing the `OpSwitch`. The `OpPhi` should produce a zero
1162            // value.
1163            crate::proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
1164                var_parent_pairs.push((self.get_constant_null(column_type_id), label_id));
1165            }
1166            // For `Unchecked` create a new block containing `OpUnreachable`.
1167            // This does not need to be handled by the `OpPhi`.
1168            crate::proc::BoundsCheckPolicy::Unchecked => {
1169                function.consume(
1170                    Block::new(default_id),
1171                    Instruction::new(spirv::Op::Unreachable),
1172                );
1173            }
1174        }
1175
1176        let mut block = Block::new(merge_id);
1177        let result_id = self.id_gen.next();
1178        block.body.push(Instruction::phi(
1179            column_type_id,
1180            result_id,
1181            &var_parent_pairs,
1182        ));
1183
1184        function.consume(block, Instruction::return_value(result_id));
1185        function.to_words(&mut self.logical_layout.function_definitions);
1186        Ok(())
1187    }
1188
1189    fn write_function(
1190        &mut self,
1191        ir_function: &crate::Function,
1192        info: &FunctionInfo,
1193        ir_module: &crate::Module,
1194        mut interface: Option<FunctionInterface>,
1195        debug_info: &Option<DebugInfoInner>,
1196    ) -> Result<Word, Error> {
1197        self.write_wrapped_functions(ir_function, info, ir_module)?;
1198
1199        log::trace!("Generating code for {:?}", ir_function.name);
1200        let mut function = Function::default();
1201
1202        let prelude_id = self.id_gen.next();
1203        let mut prelude = Block::new(prelude_id);
1204        let mut ep_context = EntryPointContext {
1205            argument_ids: Vec::new(),
1206            results: Vec::new(),
1207            task_payload_variable_id: if let Some(ref i) = interface {
1208                i.task_payload.map(|a| self.global_variables[a].var_id)
1209            } else {
1210                None
1211            },
1212            mesh_state: None,
1213        };
1214
1215        let mut parameter_type_ids = Vec::with_capacity(ir_function.arguments.len());
1216
1217        let mut local_invocation_index_var_id = None;
1218        let mut local_invocation_index_id = None;
1219
1220        for argument in ir_function.arguments.iter() {
1221            let class = spirv::StorageClass::Input;
1222            let handle_ty = ir_module.types[argument.ty].inner.is_handle();
1223            let argument_type_id = if handle_ty {
1224                self.get_handle_pointer_type_id(argument.ty, spirv::StorageClass::UniformConstant)
1225            } else {
1226                self.get_handle_type_id(argument.ty)
1227            };
1228
1229            if let Some(ref mut iface) = interface {
1230                let id = if let Some(ref binding) = argument.binding {
1231                    let name = argument.name.as_deref();
1232
1233                    let varying_id = self.write_varying(
1234                        ir_module,
1235                        iface.stage,
1236                        class,
1237                        name,
1238                        argument.ty,
1239                        binding,
1240                    )?;
1241                    iface.varying_ids.push(varying_id);
1242                    let id = self.load_io_with_f16_polyfill(
1243                        &mut prelude.body,
1244                        varying_id,
1245                        argument_type_id,
1246                    );
1247                    if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationIndex) {
1248                        local_invocation_index_id = Some(id);
1249                        local_invocation_index_var_id = Some(varying_id);
1250                    }
1251
1252                    id
1253                } else if let crate::TypeInner::Struct { ref members, .. } =
1254                    ir_module.types[argument.ty].inner
1255                {
1256                    let struct_id = self.id_gen.next();
1257                    let mut constituent_ids = Vec::with_capacity(members.len());
1258                    for member in members {
1259                        let type_id = self.get_handle_type_id(member.ty);
1260                        let name = member.name.as_deref();
1261                        let binding = member.binding.as_ref().unwrap();
1262                        let varying_id = self.write_varying(
1263                            ir_module,
1264                            iface.stage,
1265                            class,
1266                            name,
1267                            member.ty,
1268                            binding,
1269                        )?;
1270                        iface.varying_ids.push(varying_id);
1271                        let id =
1272                            self.load_io_with_f16_polyfill(&mut prelude.body, varying_id, type_id);
1273                        constituent_ids.push(id);
1274                        if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationIndex)
1275                        {
1276                            local_invocation_index_id = Some(id);
1277                            local_invocation_index_var_id = Some(varying_id);
1278                        }
1279                    }
1280                    prelude.body.push(Instruction::composite_construct(
1281                        argument_type_id,
1282                        struct_id,
1283                        &constituent_ids,
1284                    ));
1285                    struct_id
1286                } else {
1287                    unreachable!("Missing argument binding on an entry point");
1288                };
1289                ep_context.argument_ids.push(id);
1290            } else {
1291                let argument_id = self.id_gen.next();
1292                let instruction = Instruction::function_parameter(argument_type_id, argument_id);
1293                if self.flags.contains(WriterFlags::DEBUG) {
1294                    if let Some(ref name) = argument.name {
1295                        self.debugs.push(Instruction::name(argument_id, name));
1296                    }
1297                }
1298                function.parameters.push(FunctionArgument {
1299                    instruction,
1300                    handle_id: if handle_ty {
1301                        let id = self.id_gen.next();
1302                        prelude.body.push(Instruction::load(
1303                            self.get_handle_type_id(argument.ty),
1304                            id,
1305                            argument_id,
1306                            None,
1307                        ));
1308                        id
1309                    } else {
1310                        0
1311                    },
1312                });
1313                parameter_type_ids.push(argument_type_id);
1314            };
1315        }
1316
1317        let return_type_id = match ir_function.result {
1318            Some(ref result) => {
1319                if let Some(ref mut iface) = interface {
1320                    let mut has_point_size = false;
1321                    let class = spirv::StorageClass::Output;
1322                    if let Some(ref binding) = result.binding {
1323                        has_point_size |=
1324                            *binding == crate::Binding::BuiltIn(crate::BuiltIn::PointSize);
1325                        let type_id = self.get_handle_type_id(result.ty);
1326                        let varying_id =
1327                            if *binding == crate::Binding::BuiltIn(crate::BuiltIn::MeshTaskSize) {
1328                                0
1329                            } else {
1330                                let varying_id = self.write_varying(
1331                                    ir_module,
1332                                    iface.stage,
1333                                    class,
1334                                    None,
1335                                    result.ty,
1336                                    binding,
1337                                )?;
1338                                iface.varying_ids.push(varying_id);
1339                                varying_id
1340                            };
1341                        ep_context.results.push(ResultMember {
1342                            id: varying_id,
1343                            type_id,
1344                            built_in: binding.to_built_in(),
1345                        });
1346                    } else if let crate::TypeInner::Struct { ref members, .. } =
1347                        ir_module.types[result.ty].inner
1348                    {
1349                        for member in members {
1350                            let type_id = self.get_handle_type_id(member.ty);
1351                            let name = member.name.as_deref();
1352                            let binding = member.binding.as_ref().unwrap();
1353                            has_point_size |=
1354                                *binding == crate::Binding::BuiltIn(crate::BuiltIn::PointSize);
1355                            // This isn't an actual builtin in SPIR-V. It can only appear as the
1356                            // output of a task shader and the output is used when writing the
1357                            // entry point return, in which case the id is ignored anyway.
1358                            let varying_id = if *binding
1359                                == crate::Binding::BuiltIn(crate::BuiltIn::MeshTaskSize)
1360                            {
1361                                0
1362                            } else {
1363                                let varying_id = self.write_varying(
1364                                    ir_module,
1365                                    iface.stage,
1366                                    class,
1367                                    name,
1368                                    member.ty,
1369                                    binding,
1370                                )?;
1371                                iface.varying_ids.push(varying_id);
1372                                varying_id
1373                            };
1374                            ep_context.results.push(ResultMember {
1375                                id: varying_id,
1376                                type_id,
1377                                built_in: binding.to_built_in(),
1378                            });
1379                        }
1380                    } else {
1381                        unreachable!("Missing result binding on an entry point");
1382                    }
1383
1384                    if self.flags.contains(WriterFlags::FORCE_POINT_SIZE)
1385                        && iface.stage == crate::ShaderStage::Vertex
1386                        && !has_point_size
1387                    {
1388                        // add point size artificially
1389                        let varying_id = self.id_gen.next();
1390                        let pointer_type_id = self.get_f32_pointer_type_id(class);
1391                        Instruction::variable(pointer_type_id, varying_id, class, None)
1392                            .to_words(&mut self.logical_layout.declarations);
1393                        self.decorate(
1394                            varying_id,
1395                            spirv::Decoration::BuiltIn,
1396                            &[spirv::BuiltIn::PointSize as u32],
1397                        );
1398                        iface.varying_ids.push(varying_id);
1399
1400                        let default_value_id = self.get_constant_scalar(crate::Literal::F32(1.0));
1401                        prelude
1402                            .body
1403                            .push(Instruction::store(varying_id, default_value_id, None));
1404                    }
1405                    if iface.stage == crate::ShaderStage::Task {
1406                        self.get_vec3u_type_id()
1407                    } else {
1408                        self.void_type
1409                    }
1410                } else {
1411                    self.get_handle_type_id(result.ty)
1412                }
1413            }
1414            None => self.void_type,
1415        };
1416
1417        if let Some(ref mut iface) = interface {
1418            if let Some(task_payload) = iface.task_payload {
1419                iface
1420                    .varying_ids
1421                    .push(self.global_variables[task_payload].var_id);
1422            }
1423            self.write_entry_point_mesh_shader_info(
1424                iface,
1425                local_invocation_index_var_id,
1426                ir_module,
1427                &mut ep_context,
1428            )?;
1429        }
1430
1431        let lookup_function_type = LookupFunctionType {
1432            parameter_type_ids,
1433            return_type_id,
1434        };
1435
1436        let function_id = self.id_gen.next();
1437        if self.flags.contains(WriterFlags::DEBUG) {
1438            if let Some(ref name) = ir_function.name {
1439                self.debugs.push(Instruction::name(function_id, name));
1440            }
1441        }
1442
1443        let function_type = self.get_function_type(lookup_function_type);
1444        function.signature = Some(Instruction::function(
1445            return_type_id,
1446            function_id,
1447            spirv::FunctionControl::empty(),
1448            function_type,
1449        ));
1450
1451        if interface.is_some() {
1452            function.entry_point_context = Some(ep_context);
1453        }
1454
1455        // fill up the `GlobalVariable::access_id`
1456        for gv in self.global_variables.iter_mut() {
1457            gv.reset_for_function();
1458        }
1459        for (handle, var) in ir_module.global_variables.iter() {
1460            if info[handle].is_empty() {
1461                continue;
1462            }
1463
1464            let mut gv = self.global_variables[handle].clone();
1465            if let Some(ref mut iface) = interface {
1466                // Have to include global variables in the interface
1467                if self.physical_layout.version >= 0x10400 && iface.task_payload != Some(handle) {
1468                    iface.varying_ids.push(gv.var_id);
1469                }
1470            }
1471
1472            match ir_module.types[var.ty].inner {
1473                // Any that are binding arrays we skip as we cannot load the array, we must load the result after indexing.
1474                crate::TypeInner::BindingArray { .. } => {
1475                    gv.access_id = gv.var_id;
1476                }
1477                _ => {
1478                    // Handle globals are pre-emitted and should be loaded automatically.
1479                    if var.space == crate::AddressSpace::Handle {
1480                        let var_type_id = self.get_handle_type_id(var.ty);
1481                        let id = self.id_gen.next();
1482                        prelude
1483                            .body
1484                            .push(Instruction::load(var_type_id, id, gv.var_id, None));
1485                        gv.access_id = gv.var_id;
1486                        gv.handle_id = id;
1487                    } else if global_needs_wrapper(ir_module, var) {
1488                        let class = map_storage_class(var.space);
1489                        let pointer_type_id = match self.std140_compat_uniform_types.get(&var.ty) {
1490                            Some(std140_type_info) if var.space == crate::AddressSpace::Uniform => {
1491                                self.get_pointer_type_id(std140_type_info.type_id, class)
1492                            }
1493                            _ => self.get_handle_pointer_type_id(var.ty, class),
1494                        };
1495                        let index_id = self.get_index_constant(0);
1496                        let id = self.id_gen.next();
1497                        prelude.body.push(Instruction::access_chain(
1498                            pointer_type_id,
1499                            id,
1500                            gv.var_id,
1501                            &[index_id],
1502                        ));
1503                        gv.access_id = id;
1504                    } else {
1505                        // by default, the variable ID is accessed as is
1506                        gv.access_id = gv.var_id;
1507                    };
1508                }
1509            }
1510
1511            // work around borrow checking in the presence of `self.xxx()` calls
1512            self.global_variables[handle] = gv;
1513        }
1514
1515        // Create a `BlockContext` for generating SPIR-V for the function's
1516        // body.
1517        let mut context = BlockContext {
1518            ir_module,
1519            ir_function,
1520            fun_info: info,
1521            function: &mut function,
1522            // Re-use the cached expression table from prior functions.
1523            cached: core::mem::take(&mut self.saved_cached),
1524
1525            // Steal the Writer's temp list for a bit.
1526            temp_list: core::mem::take(&mut self.temp_list),
1527            force_loop_bounding: self.force_loop_bounding,
1528            writer: self,
1529            expression_constness: super::ExpressionConstnessTracker::from_arena(
1530                &ir_function.expressions,
1531            ),
1532            ray_query_tracker_expr: crate::FastHashMap::default(),
1533        };
1534
1535        // fill up the pre-emitted and const expressions
1536        context.cached.reset(ir_function.expressions.len());
1537        for (handle, expr) in ir_function.expressions.iter() {
1538            if (expr.needs_pre_emit() && !matches!(*expr, crate::Expression::LocalVariable(_)))
1539                || context.expression_constness.is_const(handle)
1540            {
1541                context.cache_expression_value(handle, &mut prelude)?;
1542            }
1543        }
1544
1545        for (handle, variable) in ir_function.local_variables.iter() {
1546            let id = context.gen_id();
1547
1548            if context.writer.flags.contains(WriterFlags::DEBUG) {
1549                if let Some(ref name) = variable.name {
1550                    context.writer.debugs.push(Instruction::name(id, name));
1551                }
1552            }
1553
1554            let init_word = variable.init.map(|constant| context.cached[constant]);
1555            let pointer_type_id = context
1556                .writer
1557                .get_handle_pointer_type_id(variable.ty, spirv::StorageClass::Function);
1558            let instruction = Instruction::variable(
1559                pointer_type_id,
1560                id,
1561                spirv::StorageClass::Function,
1562                init_word.or_else(|| match ir_module.types[variable.ty].inner {
1563                    crate::TypeInner::RayQuery { .. } => None,
1564                    _ => {
1565                        let type_id = context.get_handle_type_id(variable.ty);
1566                        Some(context.writer.write_constant_null(type_id))
1567                    }
1568                }),
1569            );
1570
1571            context
1572                .function
1573                .variables
1574                .insert(handle, LocalVariable { id, instruction });
1575
1576            if let crate::TypeInner::RayQuery { .. } = ir_module.types[variable.ty].inner {
1577                // Don't refactor this into a struct: Although spirv itself allows opaque types in structs,
1578                // the vulkan environment for spirv does not. Putting ray queries into structs can cause
1579                // confusing bugs.
1580                let u32_type_id = context.writer.get_u32_type_id();
1581                let ptr_u32_type_id = context
1582                    .writer
1583                    .get_pointer_type_id(u32_type_id, spirv::StorageClass::Function);
1584                let tracker_id = context.gen_id();
1585                let tracker_init_id = context.writer.get_constant_scalar(crate::Literal::U32(
1586                    crate::back::RayQueryPoint::empty().bits(),
1587                ));
1588                let tracker_instruction = Instruction::variable(
1589                    ptr_u32_type_id,
1590                    tracker_id,
1591                    spirv::StorageClass::Function,
1592                    Some(tracker_init_id),
1593                );
1594
1595                context
1596                    .function
1597                    .ray_query_initialization_tracker_variables
1598                    .insert(
1599                        handle,
1600                        LocalVariable {
1601                            id: tracker_id,
1602                            instruction: tracker_instruction,
1603                        },
1604                    );
1605                let f32_type_id = context.writer.get_f32_type_id();
1606                let ptr_f32_type_id = context
1607                    .writer
1608                    .get_pointer_type_id(f32_type_id, spirv::StorageClass::Function);
1609                let t_max_tracker_id = context.gen_id();
1610                let t_max_tracker_init_id =
1611                    context.writer.get_constant_scalar(crate::Literal::F32(0.0));
1612                let t_max_tracker_instruction = Instruction::variable(
1613                    ptr_f32_type_id,
1614                    t_max_tracker_id,
1615                    spirv::StorageClass::Function,
1616                    Some(t_max_tracker_init_id),
1617                );
1618
1619                context.function.ray_query_t_max_tracker_variables.insert(
1620                    handle,
1621                    LocalVariable {
1622                        id: t_max_tracker_id,
1623                        instruction: t_max_tracker_instruction,
1624                    },
1625                );
1626            }
1627        }
1628
1629        for (handle, expr) in ir_function.expressions.iter() {
1630            match *expr {
1631                crate::Expression::LocalVariable(_) => {
1632                    // Cache the `OpVariable` instruction we generated above as
1633                    // the value of this expression.
1634                    context.cache_expression_value(handle, &mut prelude)?;
1635                }
1636                crate::Expression::Access { base, .. }
1637                | crate::Expression::AccessIndex { base, .. } => {
1638                    // Count references to `base` by `Access` and `AccessIndex`
1639                    // instructions. See `access_uses` for details.
1640                    *context.function.access_uses.entry(base).or_insert(0) += 1;
1641                }
1642                _ => {}
1643            }
1644        }
1645
1646        let next_id = context.gen_id();
1647
1648        context
1649            .function
1650            .consume(prelude, Instruction::branch(next_id));
1651
1652        let workgroup_vars_init_exit_block_id =
1653            match (context.writer.zero_initialize_workgroup_memory, interface) {
1654                (
1655                    super::ZeroInitializeWorkgroupMemoryMode::Polyfill,
1656                    Some(
1657                        ref mut interface @ FunctionInterface {
1658                            stage:
1659                                crate::ShaderStage::Compute
1660                                | crate::ShaderStage::Mesh
1661                                | crate::ShaderStage::Task,
1662                            ..
1663                        },
1664                    ),
1665                ) => context.writer.generate_workgroup_vars_init_block(
1666                    next_id,
1667                    ir_module,
1668                    info,
1669                    local_invocation_index_id,
1670                    interface,
1671                    context.function,
1672                ),
1673                _ => None,
1674            };
1675
1676        let main_id = if let Some(exit_id) = workgroup_vars_init_exit_block_id {
1677            exit_id
1678        } else {
1679            next_id
1680        };
1681
1682        context.write_function_body(main_id, debug_info.as_ref())?;
1683
1684        // Consume the `BlockContext`, ending its borrows and letting the
1685        // `Writer` steal back its cached expression table and temp_list.
1686        let BlockContext {
1687            cached, temp_list, ..
1688        } = context;
1689        self.saved_cached = cached;
1690        self.temp_list = temp_list;
1691
1692        function.to_words(&mut self.logical_layout.function_definitions);
1693
1694        if let Some(EntryPointContext {
1695            mesh_state: Some(ref mesh_state),
1696            ..
1697        }) = function.entry_point_context
1698        {
1699            self.write_mesh_shader_wrapper(mesh_state, function_id)
1700        } else if let Some(EntryPointContext {
1701            task_payload_variable_id: Some(tp),
1702            ..
1703        }) = function.entry_point_context
1704        {
1705            self.write_task_shader_wrapper(tp, function_id)
1706        } else {
1707            Ok(function_id)
1708        }
1709    }
1710
1711    fn write_execution_mode(
1712        &mut self,
1713        function_id: Word,
1714        mode: spirv::ExecutionMode,
1715    ) -> Result<(), Error> {
1716        //self.check(mode.required_capabilities())?;
1717        Instruction::execution_mode(function_id, mode, &[])
1718            .to_words(&mut self.logical_layout.execution_modes);
1719        Ok(())
1720    }
1721
1722    // TODO Move to instructions module
1723    fn write_entry_point(
1724        &mut self,
1725        entry_point: &crate::EntryPoint,
1726        info: &FunctionInfo,
1727        ir_module: &crate::Module,
1728        debug_info: &Option<DebugInfoInner>,
1729    ) -> Result<Instruction, Error> {
1730        let mut interface_ids = Vec::new();
1731
1732        let function_id = self.write_function(
1733            &entry_point.function,
1734            info,
1735            ir_module,
1736            Some(FunctionInterface {
1737                varying_ids: &mut interface_ids,
1738                stage: entry_point.stage,
1739                task_payload: entry_point.task_payload,
1740                mesh_info: entry_point.mesh_info.clone(),
1741                workgroup_size: entry_point.workgroup_size,
1742            }),
1743            debug_info,
1744        )?;
1745
1746        let exec_model = match entry_point.stage {
1747            crate::ShaderStage::Vertex => spirv::ExecutionModel::Vertex,
1748            crate::ShaderStage::Fragment => {
1749                self.write_execution_mode(function_id, spirv::ExecutionMode::OriginUpperLeft)?;
1750                match entry_point.early_depth_test {
1751                    Some(crate::EarlyDepthTest::Force) => {
1752                        self.write_execution_mode(
1753                            function_id,
1754                            spirv::ExecutionMode::EarlyFragmentTests,
1755                        )?;
1756                    }
1757                    Some(crate::EarlyDepthTest::Allow { conservative }) => {
1758                        // TODO: Consider emitting EarlyAndLateFragmentTestsAMD here, if available.
1759                        // https://github.khronos.org/SPIRV-Registry/extensions/AMD/SPV_AMD_shader_early_and_late_fragment_tests.html
1760                        // This permits early depth tests even if the shader writes to a storage
1761                        // binding
1762                        match conservative {
1763                            crate::ConservativeDepth::GreaterEqual => self.write_execution_mode(
1764                                function_id,
1765                                spirv::ExecutionMode::DepthGreater,
1766                            )?,
1767                            crate::ConservativeDepth::LessEqual => self.write_execution_mode(
1768                                function_id,
1769                                spirv::ExecutionMode::DepthLess,
1770                            )?,
1771                            crate::ConservativeDepth::Unchanged => self.write_execution_mode(
1772                                function_id,
1773                                spirv::ExecutionMode::DepthUnchanged,
1774                            )?,
1775                        }
1776                    }
1777                    None => {}
1778                }
1779                if let Some(ref result) = entry_point.function.result {
1780                    if contains_builtin(
1781                        result.binding.as_ref(),
1782                        result.ty,
1783                        &ir_module.types,
1784                        crate::BuiltIn::FragDepth,
1785                    ) {
1786                        self.write_execution_mode(
1787                            function_id,
1788                            spirv::ExecutionMode::DepthReplacing,
1789                        )?;
1790                    }
1791                }
1792                spirv::ExecutionModel::Fragment
1793            }
1794            crate::ShaderStage::Compute => {
1795                let execution_mode = spirv::ExecutionMode::LocalSize;
1796                Instruction::execution_mode(
1797                    function_id,
1798                    execution_mode,
1799                    &entry_point.workgroup_size,
1800                )
1801                .to_words(&mut self.logical_layout.execution_modes);
1802                spirv::ExecutionModel::GLCompute
1803            }
1804            crate::ShaderStage::Task => {
1805                let execution_mode = spirv::ExecutionMode::LocalSize;
1806                Instruction::execution_mode(
1807                    function_id,
1808                    execution_mode,
1809                    &entry_point.workgroup_size,
1810                )
1811                .to_words(&mut self.logical_layout.execution_modes);
1812                spirv::ExecutionModel::TaskEXT
1813            }
1814            crate::ShaderStage::Mesh => {
1815                let execution_mode = spirv::ExecutionMode::LocalSize;
1816                Instruction::execution_mode(
1817                    function_id,
1818                    execution_mode,
1819                    &entry_point.workgroup_size,
1820                )
1821                .to_words(&mut self.logical_layout.execution_modes);
1822                let mesh_info = entry_point.mesh_info.as_ref().unwrap();
1823                Instruction::execution_mode(
1824                    function_id,
1825                    match mesh_info.topology {
1826                        crate::MeshOutputTopology::Points => spirv::ExecutionMode::OutputPoints,
1827                        crate::MeshOutputTopology::Lines => spirv::ExecutionMode::OutputLinesEXT,
1828                        crate::MeshOutputTopology::Triangles => {
1829                            spirv::ExecutionMode::OutputTrianglesEXT
1830                        }
1831                    },
1832                    &[],
1833                )
1834                .to_words(&mut self.logical_layout.execution_modes);
1835                Instruction::execution_mode(
1836                    function_id,
1837                    spirv::ExecutionMode::OutputVertices,
1838                    core::slice::from_ref(&mesh_info.max_vertices),
1839                )
1840                .to_words(&mut self.logical_layout.execution_modes);
1841                Instruction::execution_mode(
1842                    function_id,
1843                    spirv::ExecutionMode::OutputPrimitivesEXT,
1844                    core::slice::from_ref(&mesh_info.max_primitives),
1845                )
1846                .to_words(&mut self.logical_layout.execution_modes);
1847                spirv::ExecutionModel::MeshEXT
1848            }
1849            crate::ShaderStage::RayGeneration
1850            | crate::ShaderStage::AnyHit
1851            | crate::ShaderStage::ClosestHit
1852            | crate::ShaderStage::Miss => unreachable!(),
1853        };
1854        //self.check(exec_model.required_capabilities())?;
1855
1856        Ok(Instruction::entry_point(
1857            exec_model,
1858            function_id,
1859            &entry_point.name,
1860            interface_ids.as_slice(),
1861        ))
1862    }
1863
1864    fn make_scalar(&mut self, id: Word, scalar: crate::Scalar) -> Instruction {
1865        use crate::ScalarKind as Sk;
1866
1867        let bits = (scalar.width * BITS_PER_BYTE) as u32;
1868        match scalar.kind {
1869            Sk::Sint | Sk::Uint => {
1870                let signedness = if scalar.kind == Sk::Sint {
1871                    super::instructions::Signedness::Signed
1872                } else {
1873                    super::instructions::Signedness::Unsigned
1874                };
1875                let cap = match bits {
1876                    8 => Some(spirv::Capability::Int8),
1877                    16 => Some(spirv::Capability::Int16),
1878                    64 => Some(spirv::Capability::Int64),
1879                    _ => None,
1880                };
1881                if let Some(cap) = cap {
1882                    self.capabilities_used.insert(cap);
1883                }
1884                Instruction::type_int(id, bits, signedness)
1885            }
1886            Sk::Float => {
1887                if bits == 64 {
1888                    self.capabilities_used.insert(spirv::Capability::Float64);
1889                }
1890                if bits == 16 {
1891                    self.capabilities_used.insert(spirv::Capability::Float16);
1892                    self.capabilities_used
1893                        .insert(spirv::Capability::StorageBuffer16BitAccess);
1894                    self.capabilities_used
1895                        .insert(spirv::Capability::UniformAndStorageBuffer16BitAccess);
1896                    if self.use_storage_input_output_16 {
1897                        self.capabilities_used
1898                            .insert(spirv::Capability::StorageInputOutput16);
1899                    }
1900                }
1901                Instruction::type_float(id, bits)
1902            }
1903            Sk::Bool => Instruction::type_bool(id),
1904            Sk::AbstractInt | Sk::AbstractFloat => {
1905                unreachable!("abstract types should never reach the backend");
1906            }
1907        }
1908    }
1909
1910    fn request_type_capabilities(&mut self, inner: &crate::TypeInner) -> Result<(), Error> {
1911        match *inner {
1912            crate::TypeInner::Image {
1913                dim,
1914                arrayed,
1915                class,
1916            } => {
1917                let sampled = match class {
1918                    crate::ImageClass::Sampled { .. } => true,
1919                    crate::ImageClass::Depth { .. } => true,
1920                    crate::ImageClass::Storage { format, .. } => {
1921                        self.request_image_format_capabilities(format.into())?;
1922                        false
1923                    }
1924                    crate::ImageClass::External => unimplemented!(),
1925                };
1926
1927                match dim {
1928                    crate::ImageDimension::D1 => {
1929                        if sampled {
1930                            self.require_any("sampled 1D images", &[spirv::Capability::Sampled1D])?;
1931                        } else {
1932                            self.require_any("1D storage images", &[spirv::Capability::Image1D])?;
1933                        }
1934                    }
1935                    crate::ImageDimension::Cube if arrayed => {
1936                        if sampled {
1937                            self.require_any(
1938                                "sampled cube array images",
1939                                &[spirv::Capability::SampledCubeArray],
1940                            )?;
1941                        } else {
1942                            self.require_any(
1943                                "cube array storage images",
1944                                &[spirv::Capability::ImageCubeArray],
1945                            )?;
1946                        }
1947                    }
1948                    _ => {}
1949                }
1950            }
1951            crate::TypeInner::AccelerationStructure { .. } => {
1952                self.require_any("Acceleration Structure", &[spirv::Capability::RayQueryKHR])?;
1953            }
1954            crate::TypeInner::RayQuery { .. } => {
1955                self.require_any("Ray Query", &[spirv::Capability::RayQueryKHR])?;
1956            }
1957            crate::TypeInner::Atomic(crate::Scalar { width: 8, kind: _ }) => {
1958                self.require_any("64 bit integer atomics", &[spirv::Capability::Int64Atomics])?;
1959            }
1960            crate::TypeInner::Atomic(crate::Scalar {
1961                width: 4,
1962                kind: crate::ScalarKind::Float,
1963            }) => {
1964                self.require_any(
1965                    "32 bit floating-point atomics",
1966                    &[spirv::Capability::AtomicFloat32AddEXT],
1967                )?;
1968                self.use_extension("SPV_EXT_shader_atomic_float_add");
1969            }
1970            // 16 bit floating-point support requires Float16 capability
1971            crate::TypeInner::Matrix {
1972                scalar: crate::Scalar::F16,
1973                ..
1974            }
1975            | crate::TypeInner::Vector {
1976                scalar: crate::Scalar::F16,
1977                ..
1978            }
1979            | crate::TypeInner::Scalar(crate::Scalar::F16) => {
1980                self.require_any("16 bit floating-point", &[spirv::Capability::Float16])?;
1981                self.use_extension("SPV_KHR_16bit_storage");
1982            }
1983            // Cooperative types and ops
1984            crate::TypeInner::CooperativeMatrix { .. } => {
1985                self.require_any(
1986                    "cooperative matrix",
1987                    &[spirv::Capability::CooperativeMatrixKHR],
1988                )?;
1989                self.require_any("memory model", &[spirv::Capability::VulkanMemoryModel])?;
1990                self.use_extension("SPV_KHR_cooperative_matrix");
1991                self.use_extension("SPV_KHR_vulkan_memory_model");
1992            }
1993            _ => {}
1994        }
1995        Ok(())
1996    }
1997
1998    fn write_numeric_type_declaration_local(&mut self, id: Word, numeric: NumericType) {
1999        let instruction = match numeric {
2000            NumericType::Scalar(scalar) => self.make_scalar(id, scalar),
2001            NumericType::Vector { size, scalar } => {
2002                let scalar_id = self.get_numeric_type_id(NumericType::Scalar(scalar));
2003                Instruction::type_vector(id, scalar_id, size)
2004            }
2005            NumericType::Matrix {
2006                columns,
2007                rows,
2008                scalar,
2009            } => {
2010                let column_id =
2011                    self.get_numeric_type_id(NumericType::Vector { size: rows, scalar });
2012                Instruction::type_matrix(id, column_id, columns)
2013            }
2014        };
2015
2016        instruction.to_words(&mut self.logical_layout.declarations);
2017    }
2018
2019    fn write_cooperative_type_declaration_local(&mut self, id: Word, coop: CooperativeType) {
2020        let instruction = match coop {
2021            CooperativeType::Matrix {
2022                columns,
2023                rows,
2024                scalar,
2025                role,
2026            } => {
2027                let scalar_id =
2028                    self.get_localtype_id(LocalType::Numeric(NumericType::Scalar(scalar)));
2029                let scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32);
2030                let columns_id = self.get_index_constant(columns as u32);
2031                let rows_id = self.get_index_constant(rows as u32);
2032                let role_id =
2033                    self.get_index_constant(spirv::CooperativeMatrixUse::from(role) as u32);
2034                Instruction::type_coop_matrix(id, scalar_id, scope_id, rows_id, columns_id, role_id)
2035            }
2036        };
2037
2038        instruction.to_words(&mut self.logical_layout.declarations);
2039    }
2040
2041    fn write_type_declaration_local(&mut self, id: Word, local_ty: LocalType) {
2042        let instruction = match local_ty {
2043            LocalType::Numeric(numeric) => {
2044                self.write_numeric_type_declaration_local(id, numeric);
2045                return;
2046            }
2047            LocalType::Cooperative(coop) => {
2048                self.write_cooperative_type_declaration_local(id, coop);
2049                return;
2050            }
2051            LocalType::Pointer { base, class } => Instruction::type_pointer(id, class, base),
2052            LocalType::Image(image) => {
2053                let local_type = LocalType::Numeric(NumericType::Scalar(image.sampled_type));
2054                let type_id = self.get_localtype_id(local_type);
2055                Instruction::type_image(id, type_id, image.dim, image.flags, image.image_format)
2056            }
2057            LocalType::Sampler => Instruction::type_sampler(id),
2058            LocalType::SampledImage { image_type_id } => {
2059                Instruction::type_sampled_image(id, image_type_id)
2060            }
2061            LocalType::BindingArray { base, size } => {
2062                let inner_ty = self.get_handle_type_id(base);
2063                let scalar_id = self.get_constant_scalar(crate::Literal::U32(size));
2064                Instruction::type_array(id, inner_ty, scalar_id)
2065            }
2066            LocalType::AccelerationStructure => Instruction::type_acceleration_structure(id),
2067            LocalType::RayQuery => Instruction::type_ray_query(id),
2068        };
2069
2070        instruction.to_words(&mut self.logical_layout.declarations);
2071    }
2072
2073    fn write_type_declaration_arena(
2074        &mut self,
2075        module: &crate::Module,
2076        handle: Handle<crate::Type>,
2077    ) -> Result<Word, Error> {
2078        let ty = &module.types[handle];
2079        // If it's a type that needs SPIR-V capabilities, request them now.
2080        // This needs to happen regardless of the LocalType lookup succeeding,
2081        // because some types which map to the same LocalType have different
2082        // capability requirements. See https://github.com/gfx-rs/wgpu/issues/5569
2083        self.request_type_capabilities(&ty.inner)?;
2084        let id = if let Some(local) = self.localtype_from_inner(&ty.inner) {
2085            // This type can be represented as a `LocalType`, so check if we've
2086            // already written an instruction for it. If not, do so now, with
2087            // `write_type_declaration_local`.
2088            match self.lookup_type.entry(LookupType::Local(local)) {
2089                // We already have an id for this `LocalType`.
2090                Entry::Occupied(e) => *e.get(),
2091
2092                // It's a type we haven't seen before.
2093                Entry::Vacant(e) => {
2094                    let id = self.id_gen.next();
2095                    e.insert(id);
2096
2097                    self.write_type_declaration_local(id, local);
2098
2099                    id
2100                }
2101            }
2102        } else {
2103            use spirv::Decoration;
2104
2105            let id = self.id_gen.next();
2106            let instruction = match ty.inner {
2107                crate::TypeInner::Array { base, size, stride } => {
2108                    self.decorate(id, Decoration::ArrayStride, &[stride]);
2109
2110                    let type_id = self.get_handle_type_id(base);
2111                    match size.resolve(module.to_ctx())? {
2112                        crate::proc::IndexableLength::Known(length) => {
2113                            let length_id = self.get_index_constant(length);
2114                            Instruction::type_array(id, type_id, length_id)
2115                        }
2116                        crate::proc::IndexableLength::Dynamic => {
2117                            Instruction::type_runtime_array(id, type_id)
2118                        }
2119                    }
2120                }
2121                crate::TypeInner::BindingArray { base, size } => {
2122                    let type_id = self.get_handle_type_id(base);
2123                    match size.resolve(module.to_ctx())? {
2124                        crate::proc::IndexableLength::Known(length) => {
2125                            let length_id = self.get_index_constant(length);
2126                            Instruction::type_array(id, type_id, length_id)
2127                        }
2128                        crate::proc::IndexableLength::Dynamic => {
2129                            Instruction::type_runtime_array(id, type_id)
2130                        }
2131                    }
2132                }
2133                crate::TypeInner::Struct {
2134                    ref members,
2135                    span: _,
2136                } => {
2137                    let mut has_runtime_array = false;
2138                    let mut member_ids = Vec::with_capacity(members.len());
2139                    for (index, member) in members.iter().enumerate() {
2140                        let member_ty = &module.types[member.ty];
2141                        match member_ty.inner {
2142                            crate::TypeInner::Array {
2143                                base: _,
2144                                size: crate::ArraySize::Dynamic,
2145                                stride: _,
2146                            } => {
2147                                has_runtime_array = true;
2148                            }
2149                            _ => (),
2150                        }
2151                        self.decorate_struct_member(id, index, member, &module.types)?;
2152                        let member_id = self.get_handle_type_id(member.ty);
2153                        member_ids.push(member_id);
2154                    }
2155                    if has_runtime_array {
2156                        self.decorate(id, Decoration::Block, &[]);
2157                    }
2158                    Instruction::type_struct(id, member_ids.as_slice())
2159                }
2160
2161                // These all have TypeLocal representations, so they should have been
2162                // handled by `write_type_declaration_local` above.
2163                crate::TypeInner::Scalar(_)
2164                | crate::TypeInner::Atomic(_)
2165                | crate::TypeInner::Vector { .. }
2166                | crate::TypeInner::Matrix { .. }
2167                | crate::TypeInner::CooperativeMatrix { .. }
2168                | crate::TypeInner::Pointer { .. }
2169                | crate::TypeInner::ValuePointer { .. }
2170                | crate::TypeInner::Image { .. }
2171                | crate::TypeInner::Sampler { .. }
2172                | crate::TypeInner::AccelerationStructure { .. }
2173                | crate::TypeInner::RayQuery { .. } => unreachable!(),
2174            };
2175
2176            instruction.to_words(&mut self.logical_layout.declarations);
2177            id
2178        };
2179
2180        // Add this handle as a new alias for that type.
2181        self.lookup_type.insert(LookupType::Handle(handle), id);
2182
2183        if self.flags.contains(WriterFlags::DEBUG) {
2184            if let Some(ref name) = ty.name {
2185                self.debugs.push(Instruction::name(id, name));
2186            }
2187        }
2188
2189        Ok(id)
2190    }
2191
2192    /// Writes a std140 layout compatible type declaration for a type. Returns
2193    /// the ID of the declared type, or None if no declaration is required.
2194    ///
2195    /// This should be called for any type for which there exists a
2196    /// [`GlobalVariable`] in the [`Uniform`] address space. If the type already
2197    /// adheres to std140 layout rules it will return without declaring any
2198    /// types. If the type contains another type which requires a std140
2199    /// compatible type declaration, it will recursively call itself.
2200    ///
2201    /// When `handle` refers to a [`TypeInner::Matrix`] with 2 rows, the
2202    /// declared type will be an `OpTypeStruct` containing an `OpVector` for
2203    /// each of the matrix's columns.
2204    ///
2205    /// When `handle` refers to a [`TypeInner::Array`] whose base type is a
2206    /// matrix with 2 rows, this will declare an `OpTypeArray` whose element
2207    /// type is the matrix's corresponding std140 compatible type.
2208    ///
2209    /// When `handle` refers to a [`TypeInner::Struct`] and any of its members
2210    /// require a std140 compatible type declaration, this will declare a new
2211    /// struct with the following rules:
2212    /// * Struct or array members will be declared with their std140 compatible
2213    ///   type declaration, if one is required.
2214    /// * Two-row matrix members will have each of their columns hoisted
2215    ///   directly into the struct as 2-component vector members.
2216    /// * All other members will be declared with their normal type.
2217    ///
2218    /// Note that this means the Naga IR index of a struct member may not match
2219    /// the index in the generated SPIR-V. The mapping can be obtained via
2220    /// `Std140TypeInfo::member_indices`.
2221    ///
2222    /// [`GlobalVariable`]: crate::GlobalVariable
2223    /// [`Uniform`]: crate::AddressSpace::Uniform
2224    /// [`TypeInner::Matrix`]: crate::TypeInner::Matrix
2225    /// [`TypeInner::Array`]: crate::TypeInner::Array
2226    /// [`TypeInner::Struct`]: crate::TypeInner::Struct
2227    fn write_std140_compat_type_declaration(
2228        &mut self,
2229        module: &crate::Module,
2230        handle: Handle<crate::Type>,
2231    ) -> Result<Option<Word>, Error> {
2232        if let Some(std140_type_info) = self.std140_compat_uniform_types.get(&handle) {
2233            return Ok(Some(std140_type_info.type_id));
2234        }
2235
2236        let type_inner = &module.types[handle].inner;
2237        let std140_type_id = match *type_inner {
2238            crate::TypeInner::Matrix {
2239                columns,
2240                rows: rows @ crate::VectorSize::Bi,
2241                scalar,
2242            } => {
2243                let std140_type_id = self.id_gen.next();
2244                let mut member_type_ids: ArrayVec<Word, 4> = ArrayVec::new();
2245                let column_type_id =
2246                    self.get_numeric_type_id(NumericType::Vector { size: rows, scalar });
2247                for column in 0..columns as u32 {
2248                    member_type_ids.push(column_type_id);
2249                    self.annotations.push(Instruction::member_decorate(
2250                        std140_type_id,
2251                        column,
2252                        spirv::Decoration::Offset,
2253                        &[column * rows as u32 * scalar.width as u32],
2254                    ));
2255                    if self.flags.contains(WriterFlags::DEBUG) {
2256                        self.debugs.push(Instruction::member_name(
2257                            std140_type_id,
2258                            column,
2259                            &format!("col{column}"),
2260                        ));
2261                    }
2262                }
2263                Instruction::type_struct(std140_type_id, &member_type_ids)
2264                    .to_words(&mut self.logical_layout.declarations);
2265                self.std140_compat_uniform_types.insert(
2266                    handle,
2267                    Std140CompatTypeInfo {
2268                        type_id: std140_type_id,
2269                        member_indices: Vec::new(),
2270                    },
2271                );
2272                Some(std140_type_id)
2273            }
2274            crate::TypeInner::Array { base, size, stride } => {
2275                match self.write_std140_compat_type_declaration(module, base)? {
2276                    Some(std140_base_type_id) => {
2277                        let std140_type_id = self.id_gen.next();
2278                        self.decorate(std140_type_id, spirv::Decoration::ArrayStride, &[stride]);
2279                        let instruction = match size.resolve(module.to_ctx())? {
2280                            crate::proc::IndexableLength::Known(length) => {
2281                                let length_id = self.get_index_constant(length);
2282                                Instruction::type_array(
2283                                    std140_type_id,
2284                                    std140_base_type_id,
2285                                    length_id,
2286                                )
2287                            }
2288                            crate::proc::IndexableLength::Dynamic => {
2289                                unreachable!()
2290                            }
2291                        };
2292                        instruction.to_words(&mut self.logical_layout.declarations);
2293                        self.std140_compat_uniform_types.insert(
2294                            handle,
2295                            Std140CompatTypeInfo {
2296                                type_id: std140_type_id,
2297                                member_indices: Vec::new(),
2298                            },
2299                        );
2300                        Some(std140_type_id)
2301                    }
2302                    None => None,
2303                }
2304            }
2305            crate::TypeInner::Struct { ref members, .. } => {
2306                let mut needs_std140_type = false;
2307                for member in members {
2308                    match module.types[member.ty].inner {
2309                        // We don't need to write a std140 type for the matrix itself as
2310                        // it will be decomposed into the parent struct. As a result, the
2311                        // struct does need a std140 type, however.
2312                        crate::TypeInner::Matrix {
2313                            rows: crate::VectorSize::Bi,
2314                            ..
2315                        } => needs_std140_type = true,
2316                        // If an array member needs a std140 type, because it is an array
2317                        // (of an array, etc) of `matCx2`s, then the struct also needs
2318                        // a std140 type which uses the std140 type for this member.
2319                        crate::TypeInner::Array { .. }
2320                            if self
2321                                .write_std140_compat_type_declaration(module, member.ty)?
2322                                .is_some() =>
2323                        {
2324                            needs_std140_type = true;
2325                        }
2326                        _ => {}
2327                    }
2328                }
2329
2330                if needs_std140_type {
2331                    let std140_type_id = self.id_gen.next();
2332                    let mut member_ids = Vec::new();
2333                    let mut member_indices = Vec::new();
2334                    let mut next_index = 0;
2335
2336                    for member in members {
2337                        member_indices.push(next_index);
2338                        match module.types[member.ty].inner {
2339                            crate::TypeInner::Matrix {
2340                                columns,
2341                                rows: rows @ crate::VectorSize::Bi,
2342                                scalar,
2343                            } => {
2344                                let vector_type_id =
2345                                    self.get_numeric_type_id(NumericType::Vector {
2346                                        size: rows,
2347                                        scalar,
2348                                    });
2349                                for column in 0..columns as u32 {
2350                                    self.annotations.push(Instruction::member_decorate(
2351                                        std140_type_id,
2352                                        next_index,
2353                                        spirv::Decoration::Offset,
2354                                        &[member.offset
2355                                            + column * rows as u32 * scalar.width as u32],
2356                                    ));
2357                                    if self.flags.contains(WriterFlags::DEBUG) {
2358                                        if let Some(ref name) = member.name {
2359                                            self.debugs.push(Instruction::member_name(
2360                                                std140_type_id,
2361                                                next_index,
2362                                                &format!("{name}_col{column}"),
2363                                            ));
2364                                        }
2365                                    }
2366                                    member_ids.push(vector_type_id);
2367                                    next_index += 1;
2368                                }
2369                            }
2370                            _ => {
2371                                let member_id =
2372                                    match self.std140_compat_uniform_types.get(&member.ty) {
2373                                        Some(std140_member_type_info) => {
2374                                            self.annotations.push(Instruction::member_decorate(
2375                                                std140_type_id,
2376                                                next_index,
2377                                                spirv::Decoration::Offset,
2378                                                &[member.offset],
2379                                            ));
2380                                            if self.flags.contains(WriterFlags::DEBUG) {
2381                                                if let Some(ref name) = member.name {
2382                                                    self.debugs.push(Instruction::member_name(
2383                                                        std140_type_id,
2384                                                        next_index,
2385                                                        name,
2386                                                    ));
2387                                                }
2388                                            }
2389                                            std140_member_type_info.type_id
2390                                        }
2391                                        None => {
2392                                            self.decorate_struct_member(
2393                                                std140_type_id,
2394                                                next_index as usize,
2395                                                member,
2396                                                &module.types,
2397                                            )?;
2398                                            self.get_handle_type_id(member.ty)
2399                                        }
2400                                    };
2401                                member_ids.push(member_id);
2402                                next_index += 1;
2403                            }
2404                        }
2405                    }
2406
2407                    Instruction::type_struct(std140_type_id, &member_ids)
2408                        .to_words(&mut self.logical_layout.declarations);
2409                    self.std140_compat_uniform_types.insert(
2410                        handle,
2411                        Std140CompatTypeInfo {
2412                            type_id: std140_type_id,
2413                            member_indices,
2414                        },
2415                    );
2416                    Some(std140_type_id)
2417                } else {
2418                    None
2419                }
2420            }
2421            _ => None,
2422        };
2423
2424        if let Some(std140_type_id) = std140_type_id {
2425            if self.flags.contains(WriterFlags::DEBUG) {
2426                let name = format!("std140_{:?}", handle.for_debug(&module.types));
2427                self.debugs.push(Instruction::name(std140_type_id, &name));
2428            }
2429        }
2430        Ok(std140_type_id)
2431    }
2432
2433    fn request_image_format_capabilities(
2434        &mut self,
2435        format: spirv::ImageFormat,
2436    ) -> Result<(), Error> {
2437        use spirv::ImageFormat as If;
2438        match format {
2439            If::Rg32f
2440            | If::Rg16f
2441            | If::R11fG11fB10f
2442            | If::R16f
2443            | If::Rgba16
2444            | If::Rgb10A2
2445            | If::Rg16
2446            | If::Rg8
2447            | If::R16
2448            | If::R8
2449            | If::Rgba16Snorm
2450            | If::Rg16Snorm
2451            | If::Rg8Snorm
2452            | If::R16Snorm
2453            | If::R8Snorm
2454            | If::Rg32i
2455            | If::Rg16i
2456            | If::Rg8i
2457            | If::R16i
2458            | If::R8i
2459            | If::Rgb10a2ui
2460            | If::Rg32ui
2461            | If::Rg16ui
2462            | If::Rg8ui
2463            | If::R16ui
2464            | If::R8ui => self.require_any(
2465                "storage image format",
2466                &[spirv::Capability::StorageImageExtendedFormats],
2467            ),
2468            If::R64ui | If::R64i => {
2469                self.use_extension("SPV_EXT_shader_image_int64");
2470                self.require_any(
2471                    "64-bit integer storage image format",
2472                    &[spirv::Capability::Int64ImageEXT],
2473                )
2474            }
2475            If::Unknown
2476            | If::Rgba32f
2477            | If::Rgba16f
2478            | If::R32f
2479            | If::Rgba8
2480            | If::Rgba8Snorm
2481            | If::Rgba32i
2482            | If::Rgba16i
2483            | If::Rgba8i
2484            | If::R32i
2485            | If::Rgba32ui
2486            | If::Rgba16ui
2487            | If::Rgba8ui
2488            | If::R32ui => Ok(()),
2489        }
2490    }
2491
2492    pub(super) fn get_index_constant(&mut self, index: Word) -> Word {
2493        self.get_constant_scalar(crate::Literal::U32(index))
2494    }
2495
2496    pub(super) fn get_constant_scalar_with(
2497        &mut self,
2498        value: u8,
2499        scalar: crate::Scalar,
2500    ) -> Result<Word, Error> {
2501        Ok(
2502            self.get_constant_scalar(crate::Literal::new(value, scalar).ok_or(
2503                Error::Validation("Unexpected kind and/or width for Literal"),
2504            )?),
2505        )
2506    }
2507
2508    pub(super) fn get_constant_scalar(&mut self, value: crate::Literal) -> Word {
2509        let scalar = CachedConstant::Literal(value.into());
2510        if let Some(&id) = self.cached_constants.get(&scalar) {
2511            return id;
2512        }
2513        let id = self.id_gen.next();
2514        self.write_constant_scalar(id, &value, None);
2515        self.cached_constants.insert(scalar, id);
2516        id
2517    }
2518
2519    fn write_constant_scalar(
2520        &mut self,
2521        id: Word,
2522        value: &crate::Literal,
2523        debug_name: Option<&String>,
2524    ) {
2525        if self.flags.contains(WriterFlags::DEBUG) {
2526            if let Some(name) = debug_name {
2527                self.debugs.push(Instruction::name(id, name));
2528            }
2529        }
2530        let type_id = self.get_numeric_type_id(NumericType::Scalar(value.scalar()));
2531        let instruction = match *value {
2532            crate::Literal::F64(value) => {
2533                let bits = value.to_bits();
2534                Instruction::constant_64bit(type_id, id, bits as u32, (bits >> 32) as u32)
2535            }
2536            crate::Literal::F32(value) => Instruction::constant_32bit(type_id, id, value.to_bits()),
2537            crate::Literal::F16(value) => {
2538                let low = value.to_bits();
2539                Instruction::constant_16bit(type_id, id, low as u32)
2540            }
2541            crate::Literal::U32(value) => Instruction::constant_32bit(type_id, id, value),
2542            crate::Literal::I32(value) => Instruction::constant_32bit(type_id, id, value as u32),
2543            crate::Literal::U64(value) => {
2544                Instruction::constant_64bit(type_id, id, value as u32, (value >> 32) as u32)
2545            }
2546            crate::Literal::I64(value) => {
2547                Instruction::constant_64bit(type_id, id, value as u32, (value >> 32) as u32)
2548            }
2549            crate::Literal::Bool(true) => Instruction::constant_true(type_id, id),
2550            crate::Literal::Bool(false) => Instruction::constant_false(type_id, id),
2551            crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
2552                unreachable!("Abstract types should not appear in IR presented to backends");
2553            }
2554        };
2555
2556        instruction.to_words(&mut self.logical_layout.declarations);
2557    }
2558
2559    pub(super) fn get_constant_composite(
2560        &mut self,
2561        ty: LookupType,
2562        constituent_ids: &[Word],
2563    ) -> Word {
2564        let composite = CachedConstant::Composite {
2565            ty,
2566            constituent_ids: constituent_ids.to_vec(),
2567        };
2568        if let Some(&id) = self.cached_constants.get(&composite) {
2569            return id;
2570        }
2571        let id = self.id_gen.next();
2572        self.write_constant_composite(id, ty, constituent_ids, None);
2573        self.cached_constants.insert(composite, id);
2574        id
2575    }
2576
2577    fn write_constant_composite(
2578        &mut self,
2579        id: Word,
2580        ty: LookupType,
2581        constituent_ids: &[Word],
2582        debug_name: Option<&String>,
2583    ) {
2584        if self.flags.contains(WriterFlags::DEBUG) {
2585            if let Some(name) = debug_name {
2586                self.debugs.push(Instruction::name(id, name));
2587            }
2588        }
2589        let type_id = self.get_type_id(ty);
2590        Instruction::constant_composite(type_id, id, constituent_ids)
2591            .to_words(&mut self.logical_layout.declarations);
2592    }
2593
2594    pub(super) fn get_constant_null(&mut self, type_id: Word) -> Word {
2595        let null = CachedConstant::ZeroValue(type_id);
2596        if let Some(&id) = self.cached_constants.get(&null) {
2597            return id;
2598        }
2599        let id = self.write_constant_null(type_id);
2600        self.cached_constants.insert(null, id);
2601        id
2602    }
2603
2604    pub(super) fn write_constant_null(&mut self, type_id: Word) -> Word {
2605        let null_id = self.id_gen.next();
2606        Instruction::constant_null(type_id, null_id)
2607            .to_words(&mut self.logical_layout.declarations);
2608        null_id
2609    }
2610
2611    fn write_constant_expr(
2612        &mut self,
2613        handle: Handle<crate::Expression>,
2614        ir_module: &crate::Module,
2615        mod_info: &ModuleInfo,
2616    ) -> Result<Word, Error> {
2617        let id = match ir_module.global_expressions[handle] {
2618            crate::Expression::Literal(literal) => self.get_constant_scalar(literal),
2619            crate::Expression::Constant(constant) => {
2620                let constant = &ir_module.constants[constant];
2621                self.constant_ids[constant.init]
2622            }
2623            crate::Expression::ZeroValue(ty) => {
2624                let type_id = self.get_handle_type_id(ty);
2625                self.get_constant_null(type_id)
2626            }
2627            crate::Expression::Compose { ty, ref components } => {
2628                let component_ids: Vec<_> = crate::proc::flatten_compose(
2629                    ty,
2630                    components,
2631                    &ir_module.global_expressions,
2632                    &ir_module.types,
2633                )
2634                .map(|component| self.constant_ids[component])
2635                .collect();
2636                self.get_constant_composite(LookupType::Handle(ty), component_ids.as_slice())
2637            }
2638            crate::Expression::Splat { size, value } => {
2639                let value_id = self.constant_ids[value];
2640                let component_ids = &[value_id; 4][..size as usize];
2641
2642                let ty = self.get_expression_lookup_type(&mod_info[handle]);
2643
2644                self.get_constant_composite(ty, component_ids)
2645            }
2646            _ => {
2647                return Err(Error::Override);
2648            }
2649        };
2650
2651        self.constant_ids[handle] = id;
2652
2653        Ok(id)
2654    }
2655
2656    pub(super) fn write_control_barrier(
2657        &mut self,
2658        flags: crate::Barrier,
2659        body: &mut Vec<Instruction>,
2660    ) {
2661        let memory_scope = if flags.contains(crate::Barrier::STORAGE) {
2662            spirv::Scope::Device
2663        } else if flags.contains(crate::Barrier::SUB_GROUP) {
2664            spirv::Scope::Subgroup
2665        } else {
2666            spirv::Scope::Workgroup
2667        };
2668        let mut semantics = spirv::MemorySemantics::ACQUIRE_RELEASE;
2669        semantics.set(
2670            spirv::MemorySemantics::UNIFORM_MEMORY,
2671            flags.contains(crate::Barrier::STORAGE),
2672        );
2673        semantics.set(
2674            spirv::MemorySemantics::WORKGROUP_MEMORY,
2675            flags.contains(crate::Barrier::WORK_GROUP),
2676        );
2677        semantics.set(
2678            spirv::MemorySemantics::SUBGROUP_MEMORY,
2679            flags.contains(crate::Barrier::SUB_GROUP),
2680        );
2681        semantics.set(
2682            spirv::MemorySemantics::IMAGE_MEMORY,
2683            flags.contains(crate::Barrier::TEXTURE),
2684        );
2685        let exec_scope_id = if flags.contains(crate::Barrier::SUB_GROUP) {
2686            self.get_index_constant(spirv::Scope::Subgroup as u32)
2687        } else {
2688            self.get_index_constant(spirv::Scope::Workgroup as u32)
2689        };
2690        let mem_scope_id = self.get_index_constant(memory_scope as u32);
2691        let semantics_id = self.get_index_constant(semantics.bits());
2692        body.push(Instruction::control_barrier(
2693            exec_scope_id,
2694            mem_scope_id,
2695            semantics_id,
2696        ));
2697    }
2698
2699    pub(super) fn write_memory_barrier(&mut self, flags: crate::Barrier, block: &mut Block) {
2700        let mut semantics = spirv::MemorySemantics::ACQUIRE_RELEASE;
2701        semantics.set(
2702            spirv::MemorySemantics::UNIFORM_MEMORY,
2703            flags.contains(crate::Barrier::STORAGE),
2704        );
2705        semantics.set(
2706            spirv::MemorySemantics::WORKGROUP_MEMORY,
2707            flags.contains(crate::Barrier::WORK_GROUP),
2708        );
2709        semantics.set(
2710            spirv::MemorySemantics::SUBGROUP_MEMORY,
2711            flags.contains(crate::Barrier::SUB_GROUP),
2712        );
2713        semantics.set(
2714            spirv::MemorySemantics::IMAGE_MEMORY,
2715            flags.contains(crate::Barrier::TEXTURE),
2716        );
2717        let mem_scope_id = if flags.contains(crate::Barrier::STORAGE) {
2718            self.get_index_constant(spirv::Scope::Device as u32)
2719        } else if flags.contains(crate::Barrier::SUB_GROUP) {
2720            self.get_index_constant(spirv::Scope::Subgroup as u32)
2721        } else {
2722            self.get_index_constant(spirv::Scope::Workgroup as u32)
2723        };
2724        let semantics_id = self.get_index_constant(semantics.bits());
2725        block
2726            .body
2727            .push(Instruction::memory_barrier(mem_scope_id, semantics_id));
2728    }
2729
2730    fn generate_workgroup_vars_init_block(
2731        &mut self,
2732        entry_id: Word,
2733        ir_module: &crate::Module,
2734        info: &FunctionInfo,
2735        local_invocation_index: Option<Word>,
2736        interface: &mut FunctionInterface,
2737        function: &mut Function,
2738    ) -> Option<Word> {
2739        let body = ir_module
2740            .global_variables
2741            .iter()
2742            .filter(|&(handle, var)| {
2743                let task_exception = (var.space == crate::AddressSpace::TaskPayload)
2744                    && interface.stage == crate::ShaderStage::Task;
2745                !info[handle].is_empty()
2746                    && (var.space == crate::AddressSpace::WorkGroup || task_exception)
2747            })
2748            .map(|(handle, var)| {
2749                // It's safe to use `var_id` here, not `access_id`, because only
2750                // variables in the `Uniform` and `StorageBuffer` address spaces
2751                // get wrapped, and we're initializing `WorkGroup` variables.
2752                let var_id = self.global_variables[handle].var_id;
2753                let var_type_id = self.get_handle_type_id(var.ty);
2754                let init_word = self.get_constant_null(var_type_id);
2755                Instruction::store(var_id, init_word, None)
2756            })
2757            .collect::<Vec<_>>();
2758
2759        if body.is_empty() {
2760            return None;
2761        }
2762
2763        let mut pre_if_block = Block::new(entry_id);
2764
2765        let local_invocation_index = if let Some(local_invocation_index) = local_invocation_index {
2766            local_invocation_index
2767        } else {
2768            let varying_id = self.id_gen.next();
2769            let class = spirv::StorageClass::Input;
2770            let u32_ty_id = self.get_u32_type_id();
2771            let pointer_type_id = self.get_pointer_type_id(u32_ty_id, class);
2772
2773            Instruction::variable(pointer_type_id, varying_id, class, None)
2774                .to_words(&mut self.logical_layout.declarations);
2775
2776            self.decorate(
2777                varying_id,
2778                spirv::Decoration::BuiltIn,
2779                &[spirv::BuiltIn::LocalInvocationIndex as u32],
2780            );
2781
2782            interface.varying_ids.push(varying_id);
2783            let id = self.id_gen.next();
2784            pre_if_block
2785                .body
2786                .push(Instruction::load(u32_ty_id, id, varying_id, None));
2787
2788            id
2789        };
2790
2791        let zero_id = self.get_constant_scalar(crate::Literal::U32(0));
2792
2793        let eq_id = self.id_gen.next();
2794        pre_if_block.body.push(Instruction::binary(
2795            spirv::Op::IEqual,
2796            self.get_bool_type_id(),
2797            eq_id,
2798            local_invocation_index,
2799            zero_id,
2800        ));
2801
2802        let merge_id = self.id_gen.next();
2803        pre_if_block.body.push(Instruction::selection_merge(
2804            merge_id,
2805            spirv::SelectionControl::NONE,
2806        ));
2807
2808        let accept_id = self.id_gen.next();
2809        function.consume(
2810            pre_if_block,
2811            Instruction::branch_conditional(eq_id, accept_id, merge_id),
2812        );
2813
2814        let accept_block = Block {
2815            label_id: accept_id,
2816            body,
2817        };
2818        function.consume(accept_block, Instruction::branch(merge_id));
2819
2820        let mut post_if_block = Block::new(merge_id);
2821
2822        self.write_control_barrier(crate::Barrier::WORK_GROUP, &mut post_if_block.body);
2823
2824        let next_id = self.id_gen.next();
2825        function.consume(post_if_block, Instruction::branch(next_id));
2826        Some(next_id)
2827    }
2828
2829    /// Generate an `OpVariable` for one value in an [`EntryPoint`]'s IO interface.
2830    ///
2831    /// The [`Binding`]s of the arguments and result of an [`EntryPoint`]'s
2832    /// [`Function`] describe a SPIR-V shader interface. In SPIR-V, the
2833    /// interface is represented by global variables in the `Input` and `Output`
2834    /// storage classes, with decorations indicating which builtin or location
2835    /// each variable corresponds to.
2836    ///
2837    /// This function emits a single global `OpVariable` for a single value from
2838    /// the interface, and adds appropriate decorations to indicate which
2839    /// builtin or location it represents, how it should be interpolated, and so
2840    /// on. The `class` argument gives the variable's SPIR-V storage class,
2841    /// which should be either [`Input`] or [`Output`].
2842    ///
2843    /// [`Binding`]: crate::Binding
2844    /// [`Function`]: crate::Function
2845    /// [`EntryPoint`]: crate::EntryPoint
2846    /// [`Input`]: spirv::StorageClass::Input
2847    /// [`Output`]: spirv::StorageClass::Output
2848    fn write_varying(
2849        &mut self,
2850        ir_module: &crate::Module,
2851        stage: crate::ShaderStage,
2852        class: spirv::StorageClass,
2853        debug_name: Option<&str>,
2854        ty: Handle<crate::Type>,
2855        binding: &crate::Binding,
2856    ) -> Result<Word, Error> {
2857        let id = self.id_gen.next();
2858        let ty_inner = &ir_module.types[ty].inner;
2859        let needs_polyfill = self.needs_f16_polyfill(ty_inner);
2860
2861        let pointer_type_id = if needs_polyfill {
2862            let f32_value_local =
2863                super::f16_polyfill::F16IoPolyfill::create_polyfill_type(ty_inner)
2864                    .expect("needs_polyfill returned true but create_polyfill_type returned None");
2865
2866            let f32_type_id = self.get_localtype_id(f32_value_local);
2867            let ptr_id = self.get_pointer_type_id(f32_type_id, class);
2868            self.io_f16_polyfills.register_io_var(id, f32_type_id);
2869
2870            ptr_id
2871        } else {
2872            self.get_handle_pointer_type_id(ty, class)
2873        };
2874
2875        Instruction::variable(pointer_type_id, id, class, None)
2876            .to_words(&mut self.logical_layout.declarations);
2877
2878        if self
2879            .flags
2880            .contains(WriterFlags::DEBUG | WriterFlags::LABEL_VARYINGS)
2881        {
2882            if let Some(name) = debug_name {
2883                self.debugs.push(Instruction::name(id, name));
2884            }
2885        }
2886
2887        let binding = self.map_binding(ir_module, stage, class, ty, binding)?;
2888        self.write_binding(id, binding);
2889
2890        Ok(id)
2891    }
2892
2893    pub fn write_binding(&mut self, id: Word, binding: BindingDecorations) {
2894        match binding {
2895            BindingDecorations::None => (),
2896            BindingDecorations::BuiltIn(bi, others) => {
2897                self.decorate(id, spirv::Decoration::BuiltIn, &[bi as u32]);
2898                for other in others {
2899                    self.decorate(id, other, &[]);
2900                }
2901            }
2902            BindingDecorations::Location {
2903                location,
2904                others,
2905                blend_src,
2906            } => {
2907                self.decorate(id, spirv::Decoration::Location, &[location]);
2908                for other in others {
2909                    self.decorate(id, other, &[]);
2910                }
2911                if let Some(blend_src) = blend_src {
2912                    self.decorate(id, spirv::Decoration::Index, &[blend_src]);
2913                }
2914            }
2915        }
2916    }
2917
2918    pub fn write_binding_struct_member(
2919        &mut self,
2920        struct_id: Word,
2921        member_idx: Word,
2922        binding_info: BindingDecorations,
2923    ) {
2924        match binding_info {
2925            BindingDecorations::None => (),
2926            BindingDecorations::BuiltIn(bi, others) => {
2927                self.annotations.push(Instruction::member_decorate(
2928                    struct_id,
2929                    member_idx,
2930                    spirv::Decoration::BuiltIn,
2931                    &[bi as Word],
2932                ));
2933                for other in others {
2934                    self.annotations.push(Instruction::member_decorate(
2935                        struct_id,
2936                        member_idx,
2937                        other,
2938                        &[],
2939                    ));
2940                }
2941            }
2942            BindingDecorations::Location {
2943                location,
2944                others,
2945                blend_src,
2946            } => {
2947                self.annotations.push(Instruction::member_decorate(
2948                    struct_id,
2949                    member_idx,
2950                    spirv::Decoration::Location,
2951                    &[location],
2952                ));
2953                for other in others {
2954                    self.annotations.push(Instruction::member_decorate(
2955                        struct_id,
2956                        member_idx,
2957                        other,
2958                        &[],
2959                    ));
2960                }
2961                if let Some(blend_src) = blend_src {
2962                    self.annotations.push(Instruction::member_decorate(
2963                        struct_id,
2964                        member_idx,
2965                        spirv::Decoration::Index,
2966                        &[blend_src],
2967                    ));
2968                }
2969            }
2970        }
2971    }
2972
2973    pub fn map_binding(
2974        &mut self,
2975        ir_module: &crate::Module,
2976        stage: crate::ShaderStage,
2977        class: spirv::StorageClass,
2978        ty: Handle<crate::Type>,
2979        binding: &crate::Binding,
2980    ) -> Result<BindingDecorations, Error> {
2981        use spirv::BuiltIn;
2982        use spirv::Decoration;
2983        match *binding {
2984            crate::Binding::Location {
2985                location,
2986                interpolation,
2987                sampling,
2988                blend_src,
2989                per_primitive,
2990            } => {
2991                let mut others = ArrayVec::new();
2992
2993                let no_decorations =
2994                    // VUID-StandaloneSpirv-Flat-06202
2995                    // > The Flat, NoPerspective, Sample, and Centroid decorations
2996                    // > must not be used on variables with the Input storage class in a vertex shader
2997                    (class == spirv::StorageClass::Input && stage == crate::ShaderStage::Vertex) ||
2998                    // VUID-StandaloneSpirv-Flat-06201
2999                    // > The Flat, NoPerspective, Sample, and Centroid decorations
3000                    // > must not be used on variables with the Output storage class in a fragment shader
3001                    (class == spirv::StorageClass::Output && stage == crate::ShaderStage::Fragment);
3002
3003                if !no_decorations {
3004                    match interpolation {
3005                        // Perspective-correct interpolation is the default in SPIR-V.
3006                        None | Some(crate::Interpolation::Perspective) => (),
3007                        Some(crate::Interpolation::Flat) => {
3008                            others.push(Decoration::Flat);
3009                        }
3010                        Some(crate::Interpolation::Linear) => {
3011                            others.push(Decoration::NoPerspective);
3012                        }
3013                        Some(crate::Interpolation::PerVertex) => {
3014                            others.push(Decoration::PerVertexKHR);
3015                            self.require_any(
3016                                "`per_vertex` interpolation",
3017                                &[spirv::Capability::FragmentBarycentricKHR],
3018                            )?;
3019                            self.use_extension("SPV_KHR_fragment_shader_barycentric");
3020                        }
3021                    }
3022                    match sampling {
3023                        // Center sampling is the default in SPIR-V.
3024                        None
3025                        | Some(
3026                            crate::Sampling::Center
3027                            | crate::Sampling::First
3028                            | crate::Sampling::Either,
3029                        ) => (),
3030                        Some(crate::Sampling::Centroid) => {
3031                            others.push(Decoration::Centroid);
3032                        }
3033                        Some(crate::Sampling::Sample) => {
3034                            self.require_any(
3035                                "per-sample interpolation",
3036                                &[spirv::Capability::SampleRateShading],
3037                            )?;
3038                            others.push(Decoration::Sample);
3039                        }
3040                    }
3041                }
3042                if per_primitive && stage == crate::ShaderStage::Fragment {
3043                    others.push(Decoration::PerPrimitiveEXT);
3044                }
3045                Ok(BindingDecorations::Location {
3046                    location,
3047                    others,
3048                    blend_src,
3049                })
3050            }
3051            crate::Binding::BuiltIn(built_in) => {
3052                use crate::BuiltIn as Bi;
3053                let mut others = ArrayVec::new();
3054
3055                let built_in = match built_in {
3056                    Bi::Position { invariant } => {
3057                        if invariant {
3058                            others.push(Decoration::Invariant);
3059                        }
3060
3061                        if class == spirv::StorageClass::Output {
3062                            BuiltIn::Position
3063                        } else {
3064                            BuiltIn::FragCoord
3065                        }
3066                    }
3067                    Bi::ViewIndex => {
3068                        self.require_any("`view_index` built-in", &[spirv::Capability::MultiView])?;
3069                        BuiltIn::ViewIndex
3070                    }
3071                    // vertex
3072                    Bi::BaseInstance => BuiltIn::BaseInstance,
3073                    Bi::BaseVertex => BuiltIn::BaseVertex,
3074                    Bi::ClipDistance => {
3075                        self.require_any(
3076                            "`clip_distance` built-in",
3077                            &[spirv::Capability::ClipDistance],
3078                        )?;
3079                        BuiltIn::ClipDistance
3080                    }
3081                    Bi::CullDistance => {
3082                        self.require_any(
3083                            "`cull_distance` built-in",
3084                            &[spirv::Capability::CullDistance],
3085                        )?;
3086                        BuiltIn::CullDistance
3087                    }
3088                    Bi::InstanceIndex => BuiltIn::InstanceIndex,
3089                    Bi::PointSize => BuiltIn::PointSize,
3090                    Bi::VertexIndex => BuiltIn::VertexIndex,
3091                    Bi::DrawIndex => {
3092                        self.use_extension("SPV_KHR_shader_draw_parameters");
3093                        self.require_any(
3094                            "`draw_index built-in",
3095                            &[spirv::Capability::DrawParameters],
3096                        )?;
3097                        BuiltIn::DrawIndex
3098                    }
3099                    // fragment
3100                    Bi::FragDepth => BuiltIn::FragDepth,
3101                    Bi::PointCoord => BuiltIn::PointCoord,
3102                    Bi::FrontFacing => BuiltIn::FrontFacing,
3103                    Bi::PrimitiveIndex => {
3104                        // Geometry shader capability is required for primitive index
3105                        self.require_any(
3106                            "`primitive_index` built-in",
3107                            &[spirv::Capability::Geometry],
3108                        )?;
3109                        if stage == crate::ShaderStage::Mesh {
3110                            others.push(Decoration::PerPrimitiveEXT);
3111                        }
3112                        BuiltIn::PrimitiveId
3113                    }
3114                    Bi::Barycentric { perspective } => {
3115                        self.require_any(
3116                            "`barycentric` built-in",
3117                            &[spirv::Capability::FragmentBarycentricKHR],
3118                        )?;
3119                        self.use_extension("SPV_KHR_fragment_shader_barycentric");
3120                        if perspective {
3121                            BuiltIn::BaryCoordKHR
3122                        } else {
3123                            BuiltIn::BaryCoordNoPerspKHR
3124                        }
3125                    }
3126                    Bi::SampleIndex => {
3127                        self.require_any(
3128                            "`sample_index` built-in",
3129                            &[spirv::Capability::SampleRateShading],
3130                        )?;
3131
3132                        BuiltIn::SampleId
3133                    }
3134                    Bi::SampleMask => BuiltIn::SampleMask,
3135                    // compute
3136                    Bi::GlobalInvocationId => BuiltIn::GlobalInvocationId,
3137                    Bi::LocalInvocationId => BuiltIn::LocalInvocationId,
3138                    Bi::LocalInvocationIndex => BuiltIn::LocalInvocationIndex,
3139                    Bi::WorkGroupId => BuiltIn::WorkgroupId,
3140                    Bi::WorkGroupSize => BuiltIn::WorkgroupSize,
3141                    Bi::NumWorkGroups => BuiltIn::NumWorkgroups,
3142                    // Subgroup
3143                    Bi::NumSubgroups => {
3144                        self.require_any(
3145                            "`num_subgroups` built-in",
3146                            &[spirv::Capability::GroupNonUniform],
3147                        )?;
3148                        BuiltIn::NumSubgroups
3149                    }
3150                    Bi::SubgroupId => {
3151                        self.require_any(
3152                            "`subgroup_id` built-in",
3153                            &[spirv::Capability::GroupNonUniform],
3154                        )?;
3155                        BuiltIn::SubgroupId
3156                    }
3157                    Bi::SubgroupSize => {
3158                        self.require_any(
3159                            "`subgroup_size` built-in",
3160                            &[
3161                                spirv::Capability::GroupNonUniform,
3162                                spirv::Capability::SubgroupBallotKHR,
3163                            ],
3164                        )?;
3165                        BuiltIn::SubgroupSize
3166                    }
3167                    Bi::SubgroupInvocationId => {
3168                        self.require_any(
3169                            "`subgroup_invocation_id` built-in",
3170                            &[
3171                                spirv::Capability::GroupNonUniform,
3172                                spirv::Capability::SubgroupBallotKHR,
3173                            ],
3174                        )?;
3175                        BuiltIn::SubgroupLocalInvocationId
3176                    }
3177                    Bi::CullPrimitive => {
3178                        others.push(Decoration::PerPrimitiveEXT);
3179                        BuiltIn::CullPrimitiveEXT
3180                    }
3181                    Bi::PointIndex => BuiltIn::PrimitivePointIndicesEXT,
3182                    Bi::LineIndices => BuiltIn::PrimitiveLineIndicesEXT,
3183                    Bi::TriangleIndices => BuiltIn::PrimitiveTriangleIndicesEXT,
3184                    // No decoration, this EmitMeshTasksEXT is called at function return
3185                    Bi::MeshTaskSize => return Ok(BindingDecorations::None),
3186                    // These aren't normal builtins and don't occur in function output
3187                    Bi::VertexCount | Bi::Vertices | Bi::PrimitiveCount | Bi::Primitives => {
3188                        unreachable!()
3189                    }
3190                    Bi::RayInvocationId
3191                    | Bi::NumRayInvocations
3192                    | Bi::InstanceCustomData
3193                    | Bi::GeometryIndex
3194                    | Bi::WorldRayOrigin
3195                    | Bi::WorldRayDirection
3196                    | Bi::ObjectRayOrigin
3197                    | Bi::ObjectRayDirection
3198                    | Bi::RayTmin
3199                    | Bi::RayTCurrentMax
3200                    | Bi::ObjectToWorld
3201                    | Bi::WorldToObject
3202                    | Bi::HitKind => unreachable!(),
3203                };
3204
3205                use crate::ScalarKind as Sk;
3206
3207                // Per the Vulkan spec, `VUID-StandaloneSpirv-Flat-04744`:
3208                //
3209                // > Any variable with integer or double-precision floating-
3210                // > point type and with Input storage class in a fragment
3211                // > shader, must be decorated Flat
3212                if class == spirv::StorageClass::Input && stage == crate::ShaderStage::Fragment {
3213                    let is_flat = match ir_module.types[ty].inner {
3214                        crate::TypeInner::Scalar(scalar)
3215                        | crate::TypeInner::Vector { scalar, .. } => match scalar.kind {
3216                            Sk::Uint | Sk::Sint | Sk::Bool => true,
3217                            Sk::Float => false,
3218                            Sk::AbstractInt | Sk::AbstractFloat => {
3219                                return Err(Error::Validation(
3220                                    "Abstract types should not appear in IR presented to backends",
3221                                ))
3222                            }
3223                        },
3224                        _ => false,
3225                    };
3226
3227                    if is_flat {
3228                        others.push(Decoration::Flat);
3229                    }
3230                }
3231                Ok(BindingDecorations::BuiltIn(built_in, others))
3232            }
3233        }
3234    }
3235
3236    /// Load an IO variable, converting from `f32` to `f16` if polyfill is active.
3237    /// Returns the id of the loaded value matching `target_type_id`.
3238    pub(super) fn load_io_with_f16_polyfill(
3239        &mut self,
3240        body: &mut Vec<Instruction>,
3241        varying_id: Word,
3242        target_type_id: Word,
3243    ) -> Word {
3244        let tmp = self.id_gen.next();
3245        if let Some(f32_ty) = self.io_f16_polyfills.get_f32_io_type(varying_id) {
3246            body.push(Instruction::load(f32_ty, tmp, varying_id, None));
3247            let converted = self.id_gen.next();
3248            super::f16_polyfill::F16IoPolyfill::emit_f32_to_f16_conversion(
3249                tmp,
3250                target_type_id,
3251                converted,
3252                body,
3253            );
3254            converted
3255        } else {
3256            body.push(Instruction::load(target_type_id, tmp, varying_id, None));
3257            tmp
3258        }
3259    }
3260
3261    /// Store an IO variable, converting from `f16` to `f32` if polyfill is active.
3262    pub(super) fn store_io_with_f16_polyfill(
3263        &mut self,
3264        body: &mut Vec<Instruction>,
3265        varying_id: Word,
3266        value_id: Word,
3267    ) {
3268        if let Some(f32_ty) = self.io_f16_polyfills.get_f32_io_type(varying_id) {
3269            let converted = self.id_gen.next();
3270            super::f16_polyfill::F16IoPolyfill::emit_f16_to_f32_conversion(
3271                value_id, f32_ty, converted, body,
3272            );
3273            body.push(Instruction::store(varying_id, converted, None));
3274        } else {
3275            body.push(Instruction::store(varying_id, value_id, None));
3276        }
3277    }
3278
3279    fn write_global_variable(
3280        &mut self,
3281        ir_module: &crate::Module,
3282        global_variable: &crate::GlobalVariable,
3283    ) -> Result<Word, Error> {
3284        use spirv::Decoration;
3285
3286        let id = self.id_gen.next();
3287        let class = map_storage_class(global_variable.space);
3288
3289        //self.check(class.required_capabilities())?;
3290
3291        if self.flags.contains(WriterFlags::DEBUG) {
3292            if let Some(ref name) = global_variable.name {
3293                self.debugs.push(Instruction::name(id, name));
3294            }
3295        }
3296
3297        let storage_access = match global_variable.space {
3298            crate::AddressSpace::Storage { access } => Some(access),
3299            _ => match ir_module.types[global_variable.ty].inner {
3300                crate::TypeInner::Image {
3301                    class: crate::ImageClass::Storage { access, .. },
3302                    ..
3303                } => Some(access),
3304                _ => None,
3305            },
3306        };
3307        if let Some(storage_access) = storage_access {
3308            if !storage_access.contains(crate::StorageAccess::LOAD) {
3309                self.decorate(id, Decoration::NonReadable, &[]);
3310            }
3311            if !storage_access.contains(crate::StorageAccess::STORE) {
3312                self.decorate(id, Decoration::NonWritable, &[]);
3313            }
3314        }
3315
3316        // Note: we should be able to substitute `binding_array<Foo, 0>`,
3317        // but there is still code that tries to register the pre-substituted type,
3318        // and it is failing on 0.
3319        let mut substitute_inner_type_lookup = None;
3320        if let Some(ref res_binding) = global_variable.binding {
3321            let bind_target = self.resolve_resource_binding(res_binding)?;
3322            self.decorate(id, Decoration::DescriptorSet, &[bind_target.descriptor_set]);
3323            self.decorate(id, Decoration::Binding, &[bind_target.binding]);
3324
3325            if let Some(remapped_binding_array_size) = bind_target.binding_array_size {
3326                if let crate::TypeInner::BindingArray { base, .. } =
3327                    ir_module.types[global_variable.ty].inner
3328                {
3329                    let binding_array_type_id =
3330                        self.get_type_id(LookupType::Local(LocalType::BindingArray {
3331                            base,
3332                            size: remapped_binding_array_size,
3333                        }));
3334                    substitute_inner_type_lookup = Some(LookupType::Local(LocalType::Pointer {
3335                        base: binding_array_type_id,
3336                        class,
3337                    }));
3338                }
3339            }
3340        };
3341
3342        let init_word = global_variable
3343            .init
3344            .map(|constant| self.constant_ids[constant]);
3345        let inner_type_id = self.get_type_id(
3346            substitute_inner_type_lookup.unwrap_or(LookupType::Handle(global_variable.ty)),
3347        );
3348
3349        // generate the wrapping structure if needed
3350        let pointer_type_id = if global_needs_wrapper(ir_module, global_variable) {
3351            let wrapper_type_id = self.id_gen.next();
3352
3353            self.decorate(wrapper_type_id, Decoration::Block, &[]);
3354
3355            match self.std140_compat_uniform_types.get(&global_variable.ty) {
3356                Some(std140_type_info) if global_variable.space == crate::AddressSpace::Uniform => {
3357                    self.annotations.push(Instruction::member_decorate(
3358                        wrapper_type_id,
3359                        0,
3360                        Decoration::Offset,
3361                        &[0],
3362                    ));
3363                    Instruction::type_struct(wrapper_type_id, &[std140_type_info.type_id])
3364                        .to_words(&mut self.logical_layout.declarations);
3365                }
3366                _ => {
3367                    let member = crate::StructMember {
3368                        name: None,
3369                        ty: global_variable.ty,
3370                        binding: None,
3371                        offset: 0,
3372                    };
3373                    self.decorate_struct_member(wrapper_type_id, 0, &member, &ir_module.types)?;
3374
3375                    Instruction::type_struct(wrapper_type_id, &[inner_type_id])
3376                        .to_words(&mut self.logical_layout.declarations);
3377                }
3378            }
3379
3380            let pointer_type_id = self.id_gen.next();
3381            Instruction::type_pointer(pointer_type_id, class, wrapper_type_id)
3382                .to_words(&mut self.logical_layout.declarations);
3383
3384            pointer_type_id
3385        } else {
3386            // This is a global variable in the Storage address space. The only
3387            // way it could have `global_needs_wrapper() == false` is if it has
3388            // a runtime-sized or binding array.
3389            // Runtime-sized arrays were decorated when iterating through struct content.
3390            // Now binding arrays require Block decorating.
3391            if let crate::AddressSpace::Storage { .. } = global_variable.space {
3392                match ir_module.types[global_variable.ty].inner {
3393                    crate::TypeInner::BindingArray { base, .. } => {
3394                        let ty = &ir_module.types[base];
3395                        let mut should_decorate = true;
3396                        // Check if the type has a runtime array.
3397                        // A normal runtime array gets validated out,
3398                        // so only structs can be with runtime arrays
3399                        if let crate::TypeInner::Struct { ref members, .. } = ty.inner {
3400                            // only the last member in a struct can be dynamically sized
3401                            if let Some(last_member) = members.last() {
3402                                if let &crate::TypeInner::Array {
3403                                    size: crate::ArraySize::Dynamic,
3404                                    ..
3405                                } = &ir_module.types[last_member.ty].inner
3406                                {
3407                                    should_decorate = false;
3408                                }
3409                            }
3410                        }
3411                        if should_decorate {
3412                            let decorated_id = self.get_handle_type_id(base);
3413                            self.decorate(decorated_id, Decoration::Block, &[]);
3414                        }
3415                    }
3416                    _ => (),
3417                };
3418            }
3419            if substitute_inner_type_lookup.is_some() {
3420                inner_type_id
3421            } else {
3422                self.get_handle_pointer_type_id(global_variable.ty, class)
3423            }
3424        };
3425
3426        let init_word = match (global_variable.space, self.zero_initialize_workgroup_memory) {
3427            (crate::AddressSpace::Private, _)
3428            | (crate::AddressSpace::WorkGroup, super::ZeroInitializeWorkgroupMemoryMode::Native) => {
3429                init_word.or_else(|| Some(self.get_constant_null(inner_type_id)))
3430            }
3431            _ => init_word,
3432        };
3433
3434        Instruction::variable(pointer_type_id, id, class, init_word)
3435            .to_words(&mut self.logical_layout.declarations);
3436        Ok(id)
3437    }
3438
3439    /// Write the necessary decorations for a struct member.
3440    ///
3441    /// Emit decorations for the `index`'th member of the struct type
3442    /// designated by `struct_id`, described by `member`.
3443    fn decorate_struct_member(
3444        &mut self,
3445        struct_id: Word,
3446        index: usize,
3447        member: &crate::StructMember,
3448        arena: &UniqueArena<crate::Type>,
3449    ) -> Result<(), Error> {
3450        use spirv::Decoration;
3451
3452        self.annotations.push(Instruction::member_decorate(
3453            struct_id,
3454            index as u32,
3455            Decoration::Offset,
3456            &[member.offset],
3457        ));
3458
3459        if self.flags.contains(WriterFlags::DEBUG) {
3460            if let Some(ref name) = member.name {
3461                self.debugs
3462                    .push(Instruction::member_name(struct_id, index as u32, name));
3463            }
3464        }
3465
3466        // Matrices and (potentially nested) arrays of matrices both require decorations,
3467        // so "see through" any arrays to determine if they're needed.
3468        let mut member_array_subty_inner = &arena[member.ty].inner;
3469        while let crate::TypeInner::Array { base, .. } = *member_array_subty_inner {
3470            member_array_subty_inner = &arena[base].inner;
3471        }
3472
3473        if let crate::TypeInner::Matrix {
3474            columns: _,
3475            rows,
3476            scalar,
3477        } = *member_array_subty_inner
3478        {
3479            let byte_stride = Alignment::from(rows) * scalar.width as u32;
3480            self.annotations.push(Instruction::member_decorate(
3481                struct_id,
3482                index as u32,
3483                Decoration::ColMajor,
3484                &[],
3485            ));
3486            self.annotations.push(Instruction::member_decorate(
3487                struct_id,
3488                index as u32,
3489                Decoration::MatrixStride,
3490                &[byte_stride],
3491            ));
3492        }
3493
3494        Ok(())
3495    }
3496
3497    pub(super) fn get_function_type(&mut self, lookup_function_type: LookupFunctionType) -> Word {
3498        match self
3499            .lookup_function_type
3500            .entry(lookup_function_type.clone())
3501        {
3502            Entry::Occupied(e) => *e.get(),
3503            Entry::Vacant(_) => {
3504                let id = self.id_gen.next();
3505                let instruction = Instruction::type_function(
3506                    id,
3507                    lookup_function_type.return_type_id,
3508                    &lookup_function_type.parameter_type_ids,
3509                );
3510                instruction.to_words(&mut self.logical_layout.declarations);
3511                self.lookup_function_type.insert(lookup_function_type, id);
3512                id
3513            }
3514        }
3515    }
3516
3517    const fn write_physical_layout(&mut self) {
3518        self.physical_layout.bound = self.id_gen.0 + 1;
3519    }
3520
3521    fn write_logical_layout(
3522        &mut self,
3523        ir_module: &crate::Module,
3524        mod_info: &ModuleInfo,
3525        ep_index: Option<usize>,
3526        debug_info: &Option<DebugInfo>,
3527    ) -> Result<(), Error> {
3528        fn has_view_index_check(
3529            ir_module: &crate::Module,
3530            binding: Option<&crate::Binding>,
3531            ty: Handle<crate::Type>,
3532        ) -> bool {
3533            match ir_module.types[ty].inner {
3534                crate::TypeInner::Struct { ref members, .. } => members.iter().any(|member| {
3535                    has_view_index_check(ir_module, member.binding.as_ref(), member.ty)
3536                }),
3537                _ => binding == Some(&crate::Binding::BuiltIn(crate::BuiltIn::ViewIndex)),
3538            }
3539        }
3540
3541        let has_storage_buffers =
3542            ir_module
3543                .global_variables
3544                .iter()
3545                .any(|(_, var)| match var.space {
3546                    crate::AddressSpace::Storage { .. } => true,
3547                    _ => false,
3548                });
3549        let has_view_index = ir_module
3550            .entry_points
3551            .iter()
3552            .flat_map(|entry| entry.function.arguments.iter())
3553            .any(|arg| has_view_index_check(ir_module, arg.binding.as_ref(), arg.ty));
3554        let mut has_ray_query = ir_module.special_types.ray_desc.is_some()
3555            | ir_module.special_types.ray_intersection.is_some();
3556        let has_vertex_return = ir_module.special_types.ray_vertex_return.is_some();
3557
3558        for (_, &crate::Type { ref inner, .. }) in ir_module.types.iter() {
3559            // spirv does not know whether these have vertex return - that is done by us
3560            if let &crate::TypeInner::AccelerationStructure { .. }
3561            | &crate::TypeInner::RayQuery { .. } = inner
3562            {
3563                has_ray_query = true
3564            }
3565        }
3566
3567        if self.physical_layout.version < 0x10300 && has_storage_buffers {
3568            // enable the storage buffer class on < SPV-1.3
3569            Instruction::extension("SPV_KHR_storage_buffer_storage_class")
3570                .to_words(&mut self.logical_layout.extensions);
3571        }
3572        if has_view_index {
3573            Instruction::extension("SPV_KHR_multiview")
3574                .to_words(&mut self.logical_layout.extensions)
3575        }
3576        if has_ray_query {
3577            Instruction::extension("SPV_KHR_ray_query")
3578                .to_words(&mut self.logical_layout.extensions)
3579        }
3580        if has_vertex_return {
3581            Instruction::extension("SPV_KHR_ray_tracing_position_fetch")
3582                .to_words(&mut self.logical_layout.extensions);
3583        }
3584        if ir_module.uses_mesh_shaders() {
3585            self.use_extension("SPV_EXT_mesh_shader");
3586            self.require_any("Mesh Shaders", &[spirv::Capability::MeshShadingEXT])?;
3587            let lang_version = self.lang_version();
3588            if lang_version.0 <= 1 && lang_version.1 < 4 {
3589                return Err(Error::SpirvVersionTooLow(1, 4));
3590            }
3591        }
3592        Instruction::type_void(self.void_type).to_words(&mut self.logical_layout.declarations);
3593        Instruction::ext_inst_import(self.gl450_ext_inst_id, "GLSL.std.450")
3594            .to_words(&mut self.logical_layout.ext_inst_imports);
3595
3596        let mut debug_info_inner = None;
3597        if self.flags.contains(WriterFlags::DEBUG) {
3598            if let Some(debug_info) = debug_info.as_ref() {
3599                let source_file_id = self.id_gen.next();
3600                self.debugs
3601                    .push(Instruction::string(debug_info.file_name, source_file_id));
3602
3603                debug_info_inner = Some(DebugInfoInner {
3604                    source_code: debug_info.source_code,
3605                    source_file_id,
3606                });
3607                self.debugs.append(&mut Instruction::source_auto_continued(
3608                    debug_info.language,
3609                    0,
3610                    &debug_info_inner,
3611                ));
3612            }
3613        }
3614
3615        // write all types
3616        for (handle, _) in ir_module.types.iter() {
3617            self.write_type_declaration_arena(ir_module, handle)?;
3618        }
3619
3620        // write std140 layout compatible types required by uniforms
3621        for (_, var) in ir_module.global_variables.iter() {
3622            if var.space == crate::AddressSpace::Uniform {
3623                self.write_std140_compat_type_declaration(ir_module, var.ty)?;
3624            }
3625        }
3626
3627        // write all const-expressions as constants
3628        self.constant_ids
3629            .resize(ir_module.global_expressions.len(), 0);
3630        for (handle, _) in ir_module.global_expressions.iter() {
3631            self.write_constant_expr(handle, ir_module, mod_info)?;
3632        }
3633        debug_assert!(self.constant_ids.iter().all(|&id| id != 0));
3634
3635        // write the name of constants on their respective const-expression initializer
3636        if self.flags.contains(WriterFlags::DEBUG) {
3637            for (_, constant) in ir_module.constants.iter() {
3638                if let Some(ref name) = constant.name {
3639                    let id = self.constant_ids[constant.init];
3640                    self.debugs.push(Instruction::name(id, name));
3641                }
3642            }
3643        }
3644
3645        // write all global variables
3646        for (handle, var) in ir_module.global_variables.iter() {
3647            // If a single entry point was specified, only write `OpVariable` instructions
3648            // for the globals it actually uses. Emit dummies for the others,
3649            // to preserve the indices in `global_variables`.
3650            let gvar = match ep_index {
3651                Some(index) if mod_info.get_entry_point(index)[handle].is_empty() => {
3652                    GlobalVariable::dummy()
3653                }
3654                _ => {
3655                    let id = self.write_global_variable(ir_module, var)?;
3656                    GlobalVariable::new(id)
3657                }
3658            };
3659            self.global_variables.insert(handle, gvar);
3660        }
3661
3662        // write all functions
3663        for (handle, ir_function) in ir_module.functions.iter() {
3664            let info = &mod_info[handle];
3665            if let Some(index) = ep_index {
3666                let ep_info = mod_info.get_entry_point(index);
3667                // If this function uses globals that we omitted from the SPIR-V
3668                // because the entry point and its callees didn't use them,
3669                // then we must skip it.
3670                if !ep_info.dominates_global_use(info) {
3671                    log::debug!("Skip function {:?}", ir_function.name);
3672                    continue;
3673                }
3674
3675                // Skip functions that that are not compatible with this entry point's stage.
3676                //
3677                // When validation is enabled, it rejects modules whose entry points try to call
3678                // incompatible functions, so if we got this far, then any functions incompatible
3679                // with our selected entry point must not be used.
3680                //
3681                // When validation is disabled, `fun_info.available_stages` is always just
3682                // `ShaderStages::all()`, so this will write all functions in the module, and
3683                // the downstream GLSL compiler will catch any problems.
3684                if !info.available_stages.contains(ep_info.available_stages) {
3685                    continue;
3686                }
3687            }
3688            let id = self.write_function(ir_function, info, ir_module, None, &debug_info_inner)?;
3689            self.lookup_function.insert(handle, id);
3690        }
3691
3692        // write all or one entry points
3693        for (index, ir_ep) in ir_module.entry_points.iter().enumerate() {
3694            if ep_index.is_some() && ep_index != Some(index) {
3695                continue;
3696            }
3697            let info = mod_info.get_entry_point(index);
3698            let ep_instruction =
3699                self.write_entry_point(ir_ep, info, ir_module, &debug_info_inner)?;
3700            ep_instruction.to_words(&mut self.logical_layout.entry_points);
3701        }
3702
3703        for capability in self.capabilities_used.iter() {
3704            Instruction::capability(*capability).to_words(&mut self.logical_layout.capabilities);
3705        }
3706        for extension in self.extensions_used.iter() {
3707            Instruction::extension(extension).to_words(&mut self.logical_layout.extensions);
3708        }
3709        if ir_module.entry_points.is_empty() {
3710            // SPIR-V doesn't like modules without entry points
3711            Instruction::capability(spirv::Capability::Linkage)
3712                .to_words(&mut self.logical_layout.capabilities);
3713        }
3714
3715        let addressing_model = spirv::AddressingModel::Logical;
3716        let memory_model = if self
3717            .capabilities_used
3718            .contains(&spirv::Capability::VulkanMemoryModel)
3719        {
3720            spirv::MemoryModel::Vulkan
3721        } else {
3722            spirv::MemoryModel::GLSL450
3723        };
3724        //self.check(addressing_model.required_capabilities())?;
3725        //self.check(memory_model.required_capabilities())?;
3726
3727        Instruction::memory_model(addressing_model, memory_model)
3728            .to_words(&mut self.logical_layout.memory_model);
3729
3730        for debug_string in self.debug_strings.iter() {
3731            debug_string.to_words(&mut self.logical_layout.debugs);
3732        }
3733
3734        if self.flags.contains(WriterFlags::DEBUG) {
3735            for debug in self.debugs.iter() {
3736                debug.to_words(&mut self.logical_layout.debugs);
3737            }
3738        }
3739
3740        for annotation in self.annotations.iter() {
3741            annotation.to_words(&mut self.logical_layout.annotations);
3742        }
3743
3744        Ok(())
3745    }
3746
3747    pub fn write(
3748        &mut self,
3749        ir_module: &crate::Module,
3750        info: &ModuleInfo,
3751        pipeline_options: Option<&PipelineOptions>,
3752        debug_info: &Option<DebugInfo>,
3753        words: &mut Vec<Word>,
3754    ) -> Result<(), Error> {
3755        self.reset();
3756
3757        // Try to find the entry point and corresponding index
3758        let ep_index = match pipeline_options {
3759            Some(po) => {
3760                let index = ir_module
3761                    .entry_points
3762                    .iter()
3763                    .position(|ep| po.shader_stage == ep.stage && po.entry_point == ep.name)
3764                    .ok_or(Error::EntryPointNotFound)?;
3765                Some(index)
3766            }
3767            None => None,
3768        };
3769
3770        self.write_logical_layout(ir_module, info, ep_index, debug_info)?;
3771        self.write_physical_layout();
3772
3773        self.physical_layout.in_words(words);
3774        self.logical_layout.in_words(words);
3775        Ok(())
3776    }
3777
3778    /// Return the set of capabilities the last module written used.
3779    pub const fn get_capabilities_used(&self) -> &crate::FastIndexSet<spirv::Capability> {
3780        &self.capabilities_used
3781    }
3782
3783    pub fn decorate_non_uniform_binding_array_access(&mut self, id: Word) -> Result<(), Error> {
3784        self.require_any("NonUniformEXT", &[spirv::Capability::ShaderNonUniform])?;
3785        self.use_extension("SPV_EXT_descriptor_indexing");
3786        self.decorate(id, spirv::Decoration::NonUniform, &[]);
3787        Ok(())
3788    }
3789
3790    pub(super) fn needs_f16_polyfill(&self, ty_inner: &crate::TypeInner) -> bool {
3791        self.io_f16_polyfills.needs_polyfill(ty_inner)
3792    }
3793
3794    pub(super) fn write_debug_printf(
3795        &mut self,
3796        block: &mut Block,
3797        string: &str,
3798        format_params: &[Word],
3799    ) {
3800        if self.debug_printf.is_none() {
3801            self.use_extension("SPV_KHR_non_semantic_info");
3802            let import_id = self.id_gen.next();
3803            Instruction::ext_inst_import(import_id, "NonSemantic.DebugPrintf")
3804                .to_words(&mut self.logical_layout.ext_inst_imports);
3805            self.debug_printf = Some(import_id)
3806        }
3807
3808        let import_id = self.debug_printf.unwrap();
3809
3810        let string_id = self.id_gen.next();
3811        self.debug_strings
3812            .push(Instruction::string(string, string_id));
3813
3814        let mut operands = Vec::with_capacity(1 + format_params.len());
3815        operands.push(string_id);
3816        operands.extend(format_params.iter());
3817
3818        let print_id = self.id_gen.next();
3819        block.body.push(Instruction::ext_inst(
3820            import_id,
3821            1,
3822            self.void_type,
3823            print_id,
3824            &operands,
3825        ));
3826    }
3827}
3828
3829#[test]
3830fn test_write_physical_layout() {
3831    let mut writer = Writer::new(&Options::default()).unwrap();
3832    assert_eq!(writer.physical_layout.bound, 0);
3833    writer.write_physical_layout();
3834    assert_eq!(writer.physical_layout.bound, 3);
3835}