naga/back/spv/
writer.rs

1use alloc::{format, string::String, vec, vec::Vec};
2
3use arrayvec::ArrayVec;
4use hashbrown::hash_map::Entry;
5use spirv::Word;
6
7use super::{
8    block::DebugInfoInner,
9    helpers::{contains_builtin, global_needs_wrapper, map_storage_class},
10    Block, BlockContext, CachedConstant, CachedExpressions, CooperativeType, DebugInfo,
11    EntryPointContext, Error, Function, FunctionArgument, GlobalVariable, IdGenerator, Instruction,
12    LocalImageType, LocalType, LocalVariable, LogicalLayout, LookupFunctionType, LookupType,
13    NumericType, Options, PhysicalLayout, PipelineOptions, ResultMember, Writer, WriterFlags,
14    BITS_PER_BYTE,
15};
16use crate::{
17    arena::{Handle, HandleVec, UniqueArena},
18    back::spv::{
19        helpers::{is_uniform_matcx2_struct_member_access, BindingDecorations},
20        BindingInfo, Std140CompatTypeInfo, WrappedFunction,
21    },
22    common::ForDebugWithTypes as _,
23    proc::{Alignment, TypeResolution},
24    valid::{FunctionInfo, ModuleInfo},
25};
26
27pub struct FunctionInterface<'a> {
28    pub varying_ids: &'a mut Vec<Word>,
29    pub stage: crate::ShaderStage,
30    pub task_payload: Option<Handle<crate::GlobalVariable>>,
31    pub mesh_info: Option<crate::MeshStageInfo>,
32    pub workgroup_size: [u32; 3],
33}
34
35impl Function {
36    pub(super) fn to_words(&self, sink: &mut impl Extend<Word>) {
37        self.signature.as_ref().unwrap().to_words(sink);
38        for argument in self.parameters.iter() {
39            argument.instruction.to_words(sink);
40        }
41        for (index, block) in self.blocks.iter().enumerate() {
42            Instruction::label(block.label_id).to_words(sink);
43            if index == 0 {
44                for local_var in self.variables.values() {
45                    local_var.instruction.to_words(sink);
46                }
47                for local_var in self.ray_query_initialization_tracker_variables.values() {
48                    local_var.instruction.to_words(sink);
49                }
50                for local_var in self.ray_query_t_max_tracker_variables.values() {
51                    local_var.instruction.to_words(sink);
52                }
53                for local_var in self.force_loop_bounding_vars.iter() {
54                    local_var.instruction.to_words(sink);
55                }
56                for internal_var in self.spilled_composites.values() {
57                    internal_var.instruction.to_words(sink);
58                }
59            }
60            for instruction in block.body.iter() {
61                instruction.to_words(sink);
62            }
63        }
64        Instruction::function_end().to_words(sink);
65    }
66}
67
68impl Writer {
69    pub fn new(options: &Options) -> Result<Self, Error> {
70        let (major, minor) = options.lang_version;
71        if major != 1 {
72            return Err(Error::UnsupportedVersion(major, minor));
73        }
74
75        let mut capabilities_used = crate::FastIndexSet::default();
76        capabilities_used.insert(spirv::Capability::Shader);
77
78        let mut id_gen = IdGenerator::default();
79        let gl450_ext_inst_id = id_gen.next();
80        let void_type = id_gen.next();
81
82        Ok(Writer {
83            physical_layout: PhysicalLayout::new(major, minor),
84            logical_layout: LogicalLayout::default(),
85            id_gen,
86            capabilities_available: options.capabilities.clone(),
87            capabilities_used,
88            extensions_used: crate::FastIndexSet::default(),
89            debug_strings: vec![],
90            debugs: vec![],
91            annotations: vec![],
92            flags: options.flags,
93            bounds_check_policies: options.bounds_check_policies,
94            zero_initialize_workgroup_memory: options.zero_initialize_workgroup_memory,
95            force_loop_bounding: options.force_loop_bounding,
96            ray_query_initialization_tracking: options.ray_query_initialization_tracking,
97            trace_ray_argument_validation: options.trace_ray_argument_validation,
98            use_storage_input_output_16: options.use_storage_input_output_16,
99            void_type,
100            tuple_of_u32s_ty_id: None,
101            lookup_type: crate::FastHashMap::default(),
102            lookup_function: crate::FastHashMap::default(),
103            lookup_function_type: crate::FastHashMap::default(),
104            wrapped_functions: crate::FastHashMap::default(),
105            constant_ids: HandleVec::new(),
106            cached_constants: crate::FastHashMap::default(),
107            global_variables: HandleVec::new(),
108            std140_compat_uniform_types: crate::FastHashMap::default(),
109            fake_missing_bindings: options.fake_missing_bindings,
110            binding_map: options.binding_map.clone(),
111            saved_cached: CachedExpressions::default(),
112            gl450_ext_inst_id,
113            temp_list: Vec::new(),
114            ray_query_functions: crate::FastHashMap::default(),
115            ray_tracing_functions: crate::FastHashMap::default(),
116            has_ray_tracing_pipeline: false,
117            io_f16_polyfills: super::f16_polyfill::F16IoPolyfill::new(
118                options.use_storage_input_output_16,
119            ),
120            debug_printf: None,
121            task_dispatch_limits: options.task_dispatch_limits,
122            mesh_shader_primitive_indices_clamp: options.mesh_shader_primitive_indices_clamp,
123        })
124    }
125
126    pub fn set_options(&mut self, options: &Options) -> Result<(), Error> {
127        let (major, minor) = options.lang_version;
128        if major != 1 {
129            return Err(Error::UnsupportedVersion(major, minor));
130        }
131        self.physical_layout = PhysicalLayout::new(major, minor);
132        self.capabilities_available = options.capabilities.clone();
133        self.flags = options.flags;
134        self.bounds_check_policies = options.bounds_check_policies;
135        self.zero_initialize_workgroup_memory = options.zero_initialize_workgroup_memory;
136        self.force_loop_bounding = options.force_loop_bounding;
137        self.use_storage_input_output_16 = options.use_storage_input_output_16;
138        self.binding_map = options.binding_map.clone();
139        self.io_f16_polyfills =
140            super::f16_polyfill::F16IoPolyfill::new(options.use_storage_input_output_16);
141        self.task_dispatch_limits = options.task_dispatch_limits;
142        self.mesh_shader_primitive_indices_clamp = options.mesh_shader_primitive_indices_clamp;
143        Ok(())
144    }
145
146    /// Returns `(major, minor)` of the SPIR-V language version.
147    pub const fn lang_version(&self) -> (u8, u8) {
148        self.physical_layout.lang_version()
149    }
150
151    /// Reset `Writer` to its initial state, retaining any allocations.
152    ///
153    /// Why not just implement `Reclaimable` for `Writer`? By design,
154    /// `Reclaimable::reclaim` requires ownership of the value, not just
155    /// `&mut`; see the trait documentation. But we need to use this method
156    /// from functions like `Writer::write`, which only have `&mut Writer`.
157    /// Workarounds include unsafe code (`core::ptr::read`, then `write`, ugh)
158    /// or something like a `Default` impl that returns an oddly-initialized
159    /// `Writer`, which is worse.
160    fn reset(&mut self) {
161        use super::reclaimable::Reclaimable;
162        use core::mem::take;
163
164        let mut id_gen = IdGenerator::default();
165        let gl450_ext_inst_id = id_gen.next();
166        let void_type = id_gen.next();
167
168        // Every field of the old writer that is not determined by the `Options`
169        // passed to `Writer::new` should be reset somehow.
170        let fresh = Writer {
171            // Copied from the old Writer:
172            flags: self.flags,
173            bounds_check_policies: self.bounds_check_policies,
174            zero_initialize_workgroup_memory: self.zero_initialize_workgroup_memory,
175            force_loop_bounding: self.force_loop_bounding,
176            ray_query_initialization_tracking: self.ray_query_initialization_tracking,
177            trace_ray_argument_validation: self.trace_ray_argument_validation,
178            use_storage_input_output_16: self.use_storage_input_output_16,
179            capabilities_available: take(&mut self.capabilities_available),
180            fake_missing_bindings: self.fake_missing_bindings,
181            binding_map: take(&mut self.binding_map),
182            task_dispatch_limits: self.task_dispatch_limits,
183            mesh_shader_primitive_indices_clamp: self.mesh_shader_primitive_indices_clamp,
184
185            // Initialized afresh:
186            id_gen,
187            void_type,
188            tuple_of_u32s_ty_id: None,
189            gl450_ext_inst_id,
190
191            // Reclaimed:
192            capabilities_used: take(&mut self.capabilities_used).reclaim(),
193            extensions_used: take(&mut self.extensions_used).reclaim(),
194            physical_layout: self.physical_layout.clone().reclaim(),
195            logical_layout: take(&mut self.logical_layout).reclaim(),
196            debug_strings: take(&mut self.debug_strings).reclaim(),
197            debugs: take(&mut self.debugs).reclaim(),
198            annotations: take(&mut self.annotations).reclaim(),
199            lookup_type: take(&mut self.lookup_type).reclaim(),
200            lookup_function: take(&mut self.lookup_function).reclaim(),
201            lookup_function_type: take(&mut self.lookup_function_type).reclaim(),
202            wrapped_functions: take(&mut self.wrapped_functions).reclaim(),
203            constant_ids: take(&mut self.constant_ids).reclaim(),
204            cached_constants: take(&mut self.cached_constants).reclaim(),
205            global_variables: take(&mut self.global_variables).reclaim(),
206            std140_compat_uniform_types: take(&mut self.std140_compat_uniform_types).reclaim(),
207            saved_cached: take(&mut self.saved_cached).reclaim(),
208            temp_list: take(&mut self.temp_list).reclaim(),
209            ray_query_functions: take(&mut self.ray_query_functions).reclaim(),
210            ray_tracing_functions: take(&mut self.ray_tracing_functions).reclaim(),
211            has_ray_tracing_pipeline: false,
212            io_f16_polyfills: take(&mut self.io_f16_polyfills).reclaim(),
213            debug_printf: None,
214        };
215
216        *self = fresh;
217
218        self.capabilities_used.insert(spirv::Capability::Shader);
219    }
220
221    /// Indicate that the code requires any one of the listed capabilities.
222    ///
223    /// If nothing in `capabilities` appears in the available capabilities
224    /// specified in the [`Options`] from which this `Writer` was created,
225    /// return an error. The `what` string is used in the error message to
226    /// explain what provoked the requirement. (If no available capabilities were
227    /// given, assume everything is available.)
228    ///
229    /// The first acceptable capability will be added to this `Writer`'s
230    /// [`capabilities_used`] table, and an `OpCapability` emitted for it in the
231    /// result. For this reason, more specific capabilities should be listed
232    /// before more general.
233    ///
234    /// [`capabilities_used`]: Writer::capabilities_used
235    pub(super) fn require_any(
236        &mut self,
237        what: &'static str,
238        capabilities: &[spirv::Capability],
239    ) -> Result<(), Error> {
240        match *capabilities {
241            [] => Ok(()),
242            [first, ..] => {
243                // Find the first acceptable capability, or return an error if
244                // there is none.
245                let selected = match self.capabilities_available {
246                    None => first,
247                    Some(ref available) => {
248                        match capabilities
249                            .iter()
250                            // need explicit type for hashbrown::HashSet::contains fn call to keep rustc happy
251                            .find(|cap| available.contains::<spirv::Capability>(cap))
252                        {
253                            Some(&cap) => cap,
254                            None => {
255                                return Err(Error::MissingCapabilities(what, capabilities.to_vec()))
256                            }
257                        }
258                    }
259                };
260                self.capabilities_used.insert(selected);
261                Ok(())
262            }
263        }
264    }
265
266    /// Indicate that the code requires all of the listed capabilities.
267    ///
268    /// If all entries of `capabilities` appear in the available capabilities
269    /// specified in the [`Options`] from which this `Writer` was created
270    /// (including the case where [`Options::capabilities`] is `None`), add
271    /// them all to this `Writer`'s [`capabilities_used`] table, and return
272    /// `Ok(())`. If at least one of the listed capabilities is not available,
273    /// do not add anything to the `capabilities_used` table, and return the
274    /// first unavailable requested capability, wrapped in `Err()`.
275    ///
276    /// This method is does not return an [`enum@Error`] in case of failure
277    /// because it may be used in cases where the caller can recover (e.g.,
278    /// with a polyfill) if the requested capabilities are not available. In
279    /// this case, it would be unnecessary work to find *all* the unavailable
280    /// requested capabilities, and to allocate a `Vec` for them, just so we
281    /// could return an [`Error::MissingCapabilities`]).
282    ///
283    /// [`capabilities_used`]: Writer::capabilities_used
284    pub(super) fn require_all(
285        &mut self,
286        capabilities: &[spirv::Capability],
287    ) -> Result<(), spirv::Capability> {
288        if let Some(ref available) = self.capabilities_available {
289            for requested in capabilities {
290                if !available.contains(requested) {
291                    return Err(*requested);
292                }
293            }
294        }
295
296        for requested in capabilities {
297            self.capabilities_used.insert(*requested);
298        }
299
300        Ok(())
301    }
302
303    /// Indicate that the code uses the given extension.
304    pub(super) fn use_extension(&mut self, extension: &'static str) {
305        self.extensions_used.insert(extension);
306    }
307
308    pub(super) fn get_type_id(&mut self, lookup_ty: LookupType) -> Word {
309        match self.lookup_type.entry(lookup_ty) {
310            Entry::Occupied(e) => *e.get(),
311            Entry::Vacant(e) => {
312                let local = match lookup_ty {
313                    LookupType::Handle(_handle) => unreachable!("Handles are populated at start"),
314                    LookupType::Local(local) => local,
315                };
316
317                let id = self.id_gen.next();
318                e.insert(id);
319                self.write_type_declaration_local(id, local);
320                id
321            }
322        }
323    }
324
325    pub(super) fn get_handle_type_id(&mut self, handle: Handle<crate::Type>) -> Word {
326        self.get_type_id(LookupType::Handle(handle))
327    }
328
329    pub(super) fn get_expression_lookup_type(&mut self, tr: &TypeResolution) -> LookupType {
330        match *tr {
331            TypeResolution::Handle(ty_handle) => LookupType::Handle(ty_handle),
332            TypeResolution::Value(ref inner) => {
333                let inner_local_type = self.localtype_from_inner(inner).unwrap();
334                LookupType::Local(inner_local_type)
335            }
336        }
337    }
338
339    pub(super) fn get_expression_type_id(&mut self, tr: &TypeResolution) -> Word {
340        let lookup_ty = self.get_expression_lookup_type(tr);
341        self.get_type_id(lookup_ty)
342    }
343
344    pub(super) fn get_localtype_id(&mut self, local: LocalType) -> Word {
345        self.get_type_id(LookupType::Local(local))
346    }
347
348    pub(super) fn get_pointer_type_id(&mut self, base: Word, class: spirv::StorageClass) -> Word {
349        self.get_type_id(LookupType::Local(LocalType::Pointer { base, class }))
350    }
351
352    pub(super) fn get_handle_pointer_type_id(
353        &mut self,
354        base: Handle<crate::Type>,
355        class: spirv::StorageClass,
356    ) -> Word {
357        let base_id = self.get_handle_type_id(base);
358        self.get_pointer_type_id(base_id, class)
359    }
360
361    pub(super) fn get_ray_query_pointer_id(&mut self) -> Word {
362        let rq_id = self.get_type_id(LookupType::Local(LocalType::RayQuery));
363        self.get_pointer_type_id(rq_id, spirv::StorageClass::Function)
364    }
365
366    /// Return a SPIR-V type for a pointer to `resolution`.
367    ///
368    /// The given `resolution` must be one that we can represent
369    /// either as a `LocalType::Pointer` or `LocalType::LocalPointer`.
370    pub(super) fn get_resolution_pointer_id(
371        &mut self,
372        resolution: &TypeResolution,
373        class: spirv::StorageClass,
374    ) -> Word {
375        let resolution_type_id = self.get_expression_type_id(resolution);
376        self.get_pointer_type_id(resolution_type_id, class)
377    }
378
379    pub(super) fn get_numeric_type_id(&mut self, numeric: NumericType) -> Word {
380        self.get_type_id(LocalType::Numeric(numeric).into())
381    }
382
383    pub(super) fn get_u32_type_id(&mut self) -> Word {
384        self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32))
385    }
386
387    pub(super) fn get_f32_type_id(&mut self) -> Word {
388        self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::F32))
389    }
390
391    pub(super) fn get_vec2u_type_id(&mut self) -> Word {
392        self.get_numeric_type_id(NumericType::Vector {
393            size: crate::VectorSize::Bi,
394            scalar: crate::Scalar::U32,
395        })
396    }
397
398    pub(super) fn get_vec2f_type_id(&mut self) -> Word {
399        self.get_numeric_type_id(NumericType::Vector {
400            size: crate::VectorSize::Bi,
401            scalar: crate::Scalar::F32,
402        })
403    }
404
405    pub(super) fn get_vec3u_type_id(&mut self) -> Word {
406        self.get_numeric_type_id(NumericType::Vector {
407            size: crate::VectorSize::Tri,
408            scalar: crate::Scalar::U32,
409        })
410    }
411
412    pub(super) fn get_f32_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word {
413        let f32_id = self.get_f32_type_id();
414        self.get_pointer_type_id(f32_id, class)
415    }
416
417    pub(super) fn get_vec2u_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word {
418        let vec2u_id = self.get_numeric_type_id(NumericType::Vector {
419            size: crate::VectorSize::Bi,
420            scalar: crate::Scalar::U32,
421        });
422        self.get_pointer_type_id(vec2u_id, class)
423    }
424
425    pub(super) fn get_bool_type_id(&mut self) -> Word {
426        self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::BOOL))
427    }
428
429    pub(super) fn get_vec2_bool_type_id(&mut self) -> Word {
430        self.get_numeric_type_id(NumericType::Vector {
431            size: crate::VectorSize::Bi,
432            scalar: crate::Scalar::BOOL,
433        })
434    }
435
436    pub(super) fn get_vec3_bool_type_id(&mut self) -> Word {
437        self.get_numeric_type_id(NumericType::Vector {
438            size: crate::VectorSize::Tri,
439            scalar: crate::Scalar::BOOL,
440        })
441    }
442
443    /// Used for "mulhi" to get the upper bits of multiplication.
444    ///
445    /// More specifically, `OpUMulExtended` multiplies 2 numbers and returns the lower and upper bits of the result
446    /// as a user-defined struct type with 2 u32s. This defines that struct.
447    pub(super) fn get_tuple_of_u32s_ty_id(&mut self) -> Word {
448        if let Some(val) = self.tuple_of_u32s_ty_id {
449            val
450        } else {
451            let id = self.id_gen.next();
452            let u32_id = self.get_u32_type_id();
453            let ins = Instruction::type_struct(id, &[u32_id, u32_id]);
454            ins.to_words(&mut self.logical_layout.declarations);
455            self.tuple_of_u32s_ty_id = Some(id);
456            id
457        }
458    }
459
460    pub(super) fn decorate(&mut self, id: Word, decoration: spirv::Decoration, operands: &[Word]) {
461        self.annotations
462            .push(Instruction::decorate(id, decoration, operands));
463    }
464
465    /// Return `inner` as a `LocalType`, if that's possible.
466    ///
467    /// If `inner` can be represented as a `LocalType`, return
468    /// `Some(local_type)`.
469    ///
470    /// Otherwise, return `None`. In this case, the type must always be looked
471    /// up using a `LookupType::Handle`.
472    fn localtype_from_inner(&mut self, inner: &crate::TypeInner) -> Option<LocalType> {
473        Some(match *inner {
474            crate::TypeInner::Scalar(_)
475            | crate::TypeInner::Atomic(_)
476            | crate::TypeInner::Vector { .. }
477            | crate::TypeInner::Matrix { .. } => {
478                // We expect `NumericType::from_inner` to handle all
479                // these cases, so unwrap.
480                LocalType::Numeric(NumericType::from_inner(inner).unwrap())
481            }
482            crate::TypeInner::CooperativeMatrix { .. } => {
483                LocalType::Cooperative(CooperativeType::from_inner(inner).unwrap())
484            }
485            crate::TypeInner::Pointer { base, space } => {
486                let base_type_id = self.get_handle_type_id(base);
487                LocalType::Pointer {
488                    base: base_type_id,
489                    class: map_storage_class(space),
490                }
491            }
492            crate::TypeInner::ValuePointer {
493                size,
494                scalar,
495                space,
496            } => {
497                let base_numeric_type = match size {
498                    Some(size) => NumericType::Vector { size, scalar },
499                    None => NumericType::Scalar(scalar),
500                };
501                LocalType::Pointer {
502                    base: self.get_numeric_type_id(base_numeric_type),
503                    class: map_storage_class(space),
504                }
505            }
506            crate::TypeInner::Image {
507                dim,
508                arrayed,
509                class,
510            } => LocalType::Image(LocalImageType::from_inner(dim, arrayed, class)),
511            crate::TypeInner::Sampler { comparison: _ } => LocalType::Sampler,
512            crate::TypeInner::AccelerationStructure { .. } => LocalType::AccelerationStructure,
513            crate::TypeInner::RayQuery { .. } => LocalType::RayQuery,
514            crate::TypeInner::Array { .. }
515            | crate::TypeInner::Struct { .. }
516            | crate::TypeInner::BindingArray { .. } => return None,
517        })
518    }
519
520    /// Resolve the [`BindingInfo`] for a [`crate::ResourceBinding`] from the
521    /// provided [`Writer::binding_map`].
522    ///
523    /// If the specified resource is not present in the binding map this will
524    /// return an error, unless [`Writer::fake_missing_bindings`] is set.
525    fn resolve_resource_binding(
526        &self,
527        res_binding: &crate::ResourceBinding,
528    ) -> Result<BindingInfo, Error> {
529        match self.binding_map.get(res_binding) {
530            Some(target) => Ok(*target),
531            None if self.fake_missing_bindings => Ok(BindingInfo {
532                descriptor_set: res_binding.group,
533                binding: res_binding.binding,
534                binding_array_size: None,
535            }),
536            None => Err(Error::MissingBinding(*res_binding)),
537        }
538    }
539
540    /// Emits code for any wrapper functions required by the expressions in ir_function.
541    /// The IDs of any emitted functions will be stored in [`Self::wrapped_functions`].
542    fn write_wrapped_functions(
543        &mut self,
544        ir_function: &crate::Function,
545        info: &FunctionInfo,
546        ir_module: &crate::Module,
547    ) -> Result<(), Error> {
548        log::trace!("Generating wrapped functions for {:?}", ir_function.name);
549
550        for (expr_handle, expr) in ir_function.expressions.iter() {
551            match *expr {
552                crate::Expression::Binary { op, left, right } => {
553                    let expr_ty_inner = info[expr_handle].ty.inner_with(&ir_module.types);
554                    if let Some(expr_ty) = NumericType::from_inner(expr_ty_inner) {
555                        match (op, expr_ty.scalar().kind) {
556                            // Division and modulo are undefined behaviour when the
557                            // dividend is the minimum representable value and the divisor
558                            // is negative one, or when the divisor is zero. These wrapped
559                            // functions override the divisor to one in these cases,
560                            // matching the WGSL spec.
561                            (
562                                crate::BinaryOperator::Divide | crate::BinaryOperator::Modulo,
563                                crate::ScalarKind::Sint | crate::ScalarKind::Uint,
564                            ) => {
565                                self.write_wrapped_binary_op(
566                                    op,
567                                    expr_ty,
568                                    &info[left].ty,
569                                    &info[right].ty,
570                                )?;
571                            }
572                            _ => {}
573                        }
574                    }
575                }
576                crate::Expression::Load { pointer } => {
577                    if let crate::TypeInner::Pointer {
578                        base: pointer_type,
579                        space: crate::AddressSpace::Uniform,
580                    } = *info[pointer].ty.inner_with(&ir_module.types)
581                    {
582                        if self.std140_compat_uniform_types.contains_key(&pointer_type) {
583                            // Loading a std140 compat type requires the wrapper function
584                            // to convert to the regular type.
585                            self.write_wrapped_convert_from_std140_compat_type(
586                                ir_module,
587                                pointer_type,
588                            )?;
589                        }
590                    }
591                }
592                crate::Expression::Access { base, .. } => {
593                    if let crate::TypeInner::Pointer {
594                        base: base_type,
595                        space: crate::AddressSpace::Uniform,
596                    } = *info[base].ty.inner_with(&ir_module.types)
597                    {
598                        // Dynamic accesses of a two-row matrix's columns require a
599                        // wrapper function.
600                        if let crate::TypeInner::Matrix {
601                            rows: crate::VectorSize::Bi,
602                            ..
603                        } = ir_module.types[base_type].inner
604                        {
605                            self.write_wrapped_matcx2_get_column(ir_module, base_type)?;
606                            // If the matrix is *not* directly a member of a struct, then
607                            // we additionally require a wrapper function to convert from
608                            // the std140 compat type to the regular type.
609                            if !is_uniform_matcx2_struct_member_access(
610                                ir_function,
611                                info,
612                                ir_module,
613                                base,
614                            ) {
615                                self.write_wrapped_convert_from_std140_compat_type(
616                                    ir_module, base_type,
617                                )?;
618                            }
619                        }
620                    }
621                }
622                _ => {}
623            }
624        }
625
626        Ok(())
627    }
628
629    /// Write a SPIR-V function that performs the operator `op` with Naga IR semantics.
630    ///
631    /// Define a function that performs an integer division or modulo operation,
632    /// except that using a divisor of zero or causing signed overflow with a
633    /// divisor of -1 returns the numerator unchanged, rather than exhibiting
634    /// undefined behavior.
635    ///
636    /// Store the generated function's id in the [`wrapped_functions`] table.
637    ///
638    /// The operator `op` must be either [`Divide`] or [`Modulo`].
639    ///
640    /// # Panics
641    ///
642    /// The `return_type`, `left_type` or `right_type` arguments must all be
643    /// integer scalars or vectors. If not, this function panics.
644    ///
645    /// [`wrapped_functions`]: Writer::wrapped_functions
646    /// [`Divide`]: crate::BinaryOperator::Divide
647    /// [`Modulo`]: crate::BinaryOperator::Modulo
648    fn write_wrapped_binary_op(
649        &mut self,
650        op: crate::BinaryOperator,
651        return_type: NumericType,
652        left_type: &TypeResolution,
653        right_type: &TypeResolution,
654    ) -> Result<(), Error> {
655        let return_type_id = self.get_localtype_id(LocalType::Numeric(return_type));
656        let left_type_id = self.get_expression_type_id(left_type);
657        let right_type_id = self.get_expression_type_id(right_type);
658
659        // Check if we've already emitted this function.
660        let wrapped = WrappedFunction::BinaryOp {
661            op,
662            left_type_id,
663            right_type_id,
664        };
665        let function_id = match self.wrapped_functions.entry(wrapped) {
666            Entry::Occupied(_) => return Ok(()),
667            Entry::Vacant(e) => *e.insert(self.id_gen.next()),
668        };
669
670        let scalar = return_type.scalar();
671
672        if self.flags.contains(WriterFlags::DEBUG) {
673            let function_name = match op {
674                crate::BinaryOperator::Divide => "naga_div",
675                crate::BinaryOperator::Modulo => "naga_mod",
676                _ => unreachable!(),
677            };
678            self.debugs
679                .push(Instruction::name(function_id, function_name));
680        }
681        let mut function = Function::default();
682
683        let function_type_id = self.get_function_type(LookupFunctionType {
684            parameter_type_ids: vec![left_type_id, right_type_id],
685            return_type_id,
686        });
687        function.signature = Some(Instruction::function(
688            return_type_id,
689            function_id,
690            spirv::FunctionControl::empty(),
691            function_type_id,
692        ));
693
694        let lhs_id = self.id_gen.next();
695        let rhs_id = self.id_gen.next();
696        if self.flags.contains(WriterFlags::DEBUG) {
697            self.debugs.push(Instruction::name(lhs_id, "lhs"));
698            self.debugs.push(Instruction::name(rhs_id, "rhs"));
699        }
700        let left_par = Instruction::function_parameter(left_type_id, lhs_id);
701        let right_par = Instruction::function_parameter(right_type_id, rhs_id);
702        for instruction in [left_par, right_par] {
703            function.parameters.push(FunctionArgument {
704                instruction,
705                handle_id: 0,
706            });
707        }
708
709        let label_id = self.id_gen.next();
710        let mut block = Block::new(label_id);
711
712        let bool_type = return_type.with_scalar(crate::Scalar::BOOL);
713        let bool_type_id = self.get_numeric_type_id(bool_type);
714
715        let maybe_splat_const = |writer: &mut Self, const_id| match return_type {
716            NumericType::Scalar(_) => const_id,
717            NumericType::Vector { size, .. } => {
718                let constituent_ids = [const_id; crate::VectorSize::MAX];
719                writer.get_constant_composite(
720                    LookupType::Local(LocalType::Numeric(return_type)),
721                    &constituent_ids[..size as usize],
722                )
723            }
724            NumericType::Matrix { .. } => unreachable!(),
725        };
726
727        let const_zero_id = self.get_constant_scalar_with(0, scalar)?;
728        let composite_zero_id = maybe_splat_const(self, const_zero_id);
729        let rhs_eq_zero_id = self.id_gen.next();
730        block.body.push(Instruction::binary(
731            spirv::Op::IEqual,
732            bool_type_id,
733            rhs_eq_zero_id,
734            rhs_id,
735            composite_zero_id,
736        ));
737        let divisor_selector_id = match scalar.kind {
738            crate::ScalarKind::Sint => {
739                let (const_min_id, const_neg_one_id) = match scalar.width {
740                    2 => Ok((
741                        self.get_constant_scalar(crate::Literal::I16(i16::MIN)),
742                        self.get_constant_scalar(crate::Literal::I16(-1i16)),
743                    )),
744                    4 => Ok((
745                        self.get_constant_scalar(crate::Literal::I32(i32::MIN)),
746                        self.get_constant_scalar(crate::Literal::I32(-1i32)),
747                    )),
748                    8 => Ok((
749                        self.get_constant_scalar(crate::Literal::I64(i64::MIN)),
750                        self.get_constant_scalar(crate::Literal::I64(-1i64)),
751                    )),
752                    _ => Err(Error::Validation("Unexpected scalar width")),
753                }?;
754                let composite_min_id = maybe_splat_const(self, const_min_id);
755                let composite_neg_one_id = maybe_splat_const(self, const_neg_one_id);
756
757                let lhs_eq_int_min_id = self.id_gen.next();
758                block.body.push(Instruction::binary(
759                    spirv::Op::IEqual,
760                    bool_type_id,
761                    lhs_eq_int_min_id,
762                    lhs_id,
763                    composite_min_id,
764                ));
765                let rhs_eq_neg_one_id = self.id_gen.next();
766                block.body.push(Instruction::binary(
767                    spirv::Op::IEqual,
768                    bool_type_id,
769                    rhs_eq_neg_one_id,
770                    rhs_id,
771                    composite_neg_one_id,
772                ));
773                let lhs_eq_int_min_and_rhs_eq_neg_one_id = self.id_gen.next();
774                block.body.push(Instruction::binary(
775                    spirv::Op::LogicalAnd,
776                    bool_type_id,
777                    lhs_eq_int_min_and_rhs_eq_neg_one_id,
778                    lhs_eq_int_min_id,
779                    rhs_eq_neg_one_id,
780                ));
781                let rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id = self.id_gen.next();
782                block.body.push(Instruction::binary(
783                    spirv::Op::LogicalOr,
784                    bool_type_id,
785                    rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id,
786                    rhs_eq_zero_id,
787                    lhs_eq_int_min_and_rhs_eq_neg_one_id,
788                ));
789                rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id
790            }
791            crate::ScalarKind::Uint => rhs_eq_zero_id,
792            _ => unreachable!(),
793        };
794
795        let const_one_id = self.get_constant_scalar_with(1, scalar)?;
796        let composite_one_id = maybe_splat_const(self, const_one_id);
797        let divisor_id = self.id_gen.next();
798        block.body.push(Instruction::select(
799            right_type_id,
800            divisor_id,
801            divisor_selector_id,
802            composite_one_id,
803            rhs_id,
804        ));
805        let op = match (op, scalar.kind) {
806            (crate::BinaryOperator::Divide, crate::ScalarKind::Sint) => spirv::Op::SDiv,
807            (crate::BinaryOperator::Divide, crate::ScalarKind::Uint) => spirv::Op::UDiv,
808            (crate::BinaryOperator::Modulo, crate::ScalarKind::Sint) => spirv::Op::SRem,
809            (crate::BinaryOperator::Modulo, crate::ScalarKind::Uint) => spirv::Op::UMod,
810            _ => unreachable!(),
811        };
812        let return_id = self.id_gen.next();
813        block.body.push(Instruction::binary(
814            op,
815            return_type_id,
816            return_id,
817            lhs_id,
818            divisor_id,
819        ));
820
821        function.consume(block, Instruction::return_value(return_id));
822        function.to_words(&mut self.logical_layout.function_definitions);
823        Ok(())
824    }
825
826    /// Writes a wrapper function to convert from a std140 compat type to its
827    /// corresponding regular type.
828    ///
829    /// See [`Self::write_std140_compat_type_declaration`] for more details.
830    fn write_wrapped_convert_from_std140_compat_type(
831        &mut self,
832        ir_module: &crate::Module,
833        r#type: Handle<crate::Type>,
834    ) -> Result<(), Error> {
835        if !self.std140_compat_uniform_types.contains_key(&r#type) {
836            return Ok(());
837        }
838        // Check if we've already emitted this function.
839        let wrapped = WrappedFunction::ConvertFromStd140CompatType { r#type };
840        let function_id = match self.wrapped_functions.entry(wrapped) {
841            Entry::Occupied(_) => return Ok(()),
842            Entry::Vacant(e) => *e.insert(self.id_gen.next()),
843        };
844        if self.flags.contains(WriterFlags::DEBUG) {
845            self.debugs.push(Instruction::name(
846                function_id,
847                &format!("{:?}_from_std140", r#type.for_debug(&ir_module.types)),
848            ));
849        }
850        let param_type_id = self.std140_compat_uniform_types[&r#type].type_id;
851        let return_type_id = self.get_handle_type_id(r#type);
852
853        let mut function = Function::default();
854        let function_type_id = self.get_function_type(LookupFunctionType {
855            parameter_type_ids: vec![param_type_id],
856            return_type_id,
857        });
858        function.signature = Some(Instruction::function(
859            return_type_id,
860            function_id,
861            spirv::FunctionControl::empty(),
862            function_type_id,
863        ));
864        let param_id = self.id_gen.next();
865        function.parameters.push(FunctionArgument {
866            instruction: Instruction::function_parameter(param_type_id, param_id),
867            handle_id: 0,
868        });
869
870        let label_id = self.id_gen.next();
871        let mut block = Block::new(label_id);
872
873        let result_id = match ir_module.types[r#type].inner {
874            // Param is struct containing a vector member for each of the
875            // matrix's columns. Extract each column from the struct then
876            // composite into a matrix.
877            crate::TypeInner::Matrix {
878                columns,
879                rows: rows @ crate::VectorSize::Bi,
880                scalar,
881            } => {
882                let column_type_id =
883                    self.get_numeric_type_id(NumericType::Vector { size: rows, scalar });
884
885                let mut column_ids: ArrayVec<Word, 4> = ArrayVec::new();
886                for column in 0..columns as u32 {
887                    let column_id = self.id_gen.next();
888                    block.body.push(Instruction::composite_extract(
889                        column_type_id,
890                        column_id,
891                        param_id,
892                        &[column],
893                    ));
894                    column_ids.push(column_id);
895                }
896                let result_id = self.id_gen.next();
897                block.body.push(Instruction::composite_construct(
898                    return_type_id,
899                    result_id,
900                    &column_ids,
901                ));
902                result_id
903            }
904            // Param is an array where the base type is the std140 compatible
905            // type corresponding to `base`. Iterate through each element and
906            // call its conversion function, then composite into a new array.
907            crate::TypeInner::Array { base, size, .. } => {
908                // Ensure the conversion function for the array's base type is
909                // declared.
910                self.write_wrapped_convert_from_std140_compat_type(ir_module, base)?;
911
912                let element_type_id = self.get_handle_type_id(base);
913                let std140_info = self.std140_compat_uniform_types.get(&base);
914                let mut element_ids = Vec::new();
915                let size = match size.resolve(ir_module.to_ctx())? {
916                    crate::proc::IndexableLength::Known(size) => size,
917                    crate::proc::IndexableLength::Dynamic => {
918                        return Err(Error::Validation(
919                            "Uniform buffers cannot contain dynamic arrays",
920                        ))
921                    }
922                };
923                for i in 0..size {
924                    let std140_element_id = self.id_gen.next();
925                    let std140_element_type_id =
926                        std140_info.map_or(element_type_id, |info| info.type_id);
927                    block.body.push(Instruction::composite_extract(
928                        std140_element_type_id,
929                        std140_element_id,
930                        param_id,
931                        &[i],
932                    ));
933
934                    // Only call the conversion function if a compatibility mapping actually exists.
935                    let final_element_id = if std140_info.is_some() {
936                        let conversion_fn_id = self.wrapped_functions
937                            [&WrappedFunction::ConvertFromStd140CompatType { r#type: base }];
938                        let id = self.id_gen.next();
939                        block.body.push(Instruction::function_call(
940                            element_type_id,
941                            id,
942                            conversion_fn_id,
943                            &[std140_element_id],
944                        ));
945                        id
946                    } else {
947                        std140_element_type_id
948                    };
949                    element_ids.push(final_element_id);
950                }
951                let result_id = self.id_gen.next();
952                block.body.push(Instruction::composite_construct(
953                    return_type_id,
954                    result_id,
955                    &element_ids,
956                ));
957                result_id
958            }
959            // Param is a struct where each two-row matrix member has been
960            // decomposed in to separate vector members for each column.
961            // Other members use their std140 compatible type if one exists, or
962            // else their regular type. Iterate through each member, converting
963            // or composing any matrices if required, then finally compose into
964            // the struct.
965            crate::TypeInner::Struct { ref members, .. } => {
966                let mut member_ids = Vec::new();
967                let mut next_index = 0;
968                for member in members {
969                    let member_id = self.id_gen.next();
970                    let member_type_id = self.get_handle_type_id(member.ty);
971                    match ir_module.types[member.ty].inner {
972                        crate::TypeInner::Matrix {
973                            columns,
974                            rows: rows @ crate::VectorSize::Bi,
975                            scalar,
976                        } => {
977                            let mut column_ids: ArrayVec<Word, 4> = ArrayVec::new();
978                            let column_type_id = self
979                                .get_numeric_type_id(NumericType::Vector { size: rows, scalar });
980                            for _ in 0..columns as u32 {
981                                let column_id = self.id_gen.next();
982                                block.body.push(Instruction::composite_extract(
983                                    column_type_id,
984                                    column_id,
985                                    param_id,
986                                    &[next_index],
987                                ));
988                                column_ids.push(column_id);
989                                next_index += 1;
990                            }
991                            block.body.push(Instruction::composite_construct(
992                                member_type_id,
993                                member_id,
994                                &column_ids,
995                            ));
996                        }
997                        _ => {
998                            // Ensure the conversion function for the member's
999                            // type is declared.
1000                            self.write_wrapped_convert_from_std140_compat_type(
1001                                ir_module, member.ty,
1002                            )?;
1003                            match self.std140_compat_uniform_types.get(&member.ty) {
1004                                Some(std140_type_info) => {
1005                                    let std140_member_id = self.id_gen.next();
1006                                    block.body.push(Instruction::composite_extract(
1007                                        std140_type_info.type_id,
1008                                        std140_member_id,
1009                                        param_id,
1010                                        &[next_index],
1011                                    ));
1012                                    let function_id = self.wrapped_functions
1013                                        [&WrappedFunction::ConvertFromStd140CompatType {
1014                                            r#type: member.ty,
1015                                        }];
1016                                    block.body.push(Instruction::function_call(
1017                                        member_type_id,
1018                                        member_id,
1019                                        function_id,
1020                                        &[std140_member_id],
1021                                    ));
1022                                    next_index += 1;
1023                                }
1024                                None => {
1025                                    block.body.push(Instruction::composite_extract(
1026                                        member_type_id,
1027                                        member_id,
1028                                        param_id,
1029                                        &[next_index],
1030                                    ));
1031                                    next_index += 1;
1032                                }
1033                            }
1034                        }
1035                    }
1036                    member_ids.push(member_id);
1037                }
1038                let result_id = self.id_gen.next();
1039                block.body.push(Instruction::composite_construct(
1040                    return_type_id,
1041                    result_id,
1042                    &member_ids,
1043                ));
1044                result_id
1045            }
1046            _ => unreachable!(),
1047        };
1048
1049        function.consume(block, Instruction::return_value(result_id));
1050        function.to_words(&mut self.logical_layout.function_definitions);
1051        Ok(())
1052    }
1053
1054    /// Writes a wrapper function to get an `OpTypeVector` column from an
1055    /// `OpTypeMatrix` with a dynamic index.
1056    ///
1057    /// This is used when accessing a column of a [`TypeInner::Matrix`] through
1058    /// a [`Uniform`] address space pointer. In such cases, the matrix will have
1059    /// been declared in SPIR-V using an alternative type where each column is a
1060    /// member of a containing struct. SPIR-V is unable to dynamically access
1061    /// struct members, so instead we load the matrix then call this function to
1062    /// access a column from the loaded value.
1063    ///
1064    /// [`TypeInner::Matrix`]: crate::TypeInner::Matrix
1065    /// [`Uniform`]: crate::AddressSpace::Uniform
1066    fn write_wrapped_matcx2_get_column(
1067        &mut self,
1068        ir_module: &crate::Module,
1069        r#type: Handle<crate::Type>,
1070    ) -> Result<(), Error> {
1071        let wrapped = WrappedFunction::MatCx2GetColumn { r#type };
1072        let function_id = match self.wrapped_functions.entry(wrapped) {
1073            Entry::Occupied(_) => return Ok(()),
1074            Entry::Vacant(e) => *e.insert(self.id_gen.next()),
1075        };
1076        if self.flags.contains(WriterFlags::DEBUG) {
1077            self.debugs.push(Instruction::name(
1078                function_id,
1079                &format!("{:?}_get_column", r#type.for_debug(&ir_module.types)),
1080            ));
1081        }
1082
1083        let crate::TypeInner::Matrix {
1084            columns,
1085            rows: rows @ crate::VectorSize::Bi,
1086            scalar,
1087        } = ir_module.types[r#type].inner
1088        else {
1089            unreachable!();
1090        };
1091
1092        let mut function = Function::default();
1093        let matrix_type_id = self.get_handle_type_id(r#type);
1094        let column_index_type_id = self.get_u32_type_id();
1095        let column_type_id = self.get_numeric_type_id(NumericType::Vector { size: rows, scalar });
1096        let matrix_param_id = self.id_gen.next();
1097        let column_index_param_id = self.id_gen.next();
1098        function.parameters.push(FunctionArgument {
1099            instruction: Instruction::function_parameter(matrix_type_id, matrix_param_id),
1100            handle_id: 0,
1101        });
1102        function.parameters.push(FunctionArgument {
1103            instruction: Instruction::function_parameter(
1104                column_index_type_id,
1105                column_index_param_id,
1106            ),
1107            handle_id: 0,
1108        });
1109        let function_type_id = self.get_function_type(LookupFunctionType {
1110            parameter_type_ids: vec![matrix_type_id, column_index_type_id],
1111            return_type_id: column_type_id,
1112        });
1113        function.signature = Some(Instruction::function(
1114            column_type_id,
1115            function_id,
1116            spirv::FunctionControl::empty(),
1117            function_type_id,
1118        ));
1119
1120        let label_id = self.id_gen.next();
1121        let mut block = Block::new(label_id);
1122
1123        // Create a switch case for each column in the matrix, where each case
1124        // extracts its column from the matrix. Finally we use OpPhi to return
1125        // the correct column.
1126        let merge_id = self.id_gen.next();
1127        block.body.push(Instruction::selection_merge(
1128            merge_id,
1129            spirv::SelectionControl::NONE,
1130        ));
1131        let cases = (0..columns as u32)
1132            .map(|i| super::instructions::Case {
1133                value: i,
1134                label_id: self.id_gen.next(),
1135            })
1136            .collect::<ArrayVec<_, 4>>();
1137
1138        // Which label we branch to in the default (column index out-of-bounds)
1139        // case depends on our bounds check policy.
1140        let default_id = match self.bounds_check_policies.index {
1141            // For `Restrict`, treat the same as the final column.
1142            crate::proc::BoundsCheckPolicy::Restrict => cases.last().unwrap().label_id,
1143            // For `ReadZeroSkipWrite`, branch directly to the merge block. This
1144            // will be handled in the `OpPhi` below to produce a zero value.
1145            crate::proc::BoundsCheckPolicy::ReadZeroSkipWrite => merge_id,
1146            // For `Unchecked` we create a new block containing an
1147            // `OpUnreachable`.
1148            crate::proc::BoundsCheckPolicy::Unchecked => self.id_gen.next(),
1149        };
1150        function.consume(
1151            block,
1152            Instruction::switch(column_index_param_id, default_id, &cases),
1153        );
1154
1155        // Emit a block for each case, and produce a list of variable and parent
1156        // block IDs that will be used in an `OpPhi` below to select the right
1157        // value.
1158        let mut var_parent_pairs = cases
1159            .into_iter()
1160            .map(|case| {
1161                let mut block = Block::new(case.label_id);
1162                let column_id = self.id_gen.next();
1163                block.body.push(Instruction::composite_extract(
1164                    column_type_id,
1165                    column_id,
1166                    matrix_param_id,
1167                    &[case.value],
1168                ));
1169                function.consume(block, Instruction::branch(merge_id));
1170                (column_id, case.label_id)
1171            })
1172            // Need capacity for up to 4 columns plus possibly a default case.
1173            .collect::<ArrayVec<_, 5>>();
1174
1175        // Emit a block or append the variable and parent `OpPhi` pair for the
1176        // column index out-of-bounds case, if required.
1177        match self.bounds_check_policies.index {
1178            // Don't need to do anything for `Restrict` as we have branched from
1179            // the final column case's block.
1180            crate::proc::BoundsCheckPolicy::Restrict => {}
1181            // For `ReadZeroSkipWrite` we have branched directly from the block
1182            // containing the `OpSwitch`. The `OpPhi` should produce a zero
1183            // value.
1184            crate::proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
1185                var_parent_pairs.push((self.get_constant_null(column_type_id), label_id));
1186            }
1187            // For `Unchecked` create a new block containing `OpUnreachable`.
1188            // This does not need to be handled by the `OpPhi`.
1189            crate::proc::BoundsCheckPolicy::Unchecked => {
1190                function.consume(
1191                    Block::new(default_id),
1192                    Instruction::new(spirv::Op::Unreachable),
1193                );
1194            }
1195        }
1196
1197        let mut block = Block::new(merge_id);
1198        let result_id = self.id_gen.next();
1199        block.body.push(Instruction::phi(
1200            column_type_id,
1201            result_id,
1202            &var_parent_pairs,
1203        ));
1204
1205        function.consume(block, Instruction::return_value(result_id));
1206        function.to_words(&mut self.logical_layout.function_definitions);
1207        Ok(())
1208    }
1209
1210    fn write_function(
1211        &mut self,
1212        ir_function: &crate::Function,
1213        info: &FunctionInfo,
1214        ir_module: &crate::Module,
1215        mut interface: Option<FunctionInterface>,
1216        debug_info: &Option<DebugInfoInner>,
1217    ) -> Result<Word, Error> {
1218        self.write_wrapped_functions(ir_function, info, ir_module)?;
1219
1220        log::trace!("Generating code for {:?}", ir_function.name);
1221        let mut function = Function::default();
1222
1223        let prelude_id = self.id_gen.next();
1224        let mut prelude = Block::new(prelude_id);
1225        let mut ep_context = EntryPointContext {
1226            argument_ids: Vec::new(),
1227            results: Vec::new(),
1228            task_payload_variable_id: if let Some(ref i) = interface {
1229                i.task_payload.map(|a| self.global_variables[a].var_id)
1230            } else {
1231                None
1232            },
1233            mesh_state: None,
1234        };
1235
1236        let mut parameter_type_ids = Vec::with_capacity(ir_function.arguments.len());
1237
1238        let mut local_invocation_index_var_id = None;
1239        let mut local_invocation_index_id = None;
1240
1241        for argument in ir_function.arguments.iter() {
1242            let class = spirv::StorageClass::Input;
1243            let handle_ty = ir_module.types[argument.ty].inner.is_handle();
1244            let argument_type_id = if handle_ty {
1245                self.get_handle_pointer_type_id(argument.ty, spirv::StorageClass::UniformConstant)
1246            } else {
1247                self.get_handle_type_id(argument.ty)
1248            };
1249
1250            if let Some(ref mut iface) = interface {
1251                let id = if let Some(ref binding) = argument.binding {
1252                    let name = argument.name.as_deref();
1253
1254                    let varying_id = self.write_varying(
1255                        ir_module,
1256                        iface.stage,
1257                        class,
1258                        name,
1259                        argument.ty,
1260                        binding,
1261                    )?;
1262                    iface.varying_ids.push(varying_id);
1263                    let id = self.load_io_with_f16_polyfill(
1264                        &mut prelude.body,
1265                        varying_id,
1266                        argument_type_id,
1267                    );
1268                    if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationIndex) {
1269                        local_invocation_index_id = Some(id);
1270                        local_invocation_index_var_id = Some(varying_id);
1271                    }
1272
1273                    id
1274                } else if let crate::TypeInner::Struct { ref members, .. } =
1275                    ir_module.types[argument.ty].inner
1276                {
1277                    let struct_id = self.id_gen.next();
1278                    let mut constituent_ids = Vec::with_capacity(members.len());
1279                    for member in members {
1280                        let type_id = self.get_handle_type_id(member.ty);
1281                        let name = member.name.as_deref();
1282                        let binding = member.binding.as_ref().unwrap();
1283                        let varying_id = self.write_varying(
1284                            ir_module,
1285                            iface.stage,
1286                            class,
1287                            name,
1288                            member.ty,
1289                            binding,
1290                        )?;
1291                        iface.varying_ids.push(varying_id);
1292                        let id =
1293                            self.load_io_with_f16_polyfill(&mut prelude.body, varying_id, type_id);
1294                        constituent_ids.push(id);
1295                        if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationIndex)
1296                        {
1297                            local_invocation_index_id = Some(id);
1298                            local_invocation_index_var_id = Some(varying_id);
1299                        }
1300                    }
1301                    prelude.body.push(Instruction::composite_construct(
1302                        argument_type_id,
1303                        struct_id,
1304                        &constituent_ids,
1305                    ));
1306                    struct_id
1307                } else {
1308                    unreachable!("Missing argument binding on an entry point");
1309                };
1310                ep_context.argument_ids.push(id);
1311            } else {
1312                let argument_id = self.id_gen.next();
1313                let instruction = Instruction::function_parameter(argument_type_id, argument_id);
1314                if self.flags.contains(WriterFlags::DEBUG) {
1315                    if let Some(ref name) = argument.name {
1316                        self.debugs.push(Instruction::name(argument_id, name));
1317                    }
1318                }
1319                function.parameters.push(FunctionArgument {
1320                    instruction,
1321                    handle_id: if handle_ty {
1322                        let id = self.id_gen.next();
1323                        prelude.body.push(Instruction::load(
1324                            self.get_handle_type_id(argument.ty),
1325                            id,
1326                            argument_id,
1327                            None,
1328                        ));
1329                        id
1330                    } else {
1331                        0
1332                    },
1333                });
1334                parameter_type_ids.push(argument_type_id);
1335            };
1336        }
1337
1338        let return_type_id = match ir_function.result {
1339            Some(ref result) => {
1340                if let Some(ref mut iface) = interface {
1341                    let mut has_point_size = false;
1342                    let class = spirv::StorageClass::Output;
1343                    if let Some(ref binding) = result.binding {
1344                        has_point_size |=
1345                            *binding == crate::Binding::BuiltIn(crate::BuiltIn::PointSize);
1346                        let type_id = self.get_handle_type_id(result.ty);
1347                        let varying_id =
1348                            if *binding == crate::Binding::BuiltIn(crate::BuiltIn::MeshTaskSize) {
1349                                0
1350                            } else {
1351                                let varying_id = self.write_varying(
1352                                    ir_module,
1353                                    iface.stage,
1354                                    class,
1355                                    None,
1356                                    result.ty,
1357                                    binding,
1358                                )?;
1359                                iface.varying_ids.push(varying_id);
1360                                varying_id
1361                            };
1362                        ep_context.results.push(ResultMember {
1363                            id: varying_id,
1364                            type_id,
1365                            built_in: binding.to_built_in(),
1366                        });
1367                    } else if let crate::TypeInner::Struct { ref members, .. } =
1368                        ir_module.types[result.ty].inner
1369                    {
1370                        for member in members {
1371                            let type_id = self.get_handle_type_id(member.ty);
1372                            let name = member.name.as_deref();
1373                            let binding = member.binding.as_ref().unwrap();
1374                            has_point_size |=
1375                                *binding == crate::Binding::BuiltIn(crate::BuiltIn::PointSize);
1376                            // This isn't an actual builtin in SPIR-V. It can only appear as the
1377                            // output of a task shader and the output is used when writing the
1378                            // entry point return, in which case the id is ignored anyway.
1379                            let varying_id = if *binding
1380                                == crate::Binding::BuiltIn(crate::BuiltIn::MeshTaskSize)
1381                            {
1382                                0
1383                            } else {
1384                                let varying_id = self.write_varying(
1385                                    ir_module,
1386                                    iface.stage,
1387                                    class,
1388                                    name,
1389                                    member.ty,
1390                                    binding,
1391                                )?;
1392                                iface.varying_ids.push(varying_id);
1393                                varying_id
1394                            };
1395                            ep_context.results.push(ResultMember {
1396                                id: varying_id,
1397                                type_id,
1398                                built_in: binding.to_built_in(),
1399                            });
1400                        }
1401                    } else {
1402                        unreachable!("Missing result binding on an entry point");
1403                    }
1404
1405                    if self.flags.contains(WriterFlags::FORCE_POINT_SIZE)
1406                        && iface.stage == crate::ShaderStage::Vertex
1407                        && !has_point_size
1408                    {
1409                        // add point size artificially
1410                        let varying_id = self.id_gen.next();
1411                        let pointer_type_id = self.get_f32_pointer_type_id(class);
1412                        Instruction::variable(pointer_type_id, varying_id, class, None)
1413                            .to_words(&mut self.logical_layout.declarations);
1414                        self.decorate(
1415                            varying_id,
1416                            spirv::Decoration::BuiltIn,
1417                            &[spirv::BuiltIn::PointSize as u32],
1418                        );
1419                        iface.varying_ids.push(varying_id);
1420
1421                        let default_value_id = self.get_constant_scalar(crate::Literal::F32(1.0));
1422                        prelude
1423                            .body
1424                            .push(Instruction::store(varying_id, default_value_id, None));
1425                    }
1426                    if iface.stage == crate::ShaderStage::Task {
1427                        self.get_vec3u_type_id()
1428                    } else {
1429                        self.void_type
1430                    }
1431                } else {
1432                    self.get_handle_type_id(result.ty)
1433                }
1434            }
1435            None => self.void_type,
1436        };
1437
1438        if let Some(ref mut iface) = interface {
1439            if let Some(task_payload) = iface.task_payload {
1440                iface
1441                    .varying_ids
1442                    .push(self.global_variables[task_payload].var_id);
1443            }
1444            self.write_entry_point_mesh_shader_info(
1445                iface,
1446                local_invocation_index_var_id,
1447                ir_module,
1448                &mut ep_context,
1449            )?;
1450        }
1451
1452        let lookup_function_type = LookupFunctionType {
1453            parameter_type_ids,
1454            return_type_id,
1455        };
1456
1457        let function_id = self.id_gen.next();
1458        if self.flags.contains(WriterFlags::DEBUG) {
1459            if let Some(ref name) = ir_function.name {
1460                self.debugs.push(Instruction::name(function_id, name));
1461            }
1462        }
1463
1464        let function_type = self.get_function_type(lookup_function_type);
1465        function.signature = Some(Instruction::function(
1466            return_type_id,
1467            function_id,
1468            spirv::FunctionControl::empty(),
1469            function_type,
1470        ));
1471
1472        if interface.is_some() {
1473            function.entry_point_context = Some(ep_context);
1474        }
1475
1476        // fill up the `GlobalVariable::access_id`
1477        for gv in self.global_variables.iter_mut() {
1478            gv.reset_for_function();
1479        }
1480        for (handle, var) in ir_module.global_variables.iter() {
1481            if info[handle].is_empty() {
1482                continue;
1483            }
1484
1485            let mut gv = self.global_variables[handle].clone();
1486            if let Some(ref mut iface) = interface {
1487                // Have to include global variables in the interface
1488                if self.physical_layout.version >= 0x10400 && iface.task_payload != Some(handle) {
1489                    iface.varying_ids.push(gv.var_id);
1490                }
1491            }
1492
1493            match ir_module.types[var.ty].inner {
1494                // Any that are binding arrays we skip as we cannot load the array, we must load the result after indexing.
1495                crate::TypeInner::BindingArray { .. } => {
1496                    gv.access_id = gv.var_id;
1497                }
1498                _ => {
1499                    // Handle globals are pre-emitted and should be loaded automatically.
1500                    if var.space == crate::AddressSpace::Handle {
1501                        let var_type_id = self.get_handle_type_id(var.ty);
1502                        let id = self.id_gen.next();
1503                        prelude
1504                            .body
1505                            .push(Instruction::load(var_type_id, id, gv.var_id, None));
1506                        gv.access_id = gv.var_id;
1507                        gv.handle_id = id;
1508                    } else if global_needs_wrapper(ir_module, var) {
1509                        let class = map_storage_class(var.space);
1510                        let pointer_type_id = match self.std140_compat_uniform_types.get(&var.ty) {
1511                            Some(std140_type_info) if var.space == crate::AddressSpace::Uniform => {
1512                                self.get_pointer_type_id(std140_type_info.type_id, class)
1513                            }
1514                            _ => self.get_handle_pointer_type_id(var.ty, class),
1515                        };
1516                        let index_id = self.get_index_constant(0);
1517                        let id = self.id_gen.next();
1518                        prelude.body.push(Instruction::access_chain(
1519                            pointer_type_id,
1520                            id,
1521                            gv.var_id,
1522                            &[index_id],
1523                        ));
1524                        gv.access_id = id;
1525                    } else {
1526                        // by default, the variable ID is accessed as is
1527                        gv.access_id = gv.var_id;
1528                    };
1529                }
1530            }
1531
1532            // work around borrow checking in the presence of `self.xxx()` calls
1533            self.global_variables[handle] = gv;
1534        }
1535
1536        // Create a `BlockContext` for generating SPIR-V for the function's
1537        // body.
1538        let mut context = BlockContext {
1539            ir_module,
1540            ir_function,
1541            fun_info: info,
1542            function: &mut function,
1543            // Re-use the cached expression table from prior functions.
1544            cached: core::mem::take(&mut self.saved_cached),
1545
1546            // Steal the Writer's temp list for a bit.
1547            temp_list: core::mem::take(&mut self.temp_list),
1548            force_loop_bounding: self.force_loop_bounding,
1549            writer: self,
1550            expression_constness: super::ExpressionConstnessTracker::from_arena(
1551                &ir_function.expressions,
1552            ),
1553            ray_query_tracker_expr: crate::FastHashMap::default(),
1554        };
1555
1556        // fill up the pre-emitted and const expressions
1557        context.cached.reset(ir_function.expressions.len());
1558        for (handle, expr) in ir_function.expressions.iter() {
1559            if (expr.needs_pre_emit() && !matches!(*expr, crate::Expression::LocalVariable(_)))
1560                || context.expression_constness.is_const(handle)
1561            {
1562                context.cache_expression_value(handle, &mut prelude)?;
1563            }
1564        }
1565
1566        for (handle, variable) in ir_function.local_variables.iter() {
1567            let id = context.gen_id();
1568
1569            if context.writer.flags.contains(WriterFlags::DEBUG) {
1570                if let Some(ref name) = variable.name {
1571                    context.writer.debugs.push(Instruction::name(id, name));
1572                }
1573            }
1574
1575            let init_word = variable.init.map(|constant| context.cached[constant]);
1576            let pointer_type_id = context
1577                .writer
1578                .get_handle_pointer_type_id(variable.ty, spirv::StorageClass::Function);
1579            let instruction = Instruction::variable(
1580                pointer_type_id,
1581                id,
1582                spirv::StorageClass::Function,
1583                init_word.or_else(|| match ir_module.types[variable.ty].inner {
1584                    crate::TypeInner::RayQuery { .. } => None,
1585                    _ => {
1586                        let type_id = context.get_handle_type_id(variable.ty);
1587                        Some(context.writer.write_constant_null(type_id))
1588                    }
1589                }),
1590            );
1591
1592            context
1593                .function
1594                .variables
1595                .insert(handle, LocalVariable { id, instruction });
1596
1597            if let crate::TypeInner::RayQuery { .. } = ir_module.types[variable.ty].inner {
1598                // Don't refactor this into a struct: Although spirv itself allows opaque types in structs,
1599                // the vulkan environment for spirv does not. Putting ray queries into structs can cause
1600                // confusing bugs.
1601                let u32_type_id = context.writer.get_u32_type_id();
1602                let ptr_u32_type_id = context
1603                    .writer
1604                    .get_pointer_type_id(u32_type_id, spirv::StorageClass::Function);
1605                let tracker_id = context.gen_id();
1606                let tracker_init_id = context.writer.get_constant_scalar(crate::Literal::U32(
1607                    crate::back::RayQueryPoint::empty().bits(),
1608                ));
1609                let tracker_instruction = Instruction::variable(
1610                    ptr_u32_type_id,
1611                    tracker_id,
1612                    spirv::StorageClass::Function,
1613                    Some(tracker_init_id),
1614                );
1615
1616                context
1617                    .function
1618                    .ray_query_initialization_tracker_variables
1619                    .insert(
1620                        handle,
1621                        LocalVariable {
1622                            id: tracker_id,
1623                            instruction: tracker_instruction,
1624                        },
1625                    );
1626                let f32_type_id = context.writer.get_f32_type_id();
1627                let ptr_f32_type_id = context
1628                    .writer
1629                    .get_pointer_type_id(f32_type_id, spirv::StorageClass::Function);
1630                let t_max_tracker_id = context.gen_id();
1631                let t_max_tracker_init_id =
1632                    context.writer.get_constant_scalar(crate::Literal::F32(0.0));
1633                let t_max_tracker_instruction = Instruction::variable(
1634                    ptr_f32_type_id,
1635                    t_max_tracker_id,
1636                    spirv::StorageClass::Function,
1637                    Some(t_max_tracker_init_id),
1638                );
1639
1640                context.function.ray_query_t_max_tracker_variables.insert(
1641                    handle,
1642                    LocalVariable {
1643                        id: t_max_tracker_id,
1644                        instruction: t_max_tracker_instruction,
1645                    },
1646                );
1647            }
1648        }
1649
1650        for (handle, expr) in ir_function.expressions.iter() {
1651            match *expr {
1652                crate::Expression::LocalVariable(_) => {
1653                    // Cache the `OpVariable` instruction we generated above as
1654                    // the value of this expression.
1655                    context.cache_expression_value(handle, &mut prelude)?;
1656                }
1657                crate::Expression::Access { base, .. }
1658                | crate::Expression::AccessIndex { base, .. } => {
1659                    // Count references to `base` by `Access` and `AccessIndex`
1660                    // instructions. See `access_uses` for details.
1661                    *context.function.access_uses.entry(base).or_insert(0) += 1;
1662                }
1663                _ => {}
1664            }
1665        }
1666
1667        let next_id = context.gen_id();
1668
1669        context
1670            .function
1671            .consume(prelude, Instruction::branch(next_id));
1672
1673        let workgroup_vars_init_exit_block_id =
1674            match (context.writer.zero_initialize_workgroup_memory, interface) {
1675                (
1676                    super::ZeroInitializeWorkgroupMemoryMode::Polyfill,
1677                    Some(
1678                        ref mut interface @ FunctionInterface {
1679                            stage:
1680                                crate::ShaderStage::Compute
1681                                | crate::ShaderStage::Mesh
1682                                | crate::ShaderStage::Task,
1683                            ..
1684                        },
1685                    ),
1686                ) => context.writer.generate_workgroup_vars_init_block(
1687                    next_id,
1688                    ir_module,
1689                    info,
1690                    local_invocation_index_id,
1691                    interface,
1692                    context.function,
1693                ),
1694                _ => None,
1695            };
1696
1697        let main_id = if let Some(exit_id) = workgroup_vars_init_exit_block_id {
1698            exit_id
1699        } else {
1700            next_id
1701        };
1702
1703        context.write_function_body(main_id, debug_info.as_ref())?;
1704
1705        // Consume the `BlockContext`, ending its borrows and letting the
1706        // `Writer` steal back its cached expression table and temp_list.
1707        let BlockContext {
1708            cached, temp_list, ..
1709        } = context;
1710        self.saved_cached = cached;
1711        self.temp_list = temp_list;
1712
1713        function.to_words(&mut self.logical_layout.function_definitions);
1714
1715        if let Some(EntryPointContext {
1716            mesh_state: Some(ref mesh_state),
1717            ..
1718        }) = function.entry_point_context
1719        {
1720            self.write_mesh_shader_wrapper(mesh_state, function_id)
1721        } else if let Some(EntryPointContext {
1722            task_payload_variable_id: Some(tp),
1723            ..
1724        }) = function.entry_point_context
1725        {
1726            self.write_task_shader_wrapper(tp, function_id)
1727        } else {
1728            Ok(function_id)
1729        }
1730    }
1731
1732    fn write_execution_mode(
1733        &mut self,
1734        function_id: Word,
1735        mode: spirv::ExecutionMode,
1736    ) -> Result<(), Error> {
1737        //self.check(mode.required_capabilities())?;
1738        Instruction::execution_mode(function_id, mode, &[])
1739            .to_words(&mut self.logical_layout.execution_modes);
1740        Ok(())
1741    }
1742
1743    // TODO Move to instructions module
1744    fn write_entry_point(
1745        &mut self,
1746        entry_point: &crate::EntryPoint,
1747        info: &FunctionInfo,
1748        ir_module: &crate::Module,
1749        debug_info: &Option<DebugInfoInner>,
1750    ) -> Result<Instruction, Error> {
1751        let mut interface_ids = Vec::new();
1752
1753        let function_id = self.write_function(
1754            &entry_point.function,
1755            info,
1756            ir_module,
1757            Some(FunctionInterface {
1758                varying_ids: &mut interface_ids,
1759                stage: entry_point.stage,
1760                task_payload: entry_point.task_payload,
1761                mesh_info: entry_point.mesh_info.clone(),
1762                workgroup_size: entry_point.workgroup_size,
1763            }),
1764            debug_info,
1765        )?;
1766
1767        let exec_model = match entry_point.stage {
1768            crate::ShaderStage::Vertex => spirv::ExecutionModel::Vertex,
1769            crate::ShaderStage::Fragment => {
1770                self.write_execution_mode(function_id, spirv::ExecutionMode::OriginUpperLeft)?;
1771                match entry_point.early_depth_test {
1772                    Some(crate::EarlyDepthTest::Force) => {
1773                        self.write_execution_mode(
1774                            function_id,
1775                            spirv::ExecutionMode::EarlyFragmentTests,
1776                        )?;
1777                    }
1778                    Some(crate::EarlyDepthTest::Allow { conservative }) => {
1779                        // TODO: Consider emitting EarlyAndLateFragmentTestsAMD here, if available.
1780                        // https://github.khronos.org/SPIRV-Registry/extensions/AMD/SPV_AMD_shader_early_and_late_fragment_tests.html
1781                        // This permits early depth tests even if the shader writes to a storage
1782                        // binding
1783                        match conservative {
1784                            crate::ConservativeDepth::GreaterEqual => self.write_execution_mode(
1785                                function_id,
1786                                spirv::ExecutionMode::DepthGreater,
1787                            )?,
1788                            crate::ConservativeDepth::LessEqual => self.write_execution_mode(
1789                                function_id,
1790                                spirv::ExecutionMode::DepthLess,
1791                            )?,
1792                            crate::ConservativeDepth::Unchanged => self.write_execution_mode(
1793                                function_id,
1794                                spirv::ExecutionMode::DepthUnchanged,
1795                            )?,
1796                        }
1797                    }
1798                    None => {}
1799                }
1800                if let Some(ref result) = entry_point.function.result {
1801                    if contains_builtin(
1802                        result.binding.as_ref(),
1803                        result.ty,
1804                        &ir_module.types,
1805                        crate::BuiltIn::FragDepth,
1806                    ) {
1807                        self.write_execution_mode(
1808                            function_id,
1809                            spirv::ExecutionMode::DepthReplacing,
1810                        )?;
1811                    }
1812                }
1813                spirv::ExecutionModel::Fragment
1814            }
1815            crate::ShaderStage::Compute => {
1816                let execution_mode = spirv::ExecutionMode::LocalSize;
1817                Instruction::execution_mode(
1818                    function_id,
1819                    execution_mode,
1820                    &entry_point.workgroup_size,
1821                )
1822                .to_words(&mut self.logical_layout.execution_modes);
1823                spirv::ExecutionModel::GLCompute
1824            }
1825            crate::ShaderStage::Task => {
1826                let execution_mode = spirv::ExecutionMode::LocalSize;
1827                Instruction::execution_mode(
1828                    function_id,
1829                    execution_mode,
1830                    &entry_point.workgroup_size,
1831                )
1832                .to_words(&mut self.logical_layout.execution_modes);
1833                spirv::ExecutionModel::TaskEXT
1834            }
1835            crate::ShaderStage::Mesh => {
1836                let execution_mode = spirv::ExecutionMode::LocalSize;
1837                Instruction::execution_mode(
1838                    function_id,
1839                    execution_mode,
1840                    &entry_point.workgroup_size,
1841                )
1842                .to_words(&mut self.logical_layout.execution_modes);
1843                let mesh_info = entry_point.mesh_info.as_ref().unwrap();
1844                Instruction::execution_mode(
1845                    function_id,
1846                    match mesh_info.topology {
1847                        crate::MeshOutputTopology::Points => spirv::ExecutionMode::OutputPoints,
1848                        crate::MeshOutputTopology::Lines => spirv::ExecutionMode::OutputLinesEXT,
1849                        crate::MeshOutputTopology::Triangles => {
1850                            spirv::ExecutionMode::OutputTrianglesEXT
1851                        }
1852                    },
1853                    &[],
1854                )
1855                .to_words(&mut self.logical_layout.execution_modes);
1856                Instruction::execution_mode(
1857                    function_id,
1858                    spirv::ExecutionMode::OutputVertices,
1859                    core::slice::from_ref(&mesh_info.max_vertices),
1860                )
1861                .to_words(&mut self.logical_layout.execution_modes);
1862                Instruction::execution_mode(
1863                    function_id,
1864                    spirv::ExecutionMode::OutputPrimitivesEXT,
1865                    core::slice::from_ref(&mesh_info.max_primitives),
1866                )
1867                .to_words(&mut self.logical_layout.execution_modes);
1868                spirv::ExecutionModel::MeshEXT
1869            }
1870            crate::ShaderStage::RayGeneration => {
1871                self.require_any("ray tracing pipelines", &[spirv::Capability::RayTracingKHR])?;
1872                spirv::ExecutionModel::RayGenerationKHR
1873            }
1874            crate::ShaderStage::AnyHit => {
1875                self.require_any("ray tracing pipelines", &[spirv::Capability::RayTracingKHR])?;
1876                spirv::ExecutionModel::AnyHitKHR
1877            }
1878            crate::ShaderStage::ClosestHit => {
1879                self.require_any("ray tracing pipelines", &[spirv::Capability::RayTracingKHR])?;
1880                spirv::ExecutionModel::ClosestHitKHR
1881            }
1882            crate::ShaderStage::Miss => {
1883                self.require_any("ray tracing pipelines", &[spirv::Capability::RayTracingKHR])?;
1884                spirv::ExecutionModel::MissKHR
1885            }
1886        };
1887        //self.check(exec_model.required_capabilities())?;
1888
1889        Ok(Instruction::entry_point(
1890            exec_model,
1891            function_id,
1892            &entry_point.name,
1893            interface_ids.as_slice(),
1894        ))
1895    }
1896
1897    fn make_scalar(&mut self, id: Word, scalar: crate::Scalar) -> Instruction {
1898        use crate::ScalarKind as Sk;
1899
1900        let bits = (scalar.width * BITS_PER_BYTE) as u32;
1901        match scalar.kind {
1902            Sk::Sint | Sk::Uint => {
1903                let signedness = if scalar.kind == Sk::Sint {
1904                    super::instructions::Signedness::Signed
1905                } else {
1906                    super::instructions::Signedness::Unsigned
1907                };
1908                let cap = match bits {
1909                    8 => Some(spirv::Capability::Int8),
1910                    16 => Some(spirv::Capability::Int16),
1911                    64 => Some(spirv::Capability::Int64),
1912                    _ => None,
1913                };
1914                if let Some(cap) = cap {
1915                    self.capabilities_used.insert(cap);
1916                }
1917                if bits == 16 {
1918                    self.capabilities_used
1919                        .insert(spirv::Capability::StorageBuffer16BitAccess);
1920                    self.capabilities_used
1921                        .insert(spirv::Capability::UniformAndStorageBuffer16BitAccess);
1922                    if self.use_storage_input_output_16 {
1923                        self.capabilities_used
1924                            .insert(spirv::Capability::StorageInputOutput16);
1925                    }
1926                }
1927                Instruction::type_int(id, bits, signedness)
1928            }
1929            Sk::Float => {
1930                if bits == 64 {
1931                    self.capabilities_used.insert(spirv::Capability::Float64);
1932                }
1933                if bits == 16 {
1934                    self.capabilities_used.insert(spirv::Capability::Float16);
1935                    self.capabilities_used
1936                        .insert(spirv::Capability::StorageBuffer16BitAccess);
1937                    self.capabilities_used
1938                        .insert(spirv::Capability::UniformAndStorageBuffer16BitAccess);
1939                    if self.use_storage_input_output_16 {
1940                        self.capabilities_used
1941                            .insert(spirv::Capability::StorageInputOutput16);
1942                    }
1943                }
1944                Instruction::type_float(id, bits)
1945            }
1946            Sk::Bool => Instruction::type_bool(id),
1947            Sk::AbstractInt | Sk::AbstractFloat => {
1948                unreachable!("abstract types should never reach the backend");
1949            }
1950        }
1951    }
1952
1953    fn request_type_capabilities(&mut self, inner: &crate::TypeInner) -> Result<(), Error> {
1954        match *inner {
1955            crate::TypeInner::Image {
1956                dim,
1957                arrayed,
1958                class,
1959            } => {
1960                let sampled = match class {
1961                    crate::ImageClass::Sampled { .. } => true,
1962                    crate::ImageClass::Depth { .. } => true,
1963                    crate::ImageClass::Storage { format, .. } => {
1964                        self.request_image_format_capabilities(format.into())?;
1965                        false
1966                    }
1967                    crate::ImageClass::External => unimplemented!(),
1968                };
1969
1970                match dim {
1971                    crate::ImageDimension::D1 => {
1972                        if sampled {
1973                            self.require_any("sampled 1D images", &[spirv::Capability::Sampled1D])?;
1974                        } else {
1975                            self.require_any("1D storage images", &[spirv::Capability::Image1D])?;
1976                        }
1977                    }
1978                    crate::ImageDimension::Cube if arrayed => {
1979                        if sampled {
1980                            self.require_any(
1981                                "sampled cube array images",
1982                                &[spirv::Capability::SampledCubeArray],
1983                            )?;
1984                        } else {
1985                            self.require_any(
1986                                "cube array storage images",
1987                                &[spirv::Capability::ImageCubeArray],
1988                            )?;
1989                        }
1990                    }
1991                    _ => {}
1992                }
1993            }
1994            crate::TypeInner::AccelerationStructure { .. } => {
1995                self.require_any(
1996                    "Acceleration Structure",
1997                    // unless we use this conditional, the ray query snapshot
1998                    // tests pick the wrong capability
1999                    &[if self.has_ray_tracing_pipeline {
2000                        spirv::Capability::RayTracingKHR
2001                    } else {
2002                        spirv::Capability::RayQueryKHR
2003                    }],
2004                )?;
2005            }
2006            crate::TypeInner::RayQuery { .. } => {
2007                self.require_any("Ray Query", &[spirv::Capability::RayQueryKHR])?;
2008            }
2009            crate::TypeInner::Atomic(crate::Scalar { width: 8, kind: _ }) => {
2010                self.require_any("64 bit integer atomics", &[spirv::Capability::Int64Atomics])?;
2011            }
2012            crate::TypeInner::Atomic(crate::Scalar {
2013                width: 4,
2014                kind: crate::ScalarKind::Float,
2015            }) => {
2016                self.require_any(
2017                    "32 bit floating-point atomics",
2018                    &[spirv::Capability::AtomicFloat32AddEXT],
2019                )?;
2020                self.use_extension("SPV_EXT_shader_atomic_float_add");
2021            }
2022            // 16 bit floating-point support requires Float16 capability
2023            crate::TypeInner::Matrix {
2024                scalar: crate::Scalar::F16,
2025                ..
2026            }
2027            | crate::TypeInner::Vector {
2028                scalar: crate::Scalar::F16,
2029                ..
2030            }
2031            | crate::TypeInner::Scalar(crate::Scalar::F16) => {
2032                self.require_any("16 bit floating-point", &[spirv::Capability::Float16])?;
2033                self.use_extension("SPV_KHR_16bit_storage");
2034            }
2035            // 16 bit integer support requires Int16 capability
2036            crate::TypeInner::Vector {
2037                scalar:
2038                    crate::Scalar {
2039                        kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
2040                        width: 2,
2041                    },
2042                ..
2043            }
2044            | crate::TypeInner::Scalar(crate::Scalar {
2045                kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
2046                width: 2,
2047            }) => {
2048                self.require_any("16 bit integer", &[spirv::Capability::Int16])?;
2049                self.use_extension("SPV_KHR_16bit_storage");
2050            }
2051            // Cooperative types and ops
2052            crate::TypeInner::CooperativeMatrix { .. } => {
2053                self.require_any(
2054                    "cooperative matrix",
2055                    &[spirv::Capability::CooperativeMatrixKHR],
2056                )?;
2057                self.require_any("memory model", &[spirv::Capability::VulkanMemoryModel])?;
2058                self.use_extension("SPV_KHR_cooperative_matrix");
2059                self.use_extension("SPV_KHR_vulkan_memory_model");
2060            }
2061            _ => {}
2062        }
2063        Ok(())
2064    }
2065
2066    fn write_numeric_type_declaration_local(&mut self, id: Word, numeric: NumericType) {
2067        let instruction = match numeric {
2068            NumericType::Scalar(scalar) => self.make_scalar(id, scalar),
2069            NumericType::Vector { size, scalar } => {
2070                let scalar_id = self.get_numeric_type_id(NumericType::Scalar(scalar));
2071                Instruction::type_vector(id, scalar_id, size)
2072            }
2073            NumericType::Matrix {
2074                columns,
2075                rows,
2076                scalar,
2077            } => {
2078                let column_id =
2079                    self.get_numeric_type_id(NumericType::Vector { size: rows, scalar });
2080                Instruction::type_matrix(id, column_id, columns)
2081            }
2082        };
2083
2084        instruction.to_words(&mut self.logical_layout.declarations);
2085    }
2086
2087    fn write_cooperative_type_declaration_local(&mut self, id: Word, coop: CooperativeType) {
2088        let instruction = match coop {
2089            CooperativeType::Matrix {
2090                columns,
2091                rows,
2092                scalar,
2093                role,
2094            } => {
2095                let scalar_id =
2096                    self.get_localtype_id(LocalType::Numeric(NumericType::Scalar(scalar)));
2097                let scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32);
2098                let columns_id = self.get_index_constant(columns as u32);
2099                let rows_id = self.get_index_constant(rows as u32);
2100                let role_id =
2101                    self.get_index_constant(spirv::CooperativeMatrixUse::from(role) as u32);
2102                Instruction::type_coop_matrix(id, scalar_id, scope_id, rows_id, columns_id, role_id)
2103            }
2104        };
2105
2106        instruction.to_words(&mut self.logical_layout.declarations);
2107    }
2108
2109    fn write_type_declaration_local(&mut self, id: Word, local_ty: LocalType) {
2110        let instruction = match local_ty {
2111            LocalType::Numeric(numeric) => {
2112                self.write_numeric_type_declaration_local(id, numeric);
2113                return;
2114            }
2115            LocalType::Cooperative(coop) => {
2116                self.write_cooperative_type_declaration_local(id, coop);
2117                return;
2118            }
2119            LocalType::Pointer { base, class } => Instruction::type_pointer(id, class, base),
2120            LocalType::Image(image) => {
2121                let local_type = LocalType::Numeric(NumericType::Scalar(image.sampled_type));
2122                let type_id = self.get_localtype_id(local_type);
2123                Instruction::type_image(id, type_id, image.dim, image.flags, image.image_format)
2124            }
2125            LocalType::Sampler => Instruction::type_sampler(id),
2126            LocalType::SampledImage { image_type_id } => {
2127                Instruction::type_sampled_image(id, image_type_id)
2128            }
2129            LocalType::BindingArray { base, size } => {
2130                let inner_ty = self.get_handle_type_id(base);
2131                let scalar_id = self.get_constant_scalar(crate::Literal::U32(size));
2132                Instruction::type_array(id, inner_ty, scalar_id)
2133            }
2134            LocalType::AccelerationStructure => Instruction::type_acceleration_structure(id),
2135            LocalType::RayQuery => Instruction::type_ray_query(id),
2136        };
2137
2138        instruction.to_words(&mut self.logical_layout.declarations);
2139    }
2140
2141    fn write_type_declaration_arena(
2142        &mut self,
2143        module: &crate::Module,
2144        handle: Handle<crate::Type>,
2145    ) -> Result<Word, Error> {
2146        let ty = &module.types[handle];
2147        // If it's a type that needs SPIR-V capabilities, request them now.
2148        // This needs to happen regardless of the LocalType lookup succeeding,
2149        // because some types which map to the same LocalType have different
2150        // capability requirements. See https://github.com/gfx-rs/wgpu/issues/5569
2151        self.request_type_capabilities(&ty.inner)?;
2152        let id = if let Some(local) = self.localtype_from_inner(&ty.inner) {
2153            // This type can be represented as a `LocalType`, so check if we've
2154            // already written an instruction for it. If not, do so now, with
2155            // `write_type_declaration_local`.
2156            match self.lookup_type.entry(LookupType::Local(local)) {
2157                // We already have an id for this `LocalType`.
2158                Entry::Occupied(e) => *e.get(),
2159
2160                // It's a type we haven't seen before.
2161                Entry::Vacant(e) => {
2162                    let id = self.id_gen.next();
2163                    e.insert(id);
2164
2165                    self.write_type_declaration_local(id, local);
2166
2167                    id
2168                }
2169            }
2170        } else {
2171            use spirv::Decoration;
2172
2173            let id = self.id_gen.next();
2174            let instruction = match ty.inner {
2175                crate::TypeInner::Array { base, size, stride } => {
2176                    self.decorate(id, Decoration::ArrayStride, &[stride]);
2177
2178                    let type_id = self.get_handle_type_id(base);
2179                    match size.resolve(module.to_ctx())? {
2180                        crate::proc::IndexableLength::Known(length) => {
2181                            let length_id = self.get_index_constant(length);
2182                            Instruction::type_array(id, type_id, length_id)
2183                        }
2184                        crate::proc::IndexableLength::Dynamic => {
2185                            Instruction::type_runtime_array(id, type_id)
2186                        }
2187                    }
2188                }
2189                crate::TypeInner::BindingArray { base, size } => {
2190                    let type_id = self.get_handle_type_id(base);
2191                    match size.resolve(module.to_ctx())? {
2192                        crate::proc::IndexableLength::Known(length) => {
2193                            let length_id = self.get_index_constant(length);
2194                            Instruction::type_array(id, type_id, length_id)
2195                        }
2196                        crate::proc::IndexableLength::Dynamic => {
2197                            Instruction::type_runtime_array(id, type_id)
2198                        }
2199                    }
2200                }
2201                crate::TypeInner::Struct {
2202                    ref members,
2203                    span: _,
2204                } => {
2205                    let mut has_runtime_array = false;
2206                    let mut member_ids = Vec::with_capacity(members.len());
2207                    for (index, member) in members.iter().enumerate() {
2208                        let member_ty = &module.types[member.ty];
2209                        match member_ty.inner {
2210                            crate::TypeInner::Array {
2211                                base: _,
2212                                size: crate::ArraySize::Dynamic,
2213                                stride: _,
2214                            } => {
2215                                has_runtime_array = true;
2216                            }
2217                            _ => (),
2218                        }
2219                        self.decorate_struct_member(id, index, member, &module.types)?;
2220                        let member_id = self.get_handle_type_id(member.ty);
2221                        member_ids.push(member_id);
2222                    }
2223                    if has_runtime_array {
2224                        self.decorate(id, Decoration::Block, &[]);
2225                    }
2226                    Instruction::type_struct(id, member_ids.as_slice())
2227                }
2228
2229                // These all have TypeLocal representations, so they should have been
2230                // handled by `write_type_declaration_local` above.
2231                crate::TypeInner::Scalar(_)
2232                | crate::TypeInner::Atomic(_)
2233                | crate::TypeInner::Vector { .. }
2234                | crate::TypeInner::Matrix { .. }
2235                | crate::TypeInner::CooperativeMatrix { .. }
2236                | crate::TypeInner::Pointer { .. }
2237                | crate::TypeInner::ValuePointer { .. }
2238                | crate::TypeInner::Image { .. }
2239                | crate::TypeInner::Sampler { .. }
2240                | crate::TypeInner::AccelerationStructure { .. }
2241                | crate::TypeInner::RayQuery { .. } => unreachable!(),
2242            };
2243
2244            instruction.to_words(&mut self.logical_layout.declarations);
2245            id
2246        };
2247
2248        // Add this handle as a new alias for that type.
2249        self.lookup_type.insert(LookupType::Handle(handle), id);
2250
2251        if self.flags.contains(WriterFlags::DEBUG) {
2252            if let Some(ref name) = ty.name {
2253                self.debugs.push(Instruction::name(id, name));
2254            }
2255        }
2256
2257        Ok(id)
2258    }
2259
2260    /// Writes a std140 layout compatible type declaration for a type. Returns
2261    /// the ID of the declared type, or None if no declaration is required.
2262    ///
2263    /// This should be called for any type for which there exists a
2264    /// [`GlobalVariable`] in the [`Uniform`] address space. If the type already
2265    /// adheres to std140 layout rules it will return without declaring any
2266    /// types. If the type contains another type which requires a std140
2267    /// compatible type declaration, it will recursively call itself.
2268    ///
2269    /// When `handle` refers to a [`TypeInner::Matrix`] with 2 rows, the
2270    /// declared type will be an `OpTypeStruct` containing an `OpVector` for
2271    /// each of the matrix's columns.
2272    ///
2273    /// When `handle` refers to a [`TypeInner::Array`] whose base type is a
2274    /// matrix with 2 rows, this will declare an `OpTypeArray` whose element
2275    /// type is the matrix's corresponding std140 compatible type.
2276    ///
2277    /// When `handle` refers to a [`TypeInner::Struct`] and any of its members
2278    /// require a std140 compatible type declaration, this will declare a new
2279    /// struct with the following rules:
2280    /// * Struct or array members will be declared with their std140 compatible
2281    ///   type declaration, if one is required.
2282    /// * Two-row matrix members will have each of their columns hoisted
2283    ///   directly into the struct as 2-component vector members.
2284    /// * All other members will be declared with their normal type.
2285    ///
2286    /// Note that this means the Naga IR index of a struct member may not match
2287    /// the index in the generated SPIR-V. The mapping can be obtained via
2288    /// `Std140TypeInfo::member_indices`.
2289    ///
2290    /// [`GlobalVariable`]: crate::GlobalVariable
2291    /// [`Uniform`]: crate::AddressSpace::Uniform
2292    /// [`TypeInner::Matrix`]: crate::TypeInner::Matrix
2293    /// [`TypeInner::Array`]: crate::TypeInner::Array
2294    /// [`TypeInner::Struct`]: crate::TypeInner::Struct
2295    fn write_std140_compat_type_declaration(
2296        &mut self,
2297        module: &crate::Module,
2298        handle: Handle<crate::Type>,
2299    ) -> Result<Option<Word>, Error> {
2300        if let Some(std140_type_info) = self.std140_compat_uniform_types.get(&handle) {
2301            return Ok(Some(std140_type_info.type_id));
2302        }
2303
2304        let type_inner = &module.types[handle].inner;
2305        let std140_type_id = match *type_inner {
2306            crate::TypeInner::Matrix {
2307                columns,
2308                rows: rows @ crate::VectorSize::Bi,
2309                scalar,
2310            } => {
2311                let std140_type_id = self.id_gen.next();
2312                let mut member_type_ids: ArrayVec<Word, 4> = ArrayVec::new();
2313                let column_type_id =
2314                    self.get_numeric_type_id(NumericType::Vector { size: rows, scalar });
2315                for column in 0..columns as u32 {
2316                    member_type_ids.push(column_type_id);
2317                    self.annotations.push(Instruction::member_decorate(
2318                        std140_type_id,
2319                        column,
2320                        spirv::Decoration::Offset,
2321                        &[column * rows as u32 * scalar.width as u32],
2322                    ));
2323                    if self.flags.contains(WriterFlags::DEBUG) {
2324                        self.debugs.push(Instruction::member_name(
2325                            std140_type_id,
2326                            column,
2327                            &format!("col{column}"),
2328                        ));
2329                    }
2330                }
2331                Instruction::type_struct(std140_type_id, &member_type_ids)
2332                    .to_words(&mut self.logical_layout.declarations);
2333                self.std140_compat_uniform_types.insert(
2334                    handle,
2335                    Std140CompatTypeInfo {
2336                        type_id: std140_type_id,
2337                        member_indices: Vec::new(),
2338                    },
2339                );
2340                Some(std140_type_id)
2341            }
2342            crate::TypeInner::Array { base, size, stride } => {
2343                match self.write_std140_compat_type_declaration(module, base)? {
2344                    Some(std140_base_type_id) => {
2345                        let std140_type_id = self.id_gen.next();
2346                        self.decorate(std140_type_id, spirv::Decoration::ArrayStride, &[stride]);
2347                        let instruction = match size.resolve(module.to_ctx())? {
2348                            crate::proc::IndexableLength::Known(length) => {
2349                                let length_id = self.get_index_constant(length);
2350                                Instruction::type_array(
2351                                    std140_type_id,
2352                                    std140_base_type_id,
2353                                    length_id,
2354                                )
2355                            }
2356                            crate::proc::IndexableLength::Dynamic => {
2357                                unreachable!()
2358                            }
2359                        };
2360                        instruction.to_words(&mut self.logical_layout.declarations);
2361                        self.std140_compat_uniform_types.insert(
2362                            handle,
2363                            Std140CompatTypeInfo {
2364                                type_id: std140_type_id,
2365                                member_indices: Vec::new(),
2366                            },
2367                        );
2368                        Some(std140_type_id)
2369                    }
2370                    None => None,
2371                }
2372            }
2373            crate::TypeInner::Struct { ref members, .. } => {
2374                let mut needs_std140_type = false;
2375                for member in members {
2376                    match module.types[member.ty].inner {
2377                        // We don't need to write a std140 type for the matrix itself as
2378                        // it will be decomposed into the parent struct. As a result, the
2379                        // struct does need a std140 type, however.
2380                        crate::TypeInner::Matrix {
2381                            rows: crate::VectorSize::Bi,
2382                            ..
2383                        } => needs_std140_type = true,
2384                        // If an array member needs a std140 type, because it is an array
2385                        // (of an array, etc) of `matCx2`s, then the struct also needs
2386                        // a std140 type which uses the std140 type for this member.
2387                        crate::TypeInner::Array { .. }
2388                            if self
2389                                .write_std140_compat_type_declaration(module, member.ty)?
2390                                .is_some() =>
2391                        {
2392                            needs_std140_type = true;
2393                        }
2394                        _ => {}
2395                    }
2396                }
2397
2398                if needs_std140_type {
2399                    let std140_type_id = self.id_gen.next();
2400                    let mut member_ids = Vec::new();
2401                    let mut member_indices = Vec::new();
2402                    let mut next_index = 0;
2403
2404                    for member in members {
2405                        member_indices.push(next_index);
2406                        match module.types[member.ty].inner {
2407                            crate::TypeInner::Matrix {
2408                                columns,
2409                                rows: rows @ crate::VectorSize::Bi,
2410                                scalar,
2411                            } => {
2412                                let vector_type_id =
2413                                    self.get_numeric_type_id(NumericType::Vector {
2414                                        size: rows,
2415                                        scalar,
2416                                    });
2417                                for column in 0..columns as u32 {
2418                                    self.annotations.push(Instruction::member_decorate(
2419                                        std140_type_id,
2420                                        next_index,
2421                                        spirv::Decoration::Offset,
2422                                        &[member.offset
2423                                            + column * rows as u32 * scalar.width as u32],
2424                                    ));
2425                                    if self.flags.contains(WriterFlags::DEBUG) {
2426                                        if let Some(ref name) = member.name {
2427                                            self.debugs.push(Instruction::member_name(
2428                                                std140_type_id,
2429                                                next_index,
2430                                                &format!("{name}_col{column}"),
2431                                            ));
2432                                        }
2433                                    }
2434                                    member_ids.push(vector_type_id);
2435                                    next_index += 1;
2436                                }
2437                            }
2438                            _ => {
2439                                let member_id =
2440                                    match self.std140_compat_uniform_types.get(&member.ty) {
2441                                        Some(std140_member_type_info) => {
2442                                            self.annotations.push(Instruction::member_decorate(
2443                                                std140_type_id,
2444                                                next_index,
2445                                                spirv::Decoration::Offset,
2446                                                &[member.offset],
2447                                            ));
2448                                            if self.flags.contains(WriterFlags::DEBUG) {
2449                                                if let Some(ref name) = member.name {
2450                                                    self.debugs.push(Instruction::member_name(
2451                                                        std140_type_id,
2452                                                        next_index,
2453                                                        name,
2454                                                    ));
2455                                                }
2456                                            }
2457                                            std140_member_type_info.type_id
2458                                        }
2459                                        None => {
2460                                            self.decorate_struct_member(
2461                                                std140_type_id,
2462                                                next_index as usize,
2463                                                member,
2464                                                &module.types,
2465                                            )?;
2466                                            self.get_handle_type_id(member.ty)
2467                                        }
2468                                    };
2469                                member_ids.push(member_id);
2470                                next_index += 1;
2471                            }
2472                        }
2473                    }
2474
2475                    Instruction::type_struct(std140_type_id, &member_ids)
2476                        .to_words(&mut self.logical_layout.declarations);
2477                    self.std140_compat_uniform_types.insert(
2478                        handle,
2479                        Std140CompatTypeInfo {
2480                            type_id: std140_type_id,
2481                            member_indices,
2482                        },
2483                    );
2484                    Some(std140_type_id)
2485                } else {
2486                    None
2487                }
2488            }
2489            _ => None,
2490        };
2491
2492        if let Some(std140_type_id) = std140_type_id {
2493            if self.flags.contains(WriterFlags::DEBUG) {
2494                let name = format!("std140_{:?}", handle.for_debug(&module.types));
2495                self.debugs.push(Instruction::name(std140_type_id, &name));
2496            }
2497        }
2498        Ok(std140_type_id)
2499    }
2500
2501    fn request_image_format_capabilities(
2502        &mut self,
2503        format: spirv::ImageFormat,
2504    ) -> Result<(), Error> {
2505        use spirv::ImageFormat as If;
2506        match format {
2507            If::Rg32f
2508            | If::Rg16f
2509            | If::R11fG11fB10f
2510            | If::R16f
2511            | If::Rgba16
2512            | If::Rgb10A2
2513            | If::Rg16
2514            | If::Rg8
2515            | If::R16
2516            | If::R8
2517            | If::Rgba16Snorm
2518            | If::Rg16Snorm
2519            | If::Rg8Snorm
2520            | If::R16Snorm
2521            | If::R8Snorm
2522            | If::Rg32i
2523            | If::Rg16i
2524            | If::Rg8i
2525            | If::R16i
2526            | If::R8i
2527            | If::Rgb10a2ui
2528            | If::Rg32ui
2529            | If::Rg16ui
2530            | If::Rg8ui
2531            | If::R16ui
2532            | If::R8ui => self.require_any(
2533                "storage image format",
2534                &[spirv::Capability::StorageImageExtendedFormats],
2535            ),
2536            If::R64ui | If::R64i => {
2537                self.use_extension("SPV_EXT_shader_image_int64");
2538                self.require_any(
2539                    "64-bit integer storage image format",
2540                    &[spirv::Capability::Int64ImageEXT],
2541                )
2542            }
2543            If::Unknown
2544            | If::Rgba32f
2545            | If::Rgba16f
2546            | If::R32f
2547            | If::Rgba8
2548            | If::Rgba8Snorm
2549            | If::Rgba32i
2550            | If::Rgba16i
2551            | If::Rgba8i
2552            | If::R32i
2553            | If::Rgba32ui
2554            | If::Rgba16ui
2555            | If::Rgba8ui
2556            | If::R32ui => Ok(()),
2557        }
2558    }
2559
2560    pub(super) fn get_index_constant(&mut self, index: Word) -> Word {
2561        self.get_constant_scalar(crate::Literal::U32(index))
2562    }
2563
2564    pub(super) fn get_constant_scalar_with(
2565        &mut self,
2566        value: u8,
2567        scalar: crate::Scalar,
2568    ) -> Result<Word, Error> {
2569        Ok(
2570            self.get_constant_scalar(crate::Literal::new(value, scalar).ok_or(
2571                Error::Validation("Unexpected kind and/or width for Literal"),
2572            )?),
2573        )
2574    }
2575
2576    pub(super) fn get_constant_scalar(&mut self, value: crate::Literal) -> Word {
2577        let scalar = CachedConstant::Literal(value.into());
2578        if let Some(&id) = self.cached_constants.get(&scalar) {
2579            return id;
2580        }
2581        let id = self.id_gen.next();
2582        self.write_constant_scalar(id, &value, None);
2583        self.cached_constants.insert(scalar, id);
2584        id
2585    }
2586
2587    fn write_constant_scalar(
2588        &mut self,
2589        id: Word,
2590        value: &crate::Literal,
2591        debug_name: Option<&String>,
2592    ) {
2593        if self.flags.contains(WriterFlags::DEBUG) {
2594            if let Some(name) = debug_name {
2595                self.debugs.push(Instruction::name(id, name));
2596            }
2597        }
2598        let type_id = self.get_numeric_type_id(NumericType::Scalar(value.scalar()));
2599        let instruction = match *value {
2600            crate::Literal::F64(value) => {
2601                let bits = value.to_bits();
2602                Instruction::constant_64bit(type_id, id, bits as u32, (bits >> 32) as u32)
2603            }
2604            crate::Literal::F32(value) => Instruction::constant_32bit(type_id, id, value.to_bits()),
2605            crate::Literal::F16(value) => {
2606                let low = value.to_bits();
2607                Instruction::constant_16bit(type_id, id, low as u32)
2608            }
2609            crate::Literal::U16(value) => Instruction::constant_16bit(type_id, id, value as u32),
2610            crate::Literal::I16(value) => {
2611                // Sign-extend into the 32-bit word so that `spirv-as` can
2612                // round-trip the disassembly (it expects signed values for
2613                // signed types).
2614                Instruction::constant_16bit(type_id, id, value as i32 as u32)
2615            }
2616            crate::Literal::U32(value) => Instruction::constant_32bit(type_id, id, value),
2617            crate::Literal::I32(value) => Instruction::constant_32bit(type_id, id, value as u32),
2618            crate::Literal::U64(value) => {
2619                Instruction::constant_64bit(type_id, id, value as u32, (value >> 32) as u32)
2620            }
2621            crate::Literal::I64(value) => {
2622                Instruction::constant_64bit(type_id, id, value as u32, (value >> 32) as u32)
2623            }
2624            crate::Literal::Bool(true) => Instruction::constant_true(type_id, id),
2625            crate::Literal::Bool(false) => Instruction::constant_false(type_id, id),
2626            crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
2627                unreachable!("Abstract types should not appear in IR presented to backends");
2628            }
2629        };
2630
2631        instruction.to_words(&mut self.logical_layout.declarations);
2632    }
2633
2634    pub(super) fn get_constant_composite(
2635        &mut self,
2636        ty: LookupType,
2637        constituent_ids: &[Word],
2638    ) -> Word {
2639        let composite = CachedConstant::Composite {
2640            ty,
2641            constituent_ids: constituent_ids.to_vec(),
2642        };
2643        if let Some(&id) = self.cached_constants.get(&composite) {
2644            return id;
2645        }
2646        let id = self.id_gen.next();
2647        self.write_constant_composite(id, ty, constituent_ids, None);
2648        self.cached_constants.insert(composite, id);
2649        id
2650    }
2651
2652    fn write_constant_composite(
2653        &mut self,
2654        id: Word,
2655        ty: LookupType,
2656        constituent_ids: &[Word],
2657        debug_name: Option<&String>,
2658    ) {
2659        if self.flags.contains(WriterFlags::DEBUG) {
2660            if let Some(name) = debug_name {
2661                self.debugs.push(Instruction::name(id, name));
2662            }
2663        }
2664        let type_id = self.get_type_id(ty);
2665        Instruction::constant_composite(type_id, id, constituent_ids)
2666            .to_words(&mut self.logical_layout.declarations);
2667    }
2668
2669    pub(super) fn get_constant_null(&mut self, type_id: Word) -> Word {
2670        let null = CachedConstant::ZeroValue(type_id);
2671        if let Some(&id) = self.cached_constants.get(&null) {
2672            return id;
2673        }
2674        let id = self.write_constant_null(type_id);
2675        self.cached_constants.insert(null, id);
2676        id
2677    }
2678
2679    pub(super) fn write_constant_null(&mut self, type_id: Word) -> Word {
2680        let null_id = self.id_gen.next();
2681        Instruction::constant_null(type_id, null_id)
2682            .to_words(&mut self.logical_layout.declarations);
2683        null_id
2684    }
2685
2686    fn write_constant_expr(
2687        &mut self,
2688        handle: Handle<crate::Expression>,
2689        ir_module: &crate::Module,
2690        mod_info: &ModuleInfo,
2691    ) -> Result<Word, Error> {
2692        let id = match ir_module.global_expressions[handle] {
2693            crate::Expression::Literal(literal) => self.get_constant_scalar(literal),
2694            crate::Expression::Constant(constant) => {
2695                let constant = &ir_module.constants[constant];
2696                self.constant_ids[constant.init]
2697            }
2698            crate::Expression::ZeroValue(ty) => {
2699                let type_id = self.get_handle_type_id(ty);
2700                self.get_constant_null(type_id)
2701            }
2702            crate::Expression::Compose { ty, ref components } => {
2703                let component_ids: Vec<_> = crate::proc::flatten_compose(
2704                    ty,
2705                    components,
2706                    &ir_module.global_expressions,
2707                    &ir_module.types,
2708                )
2709                .map(|component| self.constant_ids[component])
2710                .collect();
2711                self.get_constant_composite(LookupType::Handle(ty), component_ids.as_slice())
2712            }
2713            crate::Expression::Splat { size, value } => {
2714                let value_id = self.constant_ids[value];
2715                let component_ids = &[value_id; 4][..size as usize];
2716
2717                let ty = self.get_expression_lookup_type(&mod_info[handle]);
2718
2719                self.get_constant_composite(ty, component_ids)
2720            }
2721            _ => {
2722                return Err(Error::Override);
2723            }
2724        };
2725
2726        self.constant_ids[handle] = id;
2727
2728        Ok(id)
2729    }
2730
2731    pub(super) fn write_control_barrier(
2732        &mut self,
2733        flags: crate::Barrier,
2734        body: &mut Vec<Instruction>,
2735    ) {
2736        let memory_scope = if flags.contains(crate::Barrier::STORAGE) {
2737            spirv::Scope::Device
2738        } else if flags.contains(crate::Barrier::SUB_GROUP) {
2739            spirv::Scope::Subgroup
2740        } else {
2741            spirv::Scope::Workgroup
2742        };
2743        let mut semantics = spirv::MemorySemantics::ACQUIRE_RELEASE;
2744        semantics.set(
2745            spirv::MemorySemantics::UNIFORM_MEMORY,
2746            flags.contains(crate::Barrier::STORAGE),
2747        );
2748        semantics.set(
2749            spirv::MemorySemantics::WORKGROUP_MEMORY,
2750            flags.contains(crate::Barrier::WORK_GROUP),
2751        );
2752        semantics.set(
2753            spirv::MemorySemantics::SUBGROUP_MEMORY,
2754            flags.contains(crate::Barrier::SUB_GROUP),
2755        );
2756        semantics.set(
2757            spirv::MemorySemantics::IMAGE_MEMORY,
2758            flags.contains(crate::Barrier::TEXTURE),
2759        );
2760        let exec_scope_id = if flags.contains(crate::Barrier::SUB_GROUP) {
2761            self.get_index_constant(spirv::Scope::Subgroup as u32)
2762        } else {
2763            self.get_index_constant(spirv::Scope::Workgroup as u32)
2764        };
2765        let mem_scope_id = self.get_index_constant(memory_scope as u32);
2766        let semantics_id = self.get_index_constant(semantics.bits());
2767        body.push(Instruction::control_barrier(
2768            exec_scope_id,
2769            mem_scope_id,
2770            semantics_id,
2771        ));
2772    }
2773
2774    pub(super) fn write_memory_barrier(&mut self, flags: crate::Barrier, block: &mut Block) {
2775        let mut semantics = spirv::MemorySemantics::ACQUIRE_RELEASE;
2776        semantics.set(
2777            spirv::MemorySemantics::UNIFORM_MEMORY,
2778            flags.contains(crate::Barrier::STORAGE),
2779        );
2780        semantics.set(
2781            spirv::MemorySemantics::WORKGROUP_MEMORY,
2782            flags.contains(crate::Barrier::WORK_GROUP),
2783        );
2784        semantics.set(
2785            spirv::MemorySemantics::SUBGROUP_MEMORY,
2786            flags.contains(crate::Barrier::SUB_GROUP),
2787        );
2788        semantics.set(
2789            spirv::MemorySemantics::IMAGE_MEMORY,
2790            flags.contains(crate::Barrier::TEXTURE),
2791        );
2792        let mem_scope_id = if flags.contains(crate::Barrier::STORAGE) {
2793            self.get_index_constant(spirv::Scope::Device as u32)
2794        } else if flags.contains(crate::Barrier::SUB_GROUP) {
2795            self.get_index_constant(spirv::Scope::Subgroup as u32)
2796        } else {
2797            self.get_index_constant(spirv::Scope::Workgroup as u32)
2798        };
2799        let semantics_id = self.get_index_constant(semantics.bits());
2800        block
2801            .body
2802            .push(Instruction::memory_barrier(mem_scope_id, semantics_id));
2803    }
2804
2805    fn generate_workgroup_vars_init_block(
2806        &mut self,
2807        entry_id: Word,
2808        ir_module: &crate::Module,
2809        info: &FunctionInfo,
2810        local_invocation_index: Option<Word>,
2811        interface: &mut FunctionInterface,
2812        function: &mut Function,
2813    ) -> Option<Word> {
2814        let body = ir_module
2815            .global_variables
2816            .iter()
2817            .filter(|&(handle, var)| {
2818                let task_exception = (var.space == crate::AddressSpace::TaskPayload)
2819                    && interface.stage == crate::ShaderStage::Task;
2820                !info[handle].is_empty()
2821                    && (var.space == crate::AddressSpace::WorkGroup || task_exception)
2822            })
2823            .map(|(handle, var)| {
2824                // It's safe to use `var_id` here, not `access_id`, because only
2825                // variables in the `Uniform` and `StorageBuffer` address spaces
2826                // get wrapped, and we're initializing `WorkGroup` variables.
2827                let var_id = self.global_variables[handle].var_id;
2828                let var_type_id = self.get_handle_type_id(var.ty);
2829                let init_word = self.get_constant_null(var_type_id);
2830                Instruction::store(var_id, init_word, None)
2831            })
2832            .collect::<Vec<_>>();
2833
2834        if body.is_empty() {
2835            return None;
2836        }
2837
2838        let mut pre_if_block = Block::new(entry_id);
2839
2840        let local_invocation_index = if let Some(local_invocation_index) = local_invocation_index {
2841            local_invocation_index
2842        } else {
2843            let varying_id = self.id_gen.next();
2844            let class = spirv::StorageClass::Input;
2845            let u32_ty_id = self.get_u32_type_id();
2846            let pointer_type_id = self.get_pointer_type_id(u32_ty_id, class);
2847
2848            Instruction::variable(pointer_type_id, varying_id, class, None)
2849                .to_words(&mut self.logical_layout.declarations);
2850
2851            self.decorate(
2852                varying_id,
2853                spirv::Decoration::BuiltIn,
2854                &[spirv::BuiltIn::LocalInvocationIndex as u32],
2855            );
2856
2857            interface.varying_ids.push(varying_id);
2858            let id = self.id_gen.next();
2859            pre_if_block
2860                .body
2861                .push(Instruction::load(u32_ty_id, id, varying_id, None));
2862
2863            id
2864        };
2865
2866        let zero_id = self.get_constant_scalar(crate::Literal::U32(0));
2867
2868        let eq_id = self.id_gen.next();
2869        pre_if_block.body.push(Instruction::binary(
2870            spirv::Op::IEqual,
2871            self.get_bool_type_id(),
2872            eq_id,
2873            local_invocation_index,
2874            zero_id,
2875        ));
2876
2877        let merge_id = self.id_gen.next();
2878        pre_if_block.body.push(Instruction::selection_merge(
2879            merge_id,
2880            spirv::SelectionControl::NONE,
2881        ));
2882
2883        let accept_id = self.id_gen.next();
2884        function.consume(
2885            pre_if_block,
2886            Instruction::branch_conditional(eq_id, accept_id, merge_id),
2887        );
2888
2889        let accept_block = Block {
2890            label_id: accept_id,
2891            body,
2892        };
2893        function.consume(accept_block, Instruction::branch(merge_id));
2894
2895        let mut post_if_block = Block::new(merge_id);
2896
2897        self.write_control_barrier(crate::Barrier::WORK_GROUP, &mut post_if_block.body);
2898
2899        let next_id = self.id_gen.next();
2900        function.consume(post_if_block, Instruction::branch(next_id));
2901        Some(next_id)
2902    }
2903
2904    /// Generate an `OpVariable` for one value in an [`EntryPoint`]'s IO interface.
2905    ///
2906    /// The [`Binding`]s of the arguments and result of an [`EntryPoint`]'s
2907    /// [`Function`] describe a SPIR-V shader interface. In SPIR-V, the
2908    /// interface is represented by global variables in the `Input` and `Output`
2909    /// storage classes, with decorations indicating which builtin or location
2910    /// each variable corresponds to.
2911    ///
2912    /// This function emits a single global `OpVariable` for a single value from
2913    /// the interface, and adds appropriate decorations to indicate which
2914    /// builtin or location it represents, how it should be interpolated, and so
2915    /// on. The `class` argument gives the variable's SPIR-V storage class,
2916    /// which should be either [`Input`] or [`Output`].
2917    ///
2918    /// [`Binding`]: crate::Binding
2919    /// [`Function`]: crate::Function
2920    /// [`EntryPoint`]: crate::EntryPoint
2921    /// [`Input`]: spirv::StorageClass::Input
2922    /// [`Output`]: spirv::StorageClass::Output
2923    fn write_varying(
2924        &mut self,
2925        ir_module: &crate::Module,
2926        stage: crate::ShaderStage,
2927        class: spirv::StorageClass,
2928        debug_name: Option<&str>,
2929        ty: Handle<crate::Type>,
2930        binding: &crate::Binding,
2931    ) -> Result<Word, Error> {
2932        let id = self.id_gen.next();
2933        let ty_inner = &ir_module.types[ty].inner;
2934        let needs_polyfill = self.needs_f16_polyfill(ty_inner);
2935
2936        let pointer_type_id = if needs_polyfill {
2937            let f32_value_local =
2938                super::f16_polyfill::F16IoPolyfill::create_polyfill_type(ty_inner)
2939                    .expect("needs_polyfill returned true but create_polyfill_type returned None");
2940
2941            let f32_type_id = self.get_localtype_id(f32_value_local);
2942            let ptr_id = self.get_pointer_type_id(f32_type_id, class);
2943            self.io_f16_polyfills.register_io_var(id, f32_type_id);
2944
2945            ptr_id
2946        } else {
2947            self.get_handle_pointer_type_id(ty, class)
2948        };
2949
2950        Instruction::variable(pointer_type_id, id, class, None)
2951            .to_words(&mut self.logical_layout.declarations);
2952
2953        if self
2954            .flags
2955            .contains(WriterFlags::DEBUG | WriterFlags::LABEL_VARYINGS)
2956        {
2957            if let Some(name) = debug_name {
2958                self.debugs.push(Instruction::name(id, name));
2959            }
2960        }
2961
2962        let binding = self.map_binding(ir_module, stage, class, ty, binding)?;
2963        self.write_binding(id, binding);
2964
2965        Ok(id)
2966    }
2967
2968    pub fn write_binding(&mut self, id: Word, binding: BindingDecorations) {
2969        match binding {
2970            BindingDecorations::None => (),
2971            BindingDecorations::BuiltIn(bi, others) => {
2972                self.decorate(id, spirv::Decoration::BuiltIn, &[bi as u32]);
2973                for other in others {
2974                    self.decorate(id, other, &[]);
2975                }
2976            }
2977            BindingDecorations::Location {
2978                location,
2979                others,
2980                blend_src,
2981            } => {
2982                self.decorate(id, spirv::Decoration::Location, &[location]);
2983                for other in others {
2984                    self.decorate(id, other, &[]);
2985                }
2986                if let Some(blend_src) = blend_src {
2987                    self.decorate(id, spirv::Decoration::Index, &[blend_src]);
2988                }
2989            }
2990        }
2991    }
2992
2993    pub fn write_binding_struct_member(
2994        &mut self,
2995        struct_id: Word,
2996        member_idx: Word,
2997        binding_info: BindingDecorations,
2998    ) {
2999        match binding_info {
3000            BindingDecorations::None => (),
3001            BindingDecorations::BuiltIn(bi, others) => {
3002                self.annotations.push(Instruction::member_decorate(
3003                    struct_id,
3004                    member_idx,
3005                    spirv::Decoration::BuiltIn,
3006                    &[bi as Word],
3007                ));
3008                for other in others {
3009                    self.annotations.push(Instruction::member_decorate(
3010                        struct_id,
3011                        member_idx,
3012                        other,
3013                        &[],
3014                    ));
3015                }
3016            }
3017            BindingDecorations::Location {
3018                location,
3019                others,
3020                blend_src,
3021            } => {
3022                self.annotations.push(Instruction::member_decorate(
3023                    struct_id,
3024                    member_idx,
3025                    spirv::Decoration::Location,
3026                    &[location],
3027                ));
3028                for other in others {
3029                    self.annotations.push(Instruction::member_decorate(
3030                        struct_id,
3031                        member_idx,
3032                        other,
3033                        &[],
3034                    ));
3035                }
3036                if let Some(blend_src) = blend_src {
3037                    self.annotations.push(Instruction::member_decorate(
3038                        struct_id,
3039                        member_idx,
3040                        spirv::Decoration::Index,
3041                        &[blend_src],
3042                    ));
3043                }
3044            }
3045        }
3046    }
3047
3048    pub fn map_binding(
3049        &mut self,
3050        ir_module: &crate::Module,
3051        stage: crate::ShaderStage,
3052        class: spirv::StorageClass,
3053        ty: Handle<crate::Type>,
3054        binding: &crate::Binding,
3055    ) -> Result<BindingDecorations, Error> {
3056        use spirv::BuiltIn;
3057        use spirv::Decoration;
3058        match *binding {
3059            crate::Binding::Location {
3060                location,
3061                interpolation,
3062                sampling,
3063                blend_src,
3064                per_primitive,
3065            } => {
3066                let mut others = ArrayVec::new();
3067
3068                let no_decorations =
3069                    // VUID-StandaloneSpirv-Flat-06202
3070                    // > The Flat, NoPerspective, Sample, and Centroid decorations
3071                    // > must not be used on variables with the Input storage class in a vertex shader
3072                    (class == spirv::StorageClass::Input && stage == crate::ShaderStage::Vertex) ||
3073                    // VUID-StandaloneSpirv-Flat-06201
3074                    // > The Flat, NoPerspective, Sample, and Centroid decorations
3075                    // > must not be used on variables with the Output storage class in a fragment shader
3076                    (class == spirv::StorageClass::Output && stage == crate::ShaderStage::Fragment);
3077
3078                if !no_decorations {
3079                    match interpolation {
3080                        // Perspective-correct interpolation is the default in SPIR-V.
3081                        None | Some(crate::Interpolation::Perspective) => (),
3082                        Some(crate::Interpolation::Flat) => {
3083                            others.push(Decoration::Flat);
3084                        }
3085                        Some(crate::Interpolation::Linear) => {
3086                            others.push(Decoration::NoPerspective);
3087                        }
3088                        Some(crate::Interpolation::PerVertex) => {
3089                            others.push(Decoration::PerVertexKHR);
3090                            self.require_any(
3091                                "`per_vertex` interpolation",
3092                                &[spirv::Capability::FragmentBarycentricKHR],
3093                            )?;
3094                            self.use_extension("SPV_KHR_fragment_shader_barycentric");
3095                        }
3096                    }
3097                    match sampling {
3098                        // Center sampling is the default in SPIR-V.
3099                        None
3100                        | Some(
3101                            crate::Sampling::Center
3102                            | crate::Sampling::First
3103                            | crate::Sampling::Either,
3104                        ) => (),
3105                        Some(crate::Sampling::Centroid) => {
3106                            others.push(Decoration::Centroid);
3107                        }
3108                        Some(crate::Sampling::Sample) => {
3109                            self.require_any(
3110                                "per-sample interpolation",
3111                                &[spirv::Capability::SampleRateShading],
3112                            )?;
3113                            others.push(Decoration::Sample);
3114                        }
3115                    }
3116                }
3117                if per_primitive && stage == crate::ShaderStage::Fragment {
3118                    others.push(Decoration::PerPrimitiveEXT);
3119                }
3120                Ok(BindingDecorations::Location {
3121                    location,
3122                    others,
3123                    blend_src,
3124                })
3125            }
3126            crate::Binding::BuiltIn(built_in) => {
3127                use crate::BuiltIn as Bi;
3128                let mut others = ArrayVec::new();
3129
3130                let built_in = match built_in {
3131                    Bi::Position { invariant } => {
3132                        if invariant {
3133                            others.push(Decoration::Invariant);
3134                        }
3135
3136                        if class == spirv::StorageClass::Output {
3137                            BuiltIn::Position
3138                        } else {
3139                            BuiltIn::FragCoord
3140                        }
3141                    }
3142                    Bi::ViewIndex => {
3143                        self.require_any("`view_index` built-in", &[spirv::Capability::MultiView])?;
3144                        BuiltIn::ViewIndex
3145                    }
3146                    // vertex
3147                    Bi::BaseInstance => BuiltIn::BaseInstance,
3148                    Bi::BaseVertex => BuiltIn::BaseVertex,
3149                    Bi::ClipDistances => {
3150                        self.require_any(
3151                            "`clip_distances` built-in",
3152                            &[spirv::Capability::ClipDistance],
3153                        )?;
3154                        BuiltIn::ClipDistance
3155                    }
3156                    Bi::CullDistance => {
3157                        self.require_any(
3158                            "`cull_distance` built-in",
3159                            &[spirv::Capability::CullDistance],
3160                        )?;
3161                        BuiltIn::CullDistance
3162                    }
3163                    Bi::InstanceIndex => BuiltIn::InstanceIndex,
3164                    Bi::PointSize => BuiltIn::PointSize,
3165                    Bi::VertexIndex => BuiltIn::VertexIndex,
3166                    Bi::DrawIndex => {
3167                        self.use_extension("SPV_KHR_shader_draw_parameters");
3168                        self.require_any(
3169                            "`draw_index built-in",
3170                            &[spirv::Capability::DrawParameters],
3171                        )?;
3172                        BuiltIn::DrawIndex
3173                    }
3174                    // fragment
3175                    Bi::FragDepth => BuiltIn::FragDepth,
3176                    Bi::PointCoord => BuiltIn::PointCoord,
3177                    Bi::FrontFacing => BuiltIn::FrontFacing,
3178                    Bi::PrimitiveIndex => {
3179                        // Geometry shader capability is required for primitive index
3180                        self.require_any(
3181                            "`primitive_index` built-in",
3182                            &[spirv::Capability::Geometry],
3183                        )?;
3184                        if stage == crate::ShaderStage::Mesh {
3185                            others.push(Decoration::PerPrimitiveEXT);
3186                        }
3187                        BuiltIn::PrimitiveId
3188                    }
3189                    Bi::Barycentric { perspective } => {
3190                        self.require_any(
3191                            "`barycentric` built-in",
3192                            &[spirv::Capability::FragmentBarycentricKHR],
3193                        )?;
3194                        self.use_extension("SPV_KHR_fragment_shader_barycentric");
3195                        if perspective {
3196                            BuiltIn::BaryCoordKHR
3197                        } else {
3198                            BuiltIn::BaryCoordNoPerspKHR
3199                        }
3200                    }
3201                    Bi::SampleIndex => {
3202                        self.require_any(
3203                            "`sample_index` built-in",
3204                            &[spirv::Capability::SampleRateShading],
3205                        )?;
3206
3207                        BuiltIn::SampleId
3208                    }
3209                    Bi::SampleMask => BuiltIn::SampleMask,
3210                    // compute
3211                    Bi::GlobalInvocationId => BuiltIn::GlobalInvocationId,
3212                    Bi::LocalInvocationId => BuiltIn::LocalInvocationId,
3213                    Bi::LocalInvocationIndex => BuiltIn::LocalInvocationIndex,
3214                    Bi::WorkGroupId => BuiltIn::WorkgroupId,
3215                    Bi::WorkGroupSize => BuiltIn::WorkgroupSize,
3216                    Bi::NumWorkGroups => BuiltIn::NumWorkgroups,
3217                    // Subgroup
3218                    Bi::NumSubgroups => {
3219                        self.require_any(
3220                            "`num_subgroups` built-in",
3221                            &[spirv::Capability::GroupNonUniform],
3222                        )?;
3223                        BuiltIn::NumSubgroups
3224                    }
3225                    Bi::SubgroupId => {
3226                        self.require_any(
3227                            "`subgroup_id` built-in",
3228                            &[spirv::Capability::GroupNonUniform],
3229                        )?;
3230                        BuiltIn::SubgroupId
3231                    }
3232                    Bi::SubgroupSize => {
3233                        self.require_any(
3234                            "`subgroup_size` built-in",
3235                            &[
3236                                spirv::Capability::GroupNonUniform,
3237                                spirv::Capability::SubgroupBallotKHR,
3238                            ],
3239                        )?;
3240                        BuiltIn::SubgroupSize
3241                    }
3242                    Bi::SubgroupInvocationId => {
3243                        self.require_any(
3244                            "`subgroup_invocation_id` built-in",
3245                            &[
3246                                spirv::Capability::GroupNonUniform,
3247                                spirv::Capability::SubgroupBallotKHR,
3248                            ],
3249                        )?;
3250                        BuiltIn::SubgroupLocalInvocationId
3251                    }
3252                    Bi::CullPrimitive => {
3253                        others.push(Decoration::PerPrimitiveEXT);
3254                        BuiltIn::CullPrimitiveEXT
3255                    }
3256                    Bi::PointIndex => BuiltIn::PrimitivePointIndicesEXT,
3257                    Bi::LineIndices => BuiltIn::PrimitiveLineIndicesEXT,
3258                    Bi::TriangleIndices => BuiltIn::PrimitiveTriangleIndicesEXT,
3259                    // No decoration, this EmitMeshTasksEXT is called at function return
3260                    Bi::MeshTaskSize => return Ok(BindingDecorations::None),
3261                    // These aren't normal builtins and don't occur in function output
3262                    Bi::VertexCount | Bi::Vertices | Bi::PrimitiveCount | Bi::Primitives => {
3263                        unreachable!()
3264                    }
3265                    // ray tracing pipeline
3266                    Bi::RayInvocationId => BuiltIn::LaunchIdKHR,
3267                    Bi::NumRayInvocations => BuiltIn::LaunchSizeKHR,
3268                    Bi::InstanceCustomData => BuiltIn::InstanceCustomIndexKHR,
3269                    Bi::GeometryIndex => BuiltIn::RayGeometryIndexKHR,
3270                    Bi::WorldRayOrigin => BuiltIn::WorldRayOriginKHR,
3271                    Bi::WorldRayDirection => BuiltIn::WorldRayDirectionKHR,
3272                    Bi::ObjectRayOrigin => BuiltIn::ObjectRayOriginKHR,
3273                    Bi::ObjectRayDirection => BuiltIn::ObjectRayDirectionKHR,
3274                    Bi::RayTmin => BuiltIn::RayTminKHR,
3275                    Bi::RayTCurrentMax => BuiltIn::RayTmaxKHR,
3276                    Bi::ObjectToWorld => BuiltIn::ObjectToWorldKHR,
3277                    Bi::WorldToObject => BuiltIn::WorldToObjectKHR,
3278                    Bi::HitKind => BuiltIn::HitKindKHR,
3279                };
3280
3281                use crate::ScalarKind as Sk;
3282
3283                // Per the Vulkan spec, `VUID-StandaloneSpirv-Flat-04744`:
3284                //
3285                // > Any variable with integer or double-precision floating-
3286                // > point type and with Input storage class in a fragment
3287                // > shader, must be decorated Flat
3288                if class == spirv::StorageClass::Input && stage == crate::ShaderStage::Fragment {
3289                    let is_flat = match ir_module.types[ty].inner {
3290                        crate::TypeInner::Scalar(scalar)
3291                        | crate::TypeInner::Vector { scalar, .. } => match scalar.kind {
3292                            Sk::Uint | Sk::Sint | Sk::Bool => true,
3293                            Sk::Float => false,
3294                            Sk::AbstractInt | Sk::AbstractFloat => {
3295                                return Err(Error::Validation(
3296                                    "Abstract types should not appear in IR presented to backends",
3297                                ))
3298                            }
3299                        },
3300                        _ => false,
3301                    };
3302
3303                    if is_flat {
3304                        others.push(Decoration::Flat);
3305                    }
3306                }
3307                Ok(BindingDecorations::BuiltIn(built_in, others))
3308            }
3309        }
3310    }
3311
3312    /// Load an IO variable, converting from `f32` to `f16` if polyfill is active.
3313    /// Returns the id of the loaded value matching `target_type_id`.
3314    pub(super) fn load_io_with_f16_polyfill(
3315        &mut self,
3316        body: &mut Vec<Instruction>,
3317        varying_id: Word,
3318        target_type_id: Word,
3319    ) -> Word {
3320        let tmp = self.id_gen.next();
3321        if let Some(f32_ty) = self.io_f16_polyfills.get_f32_io_type(varying_id) {
3322            body.push(Instruction::load(f32_ty, tmp, varying_id, None));
3323            let converted = self.id_gen.next();
3324            super::f16_polyfill::F16IoPolyfill::emit_f32_to_f16_conversion(
3325                tmp,
3326                target_type_id,
3327                converted,
3328                body,
3329            );
3330            converted
3331        } else {
3332            body.push(Instruction::load(target_type_id, tmp, varying_id, None));
3333            tmp
3334        }
3335    }
3336
3337    /// Store an IO variable, converting from `f16` to `f32` if polyfill is active.
3338    pub(super) fn store_io_with_f16_polyfill(
3339        &mut self,
3340        body: &mut Vec<Instruction>,
3341        varying_id: Word,
3342        value_id: Word,
3343    ) {
3344        if let Some(f32_ty) = self.io_f16_polyfills.get_f32_io_type(varying_id) {
3345            let converted = self.id_gen.next();
3346            super::f16_polyfill::F16IoPolyfill::emit_f16_to_f32_conversion(
3347                value_id, f32_ty, converted, body,
3348            );
3349            body.push(Instruction::store(varying_id, converted, None));
3350        } else {
3351            body.push(Instruction::store(varying_id, value_id, None));
3352        }
3353    }
3354
3355    fn write_global_variable(
3356        &mut self,
3357        ir_module: &crate::Module,
3358        global_variable: &crate::GlobalVariable,
3359    ) -> Result<Word, Error> {
3360        use spirv::Decoration;
3361
3362        let id = self.id_gen.next();
3363        let class = map_storage_class(global_variable.space);
3364
3365        if let crate::AddressSpace::RayPayload | crate::AddressSpace::IncomingRayPayload =
3366            global_variable.space
3367        {
3368            self.require_any("ray tracing pipelines", &[spirv::Capability::RayTracingKHR])?;
3369        }
3370
3371        //self.check(class.required_capabilities())?;
3372
3373        if global_variable
3374            .memory_decorations
3375            .contains(crate::MemoryDecorations::COHERENT)
3376        {
3377            self.decorate(id, Decoration::Coherent, &[]);
3378        }
3379        if global_variable
3380            .memory_decorations
3381            .contains(crate::MemoryDecorations::VOLATILE)
3382        {
3383            self.decorate(id, Decoration::Volatile, &[]);
3384        }
3385
3386        if self.flags.contains(WriterFlags::DEBUG) {
3387            if let Some(ref name) = global_variable.name {
3388                self.debugs.push(Instruction::name(id, name));
3389            }
3390        }
3391
3392        let storage_access = match global_variable.space {
3393            crate::AddressSpace::Storage { access } => Some(access),
3394            _ => match ir_module.types[global_variable.ty].inner {
3395                crate::TypeInner::Image {
3396                    class: crate::ImageClass::Storage { access, .. },
3397                    ..
3398                } => Some(access),
3399                _ => None,
3400            },
3401        };
3402        if let Some(storage_access) = storage_access {
3403            if !storage_access.contains(crate::StorageAccess::LOAD) {
3404                self.decorate(id, Decoration::NonReadable, &[]);
3405            }
3406            if !storage_access.contains(crate::StorageAccess::STORE) {
3407                self.decorate(id, Decoration::NonWritable, &[]);
3408            }
3409        }
3410
3411        // Note: we should be able to substitute `binding_array<Foo, 0>`,
3412        // but there is still code that tries to register the pre-substituted type,
3413        // and it is failing on 0.
3414        let mut substitute_inner_type_lookup = None;
3415        if let Some(ref res_binding) = global_variable.binding {
3416            let bind_target = self.resolve_resource_binding(res_binding)?;
3417            self.decorate(id, Decoration::DescriptorSet, &[bind_target.descriptor_set]);
3418            self.decorate(id, Decoration::Binding, &[bind_target.binding]);
3419
3420            if let Some(remapped_binding_array_size) = bind_target.binding_array_size {
3421                if let crate::TypeInner::BindingArray { base, .. } =
3422                    ir_module.types[global_variable.ty].inner
3423                {
3424                    let binding_array_type_id =
3425                        self.get_type_id(LookupType::Local(LocalType::BindingArray {
3426                            base,
3427                            size: remapped_binding_array_size,
3428                        }));
3429                    substitute_inner_type_lookup = Some(LookupType::Local(LocalType::Pointer {
3430                        base: binding_array_type_id,
3431                        class,
3432                    }));
3433                }
3434            }
3435        };
3436
3437        let init_word = global_variable
3438            .init
3439            .map(|constant| self.constant_ids[constant]);
3440        let inner_type_id = self.get_type_id(
3441            substitute_inner_type_lookup.unwrap_or(LookupType::Handle(global_variable.ty)),
3442        );
3443
3444        // generate the wrapping structure if needed
3445        let pointer_type_id = if global_needs_wrapper(ir_module, global_variable) {
3446            let wrapper_type_id = self.id_gen.next();
3447
3448            self.decorate(wrapper_type_id, Decoration::Block, &[]);
3449
3450            match self.std140_compat_uniform_types.get(&global_variable.ty) {
3451                Some(std140_type_info) if global_variable.space == crate::AddressSpace::Uniform => {
3452                    self.annotations.push(Instruction::member_decorate(
3453                        wrapper_type_id,
3454                        0,
3455                        Decoration::Offset,
3456                        &[0],
3457                    ));
3458                    Instruction::type_struct(wrapper_type_id, &[std140_type_info.type_id])
3459                        .to_words(&mut self.logical_layout.declarations);
3460                }
3461                _ => {
3462                    let member = crate::StructMember {
3463                        name: None,
3464                        ty: global_variable.ty,
3465                        binding: None,
3466                        offset: 0,
3467                    };
3468                    self.decorate_struct_member(wrapper_type_id, 0, &member, &ir_module.types)?;
3469
3470                    Instruction::type_struct(wrapper_type_id, &[inner_type_id])
3471                        .to_words(&mut self.logical_layout.declarations);
3472                }
3473            }
3474
3475            let pointer_type_id = self.id_gen.next();
3476            Instruction::type_pointer(pointer_type_id, class, wrapper_type_id)
3477                .to_words(&mut self.logical_layout.declarations);
3478
3479            pointer_type_id
3480        } else {
3481            // This is a global variable in the Storage address space. The only
3482            // way it could have `global_needs_wrapper() == false` is if it has
3483            // a runtime-sized or binding array.
3484            // Runtime-sized arrays were decorated when iterating through struct content.
3485            // Now binding arrays require Block decorating.
3486            if let crate::AddressSpace::Storage { .. } = global_variable.space {
3487                match ir_module.types[global_variable.ty].inner {
3488                    crate::TypeInner::BindingArray { base, .. } => {
3489                        let ty = &ir_module.types[base];
3490                        let mut should_decorate = true;
3491                        // Check if the type has a runtime array.
3492                        // A normal runtime array gets validated out,
3493                        // so only structs can be with runtime arrays
3494                        if let crate::TypeInner::Struct { ref members, .. } = ty.inner {
3495                            // only the last member in a struct can be dynamically sized
3496                            if let Some(last_member) = members.last() {
3497                                if let &crate::TypeInner::Array {
3498                                    size: crate::ArraySize::Dynamic,
3499                                    ..
3500                                } = &ir_module.types[last_member.ty].inner
3501                                {
3502                                    should_decorate = false;
3503                                }
3504                            }
3505                        }
3506                        if should_decorate {
3507                            let decorated_id = self.get_handle_type_id(base);
3508                            self.decorate(decorated_id, Decoration::Block, &[]);
3509                        }
3510                    }
3511                    _ => (),
3512                };
3513            }
3514            if substitute_inner_type_lookup.is_some() {
3515                inner_type_id
3516            } else {
3517                self.get_handle_pointer_type_id(global_variable.ty, class)
3518            }
3519        };
3520
3521        let init_word = match (global_variable.space, self.zero_initialize_workgroup_memory) {
3522            (crate::AddressSpace::Private, _)
3523            | (crate::AddressSpace::WorkGroup, super::ZeroInitializeWorkgroupMemoryMode::Native) => {
3524                init_word.or_else(|| Some(self.get_constant_null(inner_type_id)))
3525            }
3526            _ => init_word,
3527        };
3528
3529        Instruction::variable(pointer_type_id, id, class, init_word)
3530            .to_words(&mut self.logical_layout.declarations);
3531        Ok(id)
3532    }
3533
3534    /// Write the necessary decorations for a struct member.
3535    ///
3536    /// Emit decorations for the `index`'th member of the struct type
3537    /// designated by `struct_id`, described by `member`.
3538    fn decorate_struct_member(
3539        &mut self,
3540        struct_id: Word,
3541        index: usize,
3542        member: &crate::StructMember,
3543        arena: &UniqueArena<crate::Type>,
3544    ) -> Result<(), Error> {
3545        use spirv::Decoration;
3546
3547        self.annotations.push(Instruction::member_decorate(
3548            struct_id,
3549            index as u32,
3550            Decoration::Offset,
3551            &[member.offset],
3552        ));
3553
3554        if self.flags.contains(WriterFlags::DEBUG) {
3555            if let Some(ref name) = member.name {
3556                self.debugs
3557                    .push(Instruction::member_name(struct_id, index as u32, name));
3558            }
3559        }
3560
3561        // Matrices and (potentially nested) arrays of matrices both require decorations,
3562        // so "see through" any arrays to determine if they're needed.
3563        let mut member_array_subty_inner = &arena[member.ty].inner;
3564        while let crate::TypeInner::Array { base, .. } = *member_array_subty_inner {
3565            member_array_subty_inner = &arena[base].inner;
3566        }
3567
3568        if let crate::TypeInner::Matrix {
3569            columns: _,
3570            rows,
3571            scalar,
3572        } = *member_array_subty_inner
3573        {
3574            let byte_stride = Alignment::from(rows) * scalar.width as u32;
3575            self.annotations.push(Instruction::member_decorate(
3576                struct_id,
3577                index as u32,
3578                Decoration::ColMajor,
3579                &[],
3580            ));
3581            self.annotations.push(Instruction::member_decorate(
3582                struct_id,
3583                index as u32,
3584                Decoration::MatrixStride,
3585                &[byte_stride],
3586            ));
3587        }
3588
3589        Ok(())
3590    }
3591
3592    pub(super) fn get_function_type(&mut self, lookup_function_type: LookupFunctionType) -> Word {
3593        match self
3594            .lookup_function_type
3595            .entry(lookup_function_type.clone())
3596        {
3597            Entry::Occupied(e) => *e.get(),
3598            Entry::Vacant(_) => {
3599                let id = self.id_gen.next();
3600                let instruction = Instruction::type_function(
3601                    id,
3602                    lookup_function_type.return_type_id,
3603                    &lookup_function_type.parameter_type_ids,
3604                );
3605                instruction.to_words(&mut self.logical_layout.declarations);
3606                self.lookup_function_type.insert(lookup_function_type, id);
3607                id
3608            }
3609        }
3610    }
3611
3612    const fn write_physical_layout(&mut self) {
3613        self.physical_layout.bound = self.id_gen.0 + 1;
3614    }
3615
3616    fn write_logical_layout(
3617        &mut self,
3618        ir_module: &crate::Module,
3619        mod_info: &ModuleInfo,
3620        ep_index: Option<usize>,
3621        debug_info: &Option<DebugInfo>,
3622    ) -> Result<(), Error> {
3623        fn has_view_index_check(
3624            ir_module: &crate::Module,
3625            binding: Option<&crate::Binding>,
3626            ty: Handle<crate::Type>,
3627        ) -> bool {
3628            match ir_module.types[ty].inner {
3629                crate::TypeInner::Struct { ref members, .. } => members.iter().any(|member| {
3630                    has_view_index_check(ir_module, member.binding.as_ref(), member.ty)
3631                }),
3632                _ => binding == Some(&crate::Binding::BuiltIn(crate::BuiltIn::ViewIndex)),
3633            }
3634        }
3635
3636        let has_storage_buffers =
3637            ir_module
3638                .global_variables
3639                .iter()
3640                .any(|(_, var)| match var.space {
3641                    crate::AddressSpace::Storage { .. } => true,
3642                    _ => false,
3643                });
3644        let has_view_index = ir_module
3645            .entry_points
3646            .iter()
3647            .flat_map(|entry| entry.function.arguments.iter())
3648            .any(|arg| has_view_index_check(ir_module, arg.binding.as_ref(), arg.ty));
3649        let has_vertex_return = ir_module.special_types.ray_vertex_return.is_some();
3650
3651        let rt_uses = ir_module.uses_ray_tracing(ep_index);
3652        let has_ray_query = rt_uses.queries;
3653        let has_ray_tracing_pipeline = rt_uses.pipelines;
3654
3655        self.has_ray_tracing_pipeline = has_ray_tracing_pipeline;
3656
3657        if self.physical_layout.version < 0x10300 && has_storage_buffers {
3658            // enable the storage buffer class on < SPV-1.3
3659            Instruction::extension("SPV_KHR_storage_buffer_storage_class")
3660                .to_words(&mut self.logical_layout.extensions);
3661        }
3662        if has_view_index {
3663            Instruction::extension("SPV_KHR_multiview")
3664                .to_words(&mut self.logical_layout.extensions)
3665        }
3666        if has_ray_query {
3667            Instruction::extension("SPV_KHR_ray_query")
3668                .to_words(&mut self.logical_layout.extensions)
3669        }
3670        if has_vertex_return {
3671            Instruction::extension("SPV_KHR_ray_tracing_position_fetch")
3672                .to_words(&mut self.logical_layout.extensions);
3673        }
3674        if ir_module.uses_mesh_shaders() {
3675            self.use_extension("SPV_EXT_mesh_shader");
3676            self.require_any("Mesh Shaders", &[spirv::Capability::MeshShadingEXT])?;
3677            let lang_version = self.lang_version();
3678            if lang_version.0 <= 1 && lang_version.1 < 4 {
3679                return Err(Error::SpirvVersionTooLow(1, 4));
3680            }
3681        }
3682        if has_ray_tracing_pipeline {
3683            Instruction::extension("SPV_KHR_ray_tracing")
3684                .to_words(&mut self.logical_layout.extensions)
3685        }
3686        Instruction::type_void(self.void_type).to_words(&mut self.logical_layout.declarations);
3687        Instruction::ext_inst_import(self.gl450_ext_inst_id, "GLSL.std.450")
3688            .to_words(&mut self.logical_layout.ext_inst_imports);
3689
3690        let mut debug_info_inner = None;
3691        if self.flags.contains(WriterFlags::DEBUG) {
3692            if let Some(debug_info) = debug_info.as_ref() {
3693                let source_file_id = self.id_gen.next();
3694                self.debugs
3695                    .push(Instruction::string(debug_info.file_name, source_file_id));
3696
3697                debug_info_inner = Some(DebugInfoInner {
3698                    source_code: debug_info.source_code,
3699                    source_file_id,
3700                });
3701                self.debugs.append(&mut Instruction::source_auto_continued(
3702                    debug_info.language,
3703                    0,
3704                    &debug_info_inner,
3705                ));
3706            }
3707        }
3708
3709        // write all types
3710        for (handle, _) in ir_module.types.iter() {
3711            self.write_type_declaration_arena(ir_module, handle)?;
3712        }
3713
3714        // write std140 layout compatible types required by uniforms
3715        for (_, var) in ir_module.global_variables.iter() {
3716            if var.space == crate::AddressSpace::Uniform {
3717                self.write_std140_compat_type_declaration(ir_module, var.ty)?;
3718            }
3719        }
3720
3721        // write all const-expressions as constants
3722        self.constant_ids
3723            .resize(ir_module.global_expressions.len(), 0);
3724        for (handle, _) in ir_module.global_expressions.iter() {
3725            self.write_constant_expr(handle, ir_module, mod_info)?;
3726        }
3727        debug_assert!(self.constant_ids.iter().all(|&id| id != 0));
3728
3729        // write the name of constants on their respective const-expression initializer
3730        if self.flags.contains(WriterFlags::DEBUG) {
3731            for (_, constant) in ir_module.constants.iter() {
3732                if let Some(ref name) = constant.name {
3733                    let id = self.constant_ids[constant.init];
3734                    self.debugs.push(Instruction::name(id, name));
3735                }
3736            }
3737        }
3738
3739        // write all global variables
3740        for (handle, var) in ir_module.global_variables.iter() {
3741            // If a single entry point was specified, only write `OpVariable` instructions
3742            // for the globals it actually uses. Emit dummies for the others,
3743            // to preserve the indices in `global_variables`.
3744            let gvar = match ep_index {
3745                Some(index) if mod_info.get_entry_point(index)[handle].is_empty() => {
3746                    GlobalVariable::dummy()
3747                }
3748                _ => {
3749                    let id = self.write_global_variable(ir_module, var)?;
3750                    GlobalVariable::new(id)
3751                }
3752            };
3753            self.global_variables.insert(handle, gvar);
3754        }
3755
3756        // write all functions
3757        for (handle, ir_function) in ir_module.functions.iter() {
3758            let info = &mod_info[handle];
3759            if let Some(index) = ep_index {
3760                let ep_info = mod_info.get_entry_point(index);
3761                // If this function uses globals that we omitted from the SPIR-V
3762                // because the entry point and its callees didn't use them,
3763                // then we must skip it.
3764                if !ep_info.dominates_global_use(info) {
3765                    log::debug!("Skip function {:?}", ir_function.name);
3766                    continue;
3767                }
3768
3769                // Skip functions that that are not compatible with this entry point's stage.
3770                //
3771                // When validation is enabled, it rejects modules whose entry points try to call
3772                // incompatible functions, so if we got this far, then any functions incompatible
3773                // with our selected entry point must not be used.
3774                //
3775                // When validation is disabled, `fun_info.available_stages` is always just
3776                // `ShaderStages::all()`, so this will write all functions in the module, and
3777                // the downstream GLSL compiler will catch any problems.
3778                if !info.available_stages.contains(ep_info.available_stages) {
3779                    continue;
3780                }
3781            }
3782            let id = self.write_function(ir_function, info, ir_module, None, &debug_info_inner)?;
3783            self.lookup_function.insert(handle, id);
3784        }
3785
3786        // write all or one entry points
3787        for (index, ir_ep) in ir_module.entry_points.iter().enumerate() {
3788            if ep_index.is_some() && ep_index != Some(index) {
3789                continue;
3790            }
3791            let info = mod_info.get_entry_point(index);
3792            let ep_instruction =
3793                self.write_entry_point(ir_ep, info, ir_module, &debug_info_inner)?;
3794            ep_instruction.to_words(&mut self.logical_layout.entry_points);
3795        }
3796
3797        for capability in self.capabilities_used.iter() {
3798            Instruction::capability(*capability).to_words(&mut self.logical_layout.capabilities);
3799        }
3800        for extension in self.extensions_used.iter() {
3801            Instruction::extension(extension).to_words(&mut self.logical_layout.extensions);
3802        }
3803        if ir_module.entry_points.is_empty() {
3804            // SPIR-V doesn't like modules without entry points
3805            Instruction::capability(spirv::Capability::Linkage)
3806                .to_words(&mut self.logical_layout.capabilities);
3807        }
3808
3809        let addressing_model = spirv::AddressingModel::Logical;
3810        let memory_model = if self
3811            .capabilities_used
3812            .contains(&spirv::Capability::VulkanMemoryModel)
3813        {
3814            spirv::MemoryModel::Vulkan
3815        } else {
3816            spirv::MemoryModel::GLSL450
3817        };
3818        //self.check(addressing_model.required_capabilities())?;
3819        //self.check(memory_model.required_capabilities())?;
3820
3821        Instruction::memory_model(addressing_model, memory_model)
3822            .to_words(&mut self.logical_layout.memory_model);
3823
3824        for debug_string in self.debug_strings.iter() {
3825            debug_string.to_words(&mut self.logical_layout.debugs);
3826        }
3827
3828        if self.flags.contains(WriterFlags::DEBUG) {
3829            for debug in self.debugs.iter() {
3830                debug.to_words(&mut self.logical_layout.debugs);
3831            }
3832        }
3833
3834        for annotation in self.annotations.iter() {
3835            annotation.to_words(&mut self.logical_layout.annotations);
3836        }
3837
3838        Ok(())
3839    }
3840
3841    pub fn write(
3842        &mut self,
3843        ir_module: &crate::Module,
3844        info: &ModuleInfo,
3845        pipeline_options: Option<&PipelineOptions>,
3846        debug_info: &Option<DebugInfo>,
3847        words: &mut Vec<Word>,
3848    ) -> Result<(), Error> {
3849        self.reset();
3850
3851        // Try to find the entry point and corresponding index
3852        let ep_index = match pipeline_options {
3853            Some(po) => {
3854                let index = ir_module
3855                    .entry_points
3856                    .iter()
3857                    .position(|ep| po.shader_stage == ep.stage && po.entry_point == ep.name)
3858                    .ok_or(Error::EntryPointNotFound)?;
3859                Some(index)
3860            }
3861            None => None,
3862        };
3863
3864        self.write_logical_layout(ir_module, info, ep_index, debug_info)?;
3865        self.write_physical_layout();
3866
3867        self.physical_layout.in_words(words);
3868        self.logical_layout.in_words(words);
3869        Ok(())
3870    }
3871
3872    /// Return the set of capabilities the last module written used.
3873    pub const fn get_capabilities_used(&self) -> &crate::FastIndexSet<spirv::Capability> {
3874        &self.capabilities_used
3875    }
3876
3877    pub fn decorate_non_uniform_binding_array_access(&mut self, id: Word) -> Result<(), Error> {
3878        self.require_any("NonUniformEXT", &[spirv::Capability::ShaderNonUniform])?;
3879        self.use_extension("SPV_EXT_descriptor_indexing");
3880        self.decorate(id, spirv::Decoration::NonUniform, &[]);
3881        Ok(())
3882    }
3883
3884    pub(super) fn needs_f16_polyfill(&self, ty_inner: &crate::TypeInner) -> bool {
3885        self.io_f16_polyfills.needs_polyfill(ty_inner)
3886    }
3887
3888    pub(super) fn write_debug_printf(
3889        &mut self,
3890        block: &mut Block,
3891        string: &str,
3892        format_params: &[Word],
3893    ) {
3894        if self.debug_printf.is_none() {
3895            self.use_extension("SPV_KHR_non_semantic_info");
3896            let import_id = self.id_gen.next();
3897            Instruction::ext_inst_import(import_id, "NonSemantic.DebugPrintf")
3898                .to_words(&mut self.logical_layout.ext_inst_imports);
3899            self.debug_printf = Some(import_id)
3900        }
3901
3902        let import_id = self.debug_printf.unwrap();
3903
3904        let string_id = self.id_gen.next();
3905        self.debug_strings
3906            .push(Instruction::string(string, string_id));
3907
3908        let mut operands = Vec::with_capacity(1 + format_params.len());
3909        operands.push(string_id);
3910        operands.extend(format_params.iter());
3911
3912        let print_id = self.id_gen.next();
3913        block.body.push(Instruction::ext_inst(
3914            import_id,
3915            1,
3916            self.void_type,
3917            print_id,
3918            &operands,
3919        ));
3920    }
3921}
3922
3923#[test]
3924fn test_write_physical_layout() {
3925    let mut writer = Writer::new(&Options::default()).unwrap();
3926    assert_eq!(writer.physical_layout.bound, 0);
3927    writer.write_physical_layout();
3928    assert_eq!(writer.physical_layout.bound, 3);
3929}