naga/back/spv/
mesh_shader.rs

1use alloc::vec::Vec;
2use spirv::Word;
3
4use crate::{
5    back::spv::{
6        helpers::BindingDecorations, writer::FunctionInterface, Block, EntryPointContext, Error,
7        Instruction, WriterFlags,
8    },
9    non_max_u32::NonMaxU32,
10    Handle,
11};
12
13#[derive(Clone, Debug)]
14pub struct MeshReturnMember {
15    pub ty_id: u32,
16    pub binding: crate::Binding,
17}
18
19#[derive(Debug)]
20struct PerOutputTypeMeshReturnInfo {
21    max_length_constant: Word,
22    array_type_id: Word,
23    struct_members: Vec<MeshReturnMember>,
24
25    // * Most builtins must be in the same block.
26    // * All bindings must be in their own unique block.
27    // * The primitive indices builtin family needs its own block.
28    // * Cull primitive doesn't care about having its own block, but
29    //   some older validation layers didn't respect this.
30    builtin_block: Option<Word>,
31    bindings: Vec<Word>,
32}
33
34#[derive(Debug)]
35pub struct MeshReturnInfo {
36    /// Id of the workgroup variable containing the data to be output
37    out_variable_id: Word,
38    /// All members of the output variable struct type
39    out_members: Vec<MeshReturnMember>,
40    /// Id of the input variable for local invocation id
41    local_invocation_index_var_id: Word,
42    /// Total workgroup size (product)
43    workgroup_size: u32,
44
45    /// Vertex-specific info
46    vertex_info: PerOutputTypeMeshReturnInfo,
47    /// Primitive-specific info
48    primitive_info: PerOutputTypeMeshReturnInfo,
49    /// Array variable for the primitive indices builtin
50    primitive_indices: Option<Word>,
51}
52
53impl super::Writer {
54    /// Sets up an output variable that will handle part of the mesh shader output
55    pub(super) fn write_mesh_return_global_variable(
56        &mut self,
57        ty: u32,
58        array_size_id: u32,
59    ) -> Result<Word, Error> {
60        let array_ty = self.id_gen.next();
61        Instruction::type_array(array_ty, ty, array_size_id)
62            .to_words(&mut self.logical_layout.declarations);
63        let ptr_ty = self.get_pointer_type_id(array_ty, spirv::StorageClass::Output);
64        let var_id = self.id_gen.next();
65        Instruction::variable(ptr_ty, var_id, spirv::StorageClass::Output, None)
66            .to_words(&mut self.logical_layout.declarations);
67        Ok(var_id)
68    }
69
70    /// This does various setup things to allow mesh shader entry points
71    /// to be properly written, such as creating the output variables
72    pub(super) fn write_entry_point_mesh_shader_info(
73        &mut self,
74        iface: &mut FunctionInterface,
75        local_invocation_index_id: Option<Word>,
76        ir_module: &crate::Module,
77        ep_context: &mut EntryPointContext,
78    ) -> Result<(), Error> {
79        let Some(ref mesh_info) = iface.mesh_info else {
80            return Ok(());
81        };
82        // Collect the members in the output structs
83        let out_members: Vec<MeshReturnMember> =
84            match &ir_module.types[ir_module.global_variables[mesh_info.output_variable].ty] {
85                &crate::Type {
86                    inner: crate::TypeInner::Struct { ref members, .. },
87                    ..
88                } => members
89                    .iter()
90                    .map(|a| MeshReturnMember {
91                        ty_id: self.get_handle_type_id(a.ty),
92                        binding: a.binding.clone().unwrap(),
93                    })
94                    .collect(),
95                _ => unreachable!(),
96            };
97        let vertex_array_type_id = out_members
98            .iter()
99            .find(|a| a.binding == crate::Binding::BuiltIn(crate::BuiltIn::Vertices))
100            .unwrap()
101            .ty_id;
102        let primitive_array_type_id = out_members
103            .iter()
104            .find(|a| a.binding == crate::Binding::BuiltIn(crate::BuiltIn::Primitives))
105            .unwrap()
106            .ty_id;
107        let vertex_members = match &ir_module.types[mesh_info.vertex_output_type] {
108            &crate::Type {
109                inner: crate::TypeInner::Struct { ref members, .. },
110                ..
111            } => members
112                .iter()
113                .map(|a| MeshReturnMember {
114                    ty_id: self.get_handle_type_id(a.ty),
115                    binding: a.binding.clone().unwrap(),
116                })
117                .collect(),
118            _ => unreachable!(),
119        };
120        let primitive_members = match &ir_module.types[mesh_info.primitive_output_type] {
121            &crate::Type {
122                inner: crate::TypeInner::Struct { ref members, .. },
123                ..
124            } => members
125                .iter()
126                .map(|a| MeshReturnMember {
127                    ty_id: self.get_handle_type_id(a.ty),
128                    binding: a.binding.clone().unwrap(),
129                })
130                .collect(),
131            _ => unreachable!(),
132        };
133        // In the final return, we do a giant memcpy, for which this is helpful
134        let local_invocation_index_var_id = match local_invocation_index_id {
135            Some(a) => a,
136            None => {
137                let u32_id = self.get_u32_type_id();
138                let var = self.id_gen.next();
139                Instruction::variable(
140                    self.get_pointer_type_id(u32_id, spirv::StorageClass::Input),
141                    var,
142                    spirv::StorageClass::Input,
143                    None,
144                )
145                .to_words(&mut self.logical_layout.declarations);
146                Instruction::decorate(
147                    var,
148                    spirv::Decoration::BuiltIn,
149                    &[spirv::BuiltIn::LocalInvocationIndex as u32],
150                )
151                .to_words(&mut self.logical_layout.annotations);
152                iface.varying_ids.push(var);
153
154                var
155            }
156        };
157        // This is the information that is passed to the function writer
158        // so that it can write the final return logic
159        let mut mesh_return_info = MeshReturnInfo {
160            out_variable_id: self.global_variables[mesh_info.output_variable].var_id,
161            out_members,
162            local_invocation_index_var_id,
163            workgroup_size: self
164                .get_constant_scalar(crate::Literal::U32(iface.workgroup_size.iter().product())),
165
166            vertex_info: PerOutputTypeMeshReturnInfo {
167                array_type_id: vertex_array_type_id,
168                struct_members: vertex_members,
169                max_length_constant: self
170                    .get_constant_scalar(crate::Literal::U32(mesh_info.max_vertices)),
171                bindings: Vec::new(),
172                builtin_block: None,
173            },
174            primitive_info: PerOutputTypeMeshReturnInfo {
175                array_type_id: primitive_array_type_id,
176                struct_members: primitive_members,
177                max_length_constant: self
178                    .get_constant_scalar(crate::Literal::U32(mesh_info.max_primitives)),
179                bindings: Vec::new(),
180                builtin_block: None,
181            },
182            primitive_indices: None,
183        };
184        let vert_array_size_id =
185            self.get_constant_scalar(crate::Literal::U32(mesh_info.max_vertices));
186        let prim_array_size_id =
187            self.get_constant_scalar(crate::Literal::U32(mesh_info.max_primitives));
188
189        // Create the actual output variables and types.
190        // According to SPIR-V,
191        // * All builtins must be in the same output `Block` (except builtins for different output types like vertex/primitive)
192        // * Each member with `location` must be in its own `Block` decorated `struct`
193        // * Some builtins like CullPrimitiveEXT don't care as much (older validation layers don't know this! Wonderful!)
194        // * Some builtins like the indices ones need to be in their own output variable without a struct wrapper
195
196        // Write vertex builtin block
197        if mesh_return_info
198            .vertex_info
199            .struct_members
200            .iter()
201            .any(|a| matches!(a.binding, crate::Binding::BuiltIn(..)))
202        {
203            let builtin_block_ty_id = self.id_gen.next();
204            let mut ins = Instruction::type_struct(builtin_block_ty_id, &[]);
205            let mut bi_index = 0;
206            let mut decorations = Vec::new();
207            for member in &mesh_return_info.vertex_info.struct_members {
208                if let crate::Binding::BuiltIn(_) = member.binding {
209                    ins.add_operand(member.ty_id);
210                    let binding = self.map_binding(
211                        ir_module,
212                        iface.stage,
213                        spirv::StorageClass::Output,
214                        // Unused except in fragment shaders with other conditions, so we can pass null
215                        Handle::new(NonMaxU32::new(0).unwrap()),
216                        &member.binding,
217                    )?;
218                    match binding {
219                        BindingDecorations::BuiltIn(bi, others) => {
220                            decorations.push(Instruction::member_decorate(
221                                builtin_block_ty_id,
222                                bi_index,
223                                spirv::Decoration::BuiltIn,
224                                &[bi as Word],
225                            ));
226                            for other in others {
227                                decorations.push(Instruction::member_decorate(
228                                    builtin_block_ty_id,
229                                    bi_index,
230                                    other,
231                                    &[],
232                                ));
233                            }
234                        }
235                        _ => unreachable!(),
236                    }
237                    bi_index += 1;
238                }
239            }
240            ins.to_words(&mut self.logical_layout.declarations);
241            decorations.push(Instruction::decorate(
242                builtin_block_ty_id,
243                spirv::Decoration::Block,
244                &[],
245            ));
246            for dec in decorations {
247                dec.to_words(&mut self.logical_layout.annotations);
248            }
249            let v =
250                self.write_mesh_return_global_variable(builtin_block_ty_id, vert_array_size_id)?;
251            iface.varying_ids.push(v);
252            if self.flags.contains(WriterFlags::DEBUG) {
253                self.debugs
254                    .push(Instruction::name(v, "naga_vertex_builtin_outputs"));
255            }
256            mesh_return_info.vertex_info.builtin_block = Some(v);
257        }
258        // Write primitive builtin block
259        if mesh_return_info
260            .primitive_info
261            .struct_members
262            .iter()
263            .any(|a| {
264                !matches!(
265                    a.binding,
266                    crate::Binding::BuiltIn(
267                        crate::BuiltIn::PointIndex
268                            | crate::BuiltIn::LineIndices
269                            | crate::BuiltIn::TriangleIndices
270                    ) | crate::Binding::Location { .. }
271                )
272            })
273        {
274            let builtin_block_ty_id = self.id_gen.next();
275            let mut ins = Instruction::type_struct(builtin_block_ty_id, &[]);
276            let mut bi_index = 0;
277            let mut decorations = Vec::new();
278            for member in &mesh_return_info.primitive_info.struct_members {
279                if let crate::Binding::BuiltIn(bi) = member.binding {
280                    // These need to be in their own block, unlike other builtins
281                    if matches!(
282                        bi,
283                        crate::BuiltIn::PointIndex
284                            | crate::BuiltIn::LineIndices
285                            | crate::BuiltIn::TriangleIndices,
286                    ) {
287                        continue;
288                    }
289                    ins.add_operand(member.ty_id);
290                    let binding = self.map_binding(
291                        ir_module,
292                        iface.stage,
293                        spirv::StorageClass::Output,
294                        // Unused except in fragment shaders with other conditions, so we can pass null
295                        Handle::new(NonMaxU32::new(0).unwrap()),
296                        &member.binding,
297                    )?;
298                    match binding {
299                        BindingDecorations::BuiltIn(bi, others) => {
300                            decorations.push(Instruction::member_decorate(
301                                builtin_block_ty_id,
302                                bi_index,
303                                spirv::Decoration::BuiltIn,
304                                &[bi as Word],
305                            ));
306                            for other in others {
307                                decorations.push(Instruction::member_decorate(
308                                    builtin_block_ty_id,
309                                    bi_index,
310                                    other,
311                                    &[],
312                                ));
313                            }
314                        }
315                        _ => unreachable!(),
316                    }
317                    bi_index += 1;
318                }
319            }
320            ins.to_words(&mut self.logical_layout.declarations);
321            decorations.push(Instruction::decorate(
322                builtin_block_ty_id,
323                spirv::Decoration::Block,
324                &[],
325            ));
326            for dec in decorations {
327                dec.to_words(&mut self.logical_layout.annotations);
328            }
329            let v =
330                self.write_mesh_return_global_variable(builtin_block_ty_id, prim_array_size_id)?;
331            Instruction::decorate(v, spirv::Decoration::PerPrimitiveEXT, &[])
332                .to_words(&mut self.logical_layout.annotations);
333            iface.varying_ids.push(v);
334            if self.flags.contains(WriterFlags::DEBUG) {
335                self.debugs
336                    .push(Instruction::name(v, "naga_primitive_builtin_outputs"));
337            }
338            mesh_return_info.primitive_info.builtin_block = Some(v);
339        }
340
341        // Write vertex binding output blocks (1 array per output struct member)
342        for member in &mesh_return_info.vertex_info.struct_members {
343            match member.binding {
344                crate::Binding::Location { location, .. } => {
345                    // Create variable
346                    let v =
347                        self.write_mesh_return_global_variable(member.ty_id, vert_array_size_id)?;
348                    // Decorate the variable with Location
349                    Instruction::decorate(v, spirv::Decoration::Location, &[location])
350                        .to_words(&mut self.logical_layout.annotations);
351                    iface.varying_ids.push(v);
352                    mesh_return_info.vertex_info.bindings.push(v);
353                }
354                crate::Binding::BuiltIn(_) => (),
355            }
356        }
357        // Write primitive binding output blocks (1 array per output struct member)
358        // Also write indices output block
359        for member in &mesh_return_info.primitive_info.struct_members {
360            match member.binding {
361                crate::Binding::BuiltIn(
362                    crate::BuiltIn::PointIndex
363                    | crate::BuiltIn::LineIndices
364                    | crate::BuiltIn::TriangleIndices,
365                ) => {
366                    // This is written here instead of as part of the builtin block
367                    let v =
368                        self.write_mesh_return_global_variable(member.ty_id, prim_array_size_id)?;
369                    // This shouldn't be marked as PerPrimitiveEXT
370                    Instruction::decorate(
371                        v,
372                        spirv::Decoration::BuiltIn,
373                        &[match member.binding.to_built_in().unwrap() {
374                            crate::BuiltIn::PointIndex => spirv::BuiltIn::PrimitivePointIndicesEXT,
375                            crate::BuiltIn::LineIndices => spirv::BuiltIn::PrimitiveLineIndicesEXT,
376                            crate::BuiltIn::TriangleIndices => {
377                                spirv::BuiltIn::PrimitiveTriangleIndicesEXT
378                            }
379                            _ => unreachable!(),
380                        } as Word],
381                    )
382                    .to_words(&mut self.logical_layout.annotations);
383                    iface.varying_ids.push(v);
384                    if self.flags.contains(WriterFlags::DEBUG) {
385                        self.debugs
386                            .push(Instruction::name(v, "naga_primitive_indices_outputs"));
387                    }
388                    mesh_return_info.primitive_indices = Some(v);
389                }
390                crate::Binding::Location { location, .. } => {
391                    // Create variable
392                    let v =
393                        self.write_mesh_return_global_variable(member.ty_id, prim_array_size_id)?;
394                    // Decorate the variable with Location
395                    Instruction::decorate(v, spirv::Decoration::Location, &[location])
396                        .to_words(&mut self.logical_layout.annotations);
397                    // Decorate it with PerPrimitiveEXT
398                    Instruction::decorate(v, spirv::Decoration::PerPrimitiveEXT, &[])
399                        .to_words(&mut self.logical_layout.annotations);
400                    iface.varying_ids.push(v);
401
402                    mesh_return_info.primitive_info.bindings.push(v);
403                }
404                crate::Binding::BuiltIn(_) => (),
405            }
406        }
407
408        // Store this where it can be read later during function write
409        ep_context.mesh_state = Some(mesh_return_info);
410
411        Ok(())
412    }
413
414    pub(super) fn write_entry_point_task_return(
415        &mut self,
416        value_id: Word,
417        body: &mut Vec<Instruction>,
418        task_payload: Word,
419    ) -> Result<Instruction, Error> {
420        // OpEmitMeshTasksEXT must be called right before exiting (after setting other
421        // output variables if there are any)
422
423        // Extract the vec3<u32> into 3 u32's
424        let values = [self.id_gen.next(), self.id_gen.next(), self.id_gen.next()];
425        for (i, &value) in values.iter().enumerate() {
426            let instruction = Instruction::composite_extract(
427                self.get_u32_type_id(),
428                value,
429                value_id,
430                &[i as Word],
431            );
432            body.push(instruction);
433        }
434        let mut instruction = Instruction::new(spirv::Op::EmitMeshTasksEXT);
435        for id in values {
436            instruction.add_operand(id);
437        }
438        // We have to include the task payload in our call
439        instruction.add_operand(task_payload);
440        Ok(instruction)
441    }
442
443    /// This writes the actual loop
444    #[allow(clippy::too_many_arguments)]
445    fn write_mesh_copy_loop(
446        &mut self,
447        body: &mut Vec<Instruction>,
448        mut loop_body_block: Vec<Instruction>,
449        loop_header: u32,
450        loop_merge: u32,
451        count_id: u32,
452        index_var: u32,
453        return_info: &MeshReturnInfo,
454    ) {
455        let u32_id = self.get_u32_type_id();
456        let condition_check = self.id_gen.next();
457        let loop_continue = self.id_gen.next();
458        let loop_body = self.id_gen.next();
459
460        // Loop header
461        {
462            body.push(Instruction::label(loop_header));
463            body.push(Instruction::loop_merge(
464                loop_merge,
465                loop_continue,
466                spirv::SelectionControl::empty(),
467            ));
468            body.push(Instruction::branch(condition_check));
469        }
470        // Condition check - check if i is less than num vertices to copy
471        {
472            body.push(Instruction::label(condition_check));
473
474            let val_i = self.id_gen.next();
475            body.push(Instruction::load(u32_id, val_i, index_var, None));
476
477            let cond = self.id_gen.next();
478            body.push(Instruction::binary(
479                spirv::Op::ULessThan,
480                self.get_bool_type_id(),
481                cond,
482                val_i,
483                count_id,
484            ));
485            body.push(Instruction::branch_conditional(cond, loop_body, loop_merge));
486        }
487        // Loop body
488        {
489            body.push(Instruction::label(loop_body));
490            body.append(&mut loop_body_block);
491            body.push(Instruction::branch(loop_continue));
492        }
493        // Loop continue - increment i
494        {
495            body.push(Instruction::label(loop_continue));
496
497            let prev_val_i = self.id_gen.next();
498            body.push(Instruction::load(u32_id, prev_val_i, index_var, None));
499            let new_val_i = self.id_gen.next();
500            body.push(Instruction::binary(
501                spirv::Op::IAdd,
502                u32_id,
503                new_val_i,
504                prev_val_i,
505                return_info.workgroup_size,
506            ));
507            body.push(Instruction::store(index_var, new_val_i, None));
508
509            body.push(Instruction::branch(loop_header));
510        }
511    }
512
513    /// This generates the instructions used to copy all parts of a single output vertex/primitive
514    /// to their individual output locations
515    fn write_mesh_copy_body(
516        &mut self,
517        is_primitive: bool,
518        return_info: &MeshReturnInfo,
519        index_var: u32,
520        vert_array_ptr: u32,
521        prim_array_ptr: u32,
522    ) -> Vec<Instruction> {
523        let u32_type_id = self.get_u32_type_id();
524        let mut body = Vec::new();
525        // Current index to copy
526        let val_i = self.id_gen.next();
527        body.push(Instruction::load(u32_type_id, val_i, index_var, None));
528
529        let info = if is_primitive {
530            &return_info.primitive_info
531        } else {
532            &return_info.vertex_info
533        };
534        let array_ptr = if is_primitive {
535            prim_array_ptr
536        } else {
537            vert_array_ptr
538        };
539
540        let mut builtin_index = 0;
541        let mut binding_index = 0;
542        // Write individual members of the vertex
543        for (member_id, member) in info.struct_members.iter().enumerate() {
544            let val_to_copy_ptr = self.id_gen.next();
545            body.push(Instruction::access_chain(
546                self.get_pointer_type_id(member.ty_id, spirv::StorageClass::Workgroup),
547                val_to_copy_ptr,
548                array_ptr,
549                &[
550                    val_i,
551                    self.get_constant_scalar(crate::Literal::U32(member_id as u32)),
552                ],
553            ));
554            let val_to_copy = self.id_gen.next();
555            body.push(Instruction::load(
556                member.ty_id,
557                val_to_copy,
558                val_to_copy_ptr,
559                None,
560            ));
561            let mut needs_y_flip = false;
562            let ptr_to_copy_to = self.id_gen.next();
563            // Get a pointer to the struct member to copy
564            match member.binding {
565                crate::Binding::BuiltIn(
566                    crate::BuiltIn::PointIndex
567                    | crate::BuiltIn::LineIndices
568                    | crate::BuiltIn::TriangleIndices,
569                ) => {
570                    body.push(Instruction::access_chain(
571                        self.get_pointer_type_id(member.ty_id, spirv::StorageClass::Output),
572                        ptr_to_copy_to,
573                        return_info.primitive_indices.unwrap(),
574                        &[val_i],
575                    ));
576                }
577                crate::Binding::BuiltIn(bi) => {
578                    body.push(Instruction::access_chain(
579                        self.get_pointer_type_id(member.ty_id, spirv::StorageClass::Output),
580                        ptr_to_copy_to,
581                        info.builtin_block.unwrap(),
582                        &[
583                            val_i,
584                            self.get_constant_scalar(crate::Literal::U32(builtin_index)),
585                        ],
586                    ));
587                    needs_y_flip = matches!(bi, crate::BuiltIn::Position { .. })
588                        && self.flags.contains(WriterFlags::ADJUST_COORDINATE_SPACE);
589                    builtin_index += 1;
590                }
591                crate::Binding::Location { .. } => {
592                    body.push(Instruction::access_chain(
593                        self.get_pointer_type_id(member.ty_id, spirv::StorageClass::Output),
594                        ptr_to_copy_to,
595                        info.bindings[binding_index],
596                        &[val_i],
597                    ));
598                    binding_index += 1;
599                }
600            }
601            body.push(Instruction::store(ptr_to_copy_to, val_to_copy, None));
602            // Flip the vertex position y coordinate in some cases
603            // Can't use epilogue flip because can't read from this storage class
604            if needs_y_flip {
605                let prev_y = self.id_gen.next();
606                body.push(Instruction::composite_extract(
607                    self.get_f32_type_id(),
608                    prev_y,
609                    val_to_copy,
610                    &[1],
611                ));
612                let new_y = self.id_gen.next();
613                body.push(Instruction::unary(
614                    spirv::Op::FNegate,
615                    self.get_f32_type_id(),
616                    new_y,
617                    prev_y,
618                ));
619                let new_ptr_to_copy_to = self.id_gen.next();
620                body.push(Instruction::access_chain(
621                    self.get_f32_pointer_type_id(spirv::StorageClass::Output),
622                    new_ptr_to_copy_to,
623                    ptr_to_copy_to,
624                    &[self.get_constant_scalar(crate::Literal::U32(1))],
625                ));
626                body.push(Instruction::store(new_ptr_to_copy_to, new_y, None));
627            }
628        }
629        body
630    }
631
632    /// Writes the return call for a mesh shader, which involves copying previously
633    /// written vertices/primitives into the actual output location.
634    pub(super) fn write_mesh_shader_return(
635        &mut self,
636        return_info: &MeshReturnInfo,
637        block: &mut Block,
638        loop_counter_vertices: u32,
639        loop_counter_primitives: u32,
640        local_invocation_index_id: Word,
641    ) -> Result<(), Error> {
642        let u32_id = self.get_u32_type_id();
643
644        // Load the actual vertex and primitive counts
645        let mut load_u32_by_member_index =
646            |members: &[MeshReturnMember], bi: crate::BuiltIn, max: u32| {
647                let member_index = members
648                    .iter()
649                    .position(|a| a.binding == crate::Binding::BuiltIn(bi))
650                    .unwrap() as u32;
651                let ptr_id = self.id_gen.next();
652                block.body.push(Instruction::access_chain(
653                    self.get_pointer_type_id(u32_id, spirv::StorageClass::Workgroup),
654                    ptr_id,
655                    return_info.out_variable_id,
656                    &[self.get_constant_scalar(crate::Literal::U32(member_index))],
657                ));
658                let before_min_id = self.id_gen.next();
659                block
660                    .body
661                    .push(Instruction::load(u32_id, before_min_id, ptr_id, None));
662
663                // Clamp the values
664                let id = self.id_gen.next();
665                block.body.push(Instruction::ext_inst_gl_op(
666                    self.gl450_ext_inst_id,
667                    spirv::GlslStd450Op::UMin,
668                    u32_id,
669                    id,
670                    &[before_min_id, max],
671                ));
672                id
673            };
674        let vert_count_id = load_u32_by_member_index(
675            &return_info.out_members,
676            crate::BuiltIn::VertexCount,
677            return_info.vertex_info.max_length_constant,
678        );
679        let prim_count_id = load_u32_by_member_index(
680            &return_info.out_members,
681            crate::BuiltIn::PrimitiveCount,
682            return_info.primitive_info.max_length_constant,
683        );
684
685        // Get pointers to the arrays of data to extract
686        let mut get_array_ptr = |bi: crate::BuiltIn, array_type_id: u32| {
687            let id = self.id_gen.next();
688            block.body.push(Instruction::access_chain(
689                self.get_pointer_type_id(array_type_id, spirv::StorageClass::Workgroup),
690                id,
691                return_info.out_variable_id,
692                &[self.get_constant_scalar(crate::Literal::U32(
693                    return_info
694                        .out_members
695                        .iter()
696                        .position(|a| a.binding == crate::Binding::BuiltIn(bi))
697                        .unwrap() as u32,
698                ))],
699            ));
700            id
701        };
702        let vert_array_ptr = get_array_ptr(
703            crate::BuiltIn::Vertices,
704            return_info.vertex_info.array_type_id,
705        );
706        let prim_array_ptr = get_array_ptr(
707            crate::BuiltIn::Primitives,
708            return_info.primitive_info.array_type_id,
709        );
710
711        // This must be called exactly once before any other mesh outputs are written
712        {
713            let mut ins = Instruction::new(spirv::Op::SetMeshOutputsEXT);
714            ins.add_operand(vert_count_id);
715            ins.add_operand(prim_count_id);
716            block.body.push(ins);
717        }
718
719        // This is iterating over every returned vertex and splitting
720        // it out into the multiple per-output arrays.
721        let vertex_loop_header = self.id_gen.next();
722        let prim_loop_header = self.id_gen.next();
723        let in_between_loops = self.id_gen.next();
724        let func_end = self.id_gen.next();
725
726        block.body.push(Instruction::store(
727            loop_counter_vertices,
728            local_invocation_index_id,
729            None,
730        ));
731        block.body.push(Instruction::branch(vertex_loop_header));
732
733        let vertex_copy_body = self.write_mesh_copy_body(
734            false,
735            return_info,
736            loop_counter_vertices,
737            vert_array_ptr,
738            prim_array_ptr,
739        );
740        // Write vertex copy loop
741        self.write_mesh_copy_loop(
742            &mut block.body,
743            vertex_copy_body,
744            vertex_loop_header,
745            in_between_loops,
746            vert_count_id,
747            loop_counter_vertices,
748            return_info,
749        );
750
751        // In between loops, reset the initial index
752        {
753            block.body.push(Instruction::label(in_between_loops));
754
755            block.body.push(Instruction::store(
756                loop_counter_primitives,
757                local_invocation_index_id,
758                None,
759            ));
760
761            block.body.push(Instruction::branch(prim_loop_header));
762        }
763        let primitive_copy_body = self.write_mesh_copy_body(
764            true,
765            return_info,
766            loop_counter_primitives,
767            vert_array_ptr,
768            prim_array_ptr,
769        );
770        // Write primitive copy loop
771        self.write_mesh_copy_loop(
772            &mut block.body,
773            primitive_copy_body,
774            prim_loop_header,
775            func_end,
776            prim_count_id,
777            loop_counter_primitives,
778            return_info,
779        );
780
781        block.body.push(Instruction::label(func_end));
782        Ok(())
783    }
784
785    pub(super) fn write_mesh_shader_wrapper(
786        &mut self,
787        return_info: &MeshReturnInfo,
788        inner_id: u32,
789    ) -> Result<u32, Error> {
790        let out_id = self.id_gen.next();
791        let mut function = super::Function::default();
792        let lookup_function_type = super::LookupFunctionType {
793            parameter_type_ids: alloc::vec![],
794            return_type_id: self.void_type,
795        };
796        let function_type = self.get_function_type(lookup_function_type);
797        function.signature = Some(Instruction::function(
798            self.void_type,
799            out_id,
800            spirv::FunctionControl::empty(),
801            function_type,
802        ));
803        let u32_id = self.get_u32_type_id();
804        {
805            let mut block = Block::new(self.id_gen.next());
806            // A general function variable that we guarantee to allow in the final return. It must be
807            // declared at the top of the function. Currently it is used in the memcpy part to keep
808            // track of the current index to copy.
809            let loop_counter_vertices = self.id_gen.next();
810            let loop_counter_primitives = self.id_gen.next();
811            block.body.insert(
812                0,
813                Instruction::variable(
814                    self.get_pointer_type_id(u32_id, spirv::StorageClass::Function),
815                    loop_counter_vertices,
816                    spirv::StorageClass::Function,
817                    None,
818                ),
819            );
820            block.body.insert(
821                1,
822                Instruction::variable(
823                    self.get_pointer_type_id(u32_id, spirv::StorageClass::Function),
824                    loop_counter_primitives,
825                    spirv::StorageClass::Function,
826                    None,
827                ),
828            );
829            let local_invocation_index_id = self.id_gen.next();
830            block.body.push(Instruction::load(
831                u32_id,
832                local_invocation_index_id,
833                return_info.local_invocation_index_var_id,
834                None,
835            ));
836            block.body.push(Instruction::function_call(
837                self.void_type,
838                self.id_gen.next(),
839                inner_id,
840                &[],
841            ));
842            self.write_control_barrier(crate::Barrier::WORK_GROUP, &mut block.body);
843            self.write_mesh_shader_return(
844                return_info,
845                &mut block,
846                loop_counter_vertices,
847                loop_counter_primitives,
848                local_invocation_index_id,
849            )?;
850            function.consume(block, Instruction::return_void());
851        }
852        function.to_words(&mut self.logical_layout.function_definitions);
853        Ok(out_id)
854    }
855
856    pub(super) fn write_task_shader_wrapper(
857        &mut self,
858        task_payload: Word,
859        inner_id: u32,
860    ) -> Result<u32, Error> {
861        let out_id = self.id_gen.next();
862        let mut function = super::Function::default();
863        let lookup_function_type = super::LookupFunctionType {
864            parameter_type_ids: alloc::vec![],
865            return_type_id: self.void_type,
866        };
867        let function_type = self.get_function_type(lookup_function_type);
868        function.signature = Some(Instruction::function(
869            self.void_type,
870            out_id,
871            spirv::FunctionControl::empty(),
872            function_type,
873        ));
874
875        {
876            let mut block = Block::new(self.id_gen.next());
877            let result = self.id_gen.next();
878            block.body.push(Instruction::function_call(
879                self.get_vec3u_type_id(),
880                result,
881                inner_id,
882                &[],
883            ));
884            self.write_control_barrier(crate::Barrier::WORK_GROUP, &mut block.body);
885            let final_value = if let Some(task_limits) = self.task_dispatch_limits {
886                let zero_u32 = self.get_constant_scalar(crate::Literal::U32(0));
887                let max_per_dim = self.get_constant_scalar(crate::Literal::U32(
888                    task_limits.max_mesh_workgroups_per_dim,
889                ));
890                let max_total = self.get_constant_scalar(crate::Literal::U32(
891                    task_limits.max_mesh_workgroups_total,
892                ));
893                let combined_struct_type = self.get_tuple_of_u32s_ty_id();
894                let values = [self.id_gen.next(), self.id_gen.next(), self.id_gen.next()];
895                for (i, value) in values.into_iter().enumerate() {
896                    block.body.push(Instruction::composite_extract(
897                        self.get_u32_type_id(),
898                        value,
899                        result,
900                        &[i as u32],
901                    ));
902                }
903                let prod_1 = self.id_gen.next();
904                let overflows = [self.id_gen.next(), self.id_gen.next()];
905                {
906                    let struct_out = self.id_gen.next();
907                    block.body.push(Instruction::binary(
908                        spirv::Op::UMulExtended,
909                        combined_struct_type,
910                        struct_out,
911                        values[0],
912                        values[1],
913                    ));
914                    block.body.push(Instruction::composite_extract(
915                        self.get_u32_type_id(),
916                        prod_1,
917                        struct_out,
918                        &[0],
919                    ));
920                    block.body.push(Instruction::composite_extract(
921                        self.get_u32_type_id(),
922                        overflows[0],
923                        struct_out,
924                        &[1],
925                    ));
926                }
927                let prod_final = self.id_gen.next();
928                {
929                    let struct_out = self.id_gen.next();
930                    block.body.push(Instruction::binary(
931                        spirv::Op::UMulExtended,
932                        combined_struct_type,
933                        struct_out,
934                        prod_1,
935                        values[2],
936                    ));
937                    block.body.push(Instruction::composite_extract(
938                        self.get_u32_type_id(),
939                        prod_final,
940                        struct_out,
941                        &[0],
942                    ));
943                    block.body.push(Instruction::composite_extract(
944                        self.get_u32_type_id(),
945                        overflows[1],
946                        struct_out,
947                        &[1],
948                    ));
949                }
950                let total_too_large = self.id_gen.next();
951                block.body.push(Instruction::binary(
952                    spirv::Op::UGreaterThan,
953                    self.get_bool_type_id(),
954                    total_too_large,
955                    prod_final,
956                    max_total,
957                ));
958
959                let too_large = [self.id_gen.next(), self.id_gen.next(), self.id_gen.next()];
960                for (i, value) in values.into_iter().enumerate() {
961                    block.body.push(Instruction::binary(
962                        spirv::Op::UGreaterThan,
963                        self.get_bool_type_id(),
964                        too_large[i],
965                        value,
966                        max_per_dim,
967                    ));
968                }
969                let overflow_happens = [self.id_gen.next(), self.id_gen.next()];
970                for (i, value) in overflows.into_iter().enumerate() {
971                    block.body.push(Instruction::binary(
972                        spirv::Op::INotEqual,
973                        self.get_bool_type_id(),
974                        overflow_happens[i],
975                        value,
976                        zero_u32,
977                    ));
978                }
979                let mut current_violates_limits = total_too_large;
980                for is_too_large in too_large {
981                    let new = self.id_gen.next();
982                    block.body.push(Instruction::binary(
983                        spirv::Op::LogicalOr,
984                        self.get_bool_type_id(),
985                        new,
986                        current_violates_limits,
987                        is_too_large,
988                    ));
989                    current_violates_limits = new;
990                }
991                for overflow_happens in overflow_happens {
992                    let new = self.id_gen.next();
993                    block.body.push(Instruction::binary(
994                        spirv::Op::LogicalOr,
995                        self.get_bool_type_id(),
996                        new,
997                        current_violates_limits,
998                        overflow_happens,
999                    ));
1000                    current_violates_limits = new;
1001                }
1002                let zero_vec3 = self.id_gen.next();
1003                block.body.push(Instruction::composite_construct(
1004                    self.get_vec3u_type_id(),
1005                    zero_vec3,
1006                    &[zero_u32, zero_u32, zero_u32],
1007                ));
1008                let final_result = self.id_gen.next();
1009                block.body.push(Instruction::select(
1010                    self.get_vec3u_type_id(),
1011                    final_result,
1012                    current_violates_limits,
1013                    zero_vec3,
1014                    result,
1015                ));
1016                final_result
1017            } else {
1018                result
1019            };
1020            let ins =
1021                self.write_entry_point_task_return(final_value, &mut block.body, task_payload)?;
1022            function.consume(block, ins);
1023        }
1024        function.to_words(&mut self.logical_layout.function_definitions);
1025        Ok(out_id)
1026    }
1027}