naga/back/spv/
writer.rs

1use alloc::{format, string::String, vec, vec::Vec};
2
3use arrayvec::ArrayVec;
4use hashbrown::hash_map::Entry;
5use spirv::Word;
6
7use super::{
8    block::DebugInfoInner,
9    helpers::{contains_builtin, global_needs_wrapper, map_storage_class},
10    Block, BlockContext, CachedConstant, CachedExpressions, CooperativeType, DebugInfo,
11    EntryPointContext, Error, Function, FunctionArgument, GlobalVariable, IdGenerator, Instruction,
12    LocalImageType, LocalType, LocalVariable, LogicalLayout, LookupFunctionType, LookupType,
13    NumericType, Options, PhysicalLayout, PipelineOptions, ResultMember, Writer, WriterFlags,
14    BITS_PER_BYTE,
15};
16use crate::{
17    arena::{Handle, HandleVec, UniqueArena},
18    back::spv::{
19        helpers::{is_uniform_matcx2_struct_member_access, BindingDecorations},
20        BindingInfo, Std140CompatTypeInfo, WrappedFunction,
21    },
22    common::ForDebugWithTypes as _,
23    proc::{Alignment, TypeResolution},
24    valid::{FunctionInfo, ModuleInfo},
25};
26
27pub struct FunctionInterface<'a> {
28    pub varying_ids: &'a mut Vec<Word>,
29    pub stage: crate::ShaderStage,
30    pub task_payload: Option<Handle<crate::GlobalVariable>>,
31    pub mesh_info: Option<crate::MeshStageInfo>,
32    pub workgroup_size: [u32; 3],
33}
34
35impl Function {
36    pub(super) fn to_words(&self, sink: &mut impl Extend<Word>) {
37        self.signature.as_ref().unwrap().to_words(sink);
38        for argument in self.parameters.iter() {
39            argument.instruction.to_words(sink);
40        }
41        for (index, block) in self.blocks.iter().enumerate() {
42            Instruction::label(block.label_id).to_words(sink);
43            if index == 0 {
44                for local_var in self.variables.values() {
45                    local_var.instruction.to_words(sink);
46                }
47                for local_var in self.ray_query_initialization_tracker_variables.values() {
48                    local_var.instruction.to_words(sink);
49                }
50                for local_var in self.ray_query_t_max_tracker_variables.values() {
51                    local_var.instruction.to_words(sink);
52                }
53                for local_var in self.force_loop_bounding_vars.iter() {
54                    local_var.instruction.to_words(sink);
55                }
56                for internal_var in self.spilled_composites.values() {
57                    internal_var.instruction.to_words(sink);
58                }
59            }
60            for instruction in block.body.iter() {
61                instruction.to_words(sink);
62            }
63        }
64        Instruction::function_end().to_words(sink);
65    }
66}
67
68impl Writer {
69    pub fn new(options: &Options) -> Result<Self, Error> {
70        let (major, minor) = options.lang_version;
71        if major != 1 {
72            return Err(Error::UnsupportedVersion(major, minor));
73        }
74
75        let mut capabilities_used = crate::FastIndexSet::default();
76        capabilities_used.insert(spirv::Capability::Shader);
77
78        let mut id_gen = IdGenerator::default();
79        let gl450_ext_inst_id = id_gen.next();
80        let void_type = id_gen.next();
81
82        Ok(Writer {
83            physical_layout: PhysicalLayout::new(major, minor),
84            logical_layout: LogicalLayout::default(),
85            id_gen,
86            capabilities_available: options.capabilities.clone(),
87            capabilities_used,
88            extensions_used: crate::FastIndexSet::default(),
89            debug_strings: vec![],
90            debugs: vec![],
91            annotations: vec![],
92            flags: options.flags,
93            bounds_check_policies: options.bounds_check_policies,
94            zero_initialize_workgroup_memory: options.zero_initialize_workgroup_memory,
95            force_loop_bounding: options.force_loop_bounding,
96            ray_query_initialization_tracking: options.ray_query_initialization_tracking,
97            trace_ray_argument_validation: options.trace_ray_argument_validation,
98            use_storage_input_output_16: options.use_storage_input_output_16,
99            emit_int_div_checks: options.emit_int_div_checks,
100            void_type,
101            tuple_of_u32s_ty_id: None,
102            lookup_type: crate::FastHashMap::default(),
103            lookup_function: crate::FastHashMap::default(),
104            lookup_function_type: crate::FastHashMap::default(),
105            wrapped_functions: crate::FastHashMap::default(),
106            constant_ids: HandleVec::new(),
107            cached_constants: crate::FastHashMap::default(),
108            global_variables: HandleVec::new(),
109            std140_compat_uniform_types: crate::FastHashMap::default(),
110            fake_missing_bindings: options.fake_missing_bindings,
111            binding_map: options.binding_map.clone(),
112            saved_cached: CachedExpressions::default(),
113            gl450_ext_inst_id,
114            temp_list: Vec::new(),
115            ray_query_functions: crate::FastHashMap::default(),
116            ray_tracing_functions: crate::FastHashMap::default(),
117            has_ray_tracing_pipeline: false,
118            io_f16_polyfills: super::f16_polyfill::F16IoPolyfill::new(
119                options.use_storage_input_output_16,
120            ),
121            debug_printf: None,
122            task_dispatch_limits: options.task_dispatch_limits,
123            mesh_shader_primitive_indices_clamp: options.mesh_shader_primitive_indices_clamp,
124        })
125    }
126
127    pub fn set_options(&mut self, options: &Options) -> Result<(), Error> {
128        let (major, minor) = options.lang_version;
129        if major != 1 {
130            return Err(Error::UnsupportedVersion(major, minor));
131        }
132        self.physical_layout = PhysicalLayout::new(major, minor);
133        self.capabilities_available = options.capabilities.clone();
134        self.flags = options.flags;
135        self.bounds_check_policies = options.bounds_check_policies;
136        self.zero_initialize_workgroup_memory = options.zero_initialize_workgroup_memory;
137        self.force_loop_bounding = options.force_loop_bounding;
138        self.use_storage_input_output_16 = options.use_storage_input_output_16;
139        self.binding_map = options.binding_map.clone();
140        self.io_f16_polyfills =
141            super::f16_polyfill::F16IoPolyfill::new(options.use_storage_input_output_16);
142        self.task_dispatch_limits = options.task_dispatch_limits;
143        self.mesh_shader_primitive_indices_clamp = options.mesh_shader_primitive_indices_clamp;
144        Ok(())
145    }
146
147    /// Returns `(major, minor)` of the SPIR-V language version.
148    pub const fn lang_version(&self) -> (u8, u8) {
149        self.physical_layout.lang_version()
150    }
151
152    /// Reset `Writer` to its initial state, retaining any allocations.
153    ///
154    /// Why not just implement `Reclaimable` for `Writer`? By design,
155    /// `Reclaimable::reclaim` requires ownership of the value, not just
156    /// `&mut`; see the trait documentation. But we need to use this method
157    /// from functions like `Writer::write`, which only have `&mut Writer`.
158    /// Workarounds include unsafe code (`core::ptr::read`, then `write`, ugh)
159    /// or something like a `Default` impl that returns an oddly-initialized
160    /// `Writer`, which is worse.
161    fn reset(&mut self) {
162        use super::reclaimable::Reclaimable;
163        use core::mem::take;
164
165        let mut id_gen = IdGenerator::default();
166        let gl450_ext_inst_id = id_gen.next();
167        let void_type = id_gen.next();
168
169        // Every field of the old writer that is not determined by the `Options`
170        // passed to `Writer::new` should be reset somehow.
171        let fresh = Writer {
172            // Copied from the old Writer:
173            flags: self.flags,
174            bounds_check_policies: self.bounds_check_policies,
175            zero_initialize_workgroup_memory: self.zero_initialize_workgroup_memory,
176            force_loop_bounding: self.force_loop_bounding,
177            ray_query_initialization_tracking: self.ray_query_initialization_tracking,
178            trace_ray_argument_validation: self.trace_ray_argument_validation,
179            use_storage_input_output_16: self.use_storage_input_output_16,
180            capabilities_available: take(&mut self.capabilities_available),
181            fake_missing_bindings: self.fake_missing_bindings,
182            binding_map: take(&mut self.binding_map),
183            task_dispatch_limits: self.task_dispatch_limits,
184            mesh_shader_primitive_indices_clamp: self.mesh_shader_primitive_indices_clamp,
185            emit_int_div_checks: self.emit_int_div_checks,
186
187            // Initialized afresh:
188            id_gen,
189            void_type,
190            tuple_of_u32s_ty_id: None,
191            gl450_ext_inst_id,
192
193            // Reclaimed:
194            capabilities_used: take(&mut self.capabilities_used).reclaim(),
195            extensions_used: take(&mut self.extensions_used).reclaim(),
196            physical_layout: self.physical_layout.clone().reclaim(),
197            logical_layout: take(&mut self.logical_layout).reclaim(),
198            debug_strings: take(&mut self.debug_strings).reclaim(),
199            debugs: take(&mut self.debugs).reclaim(),
200            annotations: take(&mut self.annotations).reclaim(),
201            lookup_type: take(&mut self.lookup_type).reclaim(),
202            lookup_function: take(&mut self.lookup_function).reclaim(),
203            lookup_function_type: take(&mut self.lookup_function_type).reclaim(),
204            wrapped_functions: take(&mut self.wrapped_functions).reclaim(),
205            constant_ids: take(&mut self.constant_ids).reclaim(),
206            cached_constants: take(&mut self.cached_constants).reclaim(),
207            global_variables: take(&mut self.global_variables).reclaim(),
208            std140_compat_uniform_types: take(&mut self.std140_compat_uniform_types).reclaim(),
209            saved_cached: take(&mut self.saved_cached).reclaim(),
210            temp_list: take(&mut self.temp_list).reclaim(),
211            ray_query_functions: take(&mut self.ray_query_functions).reclaim(),
212            ray_tracing_functions: take(&mut self.ray_tracing_functions).reclaim(),
213            has_ray_tracing_pipeline: false,
214            io_f16_polyfills: take(&mut self.io_f16_polyfills).reclaim(),
215            debug_printf: None,
216        };
217
218        *self = fresh;
219
220        self.capabilities_used.insert(spirv::Capability::Shader);
221    }
222
223    /// Indicate that the code requires any one of the listed capabilities.
224    ///
225    /// If nothing in `capabilities` appears in the available capabilities
226    /// specified in the [`Options`] from which this `Writer` was created,
227    /// return an error. The `what` string is used in the error message to
228    /// explain what provoked the requirement. (If no available capabilities were
229    /// given, assume everything is available.)
230    ///
231    /// The first acceptable capability will be added to this `Writer`'s
232    /// [`capabilities_used`] table, and an `OpCapability` emitted for it in the
233    /// result. For this reason, more specific capabilities should be listed
234    /// before more general.
235    ///
236    /// [`capabilities_used`]: Writer::capabilities_used
237    pub(super) fn require_any(
238        &mut self,
239        what: &'static str,
240        capabilities: &[spirv::Capability],
241    ) -> Result<(), Error> {
242        match *capabilities {
243            [] => Ok(()),
244            [first, ..] => {
245                // Find the first acceptable capability, or return an error if
246                // there is none.
247                let selected = match self.capabilities_available {
248                    None => first,
249                    Some(ref available) => {
250                        match capabilities
251                            .iter()
252                            // need explicit type for hashbrown::HashSet::contains fn call to keep rustc happy
253                            .find(|cap| available.contains::<spirv::Capability>(cap))
254                        {
255                            Some(&cap) => cap,
256                            None => {
257                                return Err(Error::MissingCapabilities(what, capabilities.to_vec()))
258                            }
259                        }
260                    }
261                };
262                self.capabilities_used.insert(selected);
263                Ok(())
264            }
265        }
266    }
267
268    /// Indicate that the code requires all of the listed capabilities.
269    ///
270    /// If all entries of `capabilities` appear in the available capabilities
271    /// specified in the [`Options`] from which this `Writer` was created
272    /// (including the case where [`Options::capabilities`] is `None`), add
273    /// them all to this `Writer`'s [`capabilities_used`] table, and return
274    /// `Ok(())`. If at least one of the listed capabilities is not available,
275    /// do not add anything to the `capabilities_used` table, and return the
276    /// first unavailable requested capability, wrapped in `Err()`.
277    ///
278    /// This method is does not return an [`enum@Error`] in case of failure
279    /// because it may be used in cases where the caller can recover (e.g.,
280    /// with a polyfill) if the requested capabilities are not available. In
281    /// this case, it would be unnecessary work to find *all* the unavailable
282    /// requested capabilities, and to allocate a `Vec` for them, just so we
283    /// could return an [`Error::MissingCapabilities`]).
284    ///
285    /// [`capabilities_used`]: Writer::capabilities_used
286    pub(super) fn require_all(
287        &mut self,
288        capabilities: &[spirv::Capability],
289    ) -> Result<(), spirv::Capability> {
290        if let Some(ref available) = self.capabilities_available {
291            for requested in capabilities {
292                if !available.contains(requested) {
293                    return Err(*requested);
294                }
295            }
296        }
297
298        for requested in capabilities {
299            self.capabilities_used.insert(*requested);
300        }
301
302        Ok(())
303    }
304
305    /// Indicate that the code uses the given extension.
306    pub(super) fn use_extension(&mut self, extension: &'static str) {
307        self.extensions_used.insert(extension);
308    }
309
310    pub(super) fn get_type_id(&mut self, lookup_ty: LookupType) -> Word {
311        match self.lookup_type.entry(lookup_ty) {
312            Entry::Occupied(e) => *e.get(),
313            Entry::Vacant(e) => {
314                let local = match lookup_ty {
315                    LookupType::Handle(_handle) => unreachable!("Handles are populated at start"),
316                    LookupType::Local(local) => local,
317                };
318
319                let id = self.id_gen.next();
320                e.insert(id);
321                self.write_type_declaration_local(id, local);
322                id
323            }
324        }
325    }
326
327    pub(super) fn get_handle_type_id(&mut self, handle: Handle<crate::Type>) -> Word {
328        self.get_type_id(LookupType::Handle(handle))
329    }
330
331    pub(super) fn get_expression_lookup_type(&mut self, tr: &TypeResolution) -> LookupType {
332        match *tr {
333            TypeResolution::Handle(ty_handle) => LookupType::Handle(ty_handle),
334            TypeResolution::Value(ref inner) => {
335                let inner_local_type = self.localtype_from_inner(inner).unwrap();
336                LookupType::Local(inner_local_type)
337            }
338        }
339    }
340
341    pub(super) fn get_expression_type_id(&mut self, tr: &TypeResolution) -> Word {
342        let lookup_ty = self.get_expression_lookup_type(tr);
343        self.get_type_id(lookup_ty)
344    }
345
346    pub(super) fn get_localtype_id(&mut self, local: LocalType) -> Word {
347        self.get_type_id(LookupType::Local(local))
348    }
349
350    pub(super) fn get_pointer_type_id(&mut self, base: Word, class: spirv::StorageClass) -> Word {
351        self.get_type_id(LookupType::Local(LocalType::Pointer { base, class }))
352    }
353
354    pub(super) fn get_handle_pointer_type_id(
355        &mut self,
356        base: Handle<crate::Type>,
357        class: spirv::StorageClass,
358    ) -> Word {
359        let base_id = self.get_handle_type_id(base);
360        self.get_pointer_type_id(base_id, class)
361    }
362
363    pub(super) fn get_ray_query_pointer_id(&mut self) -> Word {
364        let rq_id = self.get_type_id(LookupType::Local(LocalType::RayQuery));
365        self.get_pointer_type_id(rq_id, spirv::StorageClass::Function)
366    }
367
368    /// Return a SPIR-V type for a pointer to `resolution`.
369    ///
370    /// The given `resolution` must be one that we can represent
371    /// either as a `LocalType::Pointer` or `LocalType::LocalPointer`.
372    pub(super) fn get_resolution_pointer_id(
373        &mut self,
374        resolution: &TypeResolution,
375        class: spirv::StorageClass,
376    ) -> Word {
377        let resolution_type_id = self.get_expression_type_id(resolution);
378        self.get_pointer_type_id(resolution_type_id, class)
379    }
380
381    pub(super) fn get_numeric_type_id(&mut self, numeric: NumericType) -> Word {
382        self.get_type_id(LocalType::Numeric(numeric).into())
383    }
384
385    pub(super) fn get_u32_type_id(&mut self) -> Word {
386        self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32))
387    }
388
389    pub(super) fn get_f32_type_id(&mut self) -> Word {
390        self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::F32))
391    }
392
393    pub(super) fn get_vec2u_type_id(&mut self) -> Word {
394        self.get_numeric_type_id(NumericType::Vector {
395            size: crate::VectorSize::Bi,
396            scalar: crate::Scalar::U32,
397        })
398    }
399
400    pub(super) fn get_vec2f_type_id(&mut self) -> Word {
401        self.get_numeric_type_id(NumericType::Vector {
402            size: crate::VectorSize::Bi,
403            scalar: crate::Scalar::F32,
404        })
405    }
406
407    pub(super) fn get_vec3u_type_id(&mut self) -> Word {
408        self.get_numeric_type_id(NumericType::Vector {
409            size: crate::VectorSize::Tri,
410            scalar: crate::Scalar::U32,
411        })
412    }
413
414    pub(super) fn get_f32_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word {
415        let f32_id = self.get_f32_type_id();
416        self.get_pointer_type_id(f32_id, class)
417    }
418
419    pub(super) fn get_vec2u_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word {
420        let vec2u_id = self.get_numeric_type_id(NumericType::Vector {
421            size: crate::VectorSize::Bi,
422            scalar: crate::Scalar::U32,
423        });
424        self.get_pointer_type_id(vec2u_id, class)
425    }
426
427    pub(super) fn get_bool_type_id(&mut self) -> Word {
428        self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::BOOL))
429    }
430
431    pub(super) fn get_vec2_bool_type_id(&mut self) -> Word {
432        self.get_numeric_type_id(NumericType::Vector {
433            size: crate::VectorSize::Bi,
434            scalar: crate::Scalar::BOOL,
435        })
436    }
437
438    pub(super) fn get_vec3_bool_type_id(&mut self) -> Word {
439        self.get_numeric_type_id(NumericType::Vector {
440            size: crate::VectorSize::Tri,
441            scalar: crate::Scalar::BOOL,
442        })
443    }
444
445    /// Used for "mulhi" to get the upper bits of multiplication.
446    ///
447    /// More specifically, `OpUMulExtended` multiplies 2 numbers and returns the lower and upper bits of the result
448    /// as a user-defined struct type with 2 u32s. This defines that struct.
449    pub(super) fn get_tuple_of_u32s_ty_id(&mut self) -> Word {
450        if let Some(val) = self.tuple_of_u32s_ty_id {
451            val
452        } else {
453            let id = self.id_gen.next();
454            let u32_id = self.get_u32_type_id();
455            let ins = Instruction::type_struct(id, &[u32_id, u32_id]);
456            ins.to_words(&mut self.logical_layout.declarations);
457            self.tuple_of_u32s_ty_id = Some(id);
458            id
459        }
460    }
461
462    pub(super) fn decorate(&mut self, id: Word, decoration: spirv::Decoration, operands: &[Word]) {
463        self.annotations
464            .push(Instruction::decorate(id, decoration, operands));
465    }
466
467    /// Return `inner` as a `LocalType`, if that's possible.
468    ///
469    /// If `inner` can be represented as a `LocalType`, return
470    /// `Some(local_type)`.
471    ///
472    /// Otherwise, return `None`. In this case, the type must always be looked
473    /// up using a `LookupType::Handle`.
474    fn localtype_from_inner(&mut self, inner: &crate::TypeInner) -> Option<LocalType> {
475        Some(match *inner {
476            crate::TypeInner::Scalar(_)
477            | crate::TypeInner::Atomic(_)
478            | crate::TypeInner::Vector { .. }
479            | crate::TypeInner::Matrix { .. } => {
480                // We expect `NumericType::from_inner` to handle all
481                // these cases, so unwrap.
482                LocalType::Numeric(NumericType::from_inner(inner).unwrap())
483            }
484            crate::TypeInner::CooperativeMatrix { .. } => {
485                LocalType::Cooperative(CooperativeType::from_inner(inner).unwrap())
486            }
487            crate::TypeInner::Pointer { base, space } => {
488                let base_type_id = self.get_handle_type_id(base);
489                LocalType::Pointer {
490                    base: base_type_id,
491                    class: map_storage_class(space),
492                }
493            }
494            crate::TypeInner::ValuePointer {
495                size,
496                scalar,
497                space,
498            } => {
499                let base_numeric_type = match size {
500                    Some(size) => NumericType::Vector { size, scalar },
501                    None => NumericType::Scalar(scalar),
502                };
503                LocalType::Pointer {
504                    base: self.get_numeric_type_id(base_numeric_type),
505                    class: map_storage_class(space),
506                }
507            }
508            crate::TypeInner::Image {
509                dim,
510                arrayed,
511                class,
512            } => LocalType::Image(LocalImageType::from_inner(dim, arrayed, class)),
513            crate::TypeInner::Sampler { comparison: _ } => LocalType::Sampler,
514            crate::TypeInner::AccelerationStructure { .. } => LocalType::AccelerationStructure,
515            crate::TypeInner::RayQuery { .. } => LocalType::RayQuery,
516            crate::TypeInner::Array { .. }
517            | crate::TypeInner::Struct { .. }
518            | crate::TypeInner::BindingArray { .. } => return None,
519        })
520    }
521
522    /// Resolve the [`BindingInfo`] for a [`crate::ResourceBinding`] from the
523    /// provided [`Writer::binding_map`].
524    ///
525    /// If the specified resource is not present in the binding map this will
526    /// return an error, unless [`Writer::fake_missing_bindings`] is set.
527    pub(super) fn resolve_resource_binding(
528        &self,
529        res_binding: &crate::ResourceBinding,
530    ) -> Result<BindingInfo, Error> {
531        match self.binding_map.get(res_binding) {
532            Some(target) => Ok(*target),
533            None if self.fake_missing_bindings => Ok(BindingInfo {
534                descriptor_set: res_binding.group,
535                binding: res_binding.binding,
536                binding_array_size: None,
537            }),
538            None => Err(Error::MissingBinding(*res_binding)),
539        }
540    }
541
542    /// Emits code for any wrapper functions required by the expressions in ir_function.
543    /// The IDs of any emitted functions will be stored in [`Self::wrapped_functions`].
544    fn write_wrapped_functions(
545        &mut self,
546        ir_function: &crate::Function,
547        info: &FunctionInfo,
548        ir_module: &crate::Module,
549    ) -> Result<(), Error> {
550        log::trace!("Generating wrapped functions for {:?}", ir_function.name);
551
552        for (expr_handle, expr) in ir_function.expressions.iter() {
553            match *expr {
554                crate::Expression::Binary { op, left, right } => {
555                    let expr_ty_inner = info[expr_handle].ty.inner_with(&ir_module.types);
556                    if let Some(expr_ty) = NumericType::from_inner(expr_ty_inner) {
557                        match (op, expr_ty.scalar().kind) {
558                            // Division and modulo are undefined behaviour when the
559                            // dividend is the minimum representable value and the divisor
560                            // is negative one, or when the divisor is zero. These wrapped
561                            // functions override the divisor to one in these cases,
562                            // matching the WGSL spec.
563                            //
564                            // Signed `%` is additionally always wrapped (even without
565                            // `emit_int_div_checks`) so it can be lowered to
566                            // `a - b * (a / b)`: `OpSRem` produces a poison result for
567                            // negative operands in the Vulkan SPIR-V environment unless
568                            // `VK_KHR_maintenance8` is enabled. See
569                            // <https://github.com/gfx-rs/wgpu/issues/8191>.
570                            (
571                                crate::BinaryOperator::Divide | crate::BinaryOperator::Modulo,
572                                crate::ScalarKind::Sint | crate::ScalarKind::Uint,
573                            ) if self.emit_int_div_checks
574                                || matches!(
575                                    (op, expr_ty.scalar().kind),
576                                    (crate::BinaryOperator::Modulo, crate::ScalarKind::Sint)
577                                ) =>
578                            {
579                                self.write_wrapped_binary_op(
580                                    op,
581                                    expr_ty,
582                                    &info[left].ty,
583                                    &info[right].ty,
584                                )?;
585                            }
586                            _ => {}
587                        }
588                    }
589                }
590                crate::Expression::Load { pointer } => {
591                    if let crate::TypeInner::Pointer {
592                        base: pointer_type,
593                        space: crate::AddressSpace::Uniform,
594                    } = *info[pointer].ty.inner_with(&ir_module.types)
595                    {
596                        if self.std140_compat_uniform_types.contains_key(&pointer_type) {
597                            // Loading a std140 compat type requires the wrapper function
598                            // to convert to the regular type.
599                            self.write_wrapped_convert_from_std140_compat_type(
600                                ir_module,
601                                pointer_type,
602                            )?;
603                        }
604                    }
605                }
606                crate::Expression::Access { base, .. } => {
607                    if let crate::TypeInner::Pointer {
608                        base: base_type,
609                        space: crate::AddressSpace::Uniform,
610                    } = *info[base].ty.inner_with(&ir_module.types)
611                    {
612                        // Dynamic accesses of a two-row matrix's columns require a
613                        // wrapper function.
614                        if let crate::TypeInner::Matrix {
615                            rows: crate::VectorSize::Bi,
616                            ..
617                        } = ir_module.types[base_type].inner
618                        {
619                            self.write_wrapped_matcx2_get_column(ir_module, base_type)?;
620                            // If the matrix is *not* directly a member of a struct, then
621                            // we additionally require a wrapper function to convert from
622                            // the std140 compat type to the regular type.
623                            if !is_uniform_matcx2_struct_member_access(
624                                ir_function,
625                                info,
626                                ir_module,
627                                base,
628                            ) {
629                                self.write_wrapped_convert_from_std140_compat_type(
630                                    ir_module, base_type,
631                                )?;
632                            }
633                        }
634                    }
635                }
636                _ => {}
637            }
638        }
639
640        Ok(())
641    }
642
643    /// Write a SPIR-V function that performs the operator `op` with Naga IR semantics.
644    ///
645    /// Define a function that performs an integer division or modulo operation,
646    /// except that using a divisor of zero or causing signed overflow with a
647    /// divisor of -1 returns the numerator unchanged, rather than exhibiting
648    /// undefined behavior.
649    ///
650    /// Store the generated function's id in the [`wrapped_functions`] table.
651    ///
652    /// The operator `op` must be either [`Divide`] or [`Modulo`].
653    ///
654    /// # Panics
655    ///
656    /// The `return_type`, `left_type` or `right_type` arguments must all be
657    /// integer scalars or vectors. If not, this function panics.
658    ///
659    /// [`wrapped_functions`]: Writer::wrapped_functions
660    /// [`Divide`]: crate::BinaryOperator::Divide
661    /// [`Modulo`]: crate::BinaryOperator::Modulo
662    fn write_wrapped_binary_op(
663        &mut self,
664        op: crate::BinaryOperator,
665        return_type: NumericType,
666        left_type: &TypeResolution,
667        right_type: &TypeResolution,
668    ) -> Result<(), Error> {
669        let return_type_id = self.get_localtype_id(LocalType::Numeric(return_type));
670        let left_type_id = self.get_expression_type_id(left_type);
671        let right_type_id = self.get_expression_type_id(right_type);
672
673        // Check if we've already emitted this function.
674        let wrapped = WrappedFunction::BinaryOp {
675            op,
676            left_type_id,
677            right_type_id,
678        };
679        let function_id = match self.wrapped_functions.entry(wrapped) {
680            Entry::Occupied(_) => return Ok(()),
681            Entry::Vacant(e) => *e.insert(self.id_gen.next()),
682        };
683
684        let scalar = return_type.scalar();
685
686        if self.flags.contains(WriterFlags::DEBUG) {
687            let function_name = match op {
688                crate::BinaryOperator::Divide => "naga_div",
689                crate::BinaryOperator::Modulo => "naga_mod",
690                _ => unreachable!(),
691            };
692            self.debugs
693                .push(Instruction::name(function_id, function_name));
694        }
695        let mut function = Function::default();
696
697        let function_type_id = self.get_function_type(LookupFunctionType {
698            parameter_type_ids: vec![left_type_id, right_type_id],
699            return_type_id,
700        });
701        function.signature = Some(Instruction::function(
702            return_type_id,
703            function_id,
704            spirv::FunctionControl::empty(),
705            function_type_id,
706        ));
707
708        let lhs_id = self.id_gen.next();
709        let rhs_id = self.id_gen.next();
710        if self.flags.contains(WriterFlags::DEBUG) {
711            self.debugs.push(Instruction::name(lhs_id, "lhs"));
712            self.debugs.push(Instruction::name(rhs_id, "rhs"));
713        }
714        let left_par = Instruction::function_parameter(left_type_id, lhs_id);
715        let right_par = Instruction::function_parameter(right_type_id, rhs_id);
716        for instruction in [left_par, right_par] {
717            function.parameters.push(FunctionArgument {
718                instruction,
719                handle_id: 0,
720            });
721        }
722
723        let label_id = self.id_gen.next();
724        let mut block = Block::new(label_id);
725
726        let bool_type = return_type.with_scalar(crate::Scalar::BOOL);
727        let bool_type_id = self.get_numeric_type_id(bool_type);
728
729        let maybe_splat_const = |writer: &mut Self, const_id| match return_type {
730            NumericType::Scalar(_) => const_id,
731            NumericType::Vector { size, .. } => {
732                let constituent_ids = [const_id; crate::VectorSize::MAX];
733                writer.get_constant_composite(
734                    LookupType::Local(LocalType::Numeric(return_type)),
735                    &constituent_ids[..size as usize],
736                )
737            }
738            NumericType::Matrix { .. } => unreachable!(),
739        };
740
741        let const_zero_id = self.get_constant_scalar_with(0, scalar)?;
742        let composite_zero_id = maybe_splat_const(self, const_zero_id);
743        let rhs_eq_zero_id = self.id_gen.next();
744        block.body.push(Instruction::binary(
745            spirv::Op::IEqual,
746            bool_type_id,
747            rhs_eq_zero_id,
748            rhs_id,
749            composite_zero_id,
750        ));
751        let divisor_selector_id = match scalar.kind {
752            crate::ScalarKind::Sint => {
753                let (const_min_id, const_neg_one_id) = match scalar.width {
754                    2 => Ok((
755                        self.get_constant_scalar(crate::Literal::I16(i16::MIN)),
756                        self.get_constant_scalar(crate::Literal::I16(-1i16)),
757                    )),
758                    4 => Ok((
759                        self.get_constant_scalar(crate::Literal::I32(i32::MIN)),
760                        self.get_constant_scalar(crate::Literal::I32(-1i32)),
761                    )),
762                    8 => Ok((
763                        self.get_constant_scalar(crate::Literal::I64(i64::MIN)),
764                        self.get_constant_scalar(crate::Literal::I64(-1i64)),
765                    )),
766                    _ => Err(Error::Validation("Unexpected scalar width")),
767                }?;
768                let composite_min_id = maybe_splat_const(self, const_min_id);
769                let composite_neg_one_id = maybe_splat_const(self, const_neg_one_id);
770
771                let lhs_eq_int_min_id = self.id_gen.next();
772                block.body.push(Instruction::binary(
773                    spirv::Op::IEqual,
774                    bool_type_id,
775                    lhs_eq_int_min_id,
776                    lhs_id,
777                    composite_min_id,
778                ));
779                let rhs_eq_neg_one_id = self.id_gen.next();
780                block.body.push(Instruction::binary(
781                    spirv::Op::IEqual,
782                    bool_type_id,
783                    rhs_eq_neg_one_id,
784                    rhs_id,
785                    composite_neg_one_id,
786                ));
787                let lhs_eq_int_min_and_rhs_eq_neg_one_id = self.id_gen.next();
788                block.body.push(Instruction::binary(
789                    spirv::Op::LogicalAnd,
790                    bool_type_id,
791                    lhs_eq_int_min_and_rhs_eq_neg_one_id,
792                    lhs_eq_int_min_id,
793                    rhs_eq_neg_one_id,
794                ));
795                let rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id = self.id_gen.next();
796                block.body.push(Instruction::binary(
797                    spirv::Op::LogicalOr,
798                    bool_type_id,
799                    rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id,
800                    rhs_eq_zero_id,
801                    lhs_eq_int_min_and_rhs_eq_neg_one_id,
802                ));
803                rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id
804            }
805            crate::ScalarKind::Uint => rhs_eq_zero_id,
806            _ => unreachable!(),
807        };
808
809        let const_one_id = self.get_constant_scalar_with(1, scalar)?;
810        let composite_one_id = maybe_splat_const(self, const_one_id);
811        let divisor_id = self.id_gen.next();
812        block.body.push(Instruction::select(
813            right_type_id,
814            divisor_id,
815            divisor_selector_id,
816            composite_one_id,
817            rhs_id,
818        ));
819        let return_id = if matches!(op, crate::BinaryOperator::Modulo)
820            && matches!(scalar.kind, crate::ScalarKind::Sint)
821        {
822            // `OpSRem` produces a poison result for negative operands in the Vulkan
823            // environment without the `maintenance8` feature. `OpSDiv` is not poisoned, so
824            // reconstruct the remainder as `a - b * (a / b)`, which is well-defined for
825            // negative operands. `divisor_id` is the zero/overflow-guarded divisor selected
826            // above, so the degenerate cases still match `OpSRem`'s guarded result (0).
827            let quotient_id = self.id_gen.next();
828            block.body.push(Instruction::binary(
829                spirv::Op::SDiv,
830                return_type_id,
831                quotient_id,
832                lhs_id,
833                divisor_id,
834            ));
835            let product_id = self.id_gen.next();
836            block.body.push(Instruction::binary(
837                spirv::Op::IMul,
838                return_type_id,
839                product_id,
840                quotient_id,
841                divisor_id,
842            ));
843            let remainder_id = self.id_gen.next();
844            block.body.push(Instruction::binary(
845                spirv::Op::ISub,
846                return_type_id,
847                remainder_id,
848                lhs_id,
849                product_id,
850            ));
851            remainder_id
852        } else {
853            let spv_op = match (op, scalar.kind) {
854                (crate::BinaryOperator::Divide, crate::ScalarKind::Sint) => spirv::Op::SDiv,
855                (crate::BinaryOperator::Divide, crate::ScalarKind::Uint) => spirv::Op::UDiv,
856                (crate::BinaryOperator::Modulo, crate::ScalarKind::Sint) => spirv::Op::SRem,
857                (crate::BinaryOperator::Modulo, crate::ScalarKind::Uint) => spirv::Op::UMod,
858                _ => unreachable!(),
859            };
860            let return_id = self.id_gen.next();
861            block.body.push(Instruction::binary(
862                spv_op,
863                return_type_id,
864                return_id,
865                lhs_id,
866                divisor_id,
867            ));
868            return_id
869        };
870
871        function.consume(block, Instruction::return_value(return_id));
872        function.to_words(&mut self.logical_layout.function_definitions);
873        Ok(())
874    }
875
876    /// Writes a wrapper function to convert from a std140 compat type to its
877    /// corresponding regular type.
878    ///
879    /// See [`Self::write_std140_compat_type_declaration`] for more details.
880    fn write_wrapped_convert_from_std140_compat_type(
881        &mut self,
882        ir_module: &crate::Module,
883        r#type: Handle<crate::Type>,
884    ) -> Result<(), Error> {
885        if !self.std140_compat_uniform_types.contains_key(&r#type) {
886            return Ok(());
887        }
888        // Check if we've already emitted this function.
889        let wrapped = WrappedFunction::ConvertFromStd140CompatType { r#type };
890        let function_id = match self.wrapped_functions.entry(wrapped) {
891            Entry::Occupied(_) => return Ok(()),
892            Entry::Vacant(e) => *e.insert(self.id_gen.next()),
893        };
894        if self.flags.contains(WriterFlags::DEBUG) {
895            self.debugs.push(Instruction::name(
896                function_id,
897                &format!("{:?}_from_std140", r#type.for_debug(&ir_module.types)),
898            ));
899        }
900        let param_type_id = self.std140_compat_uniform_types[&r#type].type_id;
901        let return_type_id = self.get_handle_type_id(r#type);
902
903        let mut function = Function::default();
904        let function_type_id = self.get_function_type(LookupFunctionType {
905            parameter_type_ids: vec![param_type_id],
906            return_type_id,
907        });
908        function.signature = Some(Instruction::function(
909            return_type_id,
910            function_id,
911            spirv::FunctionControl::empty(),
912            function_type_id,
913        ));
914        let param_id = self.id_gen.next();
915        function.parameters.push(FunctionArgument {
916            instruction: Instruction::function_parameter(param_type_id, param_id),
917            handle_id: 0,
918        });
919
920        let label_id = self.id_gen.next();
921        let mut block = Block::new(label_id);
922
923        let result_id = match ir_module.types[r#type].inner {
924            // Param is struct containing a vector member for each of the
925            // matrix's columns. Extract each column from the struct then
926            // composite into a matrix.
927            crate::TypeInner::Matrix {
928                columns,
929                rows: rows @ crate::VectorSize::Bi,
930                scalar,
931            } => {
932                let column_type_id =
933                    self.get_numeric_type_id(NumericType::Vector { size: rows, scalar });
934
935                let mut column_ids: ArrayVec<Word, 4> = ArrayVec::new();
936                for column in 0..columns as u32 {
937                    let column_id = self.id_gen.next();
938                    block.body.push(Instruction::composite_extract(
939                        column_type_id,
940                        column_id,
941                        param_id,
942                        &[column],
943                    ));
944                    column_ids.push(column_id);
945                }
946                let result_id = self.id_gen.next();
947                block.body.push(Instruction::composite_construct(
948                    return_type_id,
949                    result_id,
950                    &column_ids,
951                ));
952                result_id
953            }
954            // Param is an array where the base type is the std140 compatible
955            // type corresponding to `base`. Iterate through each element and
956            // call its conversion function, then composite into a new array.
957            crate::TypeInner::Array { base, size, .. } => {
958                // Ensure the conversion function for the array's base type is
959                // declared.
960                self.write_wrapped_convert_from_std140_compat_type(ir_module, base)?;
961
962                let element_type_id = self.get_handle_type_id(base);
963                let std140_info = self.std140_compat_uniform_types.get(&base);
964                let mut element_ids = Vec::new();
965                let size = match size.resolve(ir_module.to_ctx())? {
966                    crate::proc::IndexableLength::Known(size) => size,
967                    crate::proc::IndexableLength::Dynamic => {
968                        return Err(Error::Validation(
969                            "Uniform buffers cannot contain dynamic arrays",
970                        ))
971                    }
972                };
973                for i in 0..size {
974                    let std140_element_id = self.id_gen.next();
975                    let std140_element_type_id =
976                        std140_info.map_or(element_type_id, |info| info.type_id);
977                    block.body.push(Instruction::composite_extract(
978                        std140_element_type_id,
979                        std140_element_id,
980                        param_id,
981                        &[i],
982                    ));
983
984                    // Only call the conversion function if a compatibility mapping actually exists.
985                    let final_element_id = if std140_info.is_some() {
986                        let conversion_fn_id = self.wrapped_functions
987                            [&WrappedFunction::ConvertFromStd140CompatType { r#type: base }];
988                        let id = self.id_gen.next();
989                        block.body.push(Instruction::function_call(
990                            element_type_id,
991                            id,
992                            conversion_fn_id,
993                            &[std140_element_id],
994                        ));
995                        id
996                    } else {
997                        std140_element_type_id
998                    };
999                    element_ids.push(final_element_id);
1000                }
1001                let result_id = self.id_gen.next();
1002                block.body.push(Instruction::composite_construct(
1003                    return_type_id,
1004                    result_id,
1005                    &element_ids,
1006                ));
1007                result_id
1008            }
1009            // Param is a struct where each two-row matrix member has been
1010            // decomposed in to separate vector members for each column.
1011            // Other members use their std140 compatible type if one exists, or
1012            // else their regular type. Iterate through each member, converting
1013            // or composing any matrices if required, then finally compose into
1014            // the struct.
1015            crate::TypeInner::Struct { ref members, .. } => {
1016                let mut member_ids = Vec::new();
1017                let mut next_index = 0;
1018                for member in members {
1019                    let member_id = self.id_gen.next();
1020                    let member_type_id = self.get_handle_type_id(member.ty);
1021                    match ir_module.types[member.ty].inner {
1022                        crate::TypeInner::Matrix {
1023                            columns,
1024                            rows: rows @ crate::VectorSize::Bi,
1025                            scalar,
1026                        } => {
1027                            let mut column_ids: ArrayVec<Word, 4> = ArrayVec::new();
1028                            let column_type_id = self
1029                                .get_numeric_type_id(NumericType::Vector { size: rows, scalar });
1030                            for _ in 0..columns as u32 {
1031                                let column_id = self.id_gen.next();
1032                                block.body.push(Instruction::composite_extract(
1033                                    column_type_id,
1034                                    column_id,
1035                                    param_id,
1036                                    &[next_index],
1037                                ));
1038                                column_ids.push(column_id);
1039                                next_index += 1;
1040                            }
1041                            block.body.push(Instruction::composite_construct(
1042                                member_type_id,
1043                                member_id,
1044                                &column_ids,
1045                            ));
1046                        }
1047                        _ => {
1048                            // Ensure the conversion function for the member's
1049                            // type is declared.
1050                            self.write_wrapped_convert_from_std140_compat_type(
1051                                ir_module, member.ty,
1052                            )?;
1053                            match self.std140_compat_uniform_types.get(&member.ty) {
1054                                Some(std140_type_info) => {
1055                                    let std140_member_id = self.id_gen.next();
1056                                    block.body.push(Instruction::composite_extract(
1057                                        std140_type_info.type_id,
1058                                        std140_member_id,
1059                                        param_id,
1060                                        &[next_index],
1061                                    ));
1062                                    let function_id = self.wrapped_functions
1063                                        [&WrappedFunction::ConvertFromStd140CompatType {
1064                                            r#type: member.ty,
1065                                        }];
1066                                    block.body.push(Instruction::function_call(
1067                                        member_type_id,
1068                                        member_id,
1069                                        function_id,
1070                                        &[std140_member_id],
1071                                    ));
1072                                    next_index += 1;
1073                                }
1074                                None => {
1075                                    block.body.push(Instruction::composite_extract(
1076                                        member_type_id,
1077                                        member_id,
1078                                        param_id,
1079                                        &[next_index],
1080                                    ));
1081                                    next_index += 1;
1082                                }
1083                            }
1084                        }
1085                    }
1086                    member_ids.push(member_id);
1087                }
1088                let result_id = self.id_gen.next();
1089                block.body.push(Instruction::composite_construct(
1090                    return_type_id,
1091                    result_id,
1092                    &member_ids,
1093                ));
1094                result_id
1095            }
1096            _ => unreachable!(),
1097        };
1098
1099        function.consume(block, Instruction::return_value(result_id));
1100        function.to_words(&mut self.logical_layout.function_definitions);
1101        Ok(())
1102    }
1103
1104    /// Writes a wrapper function to get an `OpTypeVector` column from an
1105    /// `OpTypeMatrix` with a dynamic index.
1106    ///
1107    /// This is used when accessing a column of a [`TypeInner::Matrix`] through
1108    /// a [`Uniform`] address space pointer. In such cases, the matrix will have
1109    /// been declared in SPIR-V using an alternative type where each column is a
1110    /// member of a containing struct. SPIR-V is unable to dynamically access
1111    /// struct members, so instead we load the matrix then call this function to
1112    /// access a column from the loaded value.
1113    ///
1114    /// [`TypeInner::Matrix`]: crate::TypeInner::Matrix
1115    /// [`Uniform`]: crate::AddressSpace::Uniform
1116    fn write_wrapped_matcx2_get_column(
1117        &mut self,
1118        ir_module: &crate::Module,
1119        r#type: Handle<crate::Type>,
1120    ) -> Result<(), Error> {
1121        let wrapped = WrappedFunction::MatCx2GetColumn { r#type };
1122        let function_id = match self.wrapped_functions.entry(wrapped) {
1123            Entry::Occupied(_) => return Ok(()),
1124            Entry::Vacant(e) => *e.insert(self.id_gen.next()),
1125        };
1126        if self.flags.contains(WriterFlags::DEBUG) {
1127            self.debugs.push(Instruction::name(
1128                function_id,
1129                &format!("{:?}_get_column", r#type.for_debug(&ir_module.types)),
1130            ));
1131        }
1132
1133        let crate::TypeInner::Matrix {
1134            columns,
1135            rows: rows @ crate::VectorSize::Bi,
1136            scalar,
1137        } = ir_module.types[r#type].inner
1138        else {
1139            unreachable!();
1140        };
1141
1142        let mut function = Function::default();
1143        let matrix_type_id = self.get_handle_type_id(r#type);
1144        let column_index_type_id = self.get_u32_type_id();
1145        let column_type_id = self.get_numeric_type_id(NumericType::Vector { size: rows, scalar });
1146        let matrix_param_id = self.id_gen.next();
1147        let column_index_param_id = self.id_gen.next();
1148        function.parameters.push(FunctionArgument {
1149            instruction: Instruction::function_parameter(matrix_type_id, matrix_param_id),
1150            handle_id: 0,
1151        });
1152        function.parameters.push(FunctionArgument {
1153            instruction: Instruction::function_parameter(
1154                column_index_type_id,
1155                column_index_param_id,
1156            ),
1157            handle_id: 0,
1158        });
1159        let function_type_id = self.get_function_type(LookupFunctionType {
1160            parameter_type_ids: vec![matrix_type_id, column_index_type_id],
1161            return_type_id: column_type_id,
1162        });
1163        function.signature = Some(Instruction::function(
1164            column_type_id,
1165            function_id,
1166            spirv::FunctionControl::empty(),
1167            function_type_id,
1168        ));
1169
1170        let label_id = self.id_gen.next();
1171        let mut block = Block::new(label_id);
1172
1173        // Create a switch case for each column in the matrix, where each case
1174        // extracts its column from the matrix. Finally we use OpPhi to return
1175        // the correct column.
1176        let merge_id = self.id_gen.next();
1177        block.body.push(Instruction::selection_merge(
1178            merge_id,
1179            spirv::SelectionControl::NONE,
1180        ));
1181        let cases = (0..columns as u32)
1182            .map(|i| super::instructions::Case {
1183                value: i,
1184                label_id: self.id_gen.next(),
1185            })
1186            .collect::<ArrayVec<_, 4>>();
1187
1188        // Which label we branch to in the default (column index out-of-bounds)
1189        // case depends on our bounds check policy.
1190        let default_id = match self.bounds_check_policies.index {
1191            // For `Restrict`, treat the same as the final column.
1192            crate::proc::BoundsCheckPolicy::Restrict => cases.last().unwrap().label_id,
1193            // For `ReadZeroSkipWrite`, branch directly to the merge block. This
1194            // will be handled in the `OpPhi` below to produce a zero value.
1195            crate::proc::BoundsCheckPolicy::ReadZeroSkipWrite => merge_id,
1196            // For `Unchecked` we create a new block containing an
1197            // `OpUnreachable`.
1198            crate::proc::BoundsCheckPolicy::Unchecked => self.id_gen.next(),
1199        };
1200        function.consume(
1201            block,
1202            Instruction::switch(column_index_param_id, default_id, &cases),
1203        );
1204
1205        // Emit a block for each case, and produce a list of variable and parent
1206        // block IDs that will be used in an `OpPhi` below to select the right
1207        // value.
1208        let mut var_parent_pairs = cases
1209            .into_iter()
1210            .map(|case| {
1211                let mut block = Block::new(case.label_id);
1212                let column_id = self.id_gen.next();
1213                block.body.push(Instruction::composite_extract(
1214                    column_type_id,
1215                    column_id,
1216                    matrix_param_id,
1217                    &[case.value],
1218                ));
1219                function.consume(block, Instruction::branch(merge_id));
1220                (column_id, case.label_id)
1221            })
1222            // Need capacity for up to 4 columns plus possibly a default case.
1223            .collect::<ArrayVec<_, 5>>();
1224
1225        // Emit a block or append the variable and parent `OpPhi` pair for the
1226        // column index out-of-bounds case, if required.
1227        match self.bounds_check_policies.index {
1228            // Don't need to do anything for `Restrict` as we have branched from
1229            // the final column case's block.
1230            crate::proc::BoundsCheckPolicy::Restrict => {}
1231            // For `ReadZeroSkipWrite` we have branched directly from the block
1232            // containing the `OpSwitch`. The `OpPhi` should produce a zero
1233            // value.
1234            crate::proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
1235                var_parent_pairs.push((self.get_constant_null(column_type_id), label_id));
1236            }
1237            // For `Unchecked` create a new block containing `OpUnreachable`.
1238            // This does not need to be handled by the `OpPhi`.
1239            crate::proc::BoundsCheckPolicy::Unchecked => {
1240                function.consume(
1241                    Block::new(default_id),
1242                    Instruction::new(spirv::Op::Unreachable),
1243                );
1244            }
1245        }
1246
1247        let mut block = Block::new(merge_id);
1248        let result_id = self.id_gen.next();
1249        block.body.push(Instruction::phi(
1250            column_type_id,
1251            result_id,
1252            &var_parent_pairs,
1253        ));
1254
1255        function.consume(block, Instruction::return_value(result_id));
1256        function.to_words(&mut self.logical_layout.function_definitions);
1257        Ok(())
1258    }
1259
1260    fn write_function(
1261        &mut self,
1262        ir_function: &crate::Function,
1263        info: &FunctionInfo,
1264        ir_module: &crate::Module,
1265        mut interface: Option<FunctionInterface>,
1266        debug_info: &Option<DebugInfoInner>,
1267    ) -> Result<Word, Error> {
1268        self.write_wrapped_functions(ir_function, info, ir_module)?;
1269
1270        log::trace!("Generating code for {:?}", ir_function.name);
1271        let mut function = Function::default();
1272
1273        let prelude_id = self.id_gen.next();
1274        let mut prelude = Block::new(prelude_id);
1275        let mut ep_context = EntryPointContext {
1276            argument_ids: Vec::new(),
1277            results: Vec::new(),
1278            task_payload_variable_id: if let Some(ref i) = interface {
1279                i.task_payload.map(|a| self.global_variables[a].var_id)
1280            } else {
1281                None
1282            },
1283            mesh_state: None,
1284        };
1285
1286        let mut parameter_type_ids = Vec::with_capacity(ir_function.arguments.len());
1287
1288        let mut local_invocation_index_var_id = None;
1289        let mut local_invocation_index_id = None;
1290
1291        for argument in ir_function.arguments.iter() {
1292            let class = spirv::StorageClass::Input;
1293            let handle_ty = ir_module.types[argument.ty].inner.is_handle();
1294            let argument_type_id = if handle_ty {
1295                self.get_handle_pointer_type_id(argument.ty, spirv::StorageClass::UniformConstant)
1296            } else {
1297                self.get_handle_type_id(argument.ty)
1298            };
1299
1300            if let Some(ref mut iface) = interface {
1301                let id = if let Some(ref binding) = argument.binding {
1302                    let name = argument.name.as_deref();
1303
1304                    let varying_id = self.write_varying(
1305                        ir_module,
1306                        iface.stage,
1307                        class,
1308                        name,
1309                        argument.ty,
1310                        binding,
1311                    )?;
1312                    iface.varying_ids.push(varying_id);
1313                    let id = self.load_io_with_f16_polyfill(
1314                        &mut prelude.body,
1315                        varying_id,
1316                        argument_type_id,
1317                    );
1318                    if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationIndex) {
1319                        local_invocation_index_id = Some(id);
1320                        local_invocation_index_var_id = Some(varying_id);
1321                    }
1322
1323                    id
1324                } else if let crate::TypeInner::Struct { ref members, .. } =
1325                    ir_module.types[argument.ty].inner
1326                {
1327                    let struct_id = self.id_gen.next();
1328                    let mut constituent_ids = Vec::with_capacity(members.len());
1329                    for member in members {
1330                        let type_id = self.get_handle_type_id(member.ty);
1331                        let name = member.name.as_deref();
1332                        let binding = member.binding.as_ref().unwrap();
1333                        let varying_id = self.write_varying(
1334                            ir_module,
1335                            iface.stage,
1336                            class,
1337                            name,
1338                            member.ty,
1339                            binding,
1340                        )?;
1341                        iface.varying_ids.push(varying_id);
1342                        let id =
1343                            self.load_io_with_f16_polyfill(&mut prelude.body, varying_id, type_id);
1344                        constituent_ids.push(id);
1345                        if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationIndex)
1346                        {
1347                            local_invocation_index_id = Some(id);
1348                            local_invocation_index_var_id = Some(varying_id);
1349                        }
1350                    }
1351                    prelude.body.push(Instruction::composite_construct(
1352                        argument_type_id,
1353                        struct_id,
1354                        &constituent_ids,
1355                    ));
1356                    struct_id
1357                } else {
1358                    unreachable!("Missing argument binding on an entry point");
1359                };
1360                ep_context.argument_ids.push(id);
1361            } else {
1362                let argument_id = self.id_gen.next();
1363                let instruction = Instruction::function_parameter(argument_type_id, argument_id);
1364                if self.flags.contains(WriterFlags::DEBUG) {
1365                    if let Some(ref name) = argument.name {
1366                        self.debugs.push(Instruction::name(argument_id, name));
1367                    }
1368                }
1369                function.parameters.push(FunctionArgument {
1370                    instruction,
1371                    handle_id: if handle_ty {
1372                        let id = self.id_gen.next();
1373                        prelude.body.push(Instruction::load(
1374                            self.get_handle_type_id(argument.ty),
1375                            id,
1376                            argument_id,
1377                            None,
1378                        ));
1379                        id
1380                    } else {
1381                        0
1382                    },
1383                });
1384                parameter_type_ids.push(argument_type_id);
1385            };
1386        }
1387
1388        let return_type_id = match ir_function.result {
1389            Some(ref result) => {
1390                if let Some(ref mut iface) = interface {
1391                    let mut has_point_size = false;
1392                    let class = spirv::StorageClass::Output;
1393                    if let Some(ref binding) = result.binding {
1394                        has_point_size |=
1395                            *binding == crate::Binding::BuiltIn(crate::BuiltIn::PointSize);
1396                        let type_id = self.get_handle_type_id(result.ty);
1397                        let varying_id =
1398                            if *binding == crate::Binding::BuiltIn(crate::BuiltIn::MeshTaskSize) {
1399                                0
1400                            } else {
1401                                let varying_id = self.write_varying(
1402                                    ir_module,
1403                                    iface.stage,
1404                                    class,
1405                                    None,
1406                                    result.ty,
1407                                    binding,
1408                                )?;
1409                                iface.varying_ids.push(varying_id);
1410                                varying_id
1411                            };
1412                        ep_context.results.push(ResultMember {
1413                            id: varying_id,
1414                            type_id,
1415                            built_in: binding.to_built_in(),
1416                        });
1417                    } else if let crate::TypeInner::Struct { ref members, .. } =
1418                        ir_module.types[result.ty].inner
1419                    {
1420                        for member in members {
1421                            let type_id = self.get_handle_type_id(member.ty);
1422                            let name = member.name.as_deref();
1423                            let binding = member.binding.as_ref().unwrap();
1424                            has_point_size |=
1425                                *binding == crate::Binding::BuiltIn(crate::BuiltIn::PointSize);
1426                            // This isn't an actual builtin in SPIR-V. It can only appear as the
1427                            // output of a task shader and the output is used when writing the
1428                            // entry point return, in which case the id is ignored anyway.
1429                            let varying_id = if *binding
1430                                == crate::Binding::BuiltIn(crate::BuiltIn::MeshTaskSize)
1431                            {
1432                                0
1433                            } else {
1434                                let varying_id = self.write_varying(
1435                                    ir_module,
1436                                    iface.stage,
1437                                    class,
1438                                    name,
1439                                    member.ty,
1440                                    binding,
1441                                )?;
1442                                iface.varying_ids.push(varying_id);
1443                                varying_id
1444                            };
1445                            ep_context.results.push(ResultMember {
1446                                id: varying_id,
1447                                type_id,
1448                                built_in: binding.to_built_in(),
1449                            });
1450                        }
1451                    } else {
1452                        unreachable!("Missing result binding on an entry point");
1453                    }
1454
1455                    if self.flags.contains(WriterFlags::FORCE_POINT_SIZE)
1456                        && iface.stage == crate::ShaderStage::Vertex
1457                        && !has_point_size
1458                    {
1459                        // add point size artificially
1460                        let varying_id = self.id_gen.next();
1461                        let pointer_type_id = self.get_f32_pointer_type_id(class);
1462                        Instruction::variable(pointer_type_id, varying_id, class, None)
1463                            .to_words(&mut self.logical_layout.declarations);
1464                        self.decorate(
1465                            varying_id,
1466                            spirv::Decoration::BuiltIn,
1467                            &[spirv::BuiltIn::PointSize as u32],
1468                        );
1469                        iface.varying_ids.push(varying_id);
1470
1471                        let default_value_id = self.get_constant_scalar(crate::Literal::F32(1.0));
1472                        prelude
1473                            .body
1474                            .push(Instruction::store(varying_id, default_value_id, None));
1475                    }
1476                    if iface.stage == crate::ShaderStage::Task {
1477                        self.get_vec3u_type_id()
1478                    } else {
1479                        self.void_type
1480                    }
1481                } else {
1482                    self.get_handle_type_id(result.ty)
1483                }
1484            }
1485            None => self.void_type,
1486        };
1487
1488        if let Some(ref mut iface) = interface {
1489            if let Some(task_payload) = iface.task_payload {
1490                iface
1491                    .varying_ids
1492                    .push(self.global_variables[task_payload].var_id);
1493            }
1494            self.write_entry_point_mesh_shader_info(
1495                iface,
1496                local_invocation_index_var_id,
1497                ir_module,
1498                &mut ep_context,
1499            )?;
1500        }
1501
1502        let lookup_function_type = LookupFunctionType {
1503            parameter_type_ids,
1504            return_type_id,
1505        };
1506
1507        let function_id = self.id_gen.next();
1508        if self.flags.contains(WriterFlags::DEBUG) {
1509            if let Some(ref name) = ir_function.name {
1510                self.debugs.push(Instruction::name(function_id, name));
1511            }
1512        }
1513
1514        let function_type = self.get_function_type(lookup_function_type);
1515        function.signature = Some(Instruction::function(
1516            return_type_id,
1517            function_id,
1518            spirv::FunctionControl::empty(),
1519            function_type,
1520        ));
1521
1522        if interface.is_some() {
1523            function.entry_point_context = Some(ep_context);
1524        }
1525
1526        // fill up the `GlobalVariable::access_id`
1527        for gv in self.global_variables.iter_mut() {
1528            gv.reset_for_function();
1529        }
1530        for (handle, var) in ir_module.global_variables.iter() {
1531            if info[handle].is_empty() {
1532                continue;
1533            }
1534
1535            let mut gv = self.global_variables[handle].clone();
1536            if let Some(ref mut iface) = interface {
1537                // Have to include global variables in the interface
1538                if self.physical_layout.version >= 0x10400 && iface.task_payload != Some(handle) {
1539                    iface.varying_ids.push(gv.var_id);
1540                }
1541            }
1542
1543            match ir_module.types[var.ty].inner {
1544                // Any that are binding arrays we skip as we cannot load the array, we must load the result after indexing.
1545                crate::TypeInner::BindingArray { .. } => {
1546                    gv.access_id = gv.var_id;
1547                }
1548                _ => {
1549                    // Handle globals are pre-emitted and should be loaded automatically.
1550                    if var.space == crate::AddressSpace::Handle {
1551                        let var_type_id = self.get_handle_type_id(var.ty);
1552                        let id = self.id_gen.next();
1553                        prelude
1554                            .body
1555                            .push(Instruction::load(var_type_id, id, gv.var_id, None));
1556                        gv.access_id = gv.var_id;
1557                        gv.handle_id = id;
1558                    } else if global_needs_wrapper(ir_module, var) {
1559                        let class = map_storage_class(var.space);
1560                        let pointer_type_id = match self.std140_compat_uniform_types.get(&var.ty) {
1561                            Some(std140_type_info) if var.space == crate::AddressSpace::Uniform => {
1562                                self.get_pointer_type_id(std140_type_info.type_id, class)
1563                            }
1564                            _ => self.get_handle_pointer_type_id(var.ty, class),
1565                        };
1566                        let index_id = self.get_index_constant(0);
1567                        let id = self.id_gen.next();
1568                        prelude.body.push(Instruction::access_chain(
1569                            pointer_type_id,
1570                            id,
1571                            gv.var_id,
1572                            &[index_id],
1573                        ));
1574                        gv.access_id = id;
1575                    } else {
1576                        // by default, the variable ID is accessed as is
1577                        gv.access_id = gv.var_id;
1578                    };
1579                }
1580            }
1581
1582            // work around borrow checking in the presence of `self.xxx()` calls
1583            self.global_variables[handle] = gv;
1584        }
1585
1586        // Create a `BlockContext` for generating SPIR-V for the function's
1587        // body.
1588        let mut context = BlockContext {
1589            ir_module,
1590            ir_function,
1591            fun_info: info,
1592            function: &mut function,
1593            // Re-use the cached expression table from prior functions.
1594            cached: core::mem::take(&mut self.saved_cached),
1595
1596            // Steal the Writer's temp list for a bit.
1597            temp_list: core::mem::take(&mut self.temp_list),
1598            force_loop_bounding: self.force_loop_bounding,
1599            writer: self,
1600            expression_constness: super::ExpressionConstnessTracker::from_arena(
1601                &ir_function.expressions,
1602            ),
1603            ray_query_tracker_expr: crate::FastHashMap::default(),
1604        };
1605
1606        // fill up the pre-emitted and const expressions
1607        context.cached.reset(ir_function.expressions.len());
1608        for (handle, expr) in ir_function.expressions.iter() {
1609            if (expr.needs_pre_emit() && !matches!(*expr, crate::Expression::LocalVariable(_)))
1610                || context.expression_constness.is_const(handle)
1611            {
1612                context.cache_expression_value(handle, &mut prelude)?;
1613            }
1614        }
1615
1616        for (handle, variable) in ir_function.local_variables.iter() {
1617            let id = context.gen_id();
1618
1619            if context.writer.flags.contains(WriterFlags::DEBUG) {
1620                if let Some(ref name) = variable.name {
1621                    context.writer.debugs.push(Instruction::name(id, name));
1622                }
1623            }
1624
1625            let init_word = variable.init.map(|constant| context.cached[constant]);
1626            let pointer_type_id = context
1627                .writer
1628                .get_handle_pointer_type_id(variable.ty, spirv::StorageClass::Function);
1629            let instruction = Instruction::variable(
1630                pointer_type_id,
1631                id,
1632                spirv::StorageClass::Function,
1633                init_word.or_else(|| match ir_module.types[variable.ty].inner {
1634                    crate::TypeInner::RayQuery { .. } => None,
1635                    _ => {
1636                        let type_id = context.get_handle_type_id(variable.ty);
1637                        Some(context.writer.write_constant_null(type_id))
1638                    }
1639                }),
1640            );
1641
1642            context
1643                .function
1644                .variables
1645                .insert(handle, LocalVariable { id, instruction });
1646
1647            if let crate::TypeInner::RayQuery { .. } = ir_module.types[variable.ty].inner {
1648                // Don't refactor this into a struct: Although spirv itself allows opaque types in structs,
1649                // the vulkan environment for spirv does not. Putting ray queries into structs can cause
1650                // confusing bugs.
1651                let u32_type_id = context.writer.get_u32_type_id();
1652                let ptr_u32_type_id = context
1653                    .writer
1654                    .get_pointer_type_id(u32_type_id, spirv::StorageClass::Function);
1655                let tracker_id = context.gen_id();
1656                let tracker_init_id = context.writer.get_constant_scalar(crate::Literal::U32(
1657                    crate::back::RayQueryPoint::empty().bits(),
1658                ));
1659                let tracker_instruction = Instruction::variable(
1660                    ptr_u32_type_id,
1661                    tracker_id,
1662                    spirv::StorageClass::Function,
1663                    Some(tracker_init_id),
1664                );
1665
1666                context
1667                    .function
1668                    .ray_query_initialization_tracker_variables
1669                    .insert(
1670                        handle,
1671                        LocalVariable {
1672                            id: tracker_id,
1673                            instruction: tracker_instruction,
1674                        },
1675                    );
1676                let f32_type_id = context.writer.get_f32_type_id();
1677                let ptr_f32_type_id = context
1678                    .writer
1679                    .get_pointer_type_id(f32_type_id, spirv::StorageClass::Function);
1680                let t_max_tracker_id = context.gen_id();
1681                let t_max_tracker_init_id =
1682                    context.writer.get_constant_scalar(crate::Literal::F32(0.0));
1683                let t_max_tracker_instruction = Instruction::variable(
1684                    ptr_f32_type_id,
1685                    t_max_tracker_id,
1686                    spirv::StorageClass::Function,
1687                    Some(t_max_tracker_init_id),
1688                );
1689
1690                context.function.ray_query_t_max_tracker_variables.insert(
1691                    handle,
1692                    LocalVariable {
1693                        id: t_max_tracker_id,
1694                        instruction: t_max_tracker_instruction,
1695                    },
1696                );
1697            }
1698        }
1699
1700        for (handle, expr) in ir_function.expressions.iter() {
1701            match *expr {
1702                crate::Expression::LocalVariable(_) => {
1703                    // Cache the `OpVariable` instruction we generated above as
1704                    // the value of this expression.
1705                    context.cache_expression_value(handle, &mut prelude)?;
1706                }
1707                crate::Expression::Access { base, .. }
1708                | crate::Expression::AccessIndex { base, .. } => {
1709                    // Count references to `base` by `Access` and `AccessIndex`
1710                    // instructions. See `access_uses` for details.
1711                    *context.function.access_uses.entry(base).or_insert(0) += 1;
1712                }
1713                _ => {}
1714            }
1715        }
1716
1717        let next_id = context.gen_id();
1718
1719        context
1720            .function
1721            .consume(prelude, Instruction::branch(next_id));
1722
1723        let workgroup_vars_init_exit_block_id =
1724            match (context.writer.zero_initialize_workgroup_memory, interface) {
1725                (
1726                    super::ZeroInitializeWorkgroupMemoryMode::Polyfill,
1727                    Some(
1728                        ref mut interface @ FunctionInterface {
1729                            stage:
1730                                crate::ShaderStage::Compute
1731                                | crate::ShaderStage::Mesh
1732                                | crate::ShaderStage::Task,
1733                            ..
1734                        },
1735                    ),
1736                ) => context.writer.generate_workgroup_vars_init_block(
1737                    next_id,
1738                    ir_module,
1739                    info,
1740                    local_invocation_index_id,
1741                    interface,
1742                    context.function,
1743                ),
1744                _ => None,
1745            };
1746
1747        let main_id = if let Some(exit_id) = workgroup_vars_init_exit_block_id {
1748            exit_id
1749        } else {
1750            next_id
1751        };
1752
1753        context.write_function_body(main_id, debug_info.as_ref())?;
1754
1755        // Consume the `BlockContext`, ending its borrows and letting the
1756        // `Writer` steal back its cached expression table and temp_list.
1757        let BlockContext {
1758            cached, temp_list, ..
1759        } = context;
1760        self.saved_cached = cached;
1761        self.temp_list = temp_list;
1762
1763        function.to_words(&mut self.logical_layout.function_definitions);
1764
1765        if let Some(EntryPointContext {
1766            mesh_state: Some(ref mesh_state),
1767            ..
1768        }) = function.entry_point_context
1769        {
1770            self.write_mesh_shader_wrapper(mesh_state, function_id)
1771        } else if let Some(EntryPointContext {
1772            task_payload_variable_id: Some(tp),
1773            ..
1774        }) = function.entry_point_context
1775        {
1776            self.write_task_shader_wrapper(tp, function_id)
1777        } else {
1778            Ok(function_id)
1779        }
1780    }
1781
1782    fn write_execution_mode(
1783        &mut self,
1784        function_id: Word,
1785        mode: spirv::ExecutionMode,
1786    ) -> Result<(), Error> {
1787        //self.check(mode.required_capabilities())?;
1788        Instruction::execution_mode(function_id, mode, &[])
1789            .to_words(&mut self.logical_layout.execution_modes);
1790        Ok(())
1791    }
1792
1793    // TODO Move to instructions module
1794    fn write_entry_point(
1795        &mut self,
1796        entry_point: &crate::EntryPoint,
1797        info: &FunctionInfo,
1798        ir_module: &crate::Module,
1799        debug_info: &Option<DebugInfoInner>,
1800    ) -> Result<Instruction, Error> {
1801        let mut interface_ids = Vec::new();
1802
1803        let function_id = self.write_function(
1804            &entry_point.function,
1805            info,
1806            ir_module,
1807            Some(FunctionInterface {
1808                varying_ids: &mut interface_ids,
1809                stage: entry_point.stage,
1810                task_payload: entry_point.task_payload,
1811                mesh_info: entry_point.mesh_info.clone(),
1812                workgroup_size: entry_point.workgroup_size,
1813            }),
1814            debug_info,
1815        )?;
1816
1817        let exec_model = match entry_point.stage {
1818            crate::ShaderStage::Vertex => spirv::ExecutionModel::Vertex,
1819            crate::ShaderStage::Fragment => {
1820                self.write_execution_mode(function_id, spirv::ExecutionMode::OriginUpperLeft)?;
1821                match entry_point.early_depth_test {
1822                    Some(crate::EarlyDepthTest::Force) => {
1823                        self.write_execution_mode(
1824                            function_id,
1825                            spirv::ExecutionMode::EarlyFragmentTests,
1826                        )?;
1827                    }
1828                    Some(crate::EarlyDepthTest::Allow { conservative }) => {
1829                        // TODO: Consider emitting EarlyAndLateFragmentTestsAMD here, if available.
1830                        // https://github.khronos.org/SPIRV-Registry/extensions/AMD/SPV_AMD_shader_early_and_late_fragment_tests.html
1831                        // This permits early depth tests even if the shader writes to a storage
1832                        // binding
1833                        match conservative {
1834                            crate::ConservativeDepth::GreaterEqual => self.write_execution_mode(
1835                                function_id,
1836                                spirv::ExecutionMode::DepthGreater,
1837                            )?,
1838                            crate::ConservativeDepth::LessEqual => self.write_execution_mode(
1839                                function_id,
1840                                spirv::ExecutionMode::DepthLess,
1841                            )?,
1842                            crate::ConservativeDepth::Unchanged => self.write_execution_mode(
1843                                function_id,
1844                                spirv::ExecutionMode::DepthUnchanged,
1845                            )?,
1846                        }
1847                    }
1848                    None => {}
1849                }
1850                if let Some(ref result) = entry_point.function.result {
1851                    if contains_builtin(
1852                        result.binding.as_ref(),
1853                        result.ty,
1854                        &ir_module.types,
1855                        crate::BuiltIn::FragDepth,
1856                    ) {
1857                        self.write_execution_mode(
1858                            function_id,
1859                            spirv::ExecutionMode::DepthReplacing,
1860                        )?;
1861                    }
1862                }
1863                spirv::ExecutionModel::Fragment
1864            }
1865            crate::ShaderStage::Compute => {
1866                let execution_mode = spirv::ExecutionMode::LocalSize;
1867                Instruction::execution_mode(
1868                    function_id,
1869                    execution_mode,
1870                    &entry_point.workgroup_size,
1871                )
1872                .to_words(&mut self.logical_layout.execution_modes);
1873                spirv::ExecutionModel::GLCompute
1874            }
1875            crate::ShaderStage::Task => {
1876                let execution_mode = spirv::ExecutionMode::LocalSize;
1877                Instruction::execution_mode(
1878                    function_id,
1879                    execution_mode,
1880                    &entry_point.workgroup_size,
1881                )
1882                .to_words(&mut self.logical_layout.execution_modes);
1883                spirv::ExecutionModel::TaskEXT
1884            }
1885            crate::ShaderStage::Mesh => {
1886                let execution_mode = spirv::ExecutionMode::LocalSize;
1887                Instruction::execution_mode(
1888                    function_id,
1889                    execution_mode,
1890                    &entry_point.workgroup_size,
1891                )
1892                .to_words(&mut self.logical_layout.execution_modes);
1893                let mesh_info = entry_point.mesh_info.as_ref().unwrap();
1894                Instruction::execution_mode(
1895                    function_id,
1896                    match mesh_info.topology {
1897                        crate::MeshOutputTopology::Points => spirv::ExecutionMode::OutputPoints,
1898                        crate::MeshOutputTopology::Lines => spirv::ExecutionMode::OutputLinesEXT,
1899                        crate::MeshOutputTopology::Triangles => {
1900                            spirv::ExecutionMode::OutputTrianglesEXT
1901                        }
1902                    },
1903                    &[],
1904                )
1905                .to_words(&mut self.logical_layout.execution_modes);
1906                Instruction::execution_mode(
1907                    function_id,
1908                    spirv::ExecutionMode::OutputVertices,
1909                    core::slice::from_ref(&mesh_info.max_vertices),
1910                )
1911                .to_words(&mut self.logical_layout.execution_modes);
1912                Instruction::execution_mode(
1913                    function_id,
1914                    spirv::ExecutionMode::OutputPrimitivesEXT,
1915                    core::slice::from_ref(&mesh_info.max_primitives),
1916                )
1917                .to_words(&mut self.logical_layout.execution_modes);
1918                spirv::ExecutionModel::MeshEXT
1919            }
1920            crate::ShaderStage::RayGeneration => {
1921                self.require_any("ray tracing pipelines", &[spirv::Capability::RayTracingKHR])?;
1922                spirv::ExecutionModel::RayGenerationKHR
1923            }
1924            crate::ShaderStage::AnyHit => {
1925                self.require_any("ray tracing pipelines", &[spirv::Capability::RayTracingKHR])?;
1926                spirv::ExecutionModel::AnyHitKHR
1927            }
1928            crate::ShaderStage::ClosestHit => {
1929                self.require_any("ray tracing pipelines", &[spirv::Capability::RayTracingKHR])?;
1930                spirv::ExecutionModel::ClosestHitKHR
1931            }
1932            crate::ShaderStage::Miss => {
1933                self.require_any("ray tracing pipelines", &[spirv::Capability::RayTracingKHR])?;
1934                spirv::ExecutionModel::MissKHR
1935            }
1936        };
1937        //self.check(exec_model.required_capabilities())?;
1938
1939        Ok(Instruction::entry_point(
1940            exec_model,
1941            function_id,
1942            &entry_point.name,
1943            interface_ids.as_slice(),
1944        ))
1945    }
1946
1947    fn make_scalar(&mut self, id: Word, scalar: crate::Scalar) -> Instruction {
1948        use crate::ScalarKind as Sk;
1949
1950        let bits = (scalar.width * BITS_PER_BYTE) as u32;
1951        match scalar.kind {
1952            Sk::Sint | Sk::Uint => {
1953                let signedness = if scalar.kind == Sk::Sint {
1954                    super::instructions::Signedness::Signed
1955                } else {
1956                    super::instructions::Signedness::Unsigned
1957                };
1958                let cap = match bits {
1959                    8 => Some(spirv::Capability::Int8),
1960                    16 => Some(spirv::Capability::Int16),
1961                    64 => Some(spirv::Capability::Int64),
1962                    _ => None,
1963                };
1964                if let Some(cap) = cap {
1965                    self.capabilities_used.insert(cap);
1966                }
1967                if bits == 16 {
1968                    self.capabilities_used
1969                        .insert(spirv::Capability::StorageBuffer16BitAccess);
1970                    self.capabilities_used
1971                        .insert(spirv::Capability::UniformAndStorageBuffer16BitAccess);
1972                    if self.use_storage_input_output_16 {
1973                        self.capabilities_used
1974                            .insert(spirv::Capability::StorageInputOutput16);
1975                    }
1976                }
1977                Instruction::type_int(id, bits, signedness)
1978            }
1979            Sk::Float => {
1980                if bits == 64 {
1981                    self.capabilities_used.insert(spirv::Capability::Float64);
1982                }
1983                if bits == 16 {
1984                    self.capabilities_used.insert(spirv::Capability::Float16);
1985                    self.capabilities_used
1986                        .insert(spirv::Capability::StorageBuffer16BitAccess);
1987                    self.capabilities_used
1988                        .insert(spirv::Capability::UniformAndStorageBuffer16BitAccess);
1989                    if self.use_storage_input_output_16 {
1990                        self.capabilities_used
1991                            .insert(spirv::Capability::StorageInputOutput16);
1992                    }
1993                }
1994                Instruction::type_float(id, bits)
1995            }
1996            Sk::Bool => Instruction::type_bool(id),
1997            Sk::AbstractInt | Sk::AbstractFloat => {
1998                unreachable!("abstract types should never reach the backend");
1999            }
2000        }
2001    }
2002
2003    fn request_type_capabilities(&mut self, inner: &crate::TypeInner) -> Result<(), Error> {
2004        match *inner {
2005            crate::TypeInner::Image {
2006                dim,
2007                arrayed,
2008                class,
2009            } => {
2010                let sampled = match class {
2011                    crate::ImageClass::Sampled { .. } => true,
2012                    crate::ImageClass::Depth { .. } => true,
2013                    crate::ImageClass::Storage { format, .. } => {
2014                        self.request_image_format_capabilities(format.into())?;
2015                        false
2016                    }
2017                    crate::ImageClass::External => unimplemented!(),
2018                };
2019
2020                match dim {
2021                    crate::ImageDimension::D1 => {
2022                        if sampled {
2023                            self.require_any("sampled 1D images", &[spirv::Capability::Sampled1D])?;
2024                        } else {
2025                            self.require_any("1D storage images", &[spirv::Capability::Image1D])?;
2026                        }
2027                    }
2028                    crate::ImageDimension::Cube if arrayed => {
2029                        if sampled {
2030                            self.require_any(
2031                                "sampled cube array images",
2032                                &[spirv::Capability::SampledCubeArray],
2033                            )?;
2034                        } else {
2035                            self.require_any(
2036                                "cube array storage images",
2037                                &[spirv::Capability::ImageCubeArray],
2038                            )?;
2039                        }
2040                    }
2041                    _ => {}
2042                }
2043            }
2044            crate::TypeInner::AccelerationStructure { .. } => {
2045                self.require_any(
2046                    "Acceleration Structure",
2047                    // unless we use this conditional, the ray query snapshot
2048                    // tests pick the wrong capability
2049                    &[if self.has_ray_tracing_pipeline {
2050                        spirv::Capability::RayTracingKHR
2051                    } else {
2052                        spirv::Capability::RayQueryKHR
2053                    }],
2054                )?;
2055            }
2056            crate::TypeInner::RayQuery { .. } => {
2057                self.require_any("Ray Query", &[spirv::Capability::RayQueryKHR])?;
2058            }
2059            crate::TypeInner::Atomic(crate::Scalar { width: 8, kind: _ }) => {
2060                self.require_any("64 bit integer atomics", &[spirv::Capability::Int64Atomics])?;
2061            }
2062            crate::TypeInner::Atomic(crate::Scalar {
2063                width: 4,
2064                kind: crate::ScalarKind::Float,
2065            }) => {
2066                self.require_any(
2067                    "32 bit floating-point atomics",
2068                    &[spirv::Capability::AtomicFloat32AddEXT],
2069                )?;
2070                self.use_extension("SPV_EXT_shader_atomic_float_add");
2071            }
2072            // 16 bit floating-point support requires Float16 capability
2073            crate::TypeInner::Matrix {
2074                scalar: crate::Scalar::F16,
2075                ..
2076            }
2077            | crate::TypeInner::Vector {
2078                scalar: crate::Scalar::F16,
2079                ..
2080            }
2081            | crate::TypeInner::Scalar(crate::Scalar::F16) => {
2082                self.require_any("16 bit floating-point", &[spirv::Capability::Float16])?;
2083                self.use_extension("SPV_KHR_16bit_storage");
2084            }
2085            // 16 bit integer support requires Int16 capability
2086            crate::TypeInner::Vector {
2087                scalar:
2088                    crate::Scalar {
2089                        kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
2090                        width: 2,
2091                    },
2092                ..
2093            }
2094            | crate::TypeInner::Scalar(crate::Scalar {
2095                kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
2096                width: 2,
2097            }) => {
2098                self.require_any("16 bit integer", &[spirv::Capability::Int16])?;
2099                self.use_extension("SPV_KHR_16bit_storage");
2100            }
2101            // Cooperative types and ops
2102            crate::TypeInner::CooperativeMatrix { .. } => {
2103                self.require_any(
2104                    "cooperative matrix",
2105                    &[spirv::Capability::CooperativeMatrixKHR],
2106                )?;
2107                self.require_any("memory model", &[spirv::Capability::VulkanMemoryModel])?;
2108                self.use_extension("SPV_KHR_cooperative_matrix");
2109                self.use_extension("SPV_KHR_vulkan_memory_model");
2110            }
2111            _ => {}
2112        }
2113        Ok(())
2114    }
2115
2116    fn write_numeric_type_declaration_local(&mut self, id: Word, numeric: NumericType) {
2117        let instruction = match numeric {
2118            NumericType::Scalar(scalar) => self.make_scalar(id, scalar),
2119            NumericType::Vector { size, scalar } => {
2120                let scalar_id = self.get_numeric_type_id(NumericType::Scalar(scalar));
2121                Instruction::type_vector(id, scalar_id, size)
2122            }
2123            NumericType::Matrix {
2124                columns,
2125                rows,
2126                scalar,
2127            } => {
2128                let column_id =
2129                    self.get_numeric_type_id(NumericType::Vector { size: rows, scalar });
2130                Instruction::type_matrix(id, column_id, columns)
2131            }
2132        };
2133
2134        instruction.to_words(&mut self.logical_layout.declarations);
2135    }
2136
2137    fn write_cooperative_type_declaration_local(&mut self, id: Word, coop: CooperativeType) {
2138        let instruction = match coop {
2139            CooperativeType::Matrix {
2140                columns,
2141                rows,
2142                scalar,
2143                role,
2144            } => {
2145                let scalar_id =
2146                    self.get_localtype_id(LocalType::Numeric(NumericType::Scalar(scalar)));
2147                let scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32);
2148                let columns_id = self.get_index_constant(columns as u32);
2149                let rows_id = self.get_index_constant(rows as u32);
2150                let role_id =
2151                    self.get_index_constant(spirv::CooperativeMatrixUse::from(role) as u32);
2152                Instruction::type_coop_matrix(id, scalar_id, scope_id, rows_id, columns_id, role_id)
2153            }
2154        };
2155
2156        instruction.to_words(&mut self.logical_layout.declarations);
2157    }
2158
2159    fn write_type_declaration_local(&mut self, id: Word, local_ty: LocalType) {
2160        let instruction = match local_ty {
2161            LocalType::Numeric(numeric) => {
2162                self.write_numeric_type_declaration_local(id, numeric);
2163                return;
2164            }
2165            LocalType::Cooperative(coop) => {
2166                self.write_cooperative_type_declaration_local(id, coop);
2167                return;
2168            }
2169            LocalType::Pointer { base, class } => Instruction::type_pointer(id, class, base),
2170            LocalType::Image(image) => {
2171                let local_type = LocalType::Numeric(NumericType::Scalar(image.sampled_type));
2172                let type_id = self.get_localtype_id(local_type);
2173                Instruction::type_image(id, type_id, image.dim, image.flags, image.image_format)
2174            }
2175            LocalType::Sampler => Instruction::type_sampler(id),
2176            LocalType::SampledImage { image_type_id } => {
2177                Instruction::type_sampled_image(id, image_type_id)
2178            }
2179            LocalType::BindingArray { base, size } => {
2180                let inner_ty = self.get_handle_type_id(base);
2181                let scalar_id = self.get_constant_scalar(crate::Literal::U32(size));
2182                Instruction::type_array(id, inner_ty, scalar_id)
2183            }
2184            LocalType::AccelerationStructure => Instruction::type_acceleration_structure(id),
2185            LocalType::RayQuery => Instruction::type_ray_query(id),
2186        };
2187
2188        instruction.to_words(&mut self.logical_layout.declarations);
2189    }
2190
2191    fn write_type_declaration_arena(
2192        &mut self,
2193        module: &crate::Module,
2194        handle: Handle<crate::Type>,
2195    ) -> Result<Word, Error> {
2196        let ty = &module.types[handle];
2197        // If it's a type that needs SPIR-V capabilities, request them now.
2198        // This needs to happen regardless of the LocalType lookup succeeding,
2199        // because some types which map to the same LocalType have different
2200        // capability requirements. See https://github.com/gfx-rs/wgpu/issues/5569
2201        self.request_type_capabilities(&ty.inner)?;
2202        let id = if let Some(local) = self.localtype_from_inner(&ty.inner) {
2203            // This type can be represented as a `LocalType`, so check if we've
2204            // already written an instruction for it. If not, do so now, with
2205            // `write_type_declaration_local`.
2206            match self.lookup_type.entry(LookupType::Local(local)) {
2207                // We already have an id for this `LocalType`.
2208                Entry::Occupied(e) => *e.get(),
2209
2210                // It's a type we haven't seen before.
2211                Entry::Vacant(e) => {
2212                    let id = self.id_gen.next();
2213                    e.insert(id);
2214
2215                    self.write_type_declaration_local(id, local);
2216
2217                    id
2218                }
2219            }
2220        } else {
2221            use spirv::Decoration;
2222
2223            let id = self.id_gen.next();
2224            let instruction = match ty.inner {
2225                crate::TypeInner::Array { base, size, stride } => {
2226                    self.decorate(id, Decoration::ArrayStride, &[stride]);
2227
2228                    let type_id = self.get_handle_type_id(base);
2229                    match size.resolve(module.to_ctx())? {
2230                        crate::proc::IndexableLength::Known(length) => {
2231                            let length_id = self.get_index_constant(length);
2232                            Instruction::type_array(id, type_id, length_id)
2233                        }
2234                        crate::proc::IndexableLength::Dynamic => {
2235                            Instruction::type_runtime_array(id, type_id)
2236                        }
2237                    }
2238                }
2239                crate::TypeInner::BindingArray { base, size } => {
2240                    let type_id = self.get_handle_type_id(base);
2241                    match size.resolve(module.to_ctx())? {
2242                        crate::proc::IndexableLength::Known(length) => {
2243                            let length_id = self.get_index_constant(length);
2244                            Instruction::type_array(id, type_id, length_id)
2245                        }
2246                        crate::proc::IndexableLength::Dynamic => {
2247                            Instruction::type_runtime_array(id, type_id)
2248                        }
2249                    }
2250                }
2251                crate::TypeInner::Struct {
2252                    ref members,
2253                    span: _,
2254                } => {
2255                    let mut has_runtime_array = false;
2256                    let mut member_ids = Vec::with_capacity(members.len());
2257                    for (index, member) in members.iter().enumerate() {
2258                        let member_ty = &module.types[member.ty];
2259                        match member_ty.inner {
2260                            crate::TypeInner::Array {
2261                                base: _,
2262                                size: crate::ArraySize::Dynamic,
2263                                stride: _,
2264                            } => {
2265                                has_runtime_array = true;
2266                            }
2267                            _ => (),
2268                        }
2269                        self.decorate_struct_member(id, index, member, &module.types)?;
2270                        let member_id = self.get_handle_type_id(member.ty);
2271                        member_ids.push(member_id);
2272                    }
2273                    if has_runtime_array {
2274                        self.decorate(id, Decoration::Block, &[]);
2275                    }
2276                    Instruction::type_struct(id, member_ids.as_slice())
2277                }
2278
2279                // These all have TypeLocal representations, so they should have been
2280                // handled by `write_type_declaration_local` above.
2281                crate::TypeInner::Scalar(_)
2282                | crate::TypeInner::Atomic(_)
2283                | crate::TypeInner::Vector { .. }
2284                | crate::TypeInner::Matrix { .. }
2285                | crate::TypeInner::CooperativeMatrix { .. }
2286                | crate::TypeInner::Pointer { .. }
2287                | crate::TypeInner::ValuePointer { .. }
2288                | crate::TypeInner::Image { .. }
2289                | crate::TypeInner::Sampler { .. }
2290                | crate::TypeInner::AccelerationStructure { .. }
2291                | crate::TypeInner::RayQuery { .. } => unreachable!(),
2292            };
2293
2294            instruction.to_words(&mut self.logical_layout.declarations);
2295            id
2296        };
2297
2298        // Add this handle as a new alias for that type.
2299        self.lookup_type.insert(LookupType::Handle(handle), id);
2300
2301        if self.flags.contains(WriterFlags::DEBUG) {
2302            if let Some(ref name) = ty.name {
2303                self.debugs.push(Instruction::name(id, name));
2304            }
2305        }
2306
2307        Ok(id)
2308    }
2309
2310    /// Writes a std140 layout compatible type declaration for a type. Returns
2311    /// the ID of the declared type, or None if no declaration is required.
2312    ///
2313    /// This should be called for any type for which there exists a
2314    /// [`GlobalVariable`] in the [`Uniform`] address space. If the type already
2315    /// adheres to std140 layout rules it will return without declaring any
2316    /// types. If the type contains another type which requires a std140
2317    /// compatible type declaration, it will recursively call itself.
2318    ///
2319    /// When `handle` refers to a [`TypeInner::Matrix`] with 2 rows, the
2320    /// declared type will be an `OpTypeStruct` containing an `OpVector` for
2321    /// each of the matrix's columns.
2322    ///
2323    /// When `handle` refers to a [`TypeInner::Array`] whose base type is a
2324    /// matrix with 2 rows, this will declare an `OpTypeArray` whose element
2325    /// type is the matrix's corresponding std140 compatible type.
2326    ///
2327    /// When `handle` refers to a [`TypeInner::Struct`] and any of its members
2328    /// require a std140 compatible type declaration, this will declare a new
2329    /// struct with the following rules:
2330    /// * Struct or array members will be declared with their std140 compatible
2331    ///   type declaration, if one is required.
2332    /// * Two-row matrix members will have each of their columns hoisted
2333    ///   directly into the struct as 2-component vector members.
2334    /// * All other members will be declared with their normal type.
2335    ///
2336    /// Note that this means the Naga IR index of a struct member may not match
2337    /// the index in the generated SPIR-V. The mapping can be obtained via
2338    /// `Std140TypeInfo::member_indices`.
2339    ///
2340    /// [`GlobalVariable`]: crate::GlobalVariable
2341    /// [`Uniform`]: crate::AddressSpace::Uniform
2342    /// [`TypeInner::Matrix`]: crate::TypeInner::Matrix
2343    /// [`TypeInner::Array`]: crate::TypeInner::Array
2344    /// [`TypeInner::Struct`]: crate::TypeInner::Struct
2345    fn write_std140_compat_type_declaration(
2346        &mut self,
2347        module: &crate::Module,
2348        handle: Handle<crate::Type>,
2349    ) -> Result<Option<Word>, Error> {
2350        if let Some(std140_type_info) = self.std140_compat_uniform_types.get(&handle) {
2351            return Ok(Some(std140_type_info.type_id));
2352        }
2353
2354        let type_inner = &module.types[handle].inner;
2355        let std140_type_id = match *type_inner {
2356            crate::TypeInner::Matrix {
2357                columns,
2358                rows: rows @ crate::VectorSize::Bi,
2359                scalar,
2360            } => {
2361                let std140_type_id = self.id_gen.next();
2362                let mut member_type_ids: ArrayVec<Word, 4> = ArrayVec::new();
2363                let column_type_id =
2364                    self.get_numeric_type_id(NumericType::Vector { size: rows, scalar });
2365                for column in 0..columns as u32 {
2366                    member_type_ids.push(column_type_id);
2367                    self.annotations.push(Instruction::member_decorate(
2368                        std140_type_id,
2369                        column,
2370                        spirv::Decoration::Offset,
2371                        &[column * rows as u32 * scalar.width as u32],
2372                    ));
2373                    if self.flags.contains(WriterFlags::DEBUG) {
2374                        self.debugs.push(Instruction::member_name(
2375                            std140_type_id,
2376                            column,
2377                            &format!("col{column}"),
2378                        ));
2379                    }
2380                }
2381                Instruction::type_struct(std140_type_id, &member_type_ids)
2382                    .to_words(&mut self.logical_layout.declarations);
2383                self.std140_compat_uniform_types.insert(
2384                    handle,
2385                    Std140CompatTypeInfo {
2386                        type_id: std140_type_id,
2387                        member_indices: Vec::new(),
2388                    },
2389                );
2390                Some(std140_type_id)
2391            }
2392            crate::TypeInner::Array { base, size, stride } => {
2393                match self.write_std140_compat_type_declaration(module, base)? {
2394                    Some(std140_base_type_id) => {
2395                        let std140_type_id = self.id_gen.next();
2396                        self.decorate(std140_type_id, spirv::Decoration::ArrayStride, &[stride]);
2397                        let instruction = match size.resolve(module.to_ctx())? {
2398                            crate::proc::IndexableLength::Known(length) => {
2399                                let length_id = self.get_index_constant(length);
2400                                Instruction::type_array(
2401                                    std140_type_id,
2402                                    std140_base_type_id,
2403                                    length_id,
2404                                )
2405                            }
2406                            crate::proc::IndexableLength::Dynamic => {
2407                                unreachable!()
2408                            }
2409                        };
2410                        instruction.to_words(&mut self.logical_layout.declarations);
2411                        self.std140_compat_uniform_types.insert(
2412                            handle,
2413                            Std140CompatTypeInfo {
2414                                type_id: std140_type_id,
2415                                member_indices: Vec::new(),
2416                            },
2417                        );
2418                        Some(std140_type_id)
2419                    }
2420                    None => None,
2421                }
2422            }
2423            crate::TypeInner::Struct { ref members, .. } => {
2424                let mut needs_std140_type = false;
2425                for member in members {
2426                    match module.types[member.ty].inner {
2427                        // We don't need to write a std140 type for the matrix itself as
2428                        // it will be decomposed into the parent struct. As a result, the
2429                        // struct does need a std140 type, however.
2430                        crate::TypeInner::Matrix {
2431                            rows: crate::VectorSize::Bi,
2432                            ..
2433                        } => needs_std140_type = true,
2434                        // If an array member needs a std140 type, because it is an array
2435                        // (of an array, etc) of `matCx2`s, then the struct also needs
2436                        // a std140 type which uses the std140 type for this member.
2437                        crate::TypeInner::Array { .. }
2438                            if self
2439                                .write_std140_compat_type_declaration(module, member.ty)?
2440                                .is_some() =>
2441                        {
2442                            needs_std140_type = true;
2443                        }
2444                        _ => {}
2445                    }
2446                }
2447
2448                if needs_std140_type {
2449                    let std140_type_id = self.id_gen.next();
2450                    let mut member_ids = Vec::new();
2451                    let mut member_indices = Vec::new();
2452                    let mut next_index = 0;
2453
2454                    for member in members {
2455                        member_indices.push(next_index);
2456                        match module.types[member.ty].inner {
2457                            crate::TypeInner::Matrix {
2458                                columns,
2459                                rows: rows @ crate::VectorSize::Bi,
2460                                scalar,
2461                            } => {
2462                                let vector_type_id =
2463                                    self.get_numeric_type_id(NumericType::Vector {
2464                                        size: rows,
2465                                        scalar,
2466                                    });
2467                                for column in 0..columns as u32 {
2468                                    self.annotations.push(Instruction::member_decorate(
2469                                        std140_type_id,
2470                                        next_index,
2471                                        spirv::Decoration::Offset,
2472                                        &[member.offset
2473                                            + column * rows as u32 * scalar.width as u32],
2474                                    ));
2475                                    if self.flags.contains(WriterFlags::DEBUG) {
2476                                        if let Some(ref name) = member.name {
2477                                            self.debugs.push(Instruction::member_name(
2478                                                std140_type_id,
2479                                                next_index,
2480                                                &format!("{name}_col{column}"),
2481                                            ));
2482                                        }
2483                                    }
2484                                    member_ids.push(vector_type_id);
2485                                    next_index += 1;
2486                                }
2487                            }
2488                            _ => {
2489                                let member_id =
2490                                    match self.std140_compat_uniform_types.get(&member.ty) {
2491                                        Some(std140_member_type_info) => {
2492                                            self.annotations.push(Instruction::member_decorate(
2493                                                std140_type_id,
2494                                                next_index,
2495                                                spirv::Decoration::Offset,
2496                                                &[member.offset],
2497                                            ));
2498                                            if self.flags.contains(WriterFlags::DEBUG) {
2499                                                if let Some(ref name) = member.name {
2500                                                    self.debugs.push(Instruction::member_name(
2501                                                        std140_type_id,
2502                                                        next_index,
2503                                                        name,
2504                                                    ));
2505                                                }
2506                                            }
2507                                            std140_member_type_info.type_id
2508                                        }
2509                                        None => {
2510                                            self.decorate_struct_member(
2511                                                std140_type_id,
2512                                                next_index as usize,
2513                                                member,
2514                                                &module.types,
2515                                            )?;
2516                                            self.get_handle_type_id(member.ty)
2517                                        }
2518                                    };
2519                                member_ids.push(member_id);
2520                                next_index += 1;
2521                            }
2522                        }
2523                    }
2524
2525                    Instruction::type_struct(std140_type_id, &member_ids)
2526                        .to_words(&mut self.logical_layout.declarations);
2527                    self.std140_compat_uniform_types.insert(
2528                        handle,
2529                        Std140CompatTypeInfo {
2530                            type_id: std140_type_id,
2531                            member_indices,
2532                        },
2533                    );
2534                    Some(std140_type_id)
2535                } else {
2536                    None
2537                }
2538            }
2539            _ => None,
2540        };
2541
2542        if let Some(std140_type_id) = std140_type_id {
2543            if self.flags.contains(WriterFlags::DEBUG) {
2544                let name = format!("std140_{:?}", handle.for_debug(&module.types));
2545                self.debugs.push(Instruction::name(std140_type_id, &name));
2546            }
2547        }
2548        Ok(std140_type_id)
2549    }
2550
2551    fn request_image_format_capabilities(
2552        &mut self,
2553        format: spirv::ImageFormat,
2554    ) -> Result<(), Error> {
2555        use spirv::ImageFormat as If;
2556        match format {
2557            If::Rg32f
2558            | If::Rg16f
2559            | If::R11fG11fB10f
2560            | If::R16f
2561            | If::Rgba16
2562            | If::Rgb10A2
2563            | If::Rg16
2564            | If::Rg8
2565            | If::R16
2566            | If::R8
2567            | If::Rgba16Snorm
2568            | If::Rg16Snorm
2569            | If::Rg8Snorm
2570            | If::R16Snorm
2571            | If::R8Snorm
2572            | If::Rg32i
2573            | If::Rg16i
2574            | If::Rg8i
2575            | If::R16i
2576            | If::R8i
2577            | If::Rgb10a2ui
2578            | If::Rg32ui
2579            | If::Rg16ui
2580            | If::Rg8ui
2581            | If::R16ui
2582            | If::R8ui => self.require_any(
2583                "storage image format",
2584                &[spirv::Capability::StorageImageExtendedFormats],
2585            ),
2586            If::R64ui | If::R64i => {
2587                self.use_extension("SPV_EXT_shader_image_int64");
2588                self.require_any(
2589                    "64-bit integer storage image format",
2590                    &[spirv::Capability::Int64ImageEXT],
2591                )
2592            }
2593            If::Unknown
2594            | If::Rgba32f
2595            | If::Rgba16f
2596            | If::R32f
2597            | If::Rgba8
2598            | If::Rgba8Snorm
2599            | If::Rgba32i
2600            | If::Rgba16i
2601            | If::Rgba8i
2602            | If::R32i
2603            | If::Rgba32ui
2604            | If::Rgba16ui
2605            | If::Rgba8ui
2606            | If::R32ui => Ok(()),
2607        }
2608    }
2609
2610    pub(super) fn get_index_constant(&mut self, index: Word) -> Word {
2611        self.get_constant_scalar(crate::Literal::U32(index))
2612    }
2613
2614    pub(super) fn get_constant_scalar_with(
2615        &mut self,
2616        value: u8,
2617        scalar: crate::Scalar,
2618    ) -> Result<Word, Error> {
2619        Ok(
2620            self.get_constant_scalar(crate::Literal::new(value, scalar).ok_or(
2621                Error::Validation("Unexpected kind and/or width for Literal"),
2622            )?),
2623        )
2624    }
2625
2626    pub(super) fn get_constant_scalar(&mut self, value: crate::Literal) -> Word {
2627        let scalar = CachedConstant::Literal(value.into());
2628        if let Some(&id) = self.cached_constants.get(&scalar) {
2629            return id;
2630        }
2631        let id = self.id_gen.next();
2632        self.write_constant_scalar(id, &value, None);
2633        self.cached_constants.insert(scalar, id);
2634        id
2635    }
2636
2637    fn write_constant_scalar(
2638        &mut self,
2639        id: Word,
2640        value: &crate::Literal,
2641        debug_name: Option<&String>,
2642    ) {
2643        if self.flags.contains(WriterFlags::DEBUG) {
2644            if let Some(name) = debug_name {
2645                self.debugs.push(Instruction::name(id, name));
2646            }
2647        }
2648        let type_id = self.get_numeric_type_id(NumericType::Scalar(value.scalar()));
2649        let instruction = match *value {
2650            crate::Literal::F64(value) => {
2651                let bits = value.to_bits();
2652                Instruction::constant_64bit(type_id, id, bits as u32, (bits >> 32) as u32)
2653            }
2654            crate::Literal::F32(value) => Instruction::constant_32bit(type_id, id, value.to_bits()),
2655            crate::Literal::F16(value) => {
2656                let low = value.to_bits();
2657                Instruction::constant_16bit(type_id, id, low as u32)
2658            }
2659            crate::Literal::U16(value) => Instruction::constant_16bit(type_id, id, value as u32),
2660            crate::Literal::I16(value) => {
2661                // Sign-extend into the 32-bit word so that `spirv-as` can
2662                // round-trip the disassembly (it expects signed values for
2663                // signed types).
2664                Instruction::constant_16bit(type_id, id, value as i32 as u32)
2665            }
2666            crate::Literal::U32(value) => Instruction::constant_32bit(type_id, id, value),
2667            crate::Literal::I32(value) => Instruction::constant_32bit(type_id, id, value as u32),
2668            crate::Literal::U64(value) => {
2669                Instruction::constant_64bit(type_id, id, value as u32, (value >> 32) as u32)
2670            }
2671            crate::Literal::I64(value) => {
2672                Instruction::constant_64bit(type_id, id, value as u32, (value >> 32) as u32)
2673            }
2674            crate::Literal::Bool(true) => Instruction::constant_true(type_id, id),
2675            crate::Literal::Bool(false) => Instruction::constant_false(type_id, id),
2676            crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
2677                unreachable!("Abstract types should not appear in IR presented to backends");
2678            }
2679        };
2680
2681        instruction.to_words(&mut self.logical_layout.declarations);
2682    }
2683
2684    pub(super) fn get_constant_composite(
2685        &mut self,
2686        ty: LookupType,
2687        constituent_ids: &[Word],
2688    ) -> Word {
2689        let composite = CachedConstant::Composite {
2690            ty,
2691            constituent_ids: constituent_ids.to_vec(),
2692        };
2693        if let Some(&id) = self.cached_constants.get(&composite) {
2694            return id;
2695        }
2696        let id = self.id_gen.next();
2697        self.write_constant_composite(id, ty, constituent_ids, None);
2698        self.cached_constants.insert(composite, id);
2699        id
2700    }
2701
2702    fn write_constant_composite(
2703        &mut self,
2704        id: Word,
2705        ty: LookupType,
2706        constituent_ids: &[Word],
2707        debug_name: Option<&String>,
2708    ) {
2709        if self.flags.contains(WriterFlags::DEBUG) {
2710            if let Some(name) = debug_name {
2711                self.debugs.push(Instruction::name(id, name));
2712            }
2713        }
2714        let type_id = self.get_type_id(ty);
2715        Instruction::constant_composite(type_id, id, constituent_ids)
2716            .to_words(&mut self.logical_layout.declarations);
2717    }
2718
2719    pub(super) fn get_constant_null(&mut self, type_id: Word) -> Word {
2720        let null = CachedConstant::ZeroValue(type_id);
2721        if let Some(&id) = self.cached_constants.get(&null) {
2722            return id;
2723        }
2724        let id = self.write_constant_null(type_id);
2725        self.cached_constants.insert(null, id);
2726        id
2727    }
2728
2729    pub(super) fn write_constant_null(&mut self, type_id: Word) -> Word {
2730        let null_id = self.id_gen.next();
2731        Instruction::constant_null(type_id, null_id)
2732            .to_words(&mut self.logical_layout.declarations);
2733        null_id
2734    }
2735
2736    fn write_constant_expr(
2737        &mut self,
2738        handle: Handle<crate::Expression>,
2739        ir_module: &crate::Module,
2740        mod_info: &ModuleInfo,
2741    ) -> Result<Word, Error> {
2742        let id = match ir_module.global_expressions[handle] {
2743            crate::Expression::Literal(literal) => self.get_constant_scalar(literal),
2744            crate::Expression::Constant(constant) => {
2745                let constant = &ir_module.constants[constant];
2746                self.constant_ids[constant.init]
2747            }
2748            crate::Expression::ZeroValue(ty) => {
2749                let type_id = self.get_handle_type_id(ty);
2750                self.get_constant_null(type_id)
2751            }
2752            crate::Expression::Compose { ty, ref components } => {
2753                let component_ids: Vec<_> = crate::proc::flatten_compose(
2754                    ty,
2755                    components,
2756                    &ir_module.global_expressions,
2757                    &ir_module.types,
2758                )
2759                .map(|component| self.constant_ids[component])
2760                .collect();
2761                self.get_constant_composite(LookupType::Handle(ty), component_ids.as_slice())
2762            }
2763            crate::Expression::Splat { size, value } => {
2764                let value_id = self.constant_ids[value];
2765                let component_ids = &[value_id; 4][..size as usize];
2766
2767                let ty = self.get_expression_lookup_type(&mod_info[handle]);
2768
2769                self.get_constant_composite(ty, component_ids)
2770            }
2771            _ => {
2772                return Err(Error::Override);
2773            }
2774        };
2775
2776        self.constant_ids[handle] = id;
2777
2778        Ok(id)
2779    }
2780
2781    pub(super) fn write_control_barrier(
2782        &mut self,
2783        flags: crate::Barrier,
2784        body: &mut Vec<Instruction>,
2785    ) {
2786        let memory_scope = if flags.contains(crate::Barrier::STORAGE) {
2787            spirv::Scope::Device
2788        } else if flags.contains(crate::Barrier::SUB_GROUP) {
2789            spirv::Scope::Subgroup
2790        } else {
2791            spirv::Scope::Workgroup
2792        };
2793        let mut semantics = spirv::MemorySemantics::ACQUIRE_RELEASE;
2794        semantics.set(
2795            spirv::MemorySemantics::UNIFORM_MEMORY,
2796            flags.contains(crate::Barrier::STORAGE),
2797        );
2798        semantics.set(
2799            spirv::MemorySemantics::WORKGROUP_MEMORY,
2800            flags.contains(crate::Barrier::WORK_GROUP),
2801        );
2802        semantics.set(
2803            spirv::MemorySemantics::SUBGROUP_MEMORY,
2804            flags.contains(crate::Barrier::SUB_GROUP),
2805        );
2806        semantics.set(
2807            spirv::MemorySemantics::IMAGE_MEMORY,
2808            flags.contains(crate::Barrier::TEXTURE),
2809        );
2810        let exec_scope_id = if flags.contains(crate::Barrier::SUB_GROUP) {
2811            self.get_index_constant(spirv::Scope::Subgroup as u32)
2812        } else {
2813            self.get_index_constant(spirv::Scope::Workgroup as u32)
2814        };
2815        let mem_scope_id = self.get_index_constant(memory_scope as u32);
2816        let semantics_id = self.get_index_constant(semantics.bits());
2817        body.push(Instruction::control_barrier(
2818            exec_scope_id,
2819            mem_scope_id,
2820            semantics_id,
2821        ));
2822    }
2823
2824    pub(super) fn write_memory_barrier(&mut self, flags: crate::Barrier, block: &mut Block) {
2825        let mut semantics = spirv::MemorySemantics::ACQUIRE_RELEASE;
2826        semantics.set(
2827            spirv::MemorySemantics::UNIFORM_MEMORY,
2828            flags.contains(crate::Barrier::STORAGE),
2829        );
2830        semantics.set(
2831            spirv::MemorySemantics::WORKGROUP_MEMORY,
2832            flags.contains(crate::Barrier::WORK_GROUP),
2833        );
2834        semantics.set(
2835            spirv::MemorySemantics::SUBGROUP_MEMORY,
2836            flags.contains(crate::Barrier::SUB_GROUP),
2837        );
2838        semantics.set(
2839            spirv::MemorySemantics::IMAGE_MEMORY,
2840            flags.contains(crate::Barrier::TEXTURE),
2841        );
2842        let mem_scope_id = if flags.contains(crate::Barrier::STORAGE) {
2843            self.get_index_constant(spirv::Scope::Device as u32)
2844        } else if flags.contains(crate::Barrier::SUB_GROUP) {
2845            self.get_index_constant(spirv::Scope::Subgroup as u32)
2846        } else {
2847            self.get_index_constant(spirv::Scope::Workgroup as u32)
2848        };
2849        let semantics_id = self.get_index_constant(semantics.bits());
2850        block
2851            .body
2852            .push(Instruction::memory_barrier(mem_scope_id, semantics_id));
2853    }
2854
2855    fn generate_workgroup_vars_init_block(
2856        &mut self,
2857        entry_id: Word,
2858        ir_module: &crate::Module,
2859        info: &FunctionInfo,
2860        local_invocation_index: Option<Word>,
2861        interface: &mut FunctionInterface,
2862        function: &mut Function,
2863    ) -> Option<Word> {
2864        let body = ir_module
2865            .global_variables
2866            .iter()
2867            .filter(|&(handle, var)| {
2868                let task_exception = (var.space == crate::AddressSpace::TaskPayload)
2869                    && interface.stage == crate::ShaderStage::Task;
2870                !info[handle].is_empty()
2871                    && (var.space == crate::AddressSpace::WorkGroup || task_exception)
2872            })
2873            .map(|(handle, var)| {
2874                // It's safe to use `var_id` here, not `access_id`, because only
2875                // variables in the `Uniform` and `StorageBuffer` address spaces
2876                // get wrapped, and we're initializing `WorkGroup` variables.
2877                let var_id = self.global_variables[handle].var_id;
2878                let var_type_id = self.get_handle_type_id(var.ty);
2879                let init_word = self.get_constant_null(var_type_id);
2880                Instruction::store(var_id, init_word, None)
2881            })
2882            .collect::<Vec<_>>();
2883
2884        if body.is_empty() {
2885            return None;
2886        }
2887
2888        let mut pre_if_block = Block::new(entry_id);
2889
2890        let local_invocation_index = if let Some(local_invocation_index) = local_invocation_index {
2891            local_invocation_index
2892        } else {
2893            let varying_id = self.id_gen.next();
2894            let class = spirv::StorageClass::Input;
2895            let u32_ty_id = self.get_u32_type_id();
2896            let pointer_type_id = self.get_pointer_type_id(u32_ty_id, class);
2897
2898            Instruction::variable(pointer_type_id, varying_id, class, None)
2899                .to_words(&mut self.logical_layout.declarations);
2900
2901            self.decorate(
2902                varying_id,
2903                spirv::Decoration::BuiltIn,
2904                &[spirv::BuiltIn::LocalInvocationIndex as u32],
2905            );
2906
2907            interface.varying_ids.push(varying_id);
2908            let id = self.id_gen.next();
2909            pre_if_block
2910                .body
2911                .push(Instruction::load(u32_ty_id, id, varying_id, None));
2912
2913            id
2914        };
2915
2916        let zero_id = self.get_constant_scalar(crate::Literal::U32(0));
2917
2918        let eq_id = self.id_gen.next();
2919        pre_if_block.body.push(Instruction::binary(
2920            spirv::Op::IEqual,
2921            self.get_bool_type_id(),
2922            eq_id,
2923            local_invocation_index,
2924            zero_id,
2925        ));
2926
2927        let merge_id = self.id_gen.next();
2928        pre_if_block.body.push(Instruction::selection_merge(
2929            merge_id,
2930            spirv::SelectionControl::NONE,
2931        ));
2932
2933        let accept_id = self.id_gen.next();
2934        function.consume(
2935            pre_if_block,
2936            Instruction::branch_conditional(eq_id, accept_id, merge_id),
2937        );
2938
2939        let accept_block = Block {
2940            label_id: accept_id,
2941            body,
2942        };
2943        function.consume(accept_block, Instruction::branch(merge_id));
2944
2945        let mut post_if_block = Block::new(merge_id);
2946
2947        self.write_control_barrier(crate::Barrier::WORK_GROUP, &mut post_if_block.body);
2948
2949        let next_id = self.id_gen.next();
2950        function.consume(post_if_block, Instruction::branch(next_id));
2951        Some(next_id)
2952    }
2953
2954    /// Generate an `OpVariable` for one value in an [`EntryPoint`]'s IO interface.
2955    ///
2956    /// The [`Binding`]s of the arguments and result of an [`EntryPoint`]'s
2957    /// [`Function`] describe a SPIR-V shader interface. In SPIR-V, the
2958    /// interface is represented by global variables in the `Input` and `Output`
2959    /// storage classes, with decorations indicating which builtin or location
2960    /// each variable corresponds to.
2961    ///
2962    /// This function emits a single global `OpVariable` for a single value from
2963    /// the interface, and adds appropriate decorations to indicate which
2964    /// builtin or location it represents, how it should be interpolated, and so
2965    /// on. The `class` argument gives the variable's SPIR-V storage class,
2966    /// which should be either [`Input`] or [`Output`].
2967    ///
2968    /// [`Binding`]: crate::Binding
2969    /// [`Function`]: crate::Function
2970    /// [`EntryPoint`]: crate::EntryPoint
2971    /// [`Input`]: spirv::StorageClass::Input
2972    /// [`Output`]: spirv::StorageClass::Output
2973    fn write_varying(
2974        &mut self,
2975        ir_module: &crate::Module,
2976        stage: crate::ShaderStage,
2977        class: spirv::StorageClass,
2978        debug_name: Option<&str>,
2979        ty: Handle<crate::Type>,
2980        binding: &crate::Binding,
2981    ) -> Result<Word, Error> {
2982        let id = self.id_gen.next();
2983        let ty_inner = &ir_module.types[ty].inner;
2984        let needs_polyfill = self.needs_f16_polyfill(ty_inner);
2985
2986        let pointer_type_id = if needs_polyfill {
2987            let f32_value_local =
2988                super::f16_polyfill::F16IoPolyfill::create_polyfill_type(ty_inner)
2989                    .expect("needs_polyfill returned true but create_polyfill_type returned None");
2990
2991            let f32_type_id = self.get_localtype_id(f32_value_local);
2992            let ptr_id = self.get_pointer_type_id(f32_type_id, class);
2993            self.io_f16_polyfills.register_io_var(id, f32_type_id);
2994
2995            ptr_id
2996        } else {
2997            self.get_handle_pointer_type_id(ty, class)
2998        };
2999
3000        Instruction::variable(pointer_type_id, id, class, None)
3001            .to_words(&mut self.logical_layout.declarations);
3002
3003        if self
3004            .flags
3005            .contains(WriterFlags::DEBUG | WriterFlags::LABEL_VARYINGS)
3006        {
3007            if let Some(name) = debug_name {
3008                self.debugs.push(Instruction::name(id, name));
3009            }
3010        }
3011
3012        let binding = self.map_binding(ir_module, stage, class, ty, binding)?;
3013        self.write_binding(id, binding);
3014
3015        Ok(id)
3016    }
3017
3018    pub fn write_binding(&mut self, id: Word, binding: BindingDecorations) {
3019        match binding {
3020            BindingDecorations::None => (),
3021            BindingDecorations::BuiltIn(bi, others) => {
3022                self.decorate(id, spirv::Decoration::BuiltIn, &[bi as u32]);
3023                for other in others {
3024                    self.decorate(id, other, &[]);
3025                }
3026            }
3027            BindingDecorations::Location {
3028                location,
3029                others,
3030                blend_src,
3031            } => {
3032                self.decorate(id, spirv::Decoration::Location, &[location]);
3033                for other in others {
3034                    self.decorate(id, other, &[]);
3035                }
3036                if let Some(blend_src) = blend_src {
3037                    self.decorate(id, spirv::Decoration::Index, &[blend_src]);
3038                }
3039            }
3040        }
3041    }
3042
3043    pub fn write_binding_struct_member(
3044        &mut self,
3045        struct_id: Word,
3046        member_idx: Word,
3047        binding_info: BindingDecorations,
3048    ) {
3049        match binding_info {
3050            BindingDecorations::None => (),
3051            BindingDecorations::BuiltIn(bi, others) => {
3052                self.annotations.push(Instruction::member_decorate(
3053                    struct_id,
3054                    member_idx,
3055                    spirv::Decoration::BuiltIn,
3056                    &[bi as Word],
3057                ));
3058                for other in others {
3059                    self.annotations.push(Instruction::member_decorate(
3060                        struct_id,
3061                        member_idx,
3062                        other,
3063                        &[],
3064                    ));
3065                }
3066            }
3067            BindingDecorations::Location {
3068                location,
3069                others,
3070                blend_src,
3071            } => {
3072                self.annotations.push(Instruction::member_decorate(
3073                    struct_id,
3074                    member_idx,
3075                    spirv::Decoration::Location,
3076                    &[location],
3077                ));
3078                for other in others {
3079                    self.annotations.push(Instruction::member_decorate(
3080                        struct_id,
3081                        member_idx,
3082                        other,
3083                        &[],
3084                    ));
3085                }
3086                if let Some(blend_src) = blend_src {
3087                    self.annotations.push(Instruction::member_decorate(
3088                        struct_id,
3089                        member_idx,
3090                        spirv::Decoration::Index,
3091                        &[blend_src],
3092                    ));
3093                }
3094            }
3095        }
3096    }
3097
3098    pub fn map_binding(
3099        &mut self,
3100        ir_module: &crate::Module,
3101        stage: crate::ShaderStage,
3102        class: spirv::StorageClass,
3103        ty: Handle<crate::Type>,
3104        binding: &crate::Binding,
3105    ) -> Result<BindingDecorations, Error> {
3106        use spirv::BuiltIn;
3107        use spirv::Decoration;
3108        match *binding {
3109            crate::Binding::Location {
3110                location,
3111                interpolation,
3112                sampling,
3113                blend_src,
3114                per_primitive,
3115            } => {
3116                let mut others = ArrayVec::new();
3117
3118                let no_decorations =
3119                    // VUID-StandaloneSpirv-Flat-06202
3120                    // > The Flat, NoPerspective, Sample, and Centroid decorations
3121                    // > must not be used on variables with the Input storage class in a vertex shader
3122                    (class == spirv::StorageClass::Input && stage == crate::ShaderStage::Vertex) ||
3123                    // VUID-StandaloneSpirv-Flat-06201
3124                    // > The Flat, NoPerspective, Sample, and Centroid decorations
3125                    // > must not be used on variables with the Output storage class in a fragment shader
3126                    (class == spirv::StorageClass::Output && stage == crate::ShaderStage::Fragment);
3127
3128                if !no_decorations {
3129                    match interpolation {
3130                        // Perspective-correct interpolation is the default in SPIR-V.
3131                        None | Some(crate::Interpolation::Perspective) => (),
3132                        Some(crate::Interpolation::Flat) => {
3133                            others.push(Decoration::Flat);
3134                        }
3135                        Some(crate::Interpolation::Linear) => {
3136                            others.push(Decoration::NoPerspective);
3137                        }
3138                        Some(crate::Interpolation::PerVertex) => {
3139                            others.push(Decoration::PerVertexKHR);
3140                            self.require_any(
3141                                "`per_vertex` interpolation",
3142                                &[spirv::Capability::FragmentBarycentricKHR],
3143                            )?;
3144                            self.use_extension("SPV_KHR_fragment_shader_barycentric");
3145                        }
3146                    }
3147                    match sampling {
3148                        // Center sampling is the default in SPIR-V.
3149                        None
3150                        | Some(
3151                            crate::Sampling::Center
3152                            | crate::Sampling::First
3153                            | crate::Sampling::Either,
3154                        ) => (),
3155                        Some(crate::Sampling::Centroid) => {
3156                            others.push(Decoration::Centroid);
3157                        }
3158                        Some(crate::Sampling::Sample) => {
3159                            self.require_any(
3160                                "per-sample interpolation",
3161                                &[spirv::Capability::SampleRateShading],
3162                            )?;
3163                            others.push(Decoration::Sample);
3164                        }
3165                    }
3166                }
3167                if per_primitive && stage == crate::ShaderStage::Fragment {
3168                    others.push(Decoration::PerPrimitiveEXT);
3169                }
3170                Ok(BindingDecorations::Location {
3171                    location,
3172                    others,
3173                    blend_src,
3174                })
3175            }
3176            crate::Binding::BuiltIn(built_in) => {
3177                use crate::BuiltIn as Bi;
3178                let mut others = ArrayVec::new();
3179
3180                let built_in = match built_in {
3181                    Bi::Position { invariant } => {
3182                        if invariant {
3183                            others.push(Decoration::Invariant);
3184                        }
3185
3186                        if class == spirv::StorageClass::Output {
3187                            BuiltIn::Position
3188                        } else {
3189                            BuiltIn::FragCoord
3190                        }
3191                    }
3192                    Bi::ViewIndex => {
3193                        self.require_any("`view_index` built-in", &[spirv::Capability::MultiView])?;
3194                        BuiltIn::ViewIndex
3195                    }
3196                    // vertex
3197                    Bi::BaseInstance => BuiltIn::BaseInstance,
3198                    Bi::BaseVertex => BuiltIn::BaseVertex,
3199                    Bi::ClipDistances => {
3200                        self.require_any(
3201                            "`clip_distances` built-in",
3202                            &[spirv::Capability::ClipDistance],
3203                        )?;
3204                        BuiltIn::ClipDistance
3205                    }
3206                    Bi::CullDistance => {
3207                        self.require_any(
3208                            "`cull_distance` built-in",
3209                            &[spirv::Capability::CullDistance],
3210                        )?;
3211                        BuiltIn::CullDistance
3212                    }
3213                    Bi::InstanceIndex => BuiltIn::InstanceIndex,
3214                    Bi::PointSize => BuiltIn::PointSize,
3215                    Bi::VertexIndex => BuiltIn::VertexIndex,
3216                    Bi::DrawIndex => {
3217                        self.use_extension("SPV_KHR_shader_draw_parameters");
3218                        self.require_any(
3219                            "`draw_index built-in",
3220                            &[spirv::Capability::DrawParameters],
3221                        )?;
3222                        BuiltIn::DrawIndex
3223                    }
3224                    // fragment
3225                    Bi::FragDepth => BuiltIn::FragDepth,
3226                    Bi::PointCoord => BuiltIn::PointCoord,
3227                    Bi::FrontFacing => BuiltIn::FrontFacing,
3228                    Bi::PrimitiveIndex => {
3229                        // Geometry shader capability is required for primitive index
3230                        self.require_any(
3231                            "`primitive_index` built-in",
3232                            &[spirv::Capability::Geometry],
3233                        )?;
3234                        if stage == crate::ShaderStage::Mesh {
3235                            others.push(Decoration::PerPrimitiveEXT);
3236                        }
3237                        BuiltIn::PrimitiveId
3238                    }
3239                    Bi::Barycentric { perspective } => {
3240                        self.require_any(
3241                            "`barycentric` built-in",
3242                            &[spirv::Capability::FragmentBarycentricKHR],
3243                        )?;
3244                        self.use_extension("SPV_KHR_fragment_shader_barycentric");
3245                        if perspective {
3246                            BuiltIn::BaryCoordKHR
3247                        } else {
3248                            BuiltIn::BaryCoordNoPerspKHR
3249                        }
3250                    }
3251                    Bi::SampleIndex => {
3252                        self.require_any(
3253                            "`sample_index` built-in",
3254                            &[spirv::Capability::SampleRateShading],
3255                        )?;
3256
3257                        BuiltIn::SampleId
3258                    }
3259                    Bi::SampleMask => BuiltIn::SampleMask,
3260                    // compute
3261                    Bi::GlobalInvocationId => BuiltIn::GlobalInvocationId,
3262                    Bi::LocalInvocationId => BuiltIn::LocalInvocationId,
3263                    Bi::LocalInvocationIndex => BuiltIn::LocalInvocationIndex,
3264                    Bi::WorkGroupId => BuiltIn::WorkgroupId,
3265                    Bi::WorkGroupSize => BuiltIn::WorkgroupSize,
3266                    Bi::NumWorkGroups => BuiltIn::NumWorkgroups,
3267                    // Subgroup
3268                    Bi::NumSubgroups => {
3269                        self.require_any(
3270                            "`num_subgroups` built-in",
3271                            &[spirv::Capability::GroupNonUniform],
3272                        )?;
3273                        BuiltIn::NumSubgroups
3274                    }
3275                    Bi::SubgroupId => {
3276                        self.require_any(
3277                            "`subgroup_id` built-in",
3278                            &[spirv::Capability::GroupNonUniform],
3279                        )?;
3280                        BuiltIn::SubgroupId
3281                    }
3282                    Bi::SubgroupSize => {
3283                        self.require_any(
3284                            "`subgroup_size` built-in",
3285                            &[
3286                                spirv::Capability::GroupNonUniform,
3287                                spirv::Capability::SubgroupBallotKHR,
3288                            ],
3289                        )?;
3290                        BuiltIn::SubgroupSize
3291                    }
3292                    Bi::SubgroupInvocationId => {
3293                        self.require_any(
3294                            "`subgroup_invocation_id` built-in",
3295                            &[
3296                                spirv::Capability::GroupNonUniform,
3297                                spirv::Capability::SubgroupBallotKHR,
3298                            ],
3299                        )?;
3300                        BuiltIn::SubgroupLocalInvocationId
3301                    }
3302                    Bi::CullPrimitive => {
3303                        others.push(Decoration::PerPrimitiveEXT);
3304                        BuiltIn::CullPrimitiveEXT
3305                    }
3306                    Bi::PointIndex => BuiltIn::PrimitivePointIndicesEXT,
3307                    Bi::LineIndices => BuiltIn::PrimitiveLineIndicesEXT,
3308                    Bi::TriangleIndices => BuiltIn::PrimitiveTriangleIndicesEXT,
3309                    // No decoration, this EmitMeshTasksEXT is called at function return
3310                    Bi::MeshTaskSize => return Ok(BindingDecorations::None),
3311                    // These aren't normal builtins and don't occur in function output
3312                    Bi::VertexCount | Bi::Vertices | Bi::PrimitiveCount | Bi::Primitives => {
3313                        unreachable!()
3314                    }
3315                    // ray tracing pipeline
3316                    Bi::RayInvocationId => BuiltIn::LaunchIdKHR,
3317                    Bi::NumRayInvocations => BuiltIn::LaunchSizeKHR,
3318                    Bi::InstanceCustomData => BuiltIn::InstanceCustomIndexKHR,
3319                    Bi::GeometryIndex => BuiltIn::RayGeometryIndexKHR,
3320                    Bi::WorldRayOrigin => BuiltIn::WorldRayOriginKHR,
3321                    Bi::WorldRayDirection => BuiltIn::WorldRayDirectionKHR,
3322                    Bi::ObjectRayOrigin => BuiltIn::ObjectRayOriginKHR,
3323                    Bi::ObjectRayDirection => BuiltIn::ObjectRayDirectionKHR,
3324                    Bi::RayTmin => BuiltIn::RayTminKHR,
3325                    Bi::RayTCurrentMax => BuiltIn::RayTmaxKHR,
3326                    Bi::ObjectToWorld => BuiltIn::ObjectToWorldKHR,
3327                    Bi::WorldToObject => BuiltIn::WorldToObjectKHR,
3328                    Bi::HitKind => BuiltIn::HitKindKHR,
3329                };
3330
3331                use crate::ScalarKind as Sk;
3332
3333                // Per the Vulkan spec, `VUID-StandaloneSpirv-Flat-04744`:
3334                //
3335                // > Any variable with integer or double-precision floating-
3336                // > point type and with Input storage class in a fragment
3337                // > shader, must be decorated Flat
3338                if class == spirv::StorageClass::Input && stage == crate::ShaderStage::Fragment {
3339                    let is_flat = match ir_module.types[ty].inner {
3340                        crate::TypeInner::Scalar(scalar)
3341                        | crate::TypeInner::Vector { scalar, .. } => match scalar.kind {
3342                            Sk::Uint | Sk::Sint | Sk::Bool => true,
3343                            Sk::Float => false,
3344                            Sk::AbstractInt | Sk::AbstractFloat => {
3345                                return Err(Error::Validation(
3346                                    "Abstract types should not appear in IR presented to backends",
3347                                ))
3348                            }
3349                        },
3350                        _ => false,
3351                    };
3352
3353                    if is_flat {
3354                        others.push(Decoration::Flat);
3355                    }
3356                }
3357                Ok(BindingDecorations::BuiltIn(built_in, others))
3358            }
3359        }
3360    }
3361
3362    /// Load an IO variable, converting from `f32` to `f16` if polyfill is active.
3363    /// Returns the id of the loaded value matching `target_type_id`.
3364    pub(super) fn load_io_with_f16_polyfill(
3365        &mut self,
3366        body: &mut Vec<Instruction>,
3367        varying_id: Word,
3368        target_type_id: Word,
3369    ) -> Word {
3370        let tmp = self.id_gen.next();
3371        if let Some(f32_ty) = self.io_f16_polyfills.get_f32_io_type(varying_id) {
3372            body.push(Instruction::load(f32_ty, tmp, varying_id, None));
3373            let converted = self.id_gen.next();
3374            super::f16_polyfill::F16IoPolyfill::emit_f32_to_f16_conversion(
3375                tmp,
3376                target_type_id,
3377                converted,
3378                body,
3379            );
3380            converted
3381        } else {
3382            body.push(Instruction::load(target_type_id, tmp, varying_id, None));
3383            tmp
3384        }
3385    }
3386
3387    /// Store an IO variable, converting from `f16` to `f32` if polyfill is active.
3388    pub(super) fn store_io_with_f16_polyfill(
3389        &mut self,
3390        body: &mut Vec<Instruction>,
3391        varying_id: Word,
3392        value_id: Word,
3393    ) {
3394        if let Some(f32_ty) = self.io_f16_polyfills.get_f32_io_type(varying_id) {
3395            let converted = self.id_gen.next();
3396            super::f16_polyfill::F16IoPolyfill::emit_f16_to_f32_conversion(
3397                value_id, f32_ty, converted, body,
3398            );
3399            body.push(Instruction::store(varying_id, converted, None));
3400        } else {
3401            body.push(Instruction::store(varying_id, value_id, None));
3402        }
3403    }
3404
3405    fn write_global_variable(
3406        &mut self,
3407        ir_module: &crate::Module,
3408        global_variable: &crate::GlobalVariable,
3409    ) -> Result<Word, Error> {
3410        use spirv::Decoration;
3411
3412        let id = self.id_gen.next();
3413        let class = map_storage_class(global_variable.space);
3414
3415        if let crate::AddressSpace::RayPayload | crate::AddressSpace::IncomingRayPayload =
3416            global_variable.space
3417        {
3418            self.require_any("ray tracing pipelines", &[spirv::Capability::RayTracingKHR])?;
3419        }
3420
3421        //self.check(class.required_capabilities())?;
3422
3423        if global_variable
3424            .memory_decorations
3425            .contains(crate::MemoryDecorations::COHERENT)
3426        {
3427            self.decorate(id, Decoration::Coherent, &[]);
3428        }
3429        if global_variable
3430            .memory_decorations
3431            .contains(crate::MemoryDecorations::VOLATILE)
3432        {
3433            self.decorate(id, Decoration::Volatile, &[]);
3434        }
3435
3436        if self.flags.contains(WriterFlags::DEBUG) {
3437            if let Some(ref name) = global_variable.name {
3438                self.debugs.push(Instruction::name(id, name));
3439            }
3440        }
3441
3442        let storage_access = match global_variable.space {
3443            crate::AddressSpace::Storage { access } => Some(access),
3444            _ => match ir_module.types[global_variable.ty].inner {
3445                crate::TypeInner::Image {
3446                    class: crate::ImageClass::Storage { access, .. },
3447                    ..
3448                } => Some(access),
3449                _ => None,
3450            },
3451        };
3452        if let Some(storage_access) = storage_access {
3453            if !storage_access.contains(crate::StorageAccess::LOAD) {
3454                self.decorate(id, Decoration::NonReadable, &[]);
3455            }
3456            if !storage_access.contains(crate::StorageAccess::STORE) {
3457                self.decorate(id, Decoration::NonWritable, &[]);
3458            }
3459        }
3460
3461        // Note: we should be able to substitute `binding_array<Foo, 0>`,
3462        // but there is still code that tries to register the pre-substituted type,
3463        // and it is failing on 0.
3464        let mut substitute_inner_type_lookup = None;
3465        if let Some(ref res_binding) = global_variable.binding {
3466            let bind_target = self.resolve_resource_binding(res_binding)?;
3467            self.decorate(id, Decoration::DescriptorSet, &[bind_target.descriptor_set]);
3468            self.decorate(id, Decoration::Binding, &[bind_target.binding]);
3469
3470            if let Some(remapped_binding_array_size) = bind_target.binding_array_size {
3471                if let crate::TypeInner::BindingArray { base, .. } =
3472                    ir_module.types[global_variable.ty].inner
3473                {
3474                    let binding_array_type_id =
3475                        self.get_type_id(LookupType::Local(LocalType::BindingArray {
3476                            base,
3477                            size: remapped_binding_array_size,
3478                        }));
3479                    substitute_inner_type_lookup = Some(LookupType::Local(LocalType::Pointer {
3480                        base: binding_array_type_id,
3481                        class,
3482                    }));
3483                }
3484            }
3485        };
3486
3487        let init_word = global_variable
3488            .init
3489            .map(|constant| self.constant_ids[constant]);
3490        let inner_type_id = self.get_type_id(
3491            substitute_inner_type_lookup.unwrap_or(LookupType::Handle(global_variable.ty)),
3492        );
3493
3494        // generate the wrapping structure if needed
3495        let pointer_type_id = if global_needs_wrapper(ir_module, global_variable) {
3496            let wrapper_type_id = self.id_gen.next();
3497
3498            self.decorate(wrapper_type_id, Decoration::Block, &[]);
3499
3500            match self.std140_compat_uniform_types.get(&global_variable.ty) {
3501                Some(std140_type_info) if global_variable.space == crate::AddressSpace::Uniform => {
3502                    self.annotations.push(Instruction::member_decorate(
3503                        wrapper_type_id,
3504                        0,
3505                        Decoration::Offset,
3506                        &[0],
3507                    ));
3508                    Instruction::type_struct(wrapper_type_id, &[std140_type_info.type_id])
3509                        .to_words(&mut self.logical_layout.declarations);
3510                }
3511                _ => {
3512                    let member = crate::StructMember {
3513                        name: None,
3514                        ty: global_variable.ty,
3515                        binding: None,
3516                        offset: 0,
3517                    };
3518                    self.decorate_struct_member(wrapper_type_id, 0, &member, &ir_module.types)?;
3519
3520                    Instruction::type_struct(wrapper_type_id, &[inner_type_id])
3521                        .to_words(&mut self.logical_layout.declarations);
3522                }
3523            }
3524
3525            let pointer_type_id = self.id_gen.next();
3526            Instruction::type_pointer(pointer_type_id, class, wrapper_type_id)
3527                .to_words(&mut self.logical_layout.declarations);
3528
3529            pointer_type_id
3530        } else {
3531            // This is a global variable in the Storage address space. The only
3532            // way it could have `global_needs_wrapper() == false` is if it has
3533            // a runtime-sized or binding array.
3534            // Runtime-sized arrays were decorated when iterating through struct content.
3535            // Now binding arrays require Block decorating.
3536            if let crate::AddressSpace::Storage { .. } = global_variable.space {
3537                match ir_module.types[global_variable.ty].inner {
3538                    crate::TypeInner::BindingArray { base, .. } => {
3539                        let ty = &ir_module.types[base];
3540                        let mut should_decorate = true;
3541                        // Check if the type has a runtime array.
3542                        // A normal runtime array gets validated out,
3543                        // so only structs can be with runtime arrays
3544                        if let crate::TypeInner::Struct { ref members, .. } = ty.inner {
3545                            // only the last member in a struct can be dynamically sized
3546                            if let Some(last_member) = members.last() {
3547                                if let &crate::TypeInner::Array {
3548                                    size: crate::ArraySize::Dynamic,
3549                                    ..
3550                                } = &ir_module.types[last_member.ty].inner
3551                                {
3552                                    should_decorate = false;
3553                                }
3554                            }
3555                        }
3556                        if should_decorate {
3557                            let decorated_id = self.get_handle_type_id(base);
3558                            self.decorate(decorated_id, Decoration::Block, &[]);
3559                        }
3560                    }
3561                    _ => (),
3562                };
3563            }
3564            if substitute_inner_type_lookup.is_some() {
3565                inner_type_id
3566            } else {
3567                self.get_handle_pointer_type_id(global_variable.ty, class)
3568            }
3569        };
3570
3571        let init_word = match (global_variable.space, self.zero_initialize_workgroup_memory) {
3572            (crate::AddressSpace::Private, _)
3573            | (crate::AddressSpace::WorkGroup, super::ZeroInitializeWorkgroupMemoryMode::Native) => {
3574                init_word.or_else(|| Some(self.get_constant_null(inner_type_id)))
3575            }
3576            _ => init_word,
3577        };
3578
3579        Instruction::variable(pointer_type_id, id, class, init_word)
3580            .to_words(&mut self.logical_layout.declarations);
3581        Ok(id)
3582    }
3583
3584    /// Write the necessary decorations for a struct member.
3585    ///
3586    /// Emit decorations for the `index`'th member of the struct type
3587    /// designated by `struct_id`, described by `member`.
3588    fn decorate_struct_member(
3589        &mut self,
3590        struct_id: Word,
3591        index: usize,
3592        member: &crate::StructMember,
3593        arena: &UniqueArena<crate::Type>,
3594    ) -> Result<(), Error> {
3595        use spirv::Decoration;
3596
3597        self.annotations.push(Instruction::member_decorate(
3598            struct_id,
3599            index as u32,
3600            Decoration::Offset,
3601            &[member.offset],
3602        ));
3603
3604        if self.flags.contains(WriterFlags::DEBUG) {
3605            if let Some(ref name) = member.name {
3606                self.debugs
3607                    .push(Instruction::member_name(struct_id, index as u32, name));
3608            }
3609        }
3610
3611        // Matrices and (potentially nested) arrays of matrices both require decorations,
3612        // so "see through" any arrays to determine if they're needed.
3613        let mut member_array_subty_inner = &arena[member.ty].inner;
3614        while let crate::TypeInner::Array { base, .. } = *member_array_subty_inner {
3615            member_array_subty_inner = &arena[base].inner;
3616        }
3617
3618        if let crate::TypeInner::Matrix {
3619            columns: _,
3620            rows,
3621            scalar,
3622        } = *member_array_subty_inner
3623        {
3624            let byte_stride = Alignment::from(rows) * scalar.width as u32;
3625            self.annotations.push(Instruction::member_decorate(
3626                struct_id,
3627                index as u32,
3628                Decoration::ColMajor,
3629                &[],
3630            ));
3631            self.annotations.push(Instruction::member_decorate(
3632                struct_id,
3633                index as u32,
3634                Decoration::MatrixStride,
3635                &[byte_stride],
3636            ));
3637        }
3638
3639        Ok(())
3640    }
3641
3642    pub(super) fn get_function_type(&mut self, lookup_function_type: LookupFunctionType) -> Word {
3643        match self
3644            .lookup_function_type
3645            .entry(lookup_function_type.clone())
3646        {
3647            Entry::Occupied(e) => *e.get(),
3648            Entry::Vacant(_) => {
3649                let id = self.id_gen.next();
3650                let instruction = Instruction::type_function(
3651                    id,
3652                    lookup_function_type.return_type_id,
3653                    &lookup_function_type.parameter_type_ids,
3654                );
3655                instruction.to_words(&mut self.logical_layout.declarations);
3656                self.lookup_function_type.insert(lookup_function_type, id);
3657                id
3658            }
3659        }
3660    }
3661
3662    const fn write_physical_layout(&mut self) {
3663        self.physical_layout.bound = self.id_gen.0 + 1;
3664    }
3665
3666    fn write_logical_layout(
3667        &mut self,
3668        ir_module: &crate::Module,
3669        mod_info: &ModuleInfo,
3670        ep_index: Option<usize>,
3671        debug_info: &Option<DebugInfo>,
3672    ) -> Result<(), Error> {
3673        fn has_view_index_check(
3674            ir_module: &crate::Module,
3675            binding: Option<&crate::Binding>,
3676            ty: Handle<crate::Type>,
3677        ) -> bool {
3678            match ir_module.types[ty].inner {
3679                crate::TypeInner::Struct { ref members, .. } => members.iter().any(|member| {
3680                    has_view_index_check(ir_module, member.binding.as_ref(), member.ty)
3681                }),
3682                _ => binding == Some(&crate::Binding::BuiltIn(crate::BuiltIn::ViewIndex)),
3683            }
3684        }
3685
3686        let has_storage_buffers =
3687            ir_module
3688                .global_variables
3689                .iter()
3690                .any(|(_, var)| match var.space {
3691                    crate::AddressSpace::Storage { .. } => true,
3692                    _ => false,
3693                });
3694        let has_view_index = ir_module
3695            .entry_points
3696            .iter()
3697            .flat_map(|entry| entry.function.arguments.iter())
3698            .any(|arg| has_view_index_check(ir_module, arg.binding.as_ref(), arg.ty));
3699        let has_vertex_return = ir_module.special_types.ray_vertex_return.is_some();
3700
3701        let rt_uses = ir_module.uses_ray_tracing(ep_index);
3702        let has_ray_query = rt_uses.queries;
3703        let has_ray_tracing_pipeline = rt_uses.pipelines;
3704
3705        self.has_ray_tracing_pipeline = has_ray_tracing_pipeline;
3706
3707        if self.physical_layout.version < 0x10300 && has_storage_buffers {
3708            // enable the storage buffer class on < SPV-1.3
3709            Instruction::extension("SPV_KHR_storage_buffer_storage_class")
3710                .to_words(&mut self.logical_layout.extensions);
3711        }
3712        if has_view_index {
3713            Instruction::extension("SPV_KHR_multiview")
3714                .to_words(&mut self.logical_layout.extensions)
3715        }
3716        if has_ray_query {
3717            Instruction::extension("SPV_KHR_ray_query")
3718                .to_words(&mut self.logical_layout.extensions)
3719        }
3720        if has_vertex_return {
3721            Instruction::extension("SPV_KHR_ray_tracing_position_fetch")
3722                .to_words(&mut self.logical_layout.extensions);
3723        }
3724        if ir_module.uses_mesh_shaders() {
3725            self.use_extension("SPV_EXT_mesh_shader");
3726            self.require_any("Mesh Shaders", &[spirv::Capability::MeshShadingEXT])?;
3727            let lang_version = self.lang_version();
3728            if lang_version.0 <= 1 && lang_version.1 < 4 {
3729                return Err(Error::SpirvVersionTooLow(1, 4));
3730            }
3731        }
3732        if has_ray_tracing_pipeline {
3733            Instruction::extension("SPV_KHR_ray_tracing")
3734                .to_words(&mut self.logical_layout.extensions)
3735        }
3736        Instruction::type_void(self.void_type).to_words(&mut self.logical_layout.declarations);
3737        Instruction::ext_inst_import(self.gl450_ext_inst_id, "GLSL.std.450")
3738            .to_words(&mut self.logical_layout.ext_inst_imports);
3739
3740        let mut debug_info_inner = None;
3741        if self.flags.contains(WriterFlags::DEBUG) {
3742            if let Some(debug_info) = debug_info.as_ref() {
3743                let source_file_id = self.id_gen.next();
3744                self.debugs
3745                    .push(Instruction::string(debug_info.file_name, source_file_id));
3746
3747                debug_info_inner = Some(DebugInfoInner {
3748                    source_code: debug_info.source_code,
3749                    source_file_id,
3750                });
3751                self.debugs.append(&mut Instruction::source_auto_continued(
3752                    debug_info.language,
3753                    0,
3754                    &debug_info_inner,
3755                ));
3756            }
3757        }
3758
3759        // write all types
3760        for (handle, _) in ir_module.types.iter() {
3761            self.write_type_declaration_arena(ir_module, handle)?;
3762        }
3763
3764        // write std140 layout compatible types required by uniforms
3765        for (_, var) in ir_module.global_variables.iter() {
3766            if var.space == crate::AddressSpace::Uniform {
3767                self.write_std140_compat_type_declaration(ir_module, var.ty)?;
3768            }
3769        }
3770
3771        // write all const-expressions as constants
3772        self.constant_ids
3773            .resize(ir_module.global_expressions.len(), 0);
3774        for (handle, _) in ir_module.global_expressions.iter() {
3775            self.write_constant_expr(handle, ir_module, mod_info)?;
3776        }
3777        debug_assert!(self.constant_ids.iter().all(|&id| id != 0));
3778
3779        // write the name of constants on their respective const-expression initializer
3780        if self.flags.contains(WriterFlags::DEBUG) {
3781            for (_, constant) in ir_module.constants.iter() {
3782                if let Some(ref name) = constant.name {
3783                    let id = self.constant_ids[constant.init];
3784                    self.debugs.push(Instruction::name(id, name));
3785                }
3786            }
3787        }
3788
3789        // write all global variables
3790        for (handle, var) in ir_module.global_variables.iter() {
3791            // If a single entry point was specified, only write `OpVariable` instructions
3792            // for the globals it actually uses. Emit dummies for the others,
3793            // to preserve the indices in `global_variables`.
3794            let gvar = match ep_index {
3795                Some(index) if mod_info.get_entry_point(index)[handle].is_empty() => {
3796                    GlobalVariable::dummy()
3797                }
3798                _ => {
3799                    let id = self.write_global_variable(ir_module, var)?;
3800                    GlobalVariable::new(id)
3801                }
3802            };
3803            self.global_variables.insert(handle, gvar);
3804        }
3805
3806        // write all functions
3807        for (handle, ir_function) in ir_module.functions.iter() {
3808            let info = &mod_info[handle];
3809            if let Some(index) = ep_index {
3810                let ep_info = mod_info.get_entry_point(index);
3811                // If this function uses globals that we omitted from the SPIR-V
3812                // because the entry point and its callees didn't use them,
3813                // then we must skip it.
3814                if !ep_info.dominates_global_use(info) {
3815                    log::debug!("Skip function {:?}", ir_function.name);
3816                    continue;
3817                }
3818
3819                // Skip functions that that are not compatible with this entry point's stage.
3820                //
3821                // When validation is enabled, it rejects modules whose entry points try to call
3822                // incompatible functions, so if we got this far, then any functions incompatible
3823                // with our selected entry point must not be used.
3824                //
3825                // When validation is disabled, `fun_info.available_stages` is always just
3826                // `ShaderStages::all()`, so this will write all functions in the module, and
3827                // the downstream GLSL compiler will catch any problems.
3828                if !info.available_stages.contains(ep_info.available_stages) {
3829                    continue;
3830                }
3831            }
3832            let id = self.write_function(ir_function, info, ir_module, None, &debug_info_inner)?;
3833            self.lookup_function.insert(handle, id);
3834        }
3835
3836        // write all or one entry points
3837        for (index, ir_ep) in ir_module.entry_points.iter().enumerate() {
3838            if ep_index.is_some() && ep_index != Some(index) {
3839                continue;
3840            }
3841            let info = mod_info.get_entry_point(index);
3842            let ep_instruction =
3843                self.write_entry_point(ir_ep, info, ir_module, &debug_info_inner)?;
3844            ep_instruction.to_words(&mut self.logical_layout.entry_points);
3845        }
3846
3847        for capability in self.capabilities_used.iter() {
3848            Instruction::capability(*capability).to_words(&mut self.logical_layout.capabilities);
3849        }
3850        for extension in self.extensions_used.iter() {
3851            Instruction::extension(extension).to_words(&mut self.logical_layout.extensions);
3852        }
3853        if ir_module.entry_points.is_empty() {
3854            // SPIR-V doesn't like modules without entry points
3855            Instruction::capability(spirv::Capability::Linkage)
3856                .to_words(&mut self.logical_layout.capabilities);
3857        }
3858
3859        let addressing_model = spirv::AddressingModel::Logical;
3860        let memory_model = if self
3861            .capabilities_used
3862            .contains(&spirv::Capability::VulkanMemoryModel)
3863        {
3864            spirv::MemoryModel::Vulkan
3865        } else {
3866            spirv::MemoryModel::GLSL450
3867        };
3868        //self.check(addressing_model.required_capabilities())?;
3869        //self.check(memory_model.required_capabilities())?;
3870
3871        Instruction::memory_model(addressing_model, memory_model)
3872            .to_words(&mut self.logical_layout.memory_model);
3873
3874        for debug_string in self.debug_strings.iter() {
3875            debug_string.to_words(&mut self.logical_layout.debugs);
3876        }
3877
3878        if self.flags.contains(WriterFlags::DEBUG) {
3879            for debug in self.debugs.iter() {
3880                debug.to_words(&mut self.logical_layout.debugs);
3881            }
3882        }
3883
3884        for annotation in self.annotations.iter() {
3885            annotation.to_words(&mut self.logical_layout.annotations);
3886        }
3887
3888        Ok(())
3889    }
3890
3891    pub fn write(
3892        &mut self,
3893        ir_module: &crate::Module,
3894        info: &ModuleInfo,
3895        pipeline_options: Option<&PipelineOptions>,
3896        debug_info: &Option<DebugInfo>,
3897        words: &mut Vec<Word>,
3898    ) -> Result<(), Error> {
3899        self.reset();
3900
3901        // Try to find the entry point and corresponding index
3902        let ep_index = match pipeline_options {
3903            Some(po) => {
3904                let index = ir_module
3905                    .entry_points
3906                    .iter()
3907                    .position(|ep| po.shader_stage == ep.stage && po.entry_point == ep.name)
3908                    .ok_or(Error::EntryPointNotFound)?;
3909                Some(index)
3910            }
3911            None => None,
3912        };
3913
3914        self.write_logical_layout(ir_module, info, ep_index, debug_info)?;
3915        self.write_physical_layout();
3916
3917        self.physical_layout.in_words(words);
3918        self.logical_layout.in_words(words);
3919        Ok(())
3920    }
3921
3922    /// Return the set of capabilities the last module written used.
3923    pub const fn get_capabilities_used(&self) -> &crate::FastIndexSet<spirv::Capability> {
3924        &self.capabilities_used
3925    }
3926
3927    pub fn decorate_non_uniform_binding_array_access(&mut self, id: Word) -> Result<(), Error> {
3928        self.require_any("NonUniformEXT", &[spirv::Capability::ShaderNonUniform])?;
3929        self.use_extension("SPV_EXT_descriptor_indexing");
3930        self.decorate(id, spirv::Decoration::NonUniform, &[]);
3931        Ok(())
3932    }
3933
3934    pub(super) fn needs_f16_polyfill(&self, ty_inner: &crate::TypeInner) -> bool {
3935        self.io_f16_polyfills.needs_polyfill(ty_inner)
3936    }
3937
3938    pub(super) fn write_debug_printf(
3939        &mut self,
3940        block: &mut Block,
3941        string: &str,
3942        format_params: &[Word],
3943    ) {
3944        if self.debug_printf.is_none() {
3945            self.use_extension("SPV_KHR_non_semantic_info");
3946            let import_id = self.id_gen.next();
3947            Instruction::ext_inst_import(import_id, "NonSemantic.DebugPrintf")
3948                .to_words(&mut self.logical_layout.ext_inst_imports);
3949            self.debug_printf = Some(import_id)
3950        }
3951
3952        let import_id = self.debug_printf.unwrap();
3953
3954        let string_id = self.id_gen.next();
3955        self.debug_strings
3956            .push(Instruction::string(string, string_id));
3957
3958        let mut operands = Vec::with_capacity(1 + format_params.len());
3959        operands.push(string_id);
3960        operands.extend(format_params.iter());
3961
3962        let print_id = self.id_gen.next();
3963        block.body.push(Instruction::ext_inst(
3964            import_id,
3965            1,
3966            self.void_type,
3967            print_id,
3968            &operands,
3969        ));
3970    }
3971}
3972
3973#[test]
3974fn test_write_physical_layout() {
3975    let mut writer = Writer::new(&Options::default()).unwrap();
3976    assert_eq!(writer.physical_layout.bound, 0);
3977    writer.write_physical_layout();
3978    assert_eq!(writer.physical_layout.bound, 3);
3979}