naga/back/spv/
writer.rs

1use alloc::{format, string::String, vec, vec::Vec};
2
3use arrayvec::ArrayVec;
4use hashbrown::hash_map::Entry;
5use spirv::Word;
6
7use super::{
8    block::DebugInfoInner,
9    helpers::{contains_builtin, global_needs_wrapper, map_storage_class},
10    Block, BlockContext, CachedConstant, CachedExpressions, CooperativeType, DebugInfo,
11    EntryPointContext, Error, Function, FunctionArgument, GlobalVariable, IdGenerator, Instruction,
12    LocalImageType, LocalType, LocalVariable, LogicalLayout, LookupFunctionType, LookupType,
13    NumericType, Options, PhysicalLayout, PipelineOptions, ResultMember, Writer, WriterFlags,
14    BITS_PER_BYTE,
15};
16use crate::{
17    arena::{Handle, HandleVec, UniqueArena},
18    back::spv::{
19        helpers::{is_uniform_matcx2_struct_member_access, BindingDecorations},
20        BindingInfo, Std140CompatTypeInfo, WrappedFunction,
21    },
22    common::ForDebugWithTypes as _,
23    proc::{Alignment, TypeResolution},
24    valid::{FunctionInfo, ModuleInfo},
25};
26
27pub struct FunctionInterface<'a> {
28    pub varying_ids: &'a mut Vec<Word>,
29    pub stage: crate::ShaderStage,
30    pub task_payload: Option<Handle<crate::GlobalVariable>>,
31    pub mesh_info: Option<crate::MeshStageInfo>,
32    pub workgroup_size: [u32; 3],
33}
34
35impl Function {
36    pub(super) fn to_words(&self, sink: &mut impl Extend<Word>) {
37        self.signature.as_ref().unwrap().to_words(sink);
38        for argument in self.parameters.iter() {
39            argument.instruction.to_words(sink);
40        }
41        for (index, block) in self.blocks.iter().enumerate() {
42            Instruction::label(block.label_id).to_words(sink);
43            if index == 0 {
44                for local_var in self.variables.values() {
45                    local_var.instruction.to_words(sink);
46                }
47                for local_var in self.ray_query_initialization_tracker_variables.values() {
48                    local_var.instruction.to_words(sink);
49                }
50                for local_var in self.ray_query_t_max_tracker_variables.values() {
51                    local_var.instruction.to_words(sink);
52                }
53                for local_var in self.force_loop_bounding_vars.iter() {
54                    local_var.instruction.to_words(sink);
55                }
56                for internal_var in self.spilled_composites.values() {
57                    internal_var.instruction.to_words(sink);
58                }
59            }
60            for instruction in block.body.iter() {
61                instruction.to_words(sink);
62            }
63        }
64        Instruction::function_end().to_words(sink);
65    }
66}
67
68impl Writer {
69    pub fn new(options: &Options) -> Result<Self, Error> {
70        let (major, minor) = options.lang_version;
71        if major != 1 {
72            return Err(Error::UnsupportedVersion(major, minor));
73        }
74
75        let mut capabilities_used = crate::FastIndexSet::default();
76        capabilities_used.insert(spirv::Capability::Shader);
77
78        let mut id_gen = IdGenerator::default();
79        let gl450_ext_inst_id = id_gen.next();
80        let void_type = id_gen.next();
81
82        Ok(Writer {
83            physical_layout: PhysicalLayout::new(major, minor),
84            logical_layout: LogicalLayout::default(),
85            id_gen,
86            capabilities_available: options.capabilities.clone(),
87            capabilities_used,
88            extensions_used: crate::FastIndexSet::default(),
89            debug_strings: vec![],
90            debugs: vec![],
91            annotations: vec![],
92            flags: options.flags,
93            bounds_check_policies: options.bounds_check_policies,
94            zero_initialize_workgroup_memory: options.zero_initialize_workgroup_memory,
95            force_loop_bounding: options.force_loop_bounding,
96            ray_query_initialization_tracking: options.ray_query_initialization_tracking,
97            trace_ray_argument_validation: options.trace_ray_argument_validation,
98            use_storage_input_output_16: options.use_storage_input_output_16,
99            emit_int_div_checks: options.emit_int_div_checks,
100            void_type,
101            tuple_of_u32s_ty_id: None,
102            lookup_type: crate::FastHashMap::default(),
103            lookup_function: crate::FastHashMap::default(),
104            lookup_function_type: crate::FastHashMap::default(),
105            wrapped_functions: crate::FastHashMap::default(),
106            constant_ids: HandleVec::new(),
107            cached_constants: crate::FastHashMap::default(),
108            global_variables: HandleVec::new(),
109            std140_compat_uniform_types: crate::FastHashMap::default(),
110            fake_missing_bindings: options.fake_missing_bindings,
111            binding_map: options.binding_map.clone(),
112            saved_cached: CachedExpressions::default(),
113            gl450_ext_inst_id,
114            temp_list: Vec::new(),
115            ray_query_functions: crate::FastHashMap::default(),
116            ray_tracing_functions: crate::FastHashMap::default(),
117            has_ray_tracing_pipeline: false,
118            io_f16_polyfills: super::f16_polyfill::F16IoPolyfill::new(
119                options.use_storage_input_output_16,
120            ),
121            debug_printf: None,
122            task_dispatch_limits: options.task_dispatch_limits,
123            mesh_shader_primitive_indices_clamp: options.mesh_shader_primitive_indices_clamp,
124        })
125    }
126
127    pub fn set_options(&mut self, options: &Options) -> Result<(), Error> {
128        let (major, minor) = options.lang_version;
129        if major != 1 {
130            return Err(Error::UnsupportedVersion(major, minor));
131        }
132        self.physical_layout = PhysicalLayout::new(major, minor);
133        self.capabilities_available = options.capabilities.clone();
134        self.flags = options.flags;
135        self.bounds_check_policies = options.bounds_check_policies;
136        self.zero_initialize_workgroup_memory = options.zero_initialize_workgroup_memory;
137        self.force_loop_bounding = options.force_loop_bounding;
138        self.use_storage_input_output_16 = options.use_storage_input_output_16;
139        self.binding_map = options.binding_map.clone();
140        self.io_f16_polyfills =
141            super::f16_polyfill::F16IoPolyfill::new(options.use_storage_input_output_16);
142        self.task_dispatch_limits = options.task_dispatch_limits;
143        self.mesh_shader_primitive_indices_clamp = options.mesh_shader_primitive_indices_clamp;
144        Ok(())
145    }
146
147    /// Returns `(major, minor)` of the SPIR-V language version.
148    pub const fn lang_version(&self) -> (u8, u8) {
149        self.physical_layout.lang_version()
150    }
151
152    /// Reset `Writer` to its initial state, retaining any allocations.
153    ///
154    /// Why not just implement `Reclaimable` for `Writer`? By design,
155    /// `Reclaimable::reclaim` requires ownership of the value, not just
156    /// `&mut`; see the trait documentation. But we need to use this method
157    /// from functions like `Writer::write`, which only have `&mut Writer`.
158    /// Workarounds include unsafe code (`core::ptr::read`, then `write`, ugh)
159    /// or something like a `Default` impl that returns an oddly-initialized
160    /// `Writer`, which is worse.
161    fn reset(&mut self) {
162        use super::reclaimable::Reclaimable;
163        use core::mem::take;
164
165        let mut id_gen = IdGenerator::default();
166        let gl450_ext_inst_id = id_gen.next();
167        let void_type = id_gen.next();
168
169        // Every field of the old writer that is not determined by the `Options`
170        // passed to `Writer::new` should be reset somehow.
171        let fresh = Writer {
172            // Copied from the old Writer:
173            flags: self.flags,
174            bounds_check_policies: self.bounds_check_policies,
175            zero_initialize_workgroup_memory: self.zero_initialize_workgroup_memory,
176            force_loop_bounding: self.force_loop_bounding,
177            ray_query_initialization_tracking: self.ray_query_initialization_tracking,
178            trace_ray_argument_validation: self.trace_ray_argument_validation,
179            use_storage_input_output_16: self.use_storage_input_output_16,
180            capabilities_available: take(&mut self.capabilities_available),
181            fake_missing_bindings: self.fake_missing_bindings,
182            binding_map: take(&mut self.binding_map),
183            task_dispatch_limits: self.task_dispatch_limits,
184            mesh_shader_primitive_indices_clamp: self.mesh_shader_primitive_indices_clamp,
185            emit_int_div_checks: self.emit_int_div_checks,
186
187            // Initialized afresh:
188            id_gen,
189            void_type,
190            tuple_of_u32s_ty_id: None,
191            gl450_ext_inst_id,
192
193            // Reclaimed:
194            capabilities_used: take(&mut self.capabilities_used).reclaim(),
195            extensions_used: take(&mut self.extensions_used).reclaim(),
196            physical_layout: self.physical_layout.clone().reclaim(),
197            logical_layout: take(&mut self.logical_layout).reclaim(),
198            debug_strings: take(&mut self.debug_strings).reclaim(),
199            debugs: take(&mut self.debugs).reclaim(),
200            annotations: take(&mut self.annotations).reclaim(),
201            lookup_type: take(&mut self.lookup_type).reclaim(),
202            lookup_function: take(&mut self.lookup_function).reclaim(),
203            lookup_function_type: take(&mut self.lookup_function_type).reclaim(),
204            wrapped_functions: take(&mut self.wrapped_functions).reclaim(),
205            constant_ids: take(&mut self.constant_ids).reclaim(),
206            cached_constants: take(&mut self.cached_constants).reclaim(),
207            global_variables: take(&mut self.global_variables).reclaim(),
208            std140_compat_uniform_types: take(&mut self.std140_compat_uniform_types).reclaim(),
209            saved_cached: take(&mut self.saved_cached).reclaim(),
210            temp_list: take(&mut self.temp_list).reclaim(),
211            ray_query_functions: take(&mut self.ray_query_functions).reclaim(),
212            ray_tracing_functions: take(&mut self.ray_tracing_functions).reclaim(),
213            has_ray_tracing_pipeline: false,
214            io_f16_polyfills: take(&mut self.io_f16_polyfills).reclaim(),
215            debug_printf: None,
216        };
217
218        *self = fresh;
219
220        self.capabilities_used.insert(spirv::Capability::Shader);
221    }
222
223    /// Indicate that the code requires any one of the listed capabilities.
224    ///
225    /// If nothing in `capabilities` appears in the available capabilities
226    /// specified in the [`Options`] from which this `Writer` was created,
227    /// return an error. The `what` string is used in the error message to
228    /// explain what provoked the requirement. (If no available capabilities were
229    /// given, assume everything is available.)
230    ///
231    /// The first acceptable capability will be added to this `Writer`'s
232    /// [`capabilities_used`] table, and an `OpCapability` emitted for it in the
233    /// result. For this reason, more specific capabilities should be listed
234    /// before more general.
235    ///
236    /// [`capabilities_used`]: Writer::capabilities_used
237    pub(super) fn require_any(
238        &mut self,
239        what: &'static str,
240        capabilities: &[spirv::Capability],
241    ) -> Result<(), Error> {
242        match *capabilities {
243            [] => Ok(()),
244            [first, ..] => {
245                // Find the first acceptable capability, or return an error if
246                // there is none.
247                let selected = match self.capabilities_available {
248                    None => first,
249                    Some(ref available) => {
250                        match capabilities
251                            .iter()
252                            // need explicit type for hashbrown::HashSet::contains fn call to keep rustc happy
253                            .find(|cap| available.contains::<spirv::Capability>(cap))
254                        {
255                            Some(&cap) => cap,
256                            None => {
257                                return Err(Error::MissingCapabilities(what, capabilities.to_vec()))
258                            }
259                        }
260                    }
261                };
262                self.capabilities_used.insert(selected);
263                Ok(())
264            }
265        }
266    }
267
268    /// Indicate that the code requires all of the listed capabilities.
269    ///
270    /// If all entries of `capabilities` appear in the available capabilities
271    /// specified in the [`Options`] from which this `Writer` was created
272    /// (including the case where [`Options::capabilities`] is `None`), add
273    /// them all to this `Writer`'s [`capabilities_used`] table, and return
274    /// `Ok(())`. If at least one of the listed capabilities is not available,
275    /// do not add anything to the `capabilities_used` table, and return the
276    /// first unavailable requested capability, wrapped in `Err()`.
277    ///
278    /// This method is does not return an [`enum@Error`] in case of failure
279    /// because it may be used in cases where the caller can recover (e.g.,
280    /// with a polyfill) if the requested capabilities are not available. In
281    /// this case, it would be unnecessary work to find *all* the unavailable
282    /// requested capabilities, and to allocate a `Vec` for them, just so we
283    /// could return an [`Error::MissingCapabilities`]).
284    ///
285    /// [`capabilities_used`]: Writer::capabilities_used
286    pub(super) fn require_all(
287        &mut self,
288        capabilities: &[spirv::Capability],
289    ) -> Result<(), spirv::Capability> {
290        if let Some(ref available) = self.capabilities_available {
291            for requested in capabilities {
292                if !available.contains(requested) {
293                    return Err(*requested);
294                }
295            }
296        }
297
298        for requested in capabilities {
299            self.capabilities_used.insert(*requested);
300        }
301
302        Ok(())
303    }
304
305    /// Indicate that the code uses the given extension.
306    pub(super) fn use_extension(&mut self, extension: &'static str) {
307        self.extensions_used.insert(extension);
308    }
309
310    pub(super) fn get_type_id(&mut self, lookup_ty: LookupType) -> Word {
311        match self.lookup_type.entry(lookup_ty) {
312            Entry::Occupied(e) => *e.get(),
313            Entry::Vacant(e) => {
314                let local = match lookup_ty {
315                    LookupType::Handle(_handle) => unreachable!("Handles are populated at start"),
316                    LookupType::Local(local) => local,
317                };
318
319                let id = self.id_gen.next();
320                e.insert(id);
321                self.write_type_declaration_local(id, local);
322                id
323            }
324        }
325    }
326
327    pub(super) fn get_handle_type_id(&mut self, handle: Handle<crate::Type>) -> Word {
328        self.get_type_id(LookupType::Handle(handle))
329    }
330
331    pub(super) fn get_expression_lookup_type(&mut self, tr: &TypeResolution) -> LookupType {
332        match *tr {
333            TypeResolution::Handle(ty_handle) => LookupType::Handle(ty_handle),
334            TypeResolution::Value(ref inner) => {
335                let inner_local_type = self.localtype_from_inner(inner).unwrap();
336                LookupType::Local(inner_local_type)
337            }
338        }
339    }
340
341    pub(super) fn get_expression_type_id(&mut self, tr: &TypeResolution) -> Word {
342        let lookup_ty = self.get_expression_lookup_type(tr);
343        self.get_type_id(lookup_ty)
344    }
345
346    pub(super) fn get_localtype_id(&mut self, local: LocalType) -> Word {
347        self.get_type_id(LookupType::Local(local))
348    }
349
350    pub(super) fn get_pointer_type_id(&mut self, base: Word, class: spirv::StorageClass) -> Word {
351        self.get_type_id(LookupType::Local(LocalType::Pointer { base, class }))
352    }
353
354    pub(super) fn get_handle_pointer_type_id(
355        &mut self,
356        base: Handle<crate::Type>,
357        class: spirv::StorageClass,
358    ) -> Word {
359        let base_id = self.get_handle_type_id(base);
360        self.get_pointer_type_id(base_id, class)
361    }
362
363    pub(super) fn get_ray_query_pointer_id(&mut self) -> Word {
364        let rq_id = self.get_type_id(LookupType::Local(LocalType::RayQuery));
365        self.get_pointer_type_id(rq_id, spirv::StorageClass::Function)
366    }
367
368    /// Return a SPIR-V type for a pointer to `resolution`.
369    ///
370    /// The given `resolution` must be one that we can represent
371    /// either as a `LocalType::Pointer` or `LocalType::LocalPointer`.
372    pub(super) fn get_resolution_pointer_id(
373        &mut self,
374        resolution: &TypeResolution,
375        class: spirv::StorageClass,
376    ) -> Word {
377        let resolution_type_id = self.get_expression_type_id(resolution);
378        self.get_pointer_type_id(resolution_type_id, class)
379    }
380
381    pub(super) fn get_numeric_type_id(&mut self, numeric: NumericType) -> Word {
382        self.get_type_id(LocalType::Numeric(numeric).into())
383    }
384
385    pub(super) fn get_u32_type_id(&mut self) -> Word {
386        self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32))
387    }
388
389    pub(super) fn get_f32_type_id(&mut self) -> Word {
390        self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::F32))
391    }
392
393    pub(super) fn get_vec2u_type_id(&mut self) -> Word {
394        self.get_numeric_type_id(NumericType::Vector {
395            size: crate::VectorSize::Bi,
396            scalar: crate::Scalar::U32,
397        })
398    }
399
400    pub(super) fn get_vec2f_type_id(&mut self) -> Word {
401        self.get_numeric_type_id(NumericType::Vector {
402            size: crate::VectorSize::Bi,
403            scalar: crate::Scalar::F32,
404        })
405    }
406
407    pub(super) fn get_vec3u_type_id(&mut self) -> Word {
408        self.get_numeric_type_id(NumericType::Vector {
409            size: crate::VectorSize::Tri,
410            scalar: crate::Scalar::U32,
411        })
412    }
413
414    pub(super) fn get_f32_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word {
415        let f32_id = self.get_f32_type_id();
416        self.get_pointer_type_id(f32_id, class)
417    }
418
419    pub(super) fn get_vec2u_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word {
420        let vec2u_id = self.get_numeric_type_id(NumericType::Vector {
421            size: crate::VectorSize::Bi,
422            scalar: crate::Scalar::U32,
423        });
424        self.get_pointer_type_id(vec2u_id, class)
425    }
426
427    pub(super) fn get_bool_type_id(&mut self) -> Word {
428        self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::BOOL))
429    }
430
431    pub(super) fn get_vec2_bool_type_id(&mut self) -> Word {
432        self.get_numeric_type_id(NumericType::Vector {
433            size: crate::VectorSize::Bi,
434            scalar: crate::Scalar::BOOL,
435        })
436    }
437
438    pub(super) fn get_vec3_bool_type_id(&mut self) -> Word {
439        self.get_numeric_type_id(NumericType::Vector {
440            size: crate::VectorSize::Tri,
441            scalar: crate::Scalar::BOOL,
442        })
443    }
444
445    /// Used for "mulhi" to get the upper bits of multiplication.
446    ///
447    /// More specifically, `OpUMulExtended` multiplies 2 numbers and returns the lower and upper bits of the result
448    /// as a user-defined struct type with 2 u32s. This defines that struct.
449    pub(super) fn get_tuple_of_u32s_ty_id(&mut self) -> Word {
450        if let Some(val) = self.tuple_of_u32s_ty_id {
451            val
452        } else {
453            let id = self.id_gen.next();
454            let u32_id = self.get_u32_type_id();
455            let ins = Instruction::type_struct(id, &[u32_id, u32_id]);
456            ins.to_words(&mut self.logical_layout.declarations);
457            self.tuple_of_u32s_ty_id = Some(id);
458            id
459        }
460    }
461
462    pub(super) fn decorate(&mut self, id: Word, decoration: spirv::Decoration, operands: &[Word]) {
463        self.annotations
464            .push(Instruction::decorate(id, decoration, operands));
465    }
466
467    /// Return `inner` as a `LocalType`, if that's possible.
468    ///
469    /// If `inner` can be represented as a `LocalType`, return
470    /// `Some(local_type)`.
471    ///
472    /// Otherwise, return `None`. In this case, the type must always be looked
473    /// up using a `LookupType::Handle`.
474    fn localtype_from_inner(&mut self, inner: &crate::TypeInner) -> Option<LocalType> {
475        Some(match *inner {
476            crate::TypeInner::Scalar(_)
477            | crate::TypeInner::Atomic(_)
478            | crate::TypeInner::Vector { .. }
479            | crate::TypeInner::Matrix { .. } => {
480                // We expect `NumericType::from_inner` to handle all
481                // these cases, so unwrap.
482                LocalType::Numeric(NumericType::from_inner(inner).unwrap())
483            }
484            crate::TypeInner::CooperativeMatrix { .. } => {
485                LocalType::Cooperative(CooperativeType::from_inner(inner).unwrap())
486            }
487            crate::TypeInner::Pointer { base, space } => {
488                let base_type_id = self.get_handle_type_id(base);
489                LocalType::Pointer {
490                    base: base_type_id,
491                    class: map_storage_class(space),
492                }
493            }
494            crate::TypeInner::ValuePointer {
495                size,
496                scalar,
497                space,
498            } => {
499                let base_numeric_type = match size {
500                    Some(size) => NumericType::Vector { size, scalar },
501                    None => NumericType::Scalar(scalar),
502                };
503                LocalType::Pointer {
504                    base: self.get_numeric_type_id(base_numeric_type),
505                    class: map_storage_class(space),
506                }
507            }
508            crate::TypeInner::Image {
509                dim,
510                arrayed,
511                class,
512            } => LocalType::Image(LocalImageType::from_inner(dim, arrayed, class)),
513            crate::TypeInner::Sampler { comparison: _ } => LocalType::Sampler,
514            crate::TypeInner::AccelerationStructure { .. } => LocalType::AccelerationStructure,
515            crate::TypeInner::RayQuery { .. } => LocalType::RayQuery,
516            crate::TypeInner::Array { .. }
517            | crate::TypeInner::Struct { .. }
518            | crate::TypeInner::BindingArray { .. } => return None,
519        })
520    }
521
522    /// Resolve the [`BindingInfo`] for a [`crate::ResourceBinding`] from the
523    /// provided [`Writer::binding_map`].
524    ///
525    /// If the specified resource is not present in the binding map this will
526    /// return an error, unless [`Writer::fake_missing_bindings`] is set.
527    fn resolve_resource_binding(
528        &self,
529        res_binding: &crate::ResourceBinding,
530    ) -> Result<BindingInfo, Error> {
531        match self.binding_map.get(res_binding) {
532            Some(target) => Ok(*target),
533            None if self.fake_missing_bindings => Ok(BindingInfo {
534                descriptor_set: res_binding.group,
535                binding: res_binding.binding,
536                binding_array_size: None,
537            }),
538            None => Err(Error::MissingBinding(*res_binding)),
539        }
540    }
541
542    /// Emits code for any wrapper functions required by the expressions in ir_function.
543    /// The IDs of any emitted functions will be stored in [`Self::wrapped_functions`].
544    fn write_wrapped_functions(
545        &mut self,
546        ir_function: &crate::Function,
547        info: &FunctionInfo,
548        ir_module: &crate::Module,
549    ) -> Result<(), Error> {
550        log::trace!("Generating wrapped functions for {:?}", ir_function.name);
551
552        for (expr_handle, expr) in ir_function.expressions.iter() {
553            match *expr {
554                crate::Expression::Binary { op, left, right } => {
555                    let expr_ty_inner = info[expr_handle].ty.inner_with(&ir_module.types);
556                    if let Some(expr_ty) = NumericType::from_inner(expr_ty_inner) {
557                        match (op, expr_ty.scalar().kind) {
558                            // Division and modulo are undefined behaviour when the
559                            // dividend is the minimum representable value and the divisor
560                            // is negative one, or when the divisor is zero. These wrapped
561                            // functions override the divisor to one in these cases,
562                            // matching the WGSL spec.
563                            (
564                                crate::BinaryOperator::Divide | crate::BinaryOperator::Modulo,
565                                crate::ScalarKind::Sint | crate::ScalarKind::Uint,
566                            ) if self.emit_int_div_checks => {
567                                self.write_wrapped_binary_op(
568                                    op,
569                                    expr_ty,
570                                    &info[left].ty,
571                                    &info[right].ty,
572                                )?;
573                            }
574                            _ => {}
575                        }
576                    }
577                }
578                crate::Expression::Load { pointer } => {
579                    if let crate::TypeInner::Pointer {
580                        base: pointer_type,
581                        space: crate::AddressSpace::Uniform,
582                    } = *info[pointer].ty.inner_with(&ir_module.types)
583                    {
584                        if self.std140_compat_uniform_types.contains_key(&pointer_type) {
585                            // Loading a std140 compat type requires the wrapper function
586                            // to convert to the regular type.
587                            self.write_wrapped_convert_from_std140_compat_type(
588                                ir_module,
589                                pointer_type,
590                            )?;
591                        }
592                    }
593                }
594                crate::Expression::Access { base, .. } => {
595                    if let crate::TypeInner::Pointer {
596                        base: base_type,
597                        space: crate::AddressSpace::Uniform,
598                    } = *info[base].ty.inner_with(&ir_module.types)
599                    {
600                        // Dynamic accesses of a two-row matrix's columns require a
601                        // wrapper function.
602                        if let crate::TypeInner::Matrix {
603                            rows: crate::VectorSize::Bi,
604                            ..
605                        } = ir_module.types[base_type].inner
606                        {
607                            self.write_wrapped_matcx2_get_column(ir_module, base_type)?;
608                            // If the matrix is *not* directly a member of a struct, then
609                            // we additionally require a wrapper function to convert from
610                            // the std140 compat type to the regular type.
611                            if !is_uniform_matcx2_struct_member_access(
612                                ir_function,
613                                info,
614                                ir_module,
615                                base,
616                            ) {
617                                self.write_wrapped_convert_from_std140_compat_type(
618                                    ir_module, base_type,
619                                )?;
620                            }
621                        }
622                    }
623                }
624                _ => {}
625            }
626        }
627
628        Ok(())
629    }
630
631    /// Write a SPIR-V function that performs the operator `op` with Naga IR semantics.
632    ///
633    /// Define a function that performs an integer division or modulo operation,
634    /// except that using a divisor of zero or causing signed overflow with a
635    /// divisor of -1 returns the numerator unchanged, rather than exhibiting
636    /// undefined behavior.
637    ///
638    /// Store the generated function's id in the [`wrapped_functions`] table.
639    ///
640    /// The operator `op` must be either [`Divide`] or [`Modulo`].
641    ///
642    /// # Panics
643    ///
644    /// The `return_type`, `left_type` or `right_type` arguments must all be
645    /// integer scalars or vectors. If not, this function panics.
646    ///
647    /// [`wrapped_functions`]: Writer::wrapped_functions
648    /// [`Divide`]: crate::BinaryOperator::Divide
649    /// [`Modulo`]: crate::BinaryOperator::Modulo
650    fn write_wrapped_binary_op(
651        &mut self,
652        op: crate::BinaryOperator,
653        return_type: NumericType,
654        left_type: &TypeResolution,
655        right_type: &TypeResolution,
656    ) -> Result<(), Error> {
657        let return_type_id = self.get_localtype_id(LocalType::Numeric(return_type));
658        let left_type_id = self.get_expression_type_id(left_type);
659        let right_type_id = self.get_expression_type_id(right_type);
660
661        // Check if we've already emitted this function.
662        let wrapped = WrappedFunction::BinaryOp {
663            op,
664            left_type_id,
665            right_type_id,
666        };
667        let function_id = match self.wrapped_functions.entry(wrapped) {
668            Entry::Occupied(_) => return Ok(()),
669            Entry::Vacant(e) => *e.insert(self.id_gen.next()),
670        };
671
672        let scalar = return_type.scalar();
673
674        if self.flags.contains(WriterFlags::DEBUG) {
675            let function_name = match op {
676                crate::BinaryOperator::Divide => "naga_div",
677                crate::BinaryOperator::Modulo => "naga_mod",
678                _ => unreachable!(),
679            };
680            self.debugs
681                .push(Instruction::name(function_id, function_name));
682        }
683        let mut function = Function::default();
684
685        let function_type_id = self.get_function_type(LookupFunctionType {
686            parameter_type_ids: vec![left_type_id, right_type_id],
687            return_type_id,
688        });
689        function.signature = Some(Instruction::function(
690            return_type_id,
691            function_id,
692            spirv::FunctionControl::empty(),
693            function_type_id,
694        ));
695
696        let lhs_id = self.id_gen.next();
697        let rhs_id = self.id_gen.next();
698        if self.flags.contains(WriterFlags::DEBUG) {
699            self.debugs.push(Instruction::name(lhs_id, "lhs"));
700            self.debugs.push(Instruction::name(rhs_id, "rhs"));
701        }
702        let left_par = Instruction::function_parameter(left_type_id, lhs_id);
703        let right_par = Instruction::function_parameter(right_type_id, rhs_id);
704        for instruction in [left_par, right_par] {
705            function.parameters.push(FunctionArgument {
706                instruction,
707                handle_id: 0,
708            });
709        }
710
711        let label_id = self.id_gen.next();
712        let mut block = Block::new(label_id);
713
714        let bool_type = return_type.with_scalar(crate::Scalar::BOOL);
715        let bool_type_id = self.get_numeric_type_id(bool_type);
716
717        let maybe_splat_const = |writer: &mut Self, const_id| match return_type {
718            NumericType::Scalar(_) => const_id,
719            NumericType::Vector { size, .. } => {
720                let constituent_ids = [const_id; crate::VectorSize::MAX];
721                writer.get_constant_composite(
722                    LookupType::Local(LocalType::Numeric(return_type)),
723                    &constituent_ids[..size as usize],
724                )
725            }
726            NumericType::Matrix { .. } => unreachable!(),
727        };
728
729        let const_zero_id = self.get_constant_scalar_with(0, scalar)?;
730        let composite_zero_id = maybe_splat_const(self, const_zero_id);
731        let rhs_eq_zero_id = self.id_gen.next();
732        block.body.push(Instruction::binary(
733            spirv::Op::IEqual,
734            bool_type_id,
735            rhs_eq_zero_id,
736            rhs_id,
737            composite_zero_id,
738        ));
739        let divisor_selector_id = match scalar.kind {
740            crate::ScalarKind::Sint => {
741                let (const_min_id, const_neg_one_id) = match scalar.width {
742                    2 => Ok((
743                        self.get_constant_scalar(crate::Literal::I16(i16::MIN)),
744                        self.get_constant_scalar(crate::Literal::I16(-1i16)),
745                    )),
746                    4 => Ok((
747                        self.get_constant_scalar(crate::Literal::I32(i32::MIN)),
748                        self.get_constant_scalar(crate::Literal::I32(-1i32)),
749                    )),
750                    8 => Ok((
751                        self.get_constant_scalar(crate::Literal::I64(i64::MIN)),
752                        self.get_constant_scalar(crate::Literal::I64(-1i64)),
753                    )),
754                    _ => Err(Error::Validation("Unexpected scalar width")),
755                }?;
756                let composite_min_id = maybe_splat_const(self, const_min_id);
757                let composite_neg_one_id = maybe_splat_const(self, const_neg_one_id);
758
759                let lhs_eq_int_min_id = self.id_gen.next();
760                block.body.push(Instruction::binary(
761                    spirv::Op::IEqual,
762                    bool_type_id,
763                    lhs_eq_int_min_id,
764                    lhs_id,
765                    composite_min_id,
766                ));
767                let rhs_eq_neg_one_id = self.id_gen.next();
768                block.body.push(Instruction::binary(
769                    spirv::Op::IEqual,
770                    bool_type_id,
771                    rhs_eq_neg_one_id,
772                    rhs_id,
773                    composite_neg_one_id,
774                ));
775                let lhs_eq_int_min_and_rhs_eq_neg_one_id = self.id_gen.next();
776                block.body.push(Instruction::binary(
777                    spirv::Op::LogicalAnd,
778                    bool_type_id,
779                    lhs_eq_int_min_and_rhs_eq_neg_one_id,
780                    lhs_eq_int_min_id,
781                    rhs_eq_neg_one_id,
782                ));
783                let rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id = self.id_gen.next();
784                block.body.push(Instruction::binary(
785                    spirv::Op::LogicalOr,
786                    bool_type_id,
787                    rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id,
788                    rhs_eq_zero_id,
789                    lhs_eq_int_min_and_rhs_eq_neg_one_id,
790                ));
791                rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id
792            }
793            crate::ScalarKind::Uint => rhs_eq_zero_id,
794            _ => unreachable!(),
795        };
796
797        let const_one_id = self.get_constant_scalar_with(1, scalar)?;
798        let composite_one_id = maybe_splat_const(self, const_one_id);
799        let divisor_id = self.id_gen.next();
800        block.body.push(Instruction::select(
801            right_type_id,
802            divisor_id,
803            divisor_selector_id,
804            composite_one_id,
805            rhs_id,
806        ));
807        let op = match (op, scalar.kind) {
808            (crate::BinaryOperator::Divide, crate::ScalarKind::Sint) => spirv::Op::SDiv,
809            (crate::BinaryOperator::Divide, crate::ScalarKind::Uint) => spirv::Op::UDiv,
810            (crate::BinaryOperator::Modulo, crate::ScalarKind::Sint) => spirv::Op::SRem,
811            (crate::BinaryOperator::Modulo, crate::ScalarKind::Uint) => spirv::Op::UMod,
812            _ => unreachable!(),
813        };
814        let return_id = self.id_gen.next();
815        block.body.push(Instruction::binary(
816            op,
817            return_type_id,
818            return_id,
819            lhs_id,
820            divisor_id,
821        ));
822
823        function.consume(block, Instruction::return_value(return_id));
824        function.to_words(&mut self.logical_layout.function_definitions);
825        Ok(())
826    }
827
828    /// Writes a wrapper function to convert from a std140 compat type to its
829    /// corresponding regular type.
830    ///
831    /// See [`Self::write_std140_compat_type_declaration`] for more details.
832    fn write_wrapped_convert_from_std140_compat_type(
833        &mut self,
834        ir_module: &crate::Module,
835        r#type: Handle<crate::Type>,
836    ) -> Result<(), Error> {
837        if !self.std140_compat_uniform_types.contains_key(&r#type) {
838            return Ok(());
839        }
840        // Check if we've already emitted this function.
841        let wrapped = WrappedFunction::ConvertFromStd140CompatType { r#type };
842        let function_id = match self.wrapped_functions.entry(wrapped) {
843            Entry::Occupied(_) => return Ok(()),
844            Entry::Vacant(e) => *e.insert(self.id_gen.next()),
845        };
846        if self.flags.contains(WriterFlags::DEBUG) {
847            self.debugs.push(Instruction::name(
848                function_id,
849                &format!("{:?}_from_std140", r#type.for_debug(&ir_module.types)),
850            ));
851        }
852        let param_type_id = self.std140_compat_uniform_types[&r#type].type_id;
853        let return_type_id = self.get_handle_type_id(r#type);
854
855        let mut function = Function::default();
856        let function_type_id = self.get_function_type(LookupFunctionType {
857            parameter_type_ids: vec![param_type_id],
858            return_type_id,
859        });
860        function.signature = Some(Instruction::function(
861            return_type_id,
862            function_id,
863            spirv::FunctionControl::empty(),
864            function_type_id,
865        ));
866        let param_id = self.id_gen.next();
867        function.parameters.push(FunctionArgument {
868            instruction: Instruction::function_parameter(param_type_id, param_id),
869            handle_id: 0,
870        });
871
872        let label_id = self.id_gen.next();
873        let mut block = Block::new(label_id);
874
875        let result_id = match ir_module.types[r#type].inner {
876            // Param is struct containing a vector member for each of the
877            // matrix's columns. Extract each column from the struct then
878            // composite into a matrix.
879            crate::TypeInner::Matrix {
880                columns,
881                rows: rows @ crate::VectorSize::Bi,
882                scalar,
883            } => {
884                let column_type_id =
885                    self.get_numeric_type_id(NumericType::Vector { size: rows, scalar });
886
887                let mut column_ids: ArrayVec<Word, 4> = ArrayVec::new();
888                for column in 0..columns as u32 {
889                    let column_id = self.id_gen.next();
890                    block.body.push(Instruction::composite_extract(
891                        column_type_id,
892                        column_id,
893                        param_id,
894                        &[column],
895                    ));
896                    column_ids.push(column_id);
897                }
898                let result_id = self.id_gen.next();
899                block.body.push(Instruction::composite_construct(
900                    return_type_id,
901                    result_id,
902                    &column_ids,
903                ));
904                result_id
905            }
906            // Param is an array where the base type is the std140 compatible
907            // type corresponding to `base`. Iterate through each element and
908            // call its conversion function, then composite into a new array.
909            crate::TypeInner::Array { base, size, .. } => {
910                // Ensure the conversion function for the array's base type is
911                // declared.
912                self.write_wrapped_convert_from_std140_compat_type(ir_module, base)?;
913
914                let element_type_id = self.get_handle_type_id(base);
915                let std140_info = self.std140_compat_uniform_types.get(&base);
916                let mut element_ids = Vec::new();
917                let size = match size.resolve(ir_module.to_ctx())? {
918                    crate::proc::IndexableLength::Known(size) => size,
919                    crate::proc::IndexableLength::Dynamic => {
920                        return Err(Error::Validation(
921                            "Uniform buffers cannot contain dynamic arrays",
922                        ))
923                    }
924                };
925                for i in 0..size {
926                    let std140_element_id = self.id_gen.next();
927                    let std140_element_type_id =
928                        std140_info.map_or(element_type_id, |info| info.type_id);
929                    block.body.push(Instruction::composite_extract(
930                        std140_element_type_id,
931                        std140_element_id,
932                        param_id,
933                        &[i],
934                    ));
935
936                    // Only call the conversion function if a compatibility mapping actually exists.
937                    let final_element_id = if std140_info.is_some() {
938                        let conversion_fn_id = self.wrapped_functions
939                            [&WrappedFunction::ConvertFromStd140CompatType { r#type: base }];
940                        let id = self.id_gen.next();
941                        block.body.push(Instruction::function_call(
942                            element_type_id,
943                            id,
944                            conversion_fn_id,
945                            &[std140_element_id],
946                        ));
947                        id
948                    } else {
949                        std140_element_type_id
950                    };
951                    element_ids.push(final_element_id);
952                }
953                let result_id = self.id_gen.next();
954                block.body.push(Instruction::composite_construct(
955                    return_type_id,
956                    result_id,
957                    &element_ids,
958                ));
959                result_id
960            }
961            // Param is a struct where each two-row matrix member has been
962            // decomposed in to separate vector members for each column.
963            // Other members use their std140 compatible type if one exists, or
964            // else their regular type. Iterate through each member, converting
965            // or composing any matrices if required, then finally compose into
966            // the struct.
967            crate::TypeInner::Struct { ref members, .. } => {
968                let mut member_ids = Vec::new();
969                let mut next_index = 0;
970                for member in members {
971                    let member_id = self.id_gen.next();
972                    let member_type_id = self.get_handle_type_id(member.ty);
973                    match ir_module.types[member.ty].inner {
974                        crate::TypeInner::Matrix {
975                            columns,
976                            rows: rows @ crate::VectorSize::Bi,
977                            scalar,
978                        } => {
979                            let mut column_ids: ArrayVec<Word, 4> = ArrayVec::new();
980                            let column_type_id = self
981                                .get_numeric_type_id(NumericType::Vector { size: rows, scalar });
982                            for _ in 0..columns as u32 {
983                                let column_id = self.id_gen.next();
984                                block.body.push(Instruction::composite_extract(
985                                    column_type_id,
986                                    column_id,
987                                    param_id,
988                                    &[next_index],
989                                ));
990                                column_ids.push(column_id);
991                                next_index += 1;
992                            }
993                            block.body.push(Instruction::composite_construct(
994                                member_type_id,
995                                member_id,
996                                &column_ids,
997                            ));
998                        }
999                        _ => {
1000                            // Ensure the conversion function for the member's
1001                            // type is declared.
1002                            self.write_wrapped_convert_from_std140_compat_type(
1003                                ir_module, member.ty,
1004                            )?;
1005                            match self.std140_compat_uniform_types.get(&member.ty) {
1006                                Some(std140_type_info) => {
1007                                    let std140_member_id = self.id_gen.next();
1008                                    block.body.push(Instruction::composite_extract(
1009                                        std140_type_info.type_id,
1010                                        std140_member_id,
1011                                        param_id,
1012                                        &[next_index],
1013                                    ));
1014                                    let function_id = self.wrapped_functions
1015                                        [&WrappedFunction::ConvertFromStd140CompatType {
1016                                            r#type: member.ty,
1017                                        }];
1018                                    block.body.push(Instruction::function_call(
1019                                        member_type_id,
1020                                        member_id,
1021                                        function_id,
1022                                        &[std140_member_id],
1023                                    ));
1024                                    next_index += 1;
1025                                }
1026                                None => {
1027                                    block.body.push(Instruction::composite_extract(
1028                                        member_type_id,
1029                                        member_id,
1030                                        param_id,
1031                                        &[next_index],
1032                                    ));
1033                                    next_index += 1;
1034                                }
1035                            }
1036                        }
1037                    }
1038                    member_ids.push(member_id);
1039                }
1040                let result_id = self.id_gen.next();
1041                block.body.push(Instruction::composite_construct(
1042                    return_type_id,
1043                    result_id,
1044                    &member_ids,
1045                ));
1046                result_id
1047            }
1048            _ => unreachable!(),
1049        };
1050
1051        function.consume(block, Instruction::return_value(result_id));
1052        function.to_words(&mut self.logical_layout.function_definitions);
1053        Ok(())
1054    }
1055
1056    /// Writes a wrapper function to get an `OpTypeVector` column from an
1057    /// `OpTypeMatrix` with a dynamic index.
1058    ///
1059    /// This is used when accessing a column of a [`TypeInner::Matrix`] through
1060    /// a [`Uniform`] address space pointer. In such cases, the matrix will have
1061    /// been declared in SPIR-V using an alternative type where each column is a
1062    /// member of a containing struct. SPIR-V is unable to dynamically access
1063    /// struct members, so instead we load the matrix then call this function to
1064    /// access a column from the loaded value.
1065    ///
1066    /// [`TypeInner::Matrix`]: crate::TypeInner::Matrix
1067    /// [`Uniform`]: crate::AddressSpace::Uniform
1068    fn write_wrapped_matcx2_get_column(
1069        &mut self,
1070        ir_module: &crate::Module,
1071        r#type: Handle<crate::Type>,
1072    ) -> Result<(), Error> {
1073        let wrapped = WrappedFunction::MatCx2GetColumn { r#type };
1074        let function_id = match self.wrapped_functions.entry(wrapped) {
1075            Entry::Occupied(_) => return Ok(()),
1076            Entry::Vacant(e) => *e.insert(self.id_gen.next()),
1077        };
1078        if self.flags.contains(WriterFlags::DEBUG) {
1079            self.debugs.push(Instruction::name(
1080                function_id,
1081                &format!("{:?}_get_column", r#type.for_debug(&ir_module.types)),
1082            ));
1083        }
1084
1085        let crate::TypeInner::Matrix {
1086            columns,
1087            rows: rows @ crate::VectorSize::Bi,
1088            scalar,
1089        } = ir_module.types[r#type].inner
1090        else {
1091            unreachable!();
1092        };
1093
1094        let mut function = Function::default();
1095        let matrix_type_id = self.get_handle_type_id(r#type);
1096        let column_index_type_id = self.get_u32_type_id();
1097        let column_type_id = self.get_numeric_type_id(NumericType::Vector { size: rows, scalar });
1098        let matrix_param_id = self.id_gen.next();
1099        let column_index_param_id = self.id_gen.next();
1100        function.parameters.push(FunctionArgument {
1101            instruction: Instruction::function_parameter(matrix_type_id, matrix_param_id),
1102            handle_id: 0,
1103        });
1104        function.parameters.push(FunctionArgument {
1105            instruction: Instruction::function_parameter(
1106                column_index_type_id,
1107                column_index_param_id,
1108            ),
1109            handle_id: 0,
1110        });
1111        let function_type_id = self.get_function_type(LookupFunctionType {
1112            parameter_type_ids: vec![matrix_type_id, column_index_type_id],
1113            return_type_id: column_type_id,
1114        });
1115        function.signature = Some(Instruction::function(
1116            column_type_id,
1117            function_id,
1118            spirv::FunctionControl::empty(),
1119            function_type_id,
1120        ));
1121
1122        let label_id = self.id_gen.next();
1123        let mut block = Block::new(label_id);
1124
1125        // Create a switch case for each column in the matrix, where each case
1126        // extracts its column from the matrix. Finally we use OpPhi to return
1127        // the correct column.
1128        let merge_id = self.id_gen.next();
1129        block.body.push(Instruction::selection_merge(
1130            merge_id,
1131            spirv::SelectionControl::NONE,
1132        ));
1133        let cases = (0..columns as u32)
1134            .map(|i| super::instructions::Case {
1135                value: i,
1136                label_id: self.id_gen.next(),
1137            })
1138            .collect::<ArrayVec<_, 4>>();
1139
1140        // Which label we branch to in the default (column index out-of-bounds)
1141        // case depends on our bounds check policy.
1142        let default_id = match self.bounds_check_policies.index {
1143            // For `Restrict`, treat the same as the final column.
1144            crate::proc::BoundsCheckPolicy::Restrict => cases.last().unwrap().label_id,
1145            // For `ReadZeroSkipWrite`, branch directly to the merge block. This
1146            // will be handled in the `OpPhi` below to produce a zero value.
1147            crate::proc::BoundsCheckPolicy::ReadZeroSkipWrite => merge_id,
1148            // For `Unchecked` we create a new block containing an
1149            // `OpUnreachable`.
1150            crate::proc::BoundsCheckPolicy::Unchecked => self.id_gen.next(),
1151        };
1152        function.consume(
1153            block,
1154            Instruction::switch(column_index_param_id, default_id, &cases),
1155        );
1156
1157        // Emit a block for each case, and produce a list of variable and parent
1158        // block IDs that will be used in an `OpPhi` below to select the right
1159        // value.
1160        let mut var_parent_pairs = cases
1161            .into_iter()
1162            .map(|case| {
1163                let mut block = Block::new(case.label_id);
1164                let column_id = self.id_gen.next();
1165                block.body.push(Instruction::composite_extract(
1166                    column_type_id,
1167                    column_id,
1168                    matrix_param_id,
1169                    &[case.value],
1170                ));
1171                function.consume(block, Instruction::branch(merge_id));
1172                (column_id, case.label_id)
1173            })
1174            // Need capacity for up to 4 columns plus possibly a default case.
1175            .collect::<ArrayVec<_, 5>>();
1176
1177        // Emit a block or append the variable and parent `OpPhi` pair for the
1178        // column index out-of-bounds case, if required.
1179        match self.bounds_check_policies.index {
1180            // Don't need to do anything for `Restrict` as we have branched from
1181            // the final column case's block.
1182            crate::proc::BoundsCheckPolicy::Restrict => {}
1183            // For `ReadZeroSkipWrite` we have branched directly from the block
1184            // containing the `OpSwitch`. The `OpPhi` should produce a zero
1185            // value.
1186            crate::proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
1187                var_parent_pairs.push((self.get_constant_null(column_type_id), label_id));
1188            }
1189            // For `Unchecked` create a new block containing `OpUnreachable`.
1190            // This does not need to be handled by the `OpPhi`.
1191            crate::proc::BoundsCheckPolicy::Unchecked => {
1192                function.consume(
1193                    Block::new(default_id),
1194                    Instruction::new(spirv::Op::Unreachable),
1195                );
1196            }
1197        }
1198
1199        let mut block = Block::new(merge_id);
1200        let result_id = self.id_gen.next();
1201        block.body.push(Instruction::phi(
1202            column_type_id,
1203            result_id,
1204            &var_parent_pairs,
1205        ));
1206
1207        function.consume(block, Instruction::return_value(result_id));
1208        function.to_words(&mut self.logical_layout.function_definitions);
1209        Ok(())
1210    }
1211
1212    fn write_function(
1213        &mut self,
1214        ir_function: &crate::Function,
1215        info: &FunctionInfo,
1216        ir_module: &crate::Module,
1217        mut interface: Option<FunctionInterface>,
1218        debug_info: &Option<DebugInfoInner>,
1219    ) -> Result<Word, Error> {
1220        self.write_wrapped_functions(ir_function, info, ir_module)?;
1221
1222        log::trace!("Generating code for {:?}", ir_function.name);
1223        let mut function = Function::default();
1224
1225        let prelude_id = self.id_gen.next();
1226        let mut prelude = Block::new(prelude_id);
1227        let mut ep_context = EntryPointContext {
1228            argument_ids: Vec::new(),
1229            results: Vec::new(),
1230            task_payload_variable_id: if let Some(ref i) = interface {
1231                i.task_payload.map(|a| self.global_variables[a].var_id)
1232            } else {
1233                None
1234            },
1235            mesh_state: None,
1236        };
1237
1238        let mut parameter_type_ids = Vec::with_capacity(ir_function.arguments.len());
1239
1240        let mut local_invocation_index_var_id = None;
1241        let mut local_invocation_index_id = None;
1242
1243        for argument in ir_function.arguments.iter() {
1244            let class = spirv::StorageClass::Input;
1245            let handle_ty = ir_module.types[argument.ty].inner.is_handle();
1246            let argument_type_id = if handle_ty {
1247                self.get_handle_pointer_type_id(argument.ty, spirv::StorageClass::UniformConstant)
1248            } else {
1249                self.get_handle_type_id(argument.ty)
1250            };
1251
1252            if let Some(ref mut iface) = interface {
1253                let id = if let Some(ref binding) = argument.binding {
1254                    let name = argument.name.as_deref();
1255
1256                    let varying_id = self.write_varying(
1257                        ir_module,
1258                        iface.stage,
1259                        class,
1260                        name,
1261                        argument.ty,
1262                        binding,
1263                    )?;
1264                    iface.varying_ids.push(varying_id);
1265                    let id = self.load_io_with_f16_polyfill(
1266                        &mut prelude.body,
1267                        varying_id,
1268                        argument_type_id,
1269                    );
1270                    if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationIndex) {
1271                        local_invocation_index_id = Some(id);
1272                        local_invocation_index_var_id = Some(varying_id);
1273                    }
1274
1275                    id
1276                } else if let crate::TypeInner::Struct { ref members, .. } =
1277                    ir_module.types[argument.ty].inner
1278                {
1279                    let struct_id = self.id_gen.next();
1280                    let mut constituent_ids = Vec::with_capacity(members.len());
1281                    for member in members {
1282                        let type_id = self.get_handle_type_id(member.ty);
1283                        let name = member.name.as_deref();
1284                        let binding = member.binding.as_ref().unwrap();
1285                        let varying_id = self.write_varying(
1286                            ir_module,
1287                            iface.stage,
1288                            class,
1289                            name,
1290                            member.ty,
1291                            binding,
1292                        )?;
1293                        iface.varying_ids.push(varying_id);
1294                        let id =
1295                            self.load_io_with_f16_polyfill(&mut prelude.body, varying_id, type_id);
1296                        constituent_ids.push(id);
1297                        if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationIndex)
1298                        {
1299                            local_invocation_index_id = Some(id);
1300                            local_invocation_index_var_id = Some(varying_id);
1301                        }
1302                    }
1303                    prelude.body.push(Instruction::composite_construct(
1304                        argument_type_id,
1305                        struct_id,
1306                        &constituent_ids,
1307                    ));
1308                    struct_id
1309                } else {
1310                    unreachable!("Missing argument binding on an entry point");
1311                };
1312                ep_context.argument_ids.push(id);
1313            } else {
1314                let argument_id = self.id_gen.next();
1315                let instruction = Instruction::function_parameter(argument_type_id, argument_id);
1316                if self.flags.contains(WriterFlags::DEBUG) {
1317                    if let Some(ref name) = argument.name {
1318                        self.debugs.push(Instruction::name(argument_id, name));
1319                    }
1320                }
1321                function.parameters.push(FunctionArgument {
1322                    instruction,
1323                    handle_id: if handle_ty {
1324                        let id = self.id_gen.next();
1325                        prelude.body.push(Instruction::load(
1326                            self.get_handle_type_id(argument.ty),
1327                            id,
1328                            argument_id,
1329                            None,
1330                        ));
1331                        id
1332                    } else {
1333                        0
1334                    },
1335                });
1336                parameter_type_ids.push(argument_type_id);
1337            };
1338        }
1339
1340        let return_type_id = match ir_function.result {
1341            Some(ref result) => {
1342                if let Some(ref mut iface) = interface {
1343                    let mut has_point_size = false;
1344                    let class = spirv::StorageClass::Output;
1345                    if let Some(ref binding) = result.binding {
1346                        has_point_size |=
1347                            *binding == crate::Binding::BuiltIn(crate::BuiltIn::PointSize);
1348                        let type_id = self.get_handle_type_id(result.ty);
1349                        let varying_id =
1350                            if *binding == crate::Binding::BuiltIn(crate::BuiltIn::MeshTaskSize) {
1351                                0
1352                            } else {
1353                                let varying_id = self.write_varying(
1354                                    ir_module,
1355                                    iface.stage,
1356                                    class,
1357                                    None,
1358                                    result.ty,
1359                                    binding,
1360                                )?;
1361                                iface.varying_ids.push(varying_id);
1362                                varying_id
1363                            };
1364                        ep_context.results.push(ResultMember {
1365                            id: varying_id,
1366                            type_id,
1367                            built_in: binding.to_built_in(),
1368                        });
1369                    } else if let crate::TypeInner::Struct { ref members, .. } =
1370                        ir_module.types[result.ty].inner
1371                    {
1372                        for member in members {
1373                            let type_id = self.get_handle_type_id(member.ty);
1374                            let name = member.name.as_deref();
1375                            let binding = member.binding.as_ref().unwrap();
1376                            has_point_size |=
1377                                *binding == crate::Binding::BuiltIn(crate::BuiltIn::PointSize);
1378                            // This isn't an actual builtin in SPIR-V. It can only appear as the
1379                            // output of a task shader and the output is used when writing the
1380                            // entry point return, in which case the id is ignored anyway.
1381                            let varying_id = if *binding
1382                                == crate::Binding::BuiltIn(crate::BuiltIn::MeshTaskSize)
1383                            {
1384                                0
1385                            } else {
1386                                let varying_id = self.write_varying(
1387                                    ir_module,
1388                                    iface.stage,
1389                                    class,
1390                                    name,
1391                                    member.ty,
1392                                    binding,
1393                                )?;
1394                                iface.varying_ids.push(varying_id);
1395                                varying_id
1396                            };
1397                            ep_context.results.push(ResultMember {
1398                                id: varying_id,
1399                                type_id,
1400                                built_in: binding.to_built_in(),
1401                            });
1402                        }
1403                    } else {
1404                        unreachable!("Missing result binding on an entry point");
1405                    }
1406
1407                    if self.flags.contains(WriterFlags::FORCE_POINT_SIZE)
1408                        && iface.stage == crate::ShaderStage::Vertex
1409                        && !has_point_size
1410                    {
1411                        // add point size artificially
1412                        let varying_id = self.id_gen.next();
1413                        let pointer_type_id = self.get_f32_pointer_type_id(class);
1414                        Instruction::variable(pointer_type_id, varying_id, class, None)
1415                            .to_words(&mut self.logical_layout.declarations);
1416                        self.decorate(
1417                            varying_id,
1418                            spirv::Decoration::BuiltIn,
1419                            &[spirv::BuiltIn::PointSize as u32],
1420                        );
1421                        iface.varying_ids.push(varying_id);
1422
1423                        let default_value_id = self.get_constant_scalar(crate::Literal::F32(1.0));
1424                        prelude
1425                            .body
1426                            .push(Instruction::store(varying_id, default_value_id, None));
1427                    }
1428                    if iface.stage == crate::ShaderStage::Task {
1429                        self.get_vec3u_type_id()
1430                    } else {
1431                        self.void_type
1432                    }
1433                } else {
1434                    self.get_handle_type_id(result.ty)
1435                }
1436            }
1437            None => self.void_type,
1438        };
1439
1440        if let Some(ref mut iface) = interface {
1441            if let Some(task_payload) = iface.task_payload {
1442                iface
1443                    .varying_ids
1444                    .push(self.global_variables[task_payload].var_id);
1445            }
1446            self.write_entry_point_mesh_shader_info(
1447                iface,
1448                local_invocation_index_var_id,
1449                ir_module,
1450                &mut ep_context,
1451            )?;
1452        }
1453
1454        let lookup_function_type = LookupFunctionType {
1455            parameter_type_ids,
1456            return_type_id,
1457        };
1458
1459        let function_id = self.id_gen.next();
1460        if self.flags.contains(WriterFlags::DEBUG) {
1461            if let Some(ref name) = ir_function.name {
1462                self.debugs.push(Instruction::name(function_id, name));
1463            }
1464        }
1465
1466        let function_type = self.get_function_type(lookup_function_type);
1467        function.signature = Some(Instruction::function(
1468            return_type_id,
1469            function_id,
1470            spirv::FunctionControl::empty(),
1471            function_type,
1472        ));
1473
1474        if interface.is_some() {
1475            function.entry_point_context = Some(ep_context);
1476        }
1477
1478        // fill up the `GlobalVariable::access_id`
1479        for gv in self.global_variables.iter_mut() {
1480            gv.reset_for_function();
1481        }
1482        for (handle, var) in ir_module.global_variables.iter() {
1483            if info[handle].is_empty() {
1484                continue;
1485            }
1486
1487            let mut gv = self.global_variables[handle].clone();
1488            if let Some(ref mut iface) = interface {
1489                // Have to include global variables in the interface
1490                if self.physical_layout.version >= 0x10400 && iface.task_payload != Some(handle) {
1491                    iface.varying_ids.push(gv.var_id);
1492                }
1493            }
1494
1495            match ir_module.types[var.ty].inner {
1496                // Any that are binding arrays we skip as we cannot load the array, we must load the result after indexing.
1497                crate::TypeInner::BindingArray { .. } => {
1498                    gv.access_id = gv.var_id;
1499                }
1500                _ => {
1501                    // Handle globals are pre-emitted and should be loaded automatically.
1502                    if var.space == crate::AddressSpace::Handle {
1503                        let var_type_id = self.get_handle_type_id(var.ty);
1504                        let id = self.id_gen.next();
1505                        prelude
1506                            .body
1507                            .push(Instruction::load(var_type_id, id, gv.var_id, None));
1508                        gv.access_id = gv.var_id;
1509                        gv.handle_id = id;
1510                    } else if global_needs_wrapper(ir_module, var) {
1511                        let class = map_storage_class(var.space);
1512                        let pointer_type_id = match self.std140_compat_uniform_types.get(&var.ty) {
1513                            Some(std140_type_info) if var.space == crate::AddressSpace::Uniform => {
1514                                self.get_pointer_type_id(std140_type_info.type_id, class)
1515                            }
1516                            _ => self.get_handle_pointer_type_id(var.ty, class),
1517                        };
1518                        let index_id = self.get_index_constant(0);
1519                        let id = self.id_gen.next();
1520                        prelude.body.push(Instruction::access_chain(
1521                            pointer_type_id,
1522                            id,
1523                            gv.var_id,
1524                            &[index_id],
1525                        ));
1526                        gv.access_id = id;
1527                    } else {
1528                        // by default, the variable ID is accessed as is
1529                        gv.access_id = gv.var_id;
1530                    };
1531                }
1532            }
1533
1534            // work around borrow checking in the presence of `self.xxx()` calls
1535            self.global_variables[handle] = gv;
1536        }
1537
1538        // Create a `BlockContext` for generating SPIR-V for the function's
1539        // body.
1540        let mut context = BlockContext {
1541            ir_module,
1542            ir_function,
1543            fun_info: info,
1544            function: &mut function,
1545            // Re-use the cached expression table from prior functions.
1546            cached: core::mem::take(&mut self.saved_cached),
1547
1548            // Steal the Writer's temp list for a bit.
1549            temp_list: core::mem::take(&mut self.temp_list),
1550            force_loop_bounding: self.force_loop_bounding,
1551            writer: self,
1552            expression_constness: super::ExpressionConstnessTracker::from_arena(
1553                &ir_function.expressions,
1554            ),
1555            ray_query_tracker_expr: crate::FastHashMap::default(),
1556        };
1557
1558        // fill up the pre-emitted and const expressions
1559        context.cached.reset(ir_function.expressions.len());
1560        for (handle, expr) in ir_function.expressions.iter() {
1561            if (expr.needs_pre_emit() && !matches!(*expr, crate::Expression::LocalVariable(_)))
1562                || context.expression_constness.is_const(handle)
1563            {
1564                context.cache_expression_value(handle, &mut prelude)?;
1565            }
1566        }
1567
1568        for (handle, variable) in ir_function.local_variables.iter() {
1569            let id = context.gen_id();
1570
1571            if context.writer.flags.contains(WriterFlags::DEBUG) {
1572                if let Some(ref name) = variable.name {
1573                    context.writer.debugs.push(Instruction::name(id, name));
1574                }
1575            }
1576
1577            let init_word = variable.init.map(|constant| context.cached[constant]);
1578            let pointer_type_id = context
1579                .writer
1580                .get_handle_pointer_type_id(variable.ty, spirv::StorageClass::Function);
1581            let instruction = Instruction::variable(
1582                pointer_type_id,
1583                id,
1584                spirv::StorageClass::Function,
1585                init_word.or_else(|| match ir_module.types[variable.ty].inner {
1586                    crate::TypeInner::RayQuery { .. } => None,
1587                    _ => {
1588                        let type_id = context.get_handle_type_id(variable.ty);
1589                        Some(context.writer.write_constant_null(type_id))
1590                    }
1591                }),
1592            );
1593
1594            context
1595                .function
1596                .variables
1597                .insert(handle, LocalVariable { id, instruction });
1598
1599            if let crate::TypeInner::RayQuery { .. } = ir_module.types[variable.ty].inner {
1600                // Don't refactor this into a struct: Although spirv itself allows opaque types in structs,
1601                // the vulkan environment for spirv does not. Putting ray queries into structs can cause
1602                // confusing bugs.
1603                let u32_type_id = context.writer.get_u32_type_id();
1604                let ptr_u32_type_id = context
1605                    .writer
1606                    .get_pointer_type_id(u32_type_id, spirv::StorageClass::Function);
1607                let tracker_id = context.gen_id();
1608                let tracker_init_id = context.writer.get_constant_scalar(crate::Literal::U32(
1609                    crate::back::RayQueryPoint::empty().bits(),
1610                ));
1611                let tracker_instruction = Instruction::variable(
1612                    ptr_u32_type_id,
1613                    tracker_id,
1614                    spirv::StorageClass::Function,
1615                    Some(tracker_init_id),
1616                );
1617
1618                context
1619                    .function
1620                    .ray_query_initialization_tracker_variables
1621                    .insert(
1622                        handle,
1623                        LocalVariable {
1624                            id: tracker_id,
1625                            instruction: tracker_instruction,
1626                        },
1627                    );
1628                let f32_type_id = context.writer.get_f32_type_id();
1629                let ptr_f32_type_id = context
1630                    .writer
1631                    .get_pointer_type_id(f32_type_id, spirv::StorageClass::Function);
1632                let t_max_tracker_id = context.gen_id();
1633                let t_max_tracker_init_id =
1634                    context.writer.get_constant_scalar(crate::Literal::F32(0.0));
1635                let t_max_tracker_instruction = Instruction::variable(
1636                    ptr_f32_type_id,
1637                    t_max_tracker_id,
1638                    spirv::StorageClass::Function,
1639                    Some(t_max_tracker_init_id),
1640                );
1641
1642                context.function.ray_query_t_max_tracker_variables.insert(
1643                    handle,
1644                    LocalVariable {
1645                        id: t_max_tracker_id,
1646                        instruction: t_max_tracker_instruction,
1647                    },
1648                );
1649            }
1650        }
1651
1652        for (handle, expr) in ir_function.expressions.iter() {
1653            match *expr {
1654                crate::Expression::LocalVariable(_) => {
1655                    // Cache the `OpVariable` instruction we generated above as
1656                    // the value of this expression.
1657                    context.cache_expression_value(handle, &mut prelude)?;
1658                }
1659                crate::Expression::Access { base, .. }
1660                | crate::Expression::AccessIndex { base, .. } => {
1661                    // Count references to `base` by `Access` and `AccessIndex`
1662                    // instructions. See `access_uses` for details.
1663                    *context.function.access_uses.entry(base).or_insert(0) += 1;
1664                }
1665                _ => {}
1666            }
1667        }
1668
1669        let next_id = context.gen_id();
1670
1671        context
1672            .function
1673            .consume(prelude, Instruction::branch(next_id));
1674
1675        let workgroup_vars_init_exit_block_id =
1676            match (context.writer.zero_initialize_workgroup_memory, interface) {
1677                (
1678                    super::ZeroInitializeWorkgroupMemoryMode::Polyfill,
1679                    Some(
1680                        ref mut interface @ FunctionInterface {
1681                            stage:
1682                                crate::ShaderStage::Compute
1683                                | crate::ShaderStage::Mesh
1684                                | crate::ShaderStage::Task,
1685                            ..
1686                        },
1687                    ),
1688                ) => context.writer.generate_workgroup_vars_init_block(
1689                    next_id,
1690                    ir_module,
1691                    info,
1692                    local_invocation_index_id,
1693                    interface,
1694                    context.function,
1695                ),
1696                _ => None,
1697            };
1698
1699        let main_id = if let Some(exit_id) = workgroup_vars_init_exit_block_id {
1700            exit_id
1701        } else {
1702            next_id
1703        };
1704
1705        context.write_function_body(main_id, debug_info.as_ref())?;
1706
1707        // Consume the `BlockContext`, ending its borrows and letting the
1708        // `Writer` steal back its cached expression table and temp_list.
1709        let BlockContext {
1710            cached, temp_list, ..
1711        } = context;
1712        self.saved_cached = cached;
1713        self.temp_list = temp_list;
1714
1715        function.to_words(&mut self.logical_layout.function_definitions);
1716
1717        if let Some(EntryPointContext {
1718            mesh_state: Some(ref mesh_state),
1719            ..
1720        }) = function.entry_point_context
1721        {
1722            self.write_mesh_shader_wrapper(mesh_state, function_id)
1723        } else if let Some(EntryPointContext {
1724            task_payload_variable_id: Some(tp),
1725            ..
1726        }) = function.entry_point_context
1727        {
1728            self.write_task_shader_wrapper(tp, function_id)
1729        } else {
1730            Ok(function_id)
1731        }
1732    }
1733
1734    fn write_execution_mode(
1735        &mut self,
1736        function_id: Word,
1737        mode: spirv::ExecutionMode,
1738    ) -> Result<(), Error> {
1739        //self.check(mode.required_capabilities())?;
1740        Instruction::execution_mode(function_id, mode, &[])
1741            .to_words(&mut self.logical_layout.execution_modes);
1742        Ok(())
1743    }
1744
1745    // TODO Move to instructions module
1746    fn write_entry_point(
1747        &mut self,
1748        entry_point: &crate::EntryPoint,
1749        info: &FunctionInfo,
1750        ir_module: &crate::Module,
1751        debug_info: &Option<DebugInfoInner>,
1752    ) -> Result<Instruction, Error> {
1753        let mut interface_ids = Vec::new();
1754
1755        let function_id = self.write_function(
1756            &entry_point.function,
1757            info,
1758            ir_module,
1759            Some(FunctionInterface {
1760                varying_ids: &mut interface_ids,
1761                stage: entry_point.stage,
1762                task_payload: entry_point.task_payload,
1763                mesh_info: entry_point.mesh_info.clone(),
1764                workgroup_size: entry_point.workgroup_size,
1765            }),
1766            debug_info,
1767        )?;
1768
1769        let exec_model = match entry_point.stage {
1770            crate::ShaderStage::Vertex => spirv::ExecutionModel::Vertex,
1771            crate::ShaderStage::Fragment => {
1772                self.write_execution_mode(function_id, spirv::ExecutionMode::OriginUpperLeft)?;
1773                match entry_point.early_depth_test {
1774                    Some(crate::EarlyDepthTest::Force) => {
1775                        self.write_execution_mode(
1776                            function_id,
1777                            spirv::ExecutionMode::EarlyFragmentTests,
1778                        )?;
1779                    }
1780                    Some(crate::EarlyDepthTest::Allow { conservative }) => {
1781                        // TODO: Consider emitting EarlyAndLateFragmentTestsAMD here, if available.
1782                        // https://github.khronos.org/SPIRV-Registry/extensions/AMD/SPV_AMD_shader_early_and_late_fragment_tests.html
1783                        // This permits early depth tests even if the shader writes to a storage
1784                        // binding
1785                        match conservative {
1786                            crate::ConservativeDepth::GreaterEqual => self.write_execution_mode(
1787                                function_id,
1788                                spirv::ExecutionMode::DepthGreater,
1789                            )?,
1790                            crate::ConservativeDepth::LessEqual => self.write_execution_mode(
1791                                function_id,
1792                                spirv::ExecutionMode::DepthLess,
1793                            )?,
1794                            crate::ConservativeDepth::Unchanged => self.write_execution_mode(
1795                                function_id,
1796                                spirv::ExecutionMode::DepthUnchanged,
1797                            )?,
1798                        }
1799                    }
1800                    None => {}
1801                }
1802                if let Some(ref result) = entry_point.function.result {
1803                    if contains_builtin(
1804                        result.binding.as_ref(),
1805                        result.ty,
1806                        &ir_module.types,
1807                        crate::BuiltIn::FragDepth,
1808                    ) {
1809                        self.write_execution_mode(
1810                            function_id,
1811                            spirv::ExecutionMode::DepthReplacing,
1812                        )?;
1813                    }
1814                }
1815                spirv::ExecutionModel::Fragment
1816            }
1817            crate::ShaderStage::Compute => {
1818                let execution_mode = spirv::ExecutionMode::LocalSize;
1819                Instruction::execution_mode(
1820                    function_id,
1821                    execution_mode,
1822                    &entry_point.workgroup_size,
1823                )
1824                .to_words(&mut self.logical_layout.execution_modes);
1825                spirv::ExecutionModel::GLCompute
1826            }
1827            crate::ShaderStage::Task => {
1828                let execution_mode = spirv::ExecutionMode::LocalSize;
1829                Instruction::execution_mode(
1830                    function_id,
1831                    execution_mode,
1832                    &entry_point.workgroup_size,
1833                )
1834                .to_words(&mut self.logical_layout.execution_modes);
1835                spirv::ExecutionModel::TaskEXT
1836            }
1837            crate::ShaderStage::Mesh => {
1838                let execution_mode = spirv::ExecutionMode::LocalSize;
1839                Instruction::execution_mode(
1840                    function_id,
1841                    execution_mode,
1842                    &entry_point.workgroup_size,
1843                )
1844                .to_words(&mut self.logical_layout.execution_modes);
1845                let mesh_info = entry_point.mesh_info.as_ref().unwrap();
1846                Instruction::execution_mode(
1847                    function_id,
1848                    match mesh_info.topology {
1849                        crate::MeshOutputTopology::Points => spirv::ExecutionMode::OutputPoints,
1850                        crate::MeshOutputTopology::Lines => spirv::ExecutionMode::OutputLinesEXT,
1851                        crate::MeshOutputTopology::Triangles => {
1852                            spirv::ExecutionMode::OutputTrianglesEXT
1853                        }
1854                    },
1855                    &[],
1856                )
1857                .to_words(&mut self.logical_layout.execution_modes);
1858                Instruction::execution_mode(
1859                    function_id,
1860                    spirv::ExecutionMode::OutputVertices,
1861                    core::slice::from_ref(&mesh_info.max_vertices),
1862                )
1863                .to_words(&mut self.logical_layout.execution_modes);
1864                Instruction::execution_mode(
1865                    function_id,
1866                    spirv::ExecutionMode::OutputPrimitivesEXT,
1867                    core::slice::from_ref(&mesh_info.max_primitives),
1868                )
1869                .to_words(&mut self.logical_layout.execution_modes);
1870                spirv::ExecutionModel::MeshEXT
1871            }
1872            crate::ShaderStage::RayGeneration => {
1873                self.require_any("ray tracing pipelines", &[spirv::Capability::RayTracingKHR])?;
1874                spirv::ExecutionModel::RayGenerationKHR
1875            }
1876            crate::ShaderStage::AnyHit => {
1877                self.require_any("ray tracing pipelines", &[spirv::Capability::RayTracingKHR])?;
1878                spirv::ExecutionModel::AnyHitKHR
1879            }
1880            crate::ShaderStage::ClosestHit => {
1881                self.require_any("ray tracing pipelines", &[spirv::Capability::RayTracingKHR])?;
1882                spirv::ExecutionModel::ClosestHitKHR
1883            }
1884            crate::ShaderStage::Miss => {
1885                self.require_any("ray tracing pipelines", &[spirv::Capability::RayTracingKHR])?;
1886                spirv::ExecutionModel::MissKHR
1887            }
1888        };
1889        //self.check(exec_model.required_capabilities())?;
1890
1891        Ok(Instruction::entry_point(
1892            exec_model,
1893            function_id,
1894            &entry_point.name,
1895            interface_ids.as_slice(),
1896        ))
1897    }
1898
1899    fn make_scalar(&mut self, id: Word, scalar: crate::Scalar) -> Instruction {
1900        use crate::ScalarKind as Sk;
1901
1902        let bits = (scalar.width * BITS_PER_BYTE) as u32;
1903        match scalar.kind {
1904            Sk::Sint | Sk::Uint => {
1905                let signedness = if scalar.kind == Sk::Sint {
1906                    super::instructions::Signedness::Signed
1907                } else {
1908                    super::instructions::Signedness::Unsigned
1909                };
1910                let cap = match bits {
1911                    8 => Some(spirv::Capability::Int8),
1912                    16 => Some(spirv::Capability::Int16),
1913                    64 => Some(spirv::Capability::Int64),
1914                    _ => None,
1915                };
1916                if let Some(cap) = cap {
1917                    self.capabilities_used.insert(cap);
1918                }
1919                if bits == 16 {
1920                    self.capabilities_used
1921                        .insert(spirv::Capability::StorageBuffer16BitAccess);
1922                    self.capabilities_used
1923                        .insert(spirv::Capability::UniformAndStorageBuffer16BitAccess);
1924                    if self.use_storage_input_output_16 {
1925                        self.capabilities_used
1926                            .insert(spirv::Capability::StorageInputOutput16);
1927                    }
1928                }
1929                Instruction::type_int(id, bits, signedness)
1930            }
1931            Sk::Float => {
1932                if bits == 64 {
1933                    self.capabilities_used.insert(spirv::Capability::Float64);
1934                }
1935                if bits == 16 {
1936                    self.capabilities_used.insert(spirv::Capability::Float16);
1937                    self.capabilities_used
1938                        .insert(spirv::Capability::StorageBuffer16BitAccess);
1939                    self.capabilities_used
1940                        .insert(spirv::Capability::UniformAndStorageBuffer16BitAccess);
1941                    if self.use_storage_input_output_16 {
1942                        self.capabilities_used
1943                            .insert(spirv::Capability::StorageInputOutput16);
1944                    }
1945                }
1946                Instruction::type_float(id, bits)
1947            }
1948            Sk::Bool => Instruction::type_bool(id),
1949            Sk::AbstractInt | Sk::AbstractFloat => {
1950                unreachable!("abstract types should never reach the backend");
1951            }
1952        }
1953    }
1954
1955    fn request_type_capabilities(&mut self, inner: &crate::TypeInner) -> Result<(), Error> {
1956        match *inner {
1957            crate::TypeInner::Image {
1958                dim,
1959                arrayed,
1960                class,
1961            } => {
1962                let sampled = match class {
1963                    crate::ImageClass::Sampled { .. } => true,
1964                    crate::ImageClass::Depth { .. } => true,
1965                    crate::ImageClass::Storage { format, .. } => {
1966                        self.request_image_format_capabilities(format.into())?;
1967                        false
1968                    }
1969                    crate::ImageClass::External => unimplemented!(),
1970                };
1971
1972                match dim {
1973                    crate::ImageDimension::D1 => {
1974                        if sampled {
1975                            self.require_any("sampled 1D images", &[spirv::Capability::Sampled1D])?;
1976                        } else {
1977                            self.require_any("1D storage images", &[spirv::Capability::Image1D])?;
1978                        }
1979                    }
1980                    crate::ImageDimension::Cube if arrayed => {
1981                        if sampled {
1982                            self.require_any(
1983                                "sampled cube array images",
1984                                &[spirv::Capability::SampledCubeArray],
1985                            )?;
1986                        } else {
1987                            self.require_any(
1988                                "cube array storage images",
1989                                &[spirv::Capability::ImageCubeArray],
1990                            )?;
1991                        }
1992                    }
1993                    _ => {}
1994                }
1995            }
1996            crate::TypeInner::AccelerationStructure { .. } => {
1997                self.require_any(
1998                    "Acceleration Structure",
1999                    // unless we use this conditional, the ray query snapshot
2000                    // tests pick the wrong capability
2001                    &[if self.has_ray_tracing_pipeline {
2002                        spirv::Capability::RayTracingKHR
2003                    } else {
2004                        spirv::Capability::RayQueryKHR
2005                    }],
2006                )?;
2007            }
2008            crate::TypeInner::RayQuery { .. } => {
2009                self.require_any("Ray Query", &[spirv::Capability::RayQueryKHR])?;
2010            }
2011            crate::TypeInner::Atomic(crate::Scalar { width: 8, kind: _ }) => {
2012                self.require_any("64 bit integer atomics", &[spirv::Capability::Int64Atomics])?;
2013            }
2014            crate::TypeInner::Atomic(crate::Scalar {
2015                width: 4,
2016                kind: crate::ScalarKind::Float,
2017            }) => {
2018                self.require_any(
2019                    "32 bit floating-point atomics",
2020                    &[spirv::Capability::AtomicFloat32AddEXT],
2021                )?;
2022                self.use_extension("SPV_EXT_shader_atomic_float_add");
2023            }
2024            // 16 bit floating-point support requires Float16 capability
2025            crate::TypeInner::Matrix {
2026                scalar: crate::Scalar::F16,
2027                ..
2028            }
2029            | crate::TypeInner::Vector {
2030                scalar: crate::Scalar::F16,
2031                ..
2032            }
2033            | crate::TypeInner::Scalar(crate::Scalar::F16) => {
2034                self.require_any("16 bit floating-point", &[spirv::Capability::Float16])?;
2035                self.use_extension("SPV_KHR_16bit_storage");
2036            }
2037            // 16 bit integer support requires Int16 capability
2038            crate::TypeInner::Vector {
2039                scalar:
2040                    crate::Scalar {
2041                        kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
2042                        width: 2,
2043                    },
2044                ..
2045            }
2046            | crate::TypeInner::Scalar(crate::Scalar {
2047                kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
2048                width: 2,
2049            }) => {
2050                self.require_any("16 bit integer", &[spirv::Capability::Int16])?;
2051                self.use_extension("SPV_KHR_16bit_storage");
2052            }
2053            // Cooperative types and ops
2054            crate::TypeInner::CooperativeMatrix { .. } => {
2055                self.require_any(
2056                    "cooperative matrix",
2057                    &[spirv::Capability::CooperativeMatrixKHR],
2058                )?;
2059                self.require_any("memory model", &[spirv::Capability::VulkanMemoryModel])?;
2060                self.use_extension("SPV_KHR_cooperative_matrix");
2061                self.use_extension("SPV_KHR_vulkan_memory_model");
2062            }
2063            _ => {}
2064        }
2065        Ok(())
2066    }
2067
2068    fn write_numeric_type_declaration_local(&mut self, id: Word, numeric: NumericType) {
2069        let instruction = match numeric {
2070            NumericType::Scalar(scalar) => self.make_scalar(id, scalar),
2071            NumericType::Vector { size, scalar } => {
2072                let scalar_id = self.get_numeric_type_id(NumericType::Scalar(scalar));
2073                Instruction::type_vector(id, scalar_id, size)
2074            }
2075            NumericType::Matrix {
2076                columns,
2077                rows,
2078                scalar,
2079            } => {
2080                let column_id =
2081                    self.get_numeric_type_id(NumericType::Vector { size: rows, scalar });
2082                Instruction::type_matrix(id, column_id, columns)
2083            }
2084        };
2085
2086        instruction.to_words(&mut self.logical_layout.declarations);
2087    }
2088
2089    fn write_cooperative_type_declaration_local(&mut self, id: Word, coop: CooperativeType) {
2090        let instruction = match coop {
2091            CooperativeType::Matrix {
2092                columns,
2093                rows,
2094                scalar,
2095                role,
2096            } => {
2097                let scalar_id =
2098                    self.get_localtype_id(LocalType::Numeric(NumericType::Scalar(scalar)));
2099                let scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32);
2100                let columns_id = self.get_index_constant(columns as u32);
2101                let rows_id = self.get_index_constant(rows as u32);
2102                let role_id =
2103                    self.get_index_constant(spirv::CooperativeMatrixUse::from(role) as u32);
2104                Instruction::type_coop_matrix(id, scalar_id, scope_id, rows_id, columns_id, role_id)
2105            }
2106        };
2107
2108        instruction.to_words(&mut self.logical_layout.declarations);
2109    }
2110
2111    fn write_type_declaration_local(&mut self, id: Word, local_ty: LocalType) {
2112        let instruction = match local_ty {
2113            LocalType::Numeric(numeric) => {
2114                self.write_numeric_type_declaration_local(id, numeric);
2115                return;
2116            }
2117            LocalType::Cooperative(coop) => {
2118                self.write_cooperative_type_declaration_local(id, coop);
2119                return;
2120            }
2121            LocalType::Pointer { base, class } => Instruction::type_pointer(id, class, base),
2122            LocalType::Image(image) => {
2123                let local_type = LocalType::Numeric(NumericType::Scalar(image.sampled_type));
2124                let type_id = self.get_localtype_id(local_type);
2125                Instruction::type_image(id, type_id, image.dim, image.flags, image.image_format)
2126            }
2127            LocalType::Sampler => Instruction::type_sampler(id),
2128            LocalType::SampledImage { image_type_id } => {
2129                Instruction::type_sampled_image(id, image_type_id)
2130            }
2131            LocalType::BindingArray { base, size } => {
2132                let inner_ty = self.get_handle_type_id(base);
2133                let scalar_id = self.get_constant_scalar(crate::Literal::U32(size));
2134                Instruction::type_array(id, inner_ty, scalar_id)
2135            }
2136            LocalType::AccelerationStructure => Instruction::type_acceleration_structure(id),
2137            LocalType::RayQuery => Instruction::type_ray_query(id),
2138        };
2139
2140        instruction.to_words(&mut self.logical_layout.declarations);
2141    }
2142
2143    fn write_type_declaration_arena(
2144        &mut self,
2145        module: &crate::Module,
2146        handle: Handle<crate::Type>,
2147    ) -> Result<Word, Error> {
2148        let ty = &module.types[handle];
2149        // If it's a type that needs SPIR-V capabilities, request them now.
2150        // This needs to happen regardless of the LocalType lookup succeeding,
2151        // because some types which map to the same LocalType have different
2152        // capability requirements. See https://github.com/gfx-rs/wgpu/issues/5569
2153        self.request_type_capabilities(&ty.inner)?;
2154        let id = if let Some(local) = self.localtype_from_inner(&ty.inner) {
2155            // This type can be represented as a `LocalType`, so check if we've
2156            // already written an instruction for it. If not, do so now, with
2157            // `write_type_declaration_local`.
2158            match self.lookup_type.entry(LookupType::Local(local)) {
2159                // We already have an id for this `LocalType`.
2160                Entry::Occupied(e) => *e.get(),
2161
2162                // It's a type we haven't seen before.
2163                Entry::Vacant(e) => {
2164                    let id = self.id_gen.next();
2165                    e.insert(id);
2166
2167                    self.write_type_declaration_local(id, local);
2168
2169                    id
2170                }
2171            }
2172        } else {
2173            use spirv::Decoration;
2174
2175            let id = self.id_gen.next();
2176            let instruction = match ty.inner {
2177                crate::TypeInner::Array { base, size, stride } => {
2178                    self.decorate(id, Decoration::ArrayStride, &[stride]);
2179
2180                    let type_id = self.get_handle_type_id(base);
2181                    match size.resolve(module.to_ctx())? {
2182                        crate::proc::IndexableLength::Known(length) => {
2183                            let length_id = self.get_index_constant(length);
2184                            Instruction::type_array(id, type_id, length_id)
2185                        }
2186                        crate::proc::IndexableLength::Dynamic => {
2187                            Instruction::type_runtime_array(id, type_id)
2188                        }
2189                    }
2190                }
2191                crate::TypeInner::BindingArray { base, size } => {
2192                    let type_id = self.get_handle_type_id(base);
2193                    match size.resolve(module.to_ctx())? {
2194                        crate::proc::IndexableLength::Known(length) => {
2195                            let length_id = self.get_index_constant(length);
2196                            Instruction::type_array(id, type_id, length_id)
2197                        }
2198                        crate::proc::IndexableLength::Dynamic => {
2199                            Instruction::type_runtime_array(id, type_id)
2200                        }
2201                    }
2202                }
2203                crate::TypeInner::Struct {
2204                    ref members,
2205                    span: _,
2206                } => {
2207                    let mut has_runtime_array = false;
2208                    let mut member_ids = Vec::with_capacity(members.len());
2209                    for (index, member) in members.iter().enumerate() {
2210                        let member_ty = &module.types[member.ty];
2211                        match member_ty.inner {
2212                            crate::TypeInner::Array {
2213                                base: _,
2214                                size: crate::ArraySize::Dynamic,
2215                                stride: _,
2216                            } => {
2217                                has_runtime_array = true;
2218                            }
2219                            _ => (),
2220                        }
2221                        self.decorate_struct_member(id, index, member, &module.types)?;
2222                        let member_id = self.get_handle_type_id(member.ty);
2223                        member_ids.push(member_id);
2224                    }
2225                    if has_runtime_array {
2226                        self.decorate(id, Decoration::Block, &[]);
2227                    }
2228                    Instruction::type_struct(id, member_ids.as_slice())
2229                }
2230
2231                // These all have TypeLocal representations, so they should have been
2232                // handled by `write_type_declaration_local` above.
2233                crate::TypeInner::Scalar(_)
2234                | crate::TypeInner::Atomic(_)
2235                | crate::TypeInner::Vector { .. }
2236                | crate::TypeInner::Matrix { .. }
2237                | crate::TypeInner::CooperativeMatrix { .. }
2238                | crate::TypeInner::Pointer { .. }
2239                | crate::TypeInner::ValuePointer { .. }
2240                | crate::TypeInner::Image { .. }
2241                | crate::TypeInner::Sampler { .. }
2242                | crate::TypeInner::AccelerationStructure { .. }
2243                | crate::TypeInner::RayQuery { .. } => unreachable!(),
2244            };
2245
2246            instruction.to_words(&mut self.logical_layout.declarations);
2247            id
2248        };
2249
2250        // Add this handle as a new alias for that type.
2251        self.lookup_type.insert(LookupType::Handle(handle), id);
2252
2253        if self.flags.contains(WriterFlags::DEBUG) {
2254            if let Some(ref name) = ty.name {
2255                self.debugs.push(Instruction::name(id, name));
2256            }
2257        }
2258
2259        Ok(id)
2260    }
2261
2262    /// Writes a std140 layout compatible type declaration for a type. Returns
2263    /// the ID of the declared type, or None if no declaration is required.
2264    ///
2265    /// This should be called for any type for which there exists a
2266    /// [`GlobalVariable`] in the [`Uniform`] address space. If the type already
2267    /// adheres to std140 layout rules it will return without declaring any
2268    /// types. If the type contains another type which requires a std140
2269    /// compatible type declaration, it will recursively call itself.
2270    ///
2271    /// When `handle` refers to a [`TypeInner::Matrix`] with 2 rows, the
2272    /// declared type will be an `OpTypeStruct` containing an `OpVector` for
2273    /// each of the matrix's columns.
2274    ///
2275    /// When `handle` refers to a [`TypeInner::Array`] whose base type is a
2276    /// matrix with 2 rows, this will declare an `OpTypeArray` whose element
2277    /// type is the matrix's corresponding std140 compatible type.
2278    ///
2279    /// When `handle` refers to a [`TypeInner::Struct`] and any of its members
2280    /// require a std140 compatible type declaration, this will declare a new
2281    /// struct with the following rules:
2282    /// * Struct or array members will be declared with their std140 compatible
2283    ///   type declaration, if one is required.
2284    /// * Two-row matrix members will have each of their columns hoisted
2285    ///   directly into the struct as 2-component vector members.
2286    /// * All other members will be declared with their normal type.
2287    ///
2288    /// Note that this means the Naga IR index of a struct member may not match
2289    /// the index in the generated SPIR-V. The mapping can be obtained via
2290    /// `Std140TypeInfo::member_indices`.
2291    ///
2292    /// [`GlobalVariable`]: crate::GlobalVariable
2293    /// [`Uniform`]: crate::AddressSpace::Uniform
2294    /// [`TypeInner::Matrix`]: crate::TypeInner::Matrix
2295    /// [`TypeInner::Array`]: crate::TypeInner::Array
2296    /// [`TypeInner::Struct`]: crate::TypeInner::Struct
2297    fn write_std140_compat_type_declaration(
2298        &mut self,
2299        module: &crate::Module,
2300        handle: Handle<crate::Type>,
2301    ) -> Result<Option<Word>, Error> {
2302        if let Some(std140_type_info) = self.std140_compat_uniform_types.get(&handle) {
2303            return Ok(Some(std140_type_info.type_id));
2304        }
2305
2306        let type_inner = &module.types[handle].inner;
2307        let std140_type_id = match *type_inner {
2308            crate::TypeInner::Matrix {
2309                columns,
2310                rows: rows @ crate::VectorSize::Bi,
2311                scalar,
2312            } => {
2313                let std140_type_id = self.id_gen.next();
2314                let mut member_type_ids: ArrayVec<Word, 4> = ArrayVec::new();
2315                let column_type_id =
2316                    self.get_numeric_type_id(NumericType::Vector { size: rows, scalar });
2317                for column in 0..columns as u32 {
2318                    member_type_ids.push(column_type_id);
2319                    self.annotations.push(Instruction::member_decorate(
2320                        std140_type_id,
2321                        column,
2322                        spirv::Decoration::Offset,
2323                        &[column * rows as u32 * scalar.width as u32],
2324                    ));
2325                    if self.flags.contains(WriterFlags::DEBUG) {
2326                        self.debugs.push(Instruction::member_name(
2327                            std140_type_id,
2328                            column,
2329                            &format!("col{column}"),
2330                        ));
2331                    }
2332                }
2333                Instruction::type_struct(std140_type_id, &member_type_ids)
2334                    .to_words(&mut self.logical_layout.declarations);
2335                self.std140_compat_uniform_types.insert(
2336                    handle,
2337                    Std140CompatTypeInfo {
2338                        type_id: std140_type_id,
2339                        member_indices: Vec::new(),
2340                    },
2341                );
2342                Some(std140_type_id)
2343            }
2344            crate::TypeInner::Array { base, size, stride } => {
2345                match self.write_std140_compat_type_declaration(module, base)? {
2346                    Some(std140_base_type_id) => {
2347                        let std140_type_id = self.id_gen.next();
2348                        self.decorate(std140_type_id, spirv::Decoration::ArrayStride, &[stride]);
2349                        let instruction = match size.resolve(module.to_ctx())? {
2350                            crate::proc::IndexableLength::Known(length) => {
2351                                let length_id = self.get_index_constant(length);
2352                                Instruction::type_array(
2353                                    std140_type_id,
2354                                    std140_base_type_id,
2355                                    length_id,
2356                                )
2357                            }
2358                            crate::proc::IndexableLength::Dynamic => {
2359                                unreachable!()
2360                            }
2361                        };
2362                        instruction.to_words(&mut self.logical_layout.declarations);
2363                        self.std140_compat_uniform_types.insert(
2364                            handle,
2365                            Std140CompatTypeInfo {
2366                                type_id: std140_type_id,
2367                                member_indices: Vec::new(),
2368                            },
2369                        );
2370                        Some(std140_type_id)
2371                    }
2372                    None => None,
2373                }
2374            }
2375            crate::TypeInner::Struct { ref members, .. } => {
2376                let mut needs_std140_type = false;
2377                for member in members {
2378                    match module.types[member.ty].inner {
2379                        // We don't need to write a std140 type for the matrix itself as
2380                        // it will be decomposed into the parent struct. As a result, the
2381                        // struct does need a std140 type, however.
2382                        crate::TypeInner::Matrix {
2383                            rows: crate::VectorSize::Bi,
2384                            ..
2385                        } => needs_std140_type = true,
2386                        // If an array member needs a std140 type, because it is an array
2387                        // (of an array, etc) of `matCx2`s, then the struct also needs
2388                        // a std140 type which uses the std140 type for this member.
2389                        crate::TypeInner::Array { .. }
2390                            if self
2391                                .write_std140_compat_type_declaration(module, member.ty)?
2392                                .is_some() =>
2393                        {
2394                            needs_std140_type = true;
2395                        }
2396                        _ => {}
2397                    }
2398                }
2399
2400                if needs_std140_type {
2401                    let std140_type_id = self.id_gen.next();
2402                    let mut member_ids = Vec::new();
2403                    let mut member_indices = Vec::new();
2404                    let mut next_index = 0;
2405
2406                    for member in members {
2407                        member_indices.push(next_index);
2408                        match module.types[member.ty].inner {
2409                            crate::TypeInner::Matrix {
2410                                columns,
2411                                rows: rows @ crate::VectorSize::Bi,
2412                                scalar,
2413                            } => {
2414                                let vector_type_id =
2415                                    self.get_numeric_type_id(NumericType::Vector {
2416                                        size: rows,
2417                                        scalar,
2418                                    });
2419                                for column in 0..columns as u32 {
2420                                    self.annotations.push(Instruction::member_decorate(
2421                                        std140_type_id,
2422                                        next_index,
2423                                        spirv::Decoration::Offset,
2424                                        &[member.offset
2425                                            + column * rows as u32 * scalar.width as u32],
2426                                    ));
2427                                    if self.flags.contains(WriterFlags::DEBUG) {
2428                                        if let Some(ref name) = member.name {
2429                                            self.debugs.push(Instruction::member_name(
2430                                                std140_type_id,
2431                                                next_index,
2432                                                &format!("{name}_col{column}"),
2433                                            ));
2434                                        }
2435                                    }
2436                                    member_ids.push(vector_type_id);
2437                                    next_index += 1;
2438                                }
2439                            }
2440                            _ => {
2441                                let member_id =
2442                                    match self.std140_compat_uniform_types.get(&member.ty) {
2443                                        Some(std140_member_type_info) => {
2444                                            self.annotations.push(Instruction::member_decorate(
2445                                                std140_type_id,
2446                                                next_index,
2447                                                spirv::Decoration::Offset,
2448                                                &[member.offset],
2449                                            ));
2450                                            if self.flags.contains(WriterFlags::DEBUG) {
2451                                                if let Some(ref name) = member.name {
2452                                                    self.debugs.push(Instruction::member_name(
2453                                                        std140_type_id,
2454                                                        next_index,
2455                                                        name,
2456                                                    ));
2457                                                }
2458                                            }
2459                                            std140_member_type_info.type_id
2460                                        }
2461                                        None => {
2462                                            self.decorate_struct_member(
2463                                                std140_type_id,
2464                                                next_index as usize,
2465                                                member,
2466                                                &module.types,
2467                                            )?;
2468                                            self.get_handle_type_id(member.ty)
2469                                        }
2470                                    };
2471                                member_ids.push(member_id);
2472                                next_index += 1;
2473                            }
2474                        }
2475                    }
2476
2477                    Instruction::type_struct(std140_type_id, &member_ids)
2478                        .to_words(&mut self.logical_layout.declarations);
2479                    self.std140_compat_uniform_types.insert(
2480                        handle,
2481                        Std140CompatTypeInfo {
2482                            type_id: std140_type_id,
2483                            member_indices,
2484                        },
2485                    );
2486                    Some(std140_type_id)
2487                } else {
2488                    None
2489                }
2490            }
2491            _ => None,
2492        };
2493
2494        if let Some(std140_type_id) = std140_type_id {
2495            if self.flags.contains(WriterFlags::DEBUG) {
2496                let name = format!("std140_{:?}", handle.for_debug(&module.types));
2497                self.debugs.push(Instruction::name(std140_type_id, &name));
2498            }
2499        }
2500        Ok(std140_type_id)
2501    }
2502
2503    fn request_image_format_capabilities(
2504        &mut self,
2505        format: spirv::ImageFormat,
2506    ) -> Result<(), Error> {
2507        use spirv::ImageFormat as If;
2508        match format {
2509            If::Rg32f
2510            | If::Rg16f
2511            | If::R11fG11fB10f
2512            | If::R16f
2513            | If::Rgba16
2514            | If::Rgb10A2
2515            | If::Rg16
2516            | If::Rg8
2517            | If::R16
2518            | If::R8
2519            | If::Rgba16Snorm
2520            | If::Rg16Snorm
2521            | If::Rg8Snorm
2522            | If::R16Snorm
2523            | If::R8Snorm
2524            | If::Rg32i
2525            | If::Rg16i
2526            | If::Rg8i
2527            | If::R16i
2528            | If::R8i
2529            | If::Rgb10a2ui
2530            | If::Rg32ui
2531            | If::Rg16ui
2532            | If::Rg8ui
2533            | If::R16ui
2534            | If::R8ui => self.require_any(
2535                "storage image format",
2536                &[spirv::Capability::StorageImageExtendedFormats],
2537            ),
2538            If::R64ui | If::R64i => {
2539                self.use_extension("SPV_EXT_shader_image_int64");
2540                self.require_any(
2541                    "64-bit integer storage image format",
2542                    &[spirv::Capability::Int64ImageEXT],
2543                )
2544            }
2545            If::Unknown
2546            | If::Rgba32f
2547            | If::Rgba16f
2548            | If::R32f
2549            | If::Rgba8
2550            | If::Rgba8Snorm
2551            | If::Rgba32i
2552            | If::Rgba16i
2553            | If::Rgba8i
2554            | If::R32i
2555            | If::Rgba32ui
2556            | If::Rgba16ui
2557            | If::Rgba8ui
2558            | If::R32ui => Ok(()),
2559        }
2560    }
2561
2562    pub(super) fn get_index_constant(&mut self, index: Word) -> Word {
2563        self.get_constant_scalar(crate::Literal::U32(index))
2564    }
2565
2566    pub(super) fn get_constant_scalar_with(
2567        &mut self,
2568        value: u8,
2569        scalar: crate::Scalar,
2570    ) -> Result<Word, Error> {
2571        Ok(
2572            self.get_constant_scalar(crate::Literal::new(value, scalar).ok_or(
2573                Error::Validation("Unexpected kind and/or width for Literal"),
2574            )?),
2575        )
2576    }
2577
2578    pub(super) fn get_constant_scalar(&mut self, value: crate::Literal) -> Word {
2579        let scalar = CachedConstant::Literal(value.into());
2580        if let Some(&id) = self.cached_constants.get(&scalar) {
2581            return id;
2582        }
2583        let id = self.id_gen.next();
2584        self.write_constant_scalar(id, &value, None);
2585        self.cached_constants.insert(scalar, id);
2586        id
2587    }
2588
2589    fn write_constant_scalar(
2590        &mut self,
2591        id: Word,
2592        value: &crate::Literal,
2593        debug_name: Option<&String>,
2594    ) {
2595        if self.flags.contains(WriterFlags::DEBUG) {
2596            if let Some(name) = debug_name {
2597                self.debugs.push(Instruction::name(id, name));
2598            }
2599        }
2600        let type_id = self.get_numeric_type_id(NumericType::Scalar(value.scalar()));
2601        let instruction = match *value {
2602            crate::Literal::F64(value) => {
2603                let bits = value.to_bits();
2604                Instruction::constant_64bit(type_id, id, bits as u32, (bits >> 32) as u32)
2605            }
2606            crate::Literal::F32(value) => Instruction::constant_32bit(type_id, id, value.to_bits()),
2607            crate::Literal::F16(value) => {
2608                let low = value.to_bits();
2609                Instruction::constant_16bit(type_id, id, low as u32)
2610            }
2611            crate::Literal::U16(value) => Instruction::constant_16bit(type_id, id, value as u32),
2612            crate::Literal::I16(value) => {
2613                // Sign-extend into the 32-bit word so that `spirv-as` can
2614                // round-trip the disassembly (it expects signed values for
2615                // signed types).
2616                Instruction::constant_16bit(type_id, id, value as i32 as u32)
2617            }
2618            crate::Literal::U32(value) => Instruction::constant_32bit(type_id, id, value),
2619            crate::Literal::I32(value) => Instruction::constant_32bit(type_id, id, value as u32),
2620            crate::Literal::U64(value) => {
2621                Instruction::constant_64bit(type_id, id, value as u32, (value >> 32) as u32)
2622            }
2623            crate::Literal::I64(value) => {
2624                Instruction::constant_64bit(type_id, id, value as u32, (value >> 32) as u32)
2625            }
2626            crate::Literal::Bool(true) => Instruction::constant_true(type_id, id),
2627            crate::Literal::Bool(false) => Instruction::constant_false(type_id, id),
2628            crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
2629                unreachable!("Abstract types should not appear in IR presented to backends");
2630            }
2631        };
2632
2633        instruction.to_words(&mut self.logical_layout.declarations);
2634    }
2635
2636    pub(super) fn get_constant_composite(
2637        &mut self,
2638        ty: LookupType,
2639        constituent_ids: &[Word],
2640    ) -> Word {
2641        let composite = CachedConstant::Composite {
2642            ty,
2643            constituent_ids: constituent_ids.to_vec(),
2644        };
2645        if let Some(&id) = self.cached_constants.get(&composite) {
2646            return id;
2647        }
2648        let id = self.id_gen.next();
2649        self.write_constant_composite(id, ty, constituent_ids, None);
2650        self.cached_constants.insert(composite, id);
2651        id
2652    }
2653
2654    fn write_constant_composite(
2655        &mut self,
2656        id: Word,
2657        ty: LookupType,
2658        constituent_ids: &[Word],
2659        debug_name: Option<&String>,
2660    ) {
2661        if self.flags.contains(WriterFlags::DEBUG) {
2662            if let Some(name) = debug_name {
2663                self.debugs.push(Instruction::name(id, name));
2664            }
2665        }
2666        let type_id = self.get_type_id(ty);
2667        Instruction::constant_composite(type_id, id, constituent_ids)
2668            .to_words(&mut self.logical_layout.declarations);
2669    }
2670
2671    pub(super) fn get_constant_null(&mut self, type_id: Word) -> Word {
2672        let null = CachedConstant::ZeroValue(type_id);
2673        if let Some(&id) = self.cached_constants.get(&null) {
2674            return id;
2675        }
2676        let id = self.write_constant_null(type_id);
2677        self.cached_constants.insert(null, id);
2678        id
2679    }
2680
2681    pub(super) fn write_constant_null(&mut self, type_id: Word) -> Word {
2682        let null_id = self.id_gen.next();
2683        Instruction::constant_null(type_id, null_id)
2684            .to_words(&mut self.logical_layout.declarations);
2685        null_id
2686    }
2687
2688    fn write_constant_expr(
2689        &mut self,
2690        handle: Handle<crate::Expression>,
2691        ir_module: &crate::Module,
2692        mod_info: &ModuleInfo,
2693    ) -> Result<Word, Error> {
2694        let id = match ir_module.global_expressions[handle] {
2695            crate::Expression::Literal(literal) => self.get_constant_scalar(literal),
2696            crate::Expression::Constant(constant) => {
2697                let constant = &ir_module.constants[constant];
2698                self.constant_ids[constant.init]
2699            }
2700            crate::Expression::ZeroValue(ty) => {
2701                let type_id = self.get_handle_type_id(ty);
2702                self.get_constant_null(type_id)
2703            }
2704            crate::Expression::Compose { ty, ref components } => {
2705                let component_ids: Vec<_> = crate::proc::flatten_compose(
2706                    ty,
2707                    components,
2708                    &ir_module.global_expressions,
2709                    &ir_module.types,
2710                )
2711                .map(|component| self.constant_ids[component])
2712                .collect();
2713                self.get_constant_composite(LookupType::Handle(ty), component_ids.as_slice())
2714            }
2715            crate::Expression::Splat { size, value } => {
2716                let value_id = self.constant_ids[value];
2717                let component_ids = &[value_id; 4][..size as usize];
2718
2719                let ty = self.get_expression_lookup_type(&mod_info[handle]);
2720
2721                self.get_constant_composite(ty, component_ids)
2722            }
2723            _ => {
2724                return Err(Error::Override);
2725            }
2726        };
2727
2728        self.constant_ids[handle] = id;
2729
2730        Ok(id)
2731    }
2732
2733    pub(super) fn write_control_barrier(
2734        &mut self,
2735        flags: crate::Barrier,
2736        body: &mut Vec<Instruction>,
2737    ) {
2738        let memory_scope = if flags.contains(crate::Barrier::STORAGE) {
2739            spirv::Scope::Device
2740        } else if flags.contains(crate::Barrier::SUB_GROUP) {
2741            spirv::Scope::Subgroup
2742        } else {
2743            spirv::Scope::Workgroup
2744        };
2745        let mut semantics = spirv::MemorySemantics::ACQUIRE_RELEASE;
2746        semantics.set(
2747            spirv::MemorySemantics::UNIFORM_MEMORY,
2748            flags.contains(crate::Barrier::STORAGE),
2749        );
2750        semantics.set(
2751            spirv::MemorySemantics::WORKGROUP_MEMORY,
2752            flags.contains(crate::Barrier::WORK_GROUP),
2753        );
2754        semantics.set(
2755            spirv::MemorySemantics::SUBGROUP_MEMORY,
2756            flags.contains(crate::Barrier::SUB_GROUP),
2757        );
2758        semantics.set(
2759            spirv::MemorySemantics::IMAGE_MEMORY,
2760            flags.contains(crate::Barrier::TEXTURE),
2761        );
2762        let exec_scope_id = if flags.contains(crate::Barrier::SUB_GROUP) {
2763            self.get_index_constant(spirv::Scope::Subgroup as u32)
2764        } else {
2765            self.get_index_constant(spirv::Scope::Workgroup as u32)
2766        };
2767        let mem_scope_id = self.get_index_constant(memory_scope as u32);
2768        let semantics_id = self.get_index_constant(semantics.bits());
2769        body.push(Instruction::control_barrier(
2770            exec_scope_id,
2771            mem_scope_id,
2772            semantics_id,
2773        ));
2774    }
2775
2776    pub(super) fn write_memory_barrier(&mut self, flags: crate::Barrier, block: &mut Block) {
2777        let mut semantics = spirv::MemorySemantics::ACQUIRE_RELEASE;
2778        semantics.set(
2779            spirv::MemorySemantics::UNIFORM_MEMORY,
2780            flags.contains(crate::Barrier::STORAGE),
2781        );
2782        semantics.set(
2783            spirv::MemorySemantics::WORKGROUP_MEMORY,
2784            flags.contains(crate::Barrier::WORK_GROUP),
2785        );
2786        semantics.set(
2787            spirv::MemorySemantics::SUBGROUP_MEMORY,
2788            flags.contains(crate::Barrier::SUB_GROUP),
2789        );
2790        semantics.set(
2791            spirv::MemorySemantics::IMAGE_MEMORY,
2792            flags.contains(crate::Barrier::TEXTURE),
2793        );
2794        let mem_scope_id = if flags.contains(crate::Barrier::STORAGE) {
2795            self.get_index_constant(spirv::Scope::Device as u32)
2796        } else if flags.contains(crate::Barrier::SUB_GROUP) {
2797            self.get_index_constant(spirv::Scope::Subgroup as u32)
2798        } else {
2799            self.get_index_constant(spirv::Scope::Workgroup as u32)
2800        };
2801        let semantics_id = self.get_index_constant(semantics.bits());
2802        block
2803            .body
2804            .push(Instruction::memory_barrier(mem_scope_id, semantics_id));
2805    }
2806
2807    fn generate_workgroup_vars_init_block(
2808        &mut self,
2809        entry_id: Word,
2810        ir_module: &crate::Module,
2811        info: &FunctionInfo,
2812        local_invocation_index: Option<Word>,
2813        interface: &mut FunctionInterface,
2814        function: &mut Function,
2815    ) -> Option<Word> {
2816        let body = ir_module
2817            .global_variables
2818            .iter()
2819            .filter(|&(handle, var)| {
2820                let task_exception = (var.space == crate::AddressSpace::TaskPayload)
2821                    && interface.stage == crate::ShaderStage::Task;
2822                !info[handle].is_empty()
2823                    && (var.space == crate::AddressSpace::WorkGroup || task_exception)
2824            })
2825            .map(|(handle, var)| {
2826                // It's safe to use `var_id` here, not `access_id`, because only
2827                // variables in the `Uniform` and `StorageBuffer` address spaces
2828                // get wrapped, and we're initializing `WorkGroup` variables.
2829                let var_id = self.global_variables[handle].var_id;
2830                let var_type_id = self.get_handle_type_id(var.ty);
2831                let init_word = self.get_constant_null(var_type_id);
2832                Instruction::store(var_id, init_word, None)
2833            })
2834            .collect::<Vec<_>>();
2835
2836        if body.is_empty() {
2837            return None;
2838        }
2839
2840        let mut pre_if_block = Block::new(entry_id);
2841
2842        let local_invocation_index = if let Some(local_invocation_index) = local_invocation_index {
2843            local_invocation_index
2844        } else {
2845            let varying_id = self.id_gen.next();
2846            let class = spirv::StorageClass::Input;
2847            let u32_ty_id = self.get_u32_type_id();
2848            let pointer_type_id = self.get_pointer_type_id(u32_ty_id, class);
2849
2850            Instruction::variable(pointer_type_id, varying_id, class, None)
2851                .to_words(&mut self.logical_layout.declarations);
2852
2853            self.decorate(
2854                varying_id,
2855                spirv::Decoration::BuiltIn,
2856                &[spirv::BuiltIn::LocalInvocationIndex as u32],
2857            );
2858
2859            interface.varying_ids.push(varying_id);
2860            let id = self.id_gen.next();
2861            pre_if_block
2862                .body
2863                .push(Instruction::load(u32_ty_id, id, varying_id, None));
2864
2865            id
2866        };
2867
2868        let zero_id = self.get_constant_scalar(crate::Literal::U32(0));
2869
2870        let eq_id = self.id_gen.next();
2871        pre_if_block.body.push(Instruction::binary(
2872            spirv::Op::IEqual,
2873            self.get_bool_type_id(),
2874            eq_id,
2875            local_invocation_index,
2876            zero_id,
2877        ));
2878
2879        let merge_id = self.id_gen.next();
2880        pre_if_block.body.push(Instruction::selection_merge(
2881            merge_id,
2882            spirv::SelectionControl::NONE,
2883        ));
2884
2885        let accept_id = self.id_gen.next();
2886        function.consume(
2887            pre_if_block,
2888            Instruction::branch_conditional(eq_id, accept_id, merge_id),
2889        );
2890
2891        let accept_block = Block {
2892            label_id: accept_id,
2893            body,
2894        };
2895        function.consume(accept_block, Instruction::branch(merge_id));
2896
2897        let mut post_if_block = Block::new(merge_id);
2898
2899        self.write_control_barrier(crate::Barrier::WORK_GROUP, &mut post_if_block.body);
2900
2901        let next_id = self.id_gen.next();
2902        function.consume(post_if_block, Instruction::branch(next_id));
2903        Some(next_id)
2904    }
2905
2906    /// Generate an `OpVariable` for one value in an [`EntryPoint`]'s IO interface.
2907    ///
2908    /// The [`Binding`]s of the arguments and result of an [`EntryPoint`]'s
2909    /// [`Function`] describe a SPIR-V shader interface. In SPIR-V, the
2910    /// interface is represented by global variables in the `Input` and `Output`
2911    /// storage classes, with decorations indicating which builtin or location
2912    /// each variable corresponds to.
2913    ///
2914    /// This function emits a single global `OpVariable` for a single value from
2915    /// the interface, and adds appropriate decorations to indicate which
2916    /// builtin or location it represents, how it should be interpolated, and so
2917    /// on. The `class` argument gives the variable's SPIR-V storage class,
2918    /// which should be either [`Input`] or [`Output`].
2919    ///
2920    /// [`Binding`]: crate::Binding
2921    /// [`Function`]: crate::Function
2922    /// [`EntryPoint`]: crate::EntryPoint
2923    /// [`Input`]: spirv::StorageClass::Input
2924    /// [`Output`]: spirv::StorageClass::Output
2925    fn write_varying(
2926        &mut self,
2927        ir_module: &crate::Module,
2928        stage: crate::ShaderStage,
2929        class: spirv::StorageClass,
2930        debug_name: Option<&str>,
2931        ty: Handle<crate::Type>,
2932        binding: &crate::Binding,
2933    ) -> Result<Word, Error> {
2934        let id = self.id_gen.next();
2935        let ty_inner = &ir_module.types[ty].inner;
2936        let needs_polyfill = self.needs_f16_polyfill(ty_inner);
2937
2938        let pointer_type_id = if needs_polyfill {
2939            let f32_value_local =
2940                super::f16_polyfill::F16IoPolyfill::create_polyfill_type(ty_inner)
2941                    .expect("needs_polyfill returned true but create_polyfill_type returned None");
2942
2943            let f32_type_id = self.get_localtype_id(f32_value_local);
2944            let ptr_id = self.get_pointer_type_id(f32_type_id, class);
2945            self.io_f16_polyfills.register_io_var(id, f32_type_id);
2946
2947            ptr_id
2948        } else {
2949            self.get_handle_pointer_type_id(ty, class)
2950        };
2951
2952        Instruction::variable(pointer_type_id, id, class, None)
2953            .to_words(&mut self.logical_layout.declarations);
2954
2955        if self
2956            .flags
2957            .contains(WriterFlags::DEBUG | WriterFlags::LABEL_VARYINGS)
2958        {
2959            if let Some(name) = debug_name {
2960                self.debugs.push(Instruction::name(id, name));
2961            }
2962        }
2963
2964        let binding = self.map_binding(ir_module, stage, class, ty, binding)?;
2965        self.write_binding(id, binding);
2966
2967        Ok(id)
2968    }
2969
2970    pub fn write_binding(&mut self, id: Word, binding: BindingDecorations) {
2971        match binding {
2972            BindingDecorations::None => (),
2973            BindingDecorations::BuiltIn(bi, others) => {
2974                self.decorate(id, spirv::Decoration::BuiltIn, &[bi as u32]);
2975                for other in others {
2976                    self.decorate(id, other, &[]);
2977                }
2978            }
2979            BindingDecorations::Location {
2980                location,
2981                others,
2982                blend_src,
2983            } => {
2984                self.decorate(id, spirv::Decoration::Location, &[location]);
2985                for other in others {
2986                    self.decorate(id, other, &[]);
2987                }
2988                if let Some(blend_src) = blend_src {
2989                    self.decorate(id, spirv::Decoration::Index, &[blend_src]);
2990                }
2991            }
2992        }
2993    }
2994
2995    pub fn write_binding_struct_member(
2996        &mut self,
2997        struct_id: Word,
2998        member_idx: Word,
2999        binding_info: BindingDecorations,
3000    ) {
3001        match binding_info {
3002            BindingDecorations::None => (),
3003            BindingDecorations::BuiltIn(bi, others) => {
3004                self.annotations.push(Instruction::member_decorate(
3005                    struct_id,
3006                    member_idx,
3007                    spirv::Decoration::BuiltIn,
3008                    &[bi as Word],
3009                ));
3010                for other in others {
3011                    self.annotations.push(Instruction::member_decorate(
3012                        struct_id,
3013                        member_idx,
3014                        other,
3015                        &[],
3016                    ));
3017                }
3018            }
3019            BindingDecorations::Location {
3020                location,
3021                others,
3022                blend_src,
3023            } => {
3024                self.annotations.push(Instruction::member_decorate(
3025                    struct_id,
3026                    member_idx,
3027                    spirv::Decoration::Location,
3028                    &[location],
3029                ));
3030                for other in others {
3031                    self.annotations.push(Instruction::member_decorate(
3032                        struct_id,
3033                        member_idx,
3034                        other,
3035                        &[],
3036                    ));
3037                }
3038                if let Some(blend_src) = blend_src {
3039                    self.annotations.push(Instruction::member_decorate(
3040                        struct_id,
3041                        member_idx,
3042                        spirv::Decoration::Index,
3043                        &[blend_src],
3044                    ));
3045                }
3046            }
3047        }
3048    }
3049
3050    pub fn map_binding(
3051        &mut self,
3052        ir_module: &crate::Module,
3053        stage: crate::ShaderStage,
3054        class: spirv::StorageClass,
3055        ty: Handle<crate::Type>,
3056        binding: &crate::Binding,
3057    ) -> Result<BindingDecorations, Error> {
3058        use spirv::BuiltIn;
3059        use spirv::Decoration;
3060        match *binding {
3061            crate::Binding::Location {
3062                location,
3063                interpolation,
3064                sampling,
3065                blend_src,
3066                per_primitive,
3067            } => {
3068                let mut others = ArrayVec::new();
3069
3070                let no_decorations =
3071                    // VUID-StandaloneSpirv-Flat-06202
3072                    // > The Flat, NoPerspective, Sample, and Centroid decorations
3073                    // > must not be used on variables with the Input storage class in a vertex shader
3074                    (class == spirv::StorageClass::Input && stage == crate::ShaderStage::Vertex) ||
3075                    // VUID-StandaloneSpirv-Flat-06201
3076                    // > The Flat, NoPerspective, Sample, and Centroid decorations
3077                    // > must not be used on variables with the Output storage class in a fragment shader
3078                    (class == spirv::StorageClass::Output && stage == crate::ShaderStage::Fragment);
3079
3080                if !no_decorations {
3081                    match interpolation {
3082                        // Perspective-correct interpolation is the default in SPIR-V.
3083                        None | Some(crate::Interpolation::Perspective) => (),
3084                        Some(crate::Interpolation::Flat) => {
3085                            others.push(Decoration::Flat);
3086                        }
3087                        Some(crate::Interpolation::Linear) => {
3088                            others.push(Decoration::NoPerspective);
3089                        }
3090                        Some(crate::Interpolation::PerVertex) => {
3091                            others.push(Decoration::PerVertexKHR);
3092                            self.require_any(
3093                                "`per_vertex` interpolation",
3094                                &[spirv::Capability::FragmentBarycentricKHR],
3095                            )?;
3096                            self.use_extension("SPV_KHR_fragment_shader_barycentric");
3097                        }
3098                    }
3099                    match sampling {
3100                        // Center sampling is the default in SPIR-V.
3101                        None
3102                        | Some(
3103                            crate::Sampling::Center
3104                            | crate::Sampling::First
3105                            | crate::Sampling::Either,
3106                        ) => (),
3107                        Some(crate::Sampling::Centroid) => {
3108                            others.push(Decoration::Centroid);
3109                        }
3110                        Some(crate::Sampling::Sample) => {
3111                            self.require_any(
3112                                "per-sample interpolation",
3113                                &[spirv::Capability::SampleRateShading],
3114                            )?;
3115                            others.push(Decoration::Sample);
3116                        }
3117                    }
3118                }
3119                if per_primitive && stage == crate::ShaderStage::Fragment {
3120                    others.push(Decoration::PerPrimitiveEXT);
3121                }
3122                Ok(BindingDecorations::Location {
3123                    location,
3124                    others,
3125                    blend_src,
3126                })
3127            }
3128            crate::Binding::BuiltIn(built_in) => {
3129                use crate::BuiltIn as Bi;
3130                let mut others = ArrayVec::new();
3131
3132                let built_in = match built_in {
3133                    Bi::Position { invariant } => {
3134                        if invariant {
3135                            others.push(Decoration::Invariant);
3136                        }
3137
3138                        if class == spirv::StorageClass::Output {
3139                            BuiltIn::Position
3140                        } else {
3141                            BuiltIn::FragCoord
3142                        }
3143                    }
3144                    Bi::ViewIndex => {
3145                        self.require_any("`view_index` built-in", &[spirv::Capability::MultiView])?;
3146                        BuiltIn::ViewIndex
3147                    }
3148                    // vertex
3149                    Bi::BaseInstance => BuiltIn::BaseInstance,
3150                    Bi::BaseVertex => BuiltIn::BaseVertex,
3151                    Bi::ClipDistances => {
3152                        self.require_any(
3153                            "`clip_distances` built-in",
3154                            &[spirv::Capability::ClipDistance],
3155                        )?;
3156                        BuiltIn::ClipDistance
3157                    }
3158                    Bi::CullDistance => {
3159                        self.require_any(
3160                            "`cull_distance` built-in",
3161                            &[spirv::Capability::CullDistance],
3162                        )?;
3163                        BuiltIn::CullDistance
3164                    }
3165                    Bi::InstanceIndex => BuiltIn::InstanceIndex,
3166                    Bi::PointSize => BuiltIn::PointSize,
3167                    Bi::VertexIndex => BuiltIn::VertexIndex,
3168                    Bi::DrawIndex => {
3169                        self.use_extension("SPV_KHR_shader_draw_parameters");
3170                        self.require_any(
3171                            "`draw_index built-in",
3172                            &[spirv::Capability::DrawParameters],
3173                        )?;
3174                        BuiltIn::DrawIndex
3175                    }
3176                    // fragment
3177                    Bi::FragDepth => BuiltIn::FragDepth,
3178                    Bi::PointCoord => BuiltIn::PointCoord,
3179                    Bi::FrontFacing => BuiltIn::FrontFacing,
3180                    Bi::PrimitiveIndex => {
3181                        // Geometry shader capability is required for primitive index
3182                        self.require_any(
3183                            "`primitive_index` built-in",
3184                            &[spirv::Capability::Geometry],
3185                        )?;
3186                        if stage == crate::ShaderStage::Mesh {
3187                            others.push(Decoration::PerPrimitiveEXT);
3188                        }
3189                        BuiltIn::PrimitiveId
3190                    }
3191                    Bi::Barycentric { perspective } => {
3192                        self.require_any(
3193                            "`barycentric` built-in",
3194                            &[spirv::Capability::FragmentBarycentricKHR],
3195                        )?;
3196                        self.use_extension("SPV_KHR_fragment_shader_barycentric");
3197                        if perspective {
3198                            BuiltIn::BaryCoordKHR
3199                        } else {
3200                            BuiltIn::BaryCoordNoPerspKHR
3201                        }
3202                    }
3203                    Bi::SampleIndex => {
3204                        self.require_any(
3205                            "`sample_index` built-in",
3206                            &[spirv::Capability::SampleRateShading],
3207                        )?;
3208
3209                        BuiltIn::SampleId
3210                    }
3211                    Bi::SampleMask => BuiltIn::SampleMask,
3212                    // compute
3213                    Bi::GlobalInvocationId => BuiltIn::GlobalInvocationId,
3214                    Bi::LocalInvocationId => BuiltIn::LocalInvocationId,
3215                    Bi::LocalInvocationIndex => BuiltIn::LocalInvocationIndex,
3216                    Bi::WorkGroupId => BuiltIn::WorkgroupId,
3217                    Bi::WorkGroupSize => BuiltIn::WorkgroupSize,
3218                    Bi::NumWorkGroups => BuiltIn::NumWorkgroups,
3219                    // Subgroup
3220                    Bi::NumSubgroups => {
3221                        self.require_any(
3222                            "`num_subgroups` built-in",
3223                            &[spirv::Capability::GroupNonUniform],
3224                        )?;
3225                        BuiltIn::NumSubgroups
3226                    }
3227                    Bi::SubgroupId => {
3228                        self.require_any(
3229                            "`subgroup_id` built-in",
3230                            &[spirv::Capability::GroupNonUniform],
3231                        )?;
3232                        BuiltIn::SubgroupId
3233                    }
3234                    Bi::SubgroupSize => {
3235                        self.require_any(
3236                            "`subgroup_size` built-in",
3237                            &[
3238                                spirv::Capability::GroupNonUniform,
3239                                spirv::Capability::SubgroupBallotKHR,
3240                            ],
3241                        )?;
3242                        BuiltIn::SubgroupSize
3243                    }
3244                    Bi::SubgroupInvocationId => {
3245                        self.require_any(
3246                            "`subgroup_invocation_id` built-in",
3247                            &[
3248                                spirv::Capability::GroupNonUniform,
3249                                spirv::Capability::SubgroupBallotKHR,
3250                            ],
3251                        )?;
3252                        BuiltIn::SubgroupLocalInvocationId
3253                    }
3254                    Bi::CullPrimitive => {
3255                        others.push(Decoration::PerPrimitiveEXT);
3256                        BuiltIn::CullPrimitiveEXT
3257                    }
3258                    Bi::PointIndex => BuiltIn::PrimitivePointIndicesEXT,
3259                    Bi::LineIndices => BuiltIn::PrimitiveLineIndicesEXT,
3260                    Bi::TriangleIndices => BuiltIn::PrimitiveTriangleIndicesEXT,
3261                    // No decoration, this EmitMeshTasksEXT is called at function return
3262                    Bi::MeshTaskSize => return Ok(BindingDecorations::None),
3263                    // These aren't normal builtins and don't occur in function output
3264                    Bi::VertexCount | Bi::Vertices | Bi::PrimitiveCount | Bi::Primitives => {
3265                        unreachable!()
3266                    }
3267                    // ray tracing pipeline
3268                    Bi::RayInvocationId => BuiltIn::LaunchIdKHR,
3269                    Bi::NumRayInvocations => BuiltIn::LaunchSizeKHR,
3270                    Bi::InstanceCustomData => BuiltIn::InstanceCustomIndexKHR,
3271                    Bi::GeometryIndex => BuiltIn::RayGeometryIndexKHR,
3272                    Bi::WorldRayOrigin => BuiltIn::WorldRayOriginKHR,
3273                    Bi::WorldRayDirection => BuiltIn::WorldRayDirectionKHR,
3274                    Bi::ObjectRayOrigin => BuiltIn::ObjectRayOriginKHR,
3275                    Bi::ObjectRayDirection => BuiltIn::ObjectRayDirectionKHR,
3276                    Bi::RayTmin => BuiltIn::RayTminKHR,
3277                    Bi::RayTCurrentMax => BuiltIn::RayTmaxKHR,
3278                    Bi::ObjectToWorld => BuiltIn::ObjectToWorldKHR,
3279                    Bi::WorldToObject => BuiltIn::WorldToObjectKHR,
3280                    Bi::HitKind => BuiltIn::HitKindKHR,
3281                };
3282
3283                use crate::ScalarKind as Sk;
3284
3285                // Per the Vulkan spec, `VUID-StandaloneSpirv-Flat-04744`:
3286                //
3287                // > Any variable with integer or double-precision floating-
3288                // > point type and with Input storage class in a fragment
3289                // > shader, must be decorated Flat
3290                if class == spirv::StorageClass::Input && stage == crate::ShaderStage::Fragment {
3291                    let is_flat = match ir_module.types[ty].inner {
3292                        crate::TypeInner::Scalar(scalar)
3293                        | crate::TypeInner::Vector { scalar, .. } => match scalar.kind {
3294                            Sk::Uint | Sk::Sint | Sk::Bool => true,
3295                            Sk::Float => false,
3296                            Sk::AbstractInt | Sk::AbstractFloat => {
3297                                return Err(Error::Validation(
3298                                    "Abstract types should not appear in IR presented to backends",
3299                                ))
3300                            }
3301                        },
3302                        _ => false,
3303                    };
3304
3305                    if is_flat {
3306                        others.push(Decoration::Flat);
3307                    }
3308                }
3309                Ok(BindingDecorations::BuiltIn(built_in, others))
3310            }
3311        }
3312    }
3313
3314    /// Load an IO variable, converting from `f32` to `f16` if polyfill is active.
3315    /// Returns the id of the loaded value matching `target_type_id`.
3316    pub(super) fn load_io_with_f16_polyfill(
3317        &mut self,
3318        body: &mut Vec<Instruction>,
3319        varying_id: Word,
3320        target_type_id: Word,
3321    ) -> Word {
3322        let tmp = self.id_gen.next();
3323        if let Some(f32_ty) = self.io_f16_polyfills.get_f32_io_type(varying_id) {
3324            body.push(Instruction::load(f32_ty, tmp, varying_id, None));
3325            let converted = self.id_gen.next();
3326            super::f16_polyfill::F16IoPolyfill::emit_f32_to_f16_conversion(
3327                tmp,
3328                target_type_id,
3329                converted,
3330                body,
3331            );
3332            converted
3333        } else {
3334            body.push(Instruction::load(target_type_id, tmp, varying_id, None));
3335            tmp
3336        }
3337    }
3338
3339    /// Store an IO variable, converting from `f16` to `f32` if polyfill is active.
3340    pub(super) fn store_io_with_f16_polyfill(
3341        &mut self,
3342        body: &mut Vec<Instruction>,
3343        varying_id: Word,
3344        value_id: Word,
3345    ) {
3346        if let Some(f32_ty) = self.io_f16_polyfills.get_f32_io_type(varying_id) {
3347            let converted = self.id_gen.next();
3348            super::f16_polyfill::F16IoPolyfill::emit_f16_to_f32_conversion(
3349                value_id, f32_ty, converted, body,
3350            );
3351            body.push(Instruction::store(varying_id, converted, None));
3352        } else {
3353            body.push(Instruction::store(varying_id, value_id, None));
3354        }
3355    }
3356
3357    fn write_global_variable(
3358        &mut self,
3359        ir_module: &crate::Module,
3360        global_variable: &crate::GlobalVariable,
3361    ) -> Result<Word, Error> {
3362        use spirv::Decoration;
3363
3364        let id = self.id_gen.next();
3365        let class = map_storage_class(global_variable.space);
3366
3367        if let crate::AddressSpace::RayPayload | crate::AddressSpace::IncomingRayPayload =
3368            global_variable.space
3369        {
3370            self.require_any("ray tracing pipelines", &[spirv::Capability::RayTracingKHR])?;
3371        }
3372
3373        //self.check(class.required_capabilities())?;
3374
3375        if global_variable
3376            .memory_decorations
3377            .contains(crate::MemoryDecorations::COHERENT)
3378        {
3379            self.decorate(id, Decoration::Coherent, &[]);
3380        }
3381        if global_variable
3382            .memory_decorations
3383            .contains(crate::MemoryDecorations::VOLATILE)
3384        {
3385            self.decorate(id, Decoration::Volatile, &[]);
3386        }
3387
3388        if self.flags.contains(WriterFlags::DEBUG) {
3389            if let Some(ref name) = global_variable.name {
3390                self.debugs.push(Instruction::name(id, name));
3391            }
3392        }
3393
3394        let storage_access = match global_variable.space {
3395            crate::AddressSpace::Storage { access } => Some(access),
3396            _ => match ir_module.types[global_variable.ty].inner {
3397                crate::TypeInner::Image {
3398                    class: crate::ImageClass::Storage { access, .. },
3399                    ..
3400                } => Some(access),
3401                _ => None,
3402            },
3403        };
3404        if let Some(storage_access) = storage_access {
3405            if !storage_access.contains(crate::StorageAccess::LOAD) {
3406                self.decorate(id, Decoration::NonReadable, &[]);
3407            }
3408            if !storage_access.contains(crate::StorageAccess::STORE) {
3409                self.decorate(id, Decoration::NonWritable, &[]);
3410            }
3411        }
3412
3413        // Note: we should be able to substitute `binding_array<Foo, 0>`,
3414        // but there is still code that tries to register the pre-substituted type,
3415        // and it is failing on 0.
3416        let mut substitute_inner_type_lookup = None;
3417        if let Some(ref res_binding) = global_variable.binding {
3418            let bind_target = self.resolve_resource_binding(res_binding)?;
3419            self.decorate(id, Decoration::DescriptorSet, &[bind_target.descriptor_set]);
3420            self.decorate(id, Decoration::Binding, &[bind_target.binding]);
3421
3422            if let Some(remapped_binding_array_size) = bind_target.binding_array_size {
3423                if let crate::TypeInner::BindingArray { base, .. } =
3424                    ir_module.types[global_variable.ty].inner
3425                {
3426                    let binding_array_type_id =
3427                        self.get_type_id(LookupType::Local(LocalType::BindingArray {
3428                            base,
3429                            size: remapped_binding_array_size,
3430                        }));
3431                    substitute_inner_type_lookup = Some(LookupType::Local(LocalType::Pointer {
3432                        base: binding_array_type_id,
3433                        class,
3434                    }));
3435                }
3436            }
3437        };
3438
3439        let init_word = global_variable
3440            .init
3441            .map(|constant| self.constant_ids[constant]);
3442        let inner_type_id = self.get_type_id(
3443            substitute_inner_type_lookup.unwrap_or(LookupType::Handle(global_variable.ty)),
3444        );
3445
3446        // generate the wrapping structure if needed
3447        let pointer_type_id = if global_needs_wrapper(ir_module, global_variable) {
3448            let wrapper_type_id = self.id_gen.next();
3449
3450            self.decorate(wrapper_type_id, Decoration::Block, &[]);
3451
3452            match self.std140_compat_uniform_types.get(&global_variable.ty) {
3453                Some(std140_type_info) if global_variable.space == crate::AddressSpace::Uniform => {
3454                    self.annotations.push(Instruction::member_decorate(
3455                        wrapper_type_id,
3456                        0,
3457                        Decoration::Offset,
3458                        &[0],
3459                    ));
3460                    Instruction::type_struct(wrapper_type_id, &[std140_type_info.type_id])
3461                        .to_words(&mut self.logical_layout.declarations);
3462                }
3463                _ => {
3464                    let member = crate::StructMember {
3465                        name: None,
3466                        ty: global_variable.ty,
3467                        binding: None,
3468                        offset: 0,
3469                    };
3470                    self.decorate_struct_member(wrapper_type_id, 0, &member, &ir_module.types)?;
3471
3472                    Instruction::type_struct(wrapper_type_id, &[inner_type_id])
3473                        .to_words(&mut self.logical_layout.declarations);
3474                }
3475            }
3476
3477            let pointer_type_id = self.id_gen.next();
3478            Instruction::type_pointer(pointer_type_id, class, wrapper_type_id)
3479                .to_words(&mut self.logical_layout.declarations);
3480
3481            pointer_type_id
3482        } else {
3483            // This is a global variable in the Storage address space. The only
3484            // way it could have `global_needs_wrapper() == false` is if it has
3485            // a runtime-sized or binding array.
3486            // Runtime-sized arrays were decorated when iterating through struct content.
3487            // Now binding arrays require Block decorating.
3488            if let crate::AddressSpace::Storage { .. } = global_variable.space {
3489                match ir_module.types[global_variable.ty].inner {
3490                    crate::TypeInner::BindingArray { base, .. } => {
3491                        let ty = &ir_module.types[base];
3492                        let mut should_decorate = true;
3493                        // Check if the type has a runtime array.
3494                        // A normal runtime array gets validated out,
3495                        // so only structs can be with runtime arrays
3496                        if let crate::TypeInner::Struct { ref members, .. } = ty.inner {
3497                            // only the last member in a struct can be dynamically sized
3498                            if let Some(last_member) = members.last() {
3499                                if let &crate::TypeInner::Array {
3500                                    size: crate::ArraySize::Dynamic,
3501                                    ..
3502                                } = &ir_module.types[last_member.ty].inner
3503                                {
3504                                    should_decorate = false;
3505                                }
3506                            }
3507                        }
3508                        if should_decorate {
3509                            let decorated_id = self.get_handle_type_id(base);
3510                            self.decorate(decorated_id, Decoration::Block, &[]);
3511                        }
3512                    }
3513                    _ => (),
3514                };
3515            }
3516            if substitute_inner_type_lookup.is_some() {
3517                inner_type_id
3518            } else {
3519                self.get_handle_pointer_type_id(global_variable.ty, class)
3520            }
3521        };
3522
3523        let init_word = match (global_variable.space, self.zero_initialize_workgroup_memory) {
3524            (crate::AddressSpace::Private, _)
3525            | (crate::AddressSpace::WorkGroup, super::ZeroInitializeWorkgroupMemoryMode::Native) => {
3526                init_word.or_else(|| Some(self.get_constant_null(inner_type_id)))
3527            }
3528            _ => init_word,
3529        };
3530
3531        Instruction::variable(pointer_type_id, id, class, init_word)
3532            .to_words(&mut self.logical_layout.declarations);
3533        Ok(id)
3534    }
3535
3536    /// Write the necessary decorations for a struct member.
3537    ///
3538    /// Emit decorations for the `index`'th member of the struct type
3539    /// designated by `struct_id`, described by `member`.
3540    fn decorate_struct_member(
3541        &mut self,
3542        struct_id: Word,
3543        index: usize,
3544        member: &crate::StructMember,
3545        arena: &UniqueArena<crate::Type>,
3546    ) -> Result<(), Error> {
3547        use spirv::Decoration;
3548
3549        self.annotations.push(Instruction::member_decorate(
3550            struct_id,
3551            index as u32,
3552            Decoration::Offset,
3553            &[member.offset],
3554        ));
3555
3556        if self.flags.contains(WriterFlags::DEBUG) {
3557            if let Some(ref name) = member.name {
3558                self.debugs
3559                    .push(Instruction::member_name(struct_id, index as u32, name));
3560            }
3561        }
3562
3563        // Matrices and (potentially nested) arrays of matrices both require decorations,
3564        // so "see through" any arrays to determine if they're needed.
3565        let mut member_array_subty_inner = &arena[member.ty].inner;
3566        while let crate::TypeInner::Array { base, .. } = *member_array_subty_inner {
3567            member_array_subty_inner = &arena[base].inner;
3568        }
3569
3570        if let crate::TypeInner::Matrix {
3571            columns: _,
3572            rows,
3573            scalar,
3574        } = *member_array_subty_inner
3575        {
3576            let byte_stride = Alignment::from(rows) * scalar.width as u32;
3577            self.annotations.push(Instruction::member_decorate(
3578                struct_id,
3579                index as u32,
3580                Decoration::ColMajor,
3581                &[],
3582            ));
3583            self.annotations.push(Instruction::member_decorate(
3584                struct_id,
3585                index as u32,
3586                Decoration::MatrixStride,
3587                &[byte_stride],
3588            ));
3589        }
3590
3591        Ok(())
3592    }
3593
3594    pub(super) fn get_function_type(&mut self, lookup_function_type: LookupFunctionType) -> Word {
3595        match self
3596            .lookup_function_type
3597            .entry(lookup_function_type.clone())
3598        {
3599            Entry::Occupied(e) => *e.get(),
3600            Entry::Vacant(_) => {
3601                let id = self.id_gen.next();
3602                let instruction = Instruction::type_function(
3603                    id,
3604                    lookup_function_type.return_type_id,
3605                    &lookup_function_type.parameter_type_ids,
3606                );
3607                instruction.to_words(&mut self.logical_layout.declarations);
3608                self.lookup_function_type.insert(lookup_function_type, id);
3609                id
3610            }
3611        }
3612    }
3613
3614    const fn write_physical_layout(&mut self) {
3615        self.physical_layout.bound = self.id_gen.0 + 1;
3616    }
3617
3618    fn write_logical_layout(
3619        &mut self,
3620        ir_module: &crate::Module,
3621        mod_info: &ModuleInfo,
3622        ep_index: Option<usize>,
3623        debug_info: &Option<DebugInfo>,
3624    ) -> Result<(), Error> {
3625        fn has_view_index_check(
3626            ir_module: &crate::Module,
3627            binding: Option<&crate::Binding>,
3628            ty: Handle<crate::Type>,
3629        ) -> bool {
3630            match ir_module.types[ty].inner {
3631                crate::TypeInner::Struct { ref members, .. } => members.iter().any(|member| {
3632                    has_view_index_check(ir_module, member.binding.as_ref(), member.ty)
3633                }),
3634                _ => binding == Some(&crate::Binding::BuiltIn(crate::BuiltIn::ViewIndex)),
3635            }
3636        }
3637
3638        let has_storage_buffers =
3639            ir_module
3640                .global_variables
3641                .iter()
3642                .any(|(_, var)| match var.space {
3643                    crate::AddressSpace::Storage { .. } => true,
3644                    _ => false,
3645                });
3646        let has_view_index = ir_module
3647            .entry_points
3648            .iter()
3649            .flat_map(|entry| entry.function.arguments.iter())
3650            .any(|arg| has_view_index_check(ir_module, arg.binding.as_ref(), arg.ty));
3651        let has_vertex_return = ir_module.special_types.ray_vertex_return.is_some();
3652
3653        let rt_uses = ir_module.uses_ray_tracing(ep_index);
3654        let has_ray_query = rt_uses.queries;
3655        let has_ray_tracing_pipeline = rt_uses.pipelines;
3656
3657        self.has_ray_tracing_pipeline = has_ray_tracing_pipeline;
3658
3659        if self.physical_layout.version < 0x10300 && has_storage_buffers {
3660            // enable the storage buffer class on < SPV-1.3
3661            Instruction::extension("SPV_KHR_storage_buffer_storage_class")
3662                .to_words(&mut self.logical_layout.extensions);
3663        }
3664        if has_view_index {
3665            Instruction::extension("SPV_KHR_multiview")
3666                .to_words(&mut self.logical_layout.extensions)
3667        }
3668        if has_ray_query {
3669            Instruction::extension("SPV_KHR_ray_query")
3670                .to_words(&mut self.logical_layout.extensions)
3671        }
3672        if has_vertex_return {
3673            Instruction::extension("SPV_KHR_ray_tracing_position_fetch")
3674                .to_words(&mut self.logical_layout.extensions);
3675        }
3676        if ir_module.uses_mesh_shaders() {
3677            self.use_extension("SPV_EXT_mesh_shader");
3678            self.require_any("Mesh Shaders", &[spirv::Capability::MeshShadingEXT])?;
3679            let lang_version = self.lang_version();
3680            if lang_version.0 <= 1 && lang_version.1 < 4 {
3681                return Err(Error::SpirvVersionTooLow(1, 4));
3682            }
3683        }
3684        if has_ray_tracing_pipeline {
3685            Instruction::extension("SPV_KHR_ray_tracing")
3686                .to_words(&mut self.logical_layout.extensions)
3687        }
3688        Instruction::type_void(self.void_type).to_words(&mut self.logical_layout.declarations);
3689        Instruction::ext_inst_import(self.gl450_ext_inst_id, "GLSL.std.450")
3690            .to_words(&mut self.logical_layout.ext_inst_imports);
3691
3692        let mut debug_info_inner = None;
3693        if self.flags.contains(WriterFlags::DEBUG) {
3694            if let Some(debug_info) = debug_info.as_ref() {
3695                let source_file_id = self.id_gen.next();
3696                self.debugs
3697                    .push(Instruction::string(debug_info.file_name, source_file_id));
3698
3699                debug_info_inner = Some(DebugInfoInner {
3700                    source_code: debug_info.source_code,
3701                    source_file_id,
3702                });
3703                self.debugs.append(&mut Instruction::source_auto_continued(
3704                    debug_info.language,
3705                    0,
3706                    &debug_info_inner,
3707                ));
3708            }
3709        }
3710
3711        // write all types
3712        for (handle, _) in ir_module.types.iter() {
3713            self.write_type_declaration_arena(ir_module, handle)?;
3714        }
3715
3716        // write std140 layout compatible types required by uniforms
3717        for (_, var) in ir_module.global_variables.iter() {
3718            if var.space == crate::AddressSpace::Uniform {
3719                self.write_std140_compat_type_declaration(ir_module, var.ty)?;
3720            }
3721        }
3722
3723        // write all const-expressions as constants
3724        self.constant_ids
3725            .resize(ir_module.global_expressions.len(), 0);
3726        for (handle, _) in ir_module.global_expressions.iter() {
3727            self.write_constant_expr(handle, ir_module, mod_info)?;
3728        }
3729        debug_assert!(self.constant_ids.iter().all(|&id| id != 0));
3730
3731        // write the name of constants on their respective const-expression initializer
3732        if self.flags.contains(WriterFlags::DEBUG) {
3733            for (_, constant) in ir_module.constants.iter() {
3734                if let Some(ref name) = constant.name {
3735                    let id = self.constant_ids[constant.init];
3736                    self.debugs.push(Instruction::name(id, name));
3737                }
3738            }
3739        }
3740
3741        // write all global variables
3742        for (handle, var) in ir_module.global_variables.iter() {
3743            // If a single entry point was specified, only write `OpVariable` instructions
3744            // for the globals it actually uses. Emit dummies for the others,
3745            // to preserve the indices in `global_variables`.
3746            let gvar = match ep_index {
3747                Some(index) if mod_info.get_entry_point(index)[handle].is_empty() => {
3748                    GlobalVariable::dummy()
3749                }
3750                _ => {
3751                    let id = self.write_global_variable(ir_module, var)?;
3752                    GlobalVariable::new(id)
3753                }
3754            };
3755            self.global_variables.insert(handle, gvar);
3756        }
3757
3758        // write all functions
3759        for (handle, ir_function) in ir_module.functions.iter() {
3760            let info = &mod_info[handle];
3761            if let Some(index) = ep_index {
3762                let ep_info = mod_info.get_entry_point(index);
3763                // If this function uses globals that we omitted from the SPIR-V
3764                // because the entry point and its callees didn't use them,
3765                // then we must skip it.
3766                if !ep_info.dominates_global_use(info) {
3767                    log::debug!("Skip function {:?}", ir_function.name);
3768                    continue;
3769                }
3770
3771                // Skip functions that that are not compatible with this entry point's stage.
3772                //
3773                // When validation is enabled, it rejects modules whose entry points try to call
3774                // incompatible functions, so if we got this far, then any functions incompatible
3775                // with our selected entry point must not be used.
3776                //
3777                // When validation is disabled, `fun_info.available_stages` is always just
3778                // `ShaderStages::all()`, so this will write all functions in the module, and
3779                // the downstream GLSL compiler will catch any problems.
3780                if !info.available_stages.contains(ep_info.available_stages) {
3781                    continue;
3782                }
3783            }
3784            let id = self.write_function(ir_function, info, ir_module, None, &debug_info_inner)?;
3785            self.lookup_function.insert(handle, id);
3786        }
3787
3788        // write all or one entry points
3789        for (index, ir_ep) in ir_module.entry_points.iter().enumerate() {
3790            if ep_index.is_some() && ep_index != Some(index) {
3791                continue;
3792            }
3793            let info = mod_info.get_entry_point(index);
3794            let ep_instruction =
3795                self.write_entry_point(ir_ep, info, ir_module, &debug_info_inner)?;
3796            ep_instruction.to_words(&mut self.logical_layout.entry_points);
3797        }
3798
3799        for capability in self.capabilities_used.iter() {
3800            Instruction::capability(*capability).to_words(&mut self.logical_layout.capabilities);
3801        }
3802        for extension in self.extensions_used.iter() {
3803            Instruction::extension(extension).to_words(&mut self.logical_layout.extensions);
3804        }
3805        if ir_module.entry_points.is_empty() {
3806            // SPIR-V doesn't like modules without entry points
3807            Instruction::capability(spirv::Capability::Linkage)
3808                .to_words(&mut self.logical_layout.capabilities);
3809        }
3810
3811        let addressing_model = spirv::AddressingModel::Logical;
3812        let memory_model = if self
3813            .capabilities_used
3814            .contains(&spirv::Capability::VulkanMemoryModel)
3815        {
3816            spirv::MemoryModel::Vulkan
3817        } else {
3818            spirv::MemoryModel::GLSL450
3819        };
3820        //self.check(addressing_model.required_capabilities())?;
3821        //self.check(memory_model.required_capabilities())?;
3822
3823        Instruction::memory_model(addressing_model, memory_model)
3824            .to_words(&mut self.logical_layout.memory_model);
3825
3826        for debug_string in self.debug_strings.iter() {
3827            debug_string.to_words(&mut self.logical_layout.debugs);
3828        }
3829
3830        if self.flags.contains(WriterFlags::DEBUG) {
3831            for debug in self.debugs.iter() {
3832                debug.to_words(&mut self.logical_layout.debugs);
3833            }
3834        }
3835
3836        for annotation in self.annotations.iter() {
3837            annotation.to_words(&mut self.logical_layout.annotations);
3838        }
3839
3840        Ok(())
3841    }
3842
3843    pub fn write(
3844        &mut self,
3845        ir_module: &crate::Module,
3846        info: &ModuleInfo,
3847        pipeline_options: Option<&PipelineOptions>,
3848        debug_info: &Option<DebugInfo>,
3849        words: &mut Vec<Word>,
3850    ) -> Result<(), Error> {
3851        self.reset();
3852
3853        // Try to find the entry point and corresponding index
3854        let ep_index = match pipeline_options {
3855            Some(po) => {
3856                let index = ir_module
3857                    .entry_points
3858                    .iter()
3859                    .position(|ep| po.shader_stage == ep.stage && po.entry_point == ep.name)
3860                    .ok_or(Error::EntryPointNotFound)?;
3861                Some(index)
3862            }
3863            None => None,
3864        };
3865
3866        self.write_logical_layout(ir_module, info, ep_index, debug_info)?;
3867        self.write_physical_layout();
3868
3869        self.physical_layout.in_words(words);
3870        self.logical_layout.in_words(words);
3871        Ok(())
3872    }
3873
3874    /// Return the set of capabilities the last module written used.
3875    pub const fn get_capabilities_used(&self) -> &crate::FastIndexSet<spirv::Capability> {
3876        &self.capabilities_used
3877    }
3878
3879    pub fn decorate_non_uniform_binding_array_access(&mut self, id: Word) -> Result<(), Error> {
3880        self.require_any("NonUniformEXT", &[spirv::Capability::ShaderNonUniform])?;
3881        self.use_extension("SPV_EXT_descriptor_indexing");
3882        self.decorate(id, spirv::Decoration::NonUniform, &[]);
3883        Ok(())
3884    }
3885
3886    pub(super) fn needs_f16_polyfill(&self, ty_inner: &crate::TypeInner) -> bool {
3887        self.io_f16_polyfills.needs_polyfill(ty_inner)
3888    }
3889
3890    pub(super) fn write_debug_printf(
3891        &mut self,
3892        block: &mut Block,
3893        string: &str,
3894        format_params: &[Word],
3895    ) {
3896        if self.debug_printf.is_none() {
3897            self.use_extension("SPV_KHR_non_semantic_info");
3898            let import_id = self.id_gen.next();
3899            Instruction::ext_inst_import(import_id, "NonSemantic.DebugPrintf")
3900                .to_words(&mut self.logical_layout.ext_inst_imports);
3901            self.debug_printf = Some(import_id)
3902        }
3903
3904        let import_id = self.debug_printf.unwrap();
3905
3906        let string_id = self.id_gen.next();
3907        self.debug_strings
3908            .push(Instruction::string(string, string_id));
3909
3910        let mut operands = Vec::with_capacity(1 + format_params.len());
3911        operands.push(string_id);
3912        operands.extend(format_params.iter());
3913
3914        let print_id = self.id_gen.next();
3915        block.body.push(Instruction::ext_inst(
3916            import_id,
3917            1,
3918            self.void_type,
3919            print_id,
3920            &operands,
3921        ));
3922    }
3923}
3924
3925#[test]
3926fn test_write_physical_layout() {
3927    let mut writer = Writer::new(&Options::default()).unwrap();
3928    assert_eq!(writer.physical_layout.bound, 0);
3929    writer.write_physical_layout();
3930    assert_eq!(writer.physical_layout.bound, 3);
3931}