naga/back/spv/
mesh_shader.rs

1use alloc::vec::Vec;
2use spirv::Word;
3
4use crate::{
5    back::spv::{
6        helpers::BindingDecorations, writer::FunctionInterface, Block, EntryPointContext, Error,
7        Instruction, WriterFlags,
8    },
9    non_max_u32::NonMaxU32,
10    Handle,
11};
12
13#[derive(Clone)]
14pub struct MeshReturnMember {
15    pub ty_id: u32,
16    pub binding: crate::Binding,
17}
18
19struct PerOutputTypeMeshReturnInfo {
20    max_length_constant: Word,
21    array_type_id: Word,
22    struct_members: Vec<MeshReturnMember>,
23
24    // * Most builtins must be in the same block.
25    // * All bindings must be in their own unique block.
26    // * The primitive indices builtin family needs its own block.
27    // * Cull primitive doesn't care about having its own block, but
28    //   some older validation layers didn't respect this.
29    builtin_block: Option<Word>,
30    bindings: Vec<Word>,
31}
32
33pub struct MeshReturnInfo {
34    /// Id of the workgroup variable containing the data to be output
35    out_variable_id: Word,
36    /// All members of the output variable struct type
37    out_members: Vec<MeshReturnMember>,
38    /// Id of the input variable for local invocation id
39    local_invocation_index_var_id: Word,
40    /// Total workgroup size (product)
41    workgroup_size: u32,
42
43    /// Vertex-specific info
44    vertex_info: PerOutputTypeMeshReturnInfo,
45    /// Primitive-specific info
46    primitive_info: PerOutputTypeMeshReturnInfo,
47    /// Array variable for the primitive indices builtin
48    primitive_indices: Option<Word>,
49}
50
51impl super::Writer {
52    /// Sets up an output variable that will handle part of the mesh shader output
53    pub(super) fn write_mesh_return_global_variable(
54        &mut self,
55        ty: u32,
56        array_size_id: u32,
57    ) -> Result<Word, Error> {
58        let array_ty = self.id_gen.next();
59        Instruction::type_array(array_ty, ty, array_size_id)
60            .to_words(&mut self.logical_layout.declarations);
61        let ptr_ty = self.get_pointer_type_id(array_ty, spirv::StorageClass::Output);
62        let var_id = self.id_gen.next();
63        Instruction::variable(ptr_ty, var_id, spirv::StorageClass::Output, None)
64            .to_words(&mut self.logical_layout.declarations);
65        Ok(var_id)
66    }
67
68    /// This does various setup things to allow mesh shader entry points
69    /// to be properly written, such as creating the output variables
70    pub(super) fn write_entry_point_mesh_shader_info(
71        &mut self,
72        iface: &mut FunctionInterface,
73        local_invocation_index_id: Option<Word>,
74        ir_module: &crate::Module,
75        ep_context: &mut EntryPointContext,
76    ) -> Result<(), Error> {
77        let Some(ref mesh_info) = iface.mesh_info else {
78            return Ok(());
79        };
80        // Collect the members in the output structs
81        let out_members: Vec<MeshReturnMember> =
82            match &ir_module.types[ir_module.global_variables[mesh_info.output_variable].ty] {
83                &crate::Type {
84                    inner: crate::TypeInner::Struct { ref members, .. },
85                    ..
86                } => members
87                    .iter()
88                    .map(|a| MeshReturnMember {
89                        ty_id: self.get_handle_type_id(a.ty),
90                        binding: a.binding.clone().unwrap(),
91                    })
92                    .collect(),
93                _ => unreachable!(),
94            };
95        let vertex_array_type_id = out_members
96            .iter()
97            .find(|a| a.binding == crate::Binding::BuiltIn(crate::BuiltIn::Vertices))
98            .unwrap()
99            .ty_id;
100        let primitive_array_type_id = out_members
101            .iter()
102            .find(|a| a.binding == crate::Binding::BuiltIn(crate::BuiltIn::Primitives))
103            .unwrap()
104            .ty_id;
105        let vertex_members = match &ir_module.types[mesh_info.vertex_output_type] {
106            &crate::Type {
107                inner: crate::TypeInner::Struct { ref members, .. },
108                ..
109            } => members
110                .iter()
111                .map(|a| MeshReturnMember {
112                    ty_id: self.get_handle_type_id(a.ty),
113                    binding: a.binding.clone().unwrap(),
114                })
115                .collect(),
116            _ => unreachable!(),
117        };
118        let primitive_members = match &ir_module.types[mesh_info.primitive_output_type] {
119            &crate::Type {
120                inner: crate::TypeInner::Struct { ref members, .. },
121                ..
122            } => members
123                .iter()
124                .map(|a| MeshReturnMember {
125                    ty_id: self.get_handle_type_id(a.ty),
126                    binding: a.binding.clone().unwrap(),
127                })
128                .collect(),
129            _ => unreachable!(),
130        };
131        // In the final return, we do a giant memcpy, for which this is helpful
132        let local_invocation_index_var_id = match local_invocation_index_id {
133            Some(a) => a,
134            None => {
135                let u32_id = self.get_u32_type_id();
136                let var = self.id_gen.next();
137                Instruction::variable(
138                    self.get_pointer_type_id(u32_id, spirv::StorageClass::Input),
139                    var,
140                    spirv::StorageClass::Input,
141                    None,
142                )
143                .to_words(&mut self.logical_layout.declarations);
144                Instruction::decorate(
145                    var,
146                    spirv::Decoration::BuiltIn,
147                    &[spirv::BuiltIn::LocalInvocationIndex as u32],
148                )
149                .to_words(&mut self.logical_layout.annotations);
150                iface.varying_ids.push(var);
151
152                var
153            }
154        };
155        // This is the information that is passed to the function writer
156        // so that it can write the final return logic
157        let mut mesh_return_info = MeshReturnInfo {
158            out_variable_id: self.global_variables[mesh_info.output_variable].var_id,
159            out_members,
160            local_invocation_index_var_id,
161            workgroup_size: self
162                .get_constant_scalar(crate::Literal::U32(iface.workgroup_size.iter().product())),
163
164            vertex_info: PerOutputTypeMeshReturnInfo {
165                array_type_id: vertex_array_type_id,
166                struct_members: vertex_members,
167                max_length_constant: self
168                    .get_constant_scalar(crate::Literal::U32(mesh_info.max_vertices)),
169                bindings: Vec::new(),
170                builtin_block: None,
171            },
172            primitive_info: PerOutputTypeMeshReturnInfo {
173                array_type_id: primitive_array_type_id,
174                struct_members: primitive_members,
175                max_length_constant: self
176                    .get_constant_scalar(crate::Literal::U32(mesh_info.max_primitives)),
177                bindings: Vec::new(),
178                builtin_block: None,
179            },
180            primitive_indices: None,
181        };
182        let vert_array_size_id =
183            self.get_constant_scalar(crate::Literal::U32(mesh_info.max_vertices));
184        let prim_array_size_id =
185            self.get_constant_scalar(crate::Literal::U32(mesh_info.max_primitives));
186
187        // Create the actual output variables and types.
188        // According to SPIR-V,
189        // * All builtins must be in the same output `Block` (except builtins for different output types like vertex/primitive)
190        // * Each member with `location` must be in its own `Block` decorated `struct`
191        // * Some builtins like CullPrimitiveEXT don't care as much (older validation layers don't know this! Wonderful!)
192        // * Some builtins like the indices ones need to be in their own output variable without a struct wrapper
193
194        // Write vertex builtin block
195        if mesh_return_info
196            .vertex_info
197            .struct_members
198            .iter()
199            .any(|a| matches!(a.binding, crate::Binding::BuiltIn(..)))
200        {
201            let builtin_block_ty_id = self.id_gen.next();
202            let mut ins = Instruction::type_struct(builtin_block_ty_id, &[]);
203            let mut bi_index = 0;
204            let mut decorations = Vec::new();
205            for member in &mesh_return_info.vertex_info.struct_members {
206                if let crate::Binding::BuiltIn(_) = member.binding {
207                    ins.add_operand(member.ty_id);
208                    let binding = self.map_binding(
209                        ir_module,
210                        iface.stage,
211                        spirv::StorageClass::Output,
212                        // Unused except in fragment shaders with other conditions, so we can pass null
213                        Handle::new(NonMaxU32::new(0).unwrap()),
214                        &member.binding,
215                    )?;
216                    match binding {
217                        BindingDecorations::BuiltIn(bi, others) => {
218                            decorations.push(Instruction::member_decorate(
219                                builtin_block_ty_id,
220                                bi_index,
221                                spirv::Decoration::BuiltIn,
222                                &[bi as Word],
223                            ));
224                            for other in others {
225                                decorations.push(Instruction::member_decorate(
226                                    builtin_block_ty_id,
227                                    bi_index,
228                                    other,
229                                    &[],
230                                ));
231                            }
232                        }
233                        _ => unreachable!(),
234                    }
235                    bi_index += 1;
236                }
237            }
238            ins.to_words(&mut self.logical_layout.declarations);
239            decorations.push(Instruction::decorate(
240                builtin_block_ty_id,
241                spirv::Decoration::Block,
242                &[],
243            ));
244            for dec in decorations {
245                dec.to_words(&mut self.logical_layout.annotations);
246            }
247            let v =
248                self.write_mesh_return_global_variable(builtin_block_ty_id, vert_array_size_id)?;
249            iface.varying_ids.push(v);
250            if self.flags.contains(WriterFlags::DEBUG) {
251                self.debugs
252                    .push(Instruction::name(v, "naga_vertex_builtin_outputs"));
253            }
254            mesh_return_info.vertex_info.builtin_block = Some(v);
255        }
256        // Write primitive builtin block
257        if mesh_return_info
258            .primitive_info
259            .struct_members
260            .iter()
261            .any(|a| {
262                !matches!(
263                    a.binding,
264                    crate::Binding::BuiltIn(
265                        crate::BuiltIn::PointIndex
266                            | crate::BuiltIn::LineIndices
267                            | crate::BuiltIn::TriangleIndices
268                    ) | crate::Binding::Location { .. }
269                )
270            })
271        {
272            let builtin_block_ty_id = self.id_gen.next();
273            let mut ins = Instruction::type_struct(builtin_block_ty_id, &[]);
274            let mut bi_index = 0;
275            let mut decorations = Vec::new();
276            for member in &mesh_return_info.primitive_info.struct_members {
277                if let crate::Binding::BuiltIn(bi) = member.binding {
278                    // These need to be in their own block, unlike other builtins
279                    if matches!(
280                        bi,
281                        crate::BuiltIn::PointIndex
282                            | crate::BuiltIn::LineIndices
283                            | crate::BuiltIn::TriangleIndices,
284                    ) {
285                        continue;
286                    }
287                    ins.add_operand(member.ty_id);
288                    let binding = self.map_binding(
289                        ir_module,
290                        iface.stage,
291                        spirv::StorageClass::Output,
292                        // Unused except in fragment shaders with other conditions, so we can pass null
293                        Handle::new(NonMaxU32::new(0).unwrap()),
294                        &member.binding,
295                    )?;
296                    match binding {
297                        BindingDecorations::BuiltIn(bi, others) => {
298                            decorations.push(Instruction::member_decorate(
299                                builtin_block_ty_id,
300                                bi_index,
301                                spirv::Decoration::BuiltIn,
302                                &[bi as Word],
303                            ));
304                            for other in others {
305                                decorations.push(Instruction::member_decorate(
306                                    builtin_block_ty_id,
307                                    bi_index,
308                                    other,
309                                    &[],
310                                ));
311                            }
312                        }
313                        _ => unreachable!(),
314                    }
315                    bi_index += 1;
316                }
317            }
318            ins.to_words(&mut self.logical_layout.declarations);
319            decorations.push(Instruction::decorate(
320                builtin_block_ty_id,
321                spirv::Decoration::Block,
322                &[],
323            ));
324            for dec in decorations {
325                dec.to_words(&mut self.logical_layout.annotations);
326            }
327            let v =
328                self.write_mesh_return_global_variable(builtin_block_ty_id, prim_array_size_id)?;
329            Instruction::decorate(v, spirv::Decoration::PerPrimitiveEXT, &[])
330                .to_words(&mut self.logical_layout.annotations);
331            iface.varying_ids.push(v);
332            if self.flags.contains(WriterFlags::DEBUG) {
333                self.debugs
334                    .push(Instruction::name(v, "naga_primitive_builtin_outputs"));
335            }
336            mesh_return_info.primitive_info.builtin_block = Some(v);
337        }
338
339        // Write vertex binding output blocks (1 array per output struct member)
340        for member in &mesh_return_info.vertex_info.struct_members {
341            match member.binding {
342                crate::Binding::Location { location, .. } => {
343                    // Create variable
344                    let v =
345                        self.write_mesh_return_global_variable(member.ty_id, vert_array_size_id)?;
346                    // Decorate the variable with Location
347                    Instruction::decorate(v, spirv::Decoration::Location, &[location])
348                        .to_words(&mut self.logical_layout.annotations);
349                    iface.varying_ids.push(v);
350                    mesh_return_info.vertex_info.bindings.push(v);
351                }
352                crate::Binding::BuiltIn(_) => (),
353            }
354        }
355        // Write primitive binding output blocks (1 array per output struct member)
356        // Also write indices output block
357        for member in &mesh_return_info.primitive_info.struct_members {
358            match member.binding {
359                crate::Binding::BuiltIn(
360                    crate::BuiltIn::PointIndex
361                    | crate::BuiltIn::LineIndices
362                    | crate::BuiltIn::TriangleIndices,
363                ) => {
364                    // This is written here instead of as part of the builtin block
365                    let v =
366                        self.write_mesh_return_global_variable(member.ty_id, prim_array_size_id)?;
367                    // This shouldn't be marked as PerPrimitiveEXT
368                    Instruction::decorate(
369                        v,
370                        spirv::Decoration::BuiltIn,
371                        &[match member.binding.to_built_in().unwrap() {
372                            crate::BuiltIn::PointIndex => spirv::BuiltIn::PrimitivePointIndicesEXT,
373                            crate::BuiltIn::LineIndices => spirv::BuiltIn::PrimitiveLineIndicesEXT,
374                            crate::BuiltIn::TriangleIndices => {
375                                spirv::BuiltIn::PrimitiveTriangleIndicesEXT
376                            }
377                            _ => unreachable!(),
378                        } as Word],
379                    )
380                    .to_words(&mut self.logical_layout.annotations);
381                    iface.varying_ids.push(v);
382                    if self.flags.contains(WriterFlags::DEBUG) {
383                        self.debugs
384                            .push(Instruction::name(v, "naga_primitive_indices_outputs"));
385                    }
386                    mesh_return_info.primitive_indices = Some(v);
387                }
388                crate::Binding::Location { location, .. } => {
389                    // Create variable
390                    let v =
391                        self.write_mesh_return_global_variable(member.ty_id, prim_array_size_id)?;
392                    // Decorate the variable with Location
393                    Instruction::decorate(v, spirv::Decoration::Location, &[location])
394                        .to_words(&mut self.logical_layout.annotations);
395                    // Decorate it with PerPrimitiveEXT
396                    Instruction::decorate(v, spirv::Decoration::PerPrimitiveEXT, &[])
397                        .to_words(&mut self.logical_layout.annotations);
398                    iface.varying_ids.push(v);
399
400                    mesh_return_info.primitive_info.bindings.push(v);
401                }
402                crate::Binding::BuiltIn(_) => (),
403            }
404        }
405
406        // Store this where it can be read later during function write
407        ep_context.mesh_state = Some(mesh_return_info);
408
409        Ok(())
410    }
411
412    pub(super) fn write_entry_point_task_return(
413        &mut self,
414        value_id: Word,
415        body: &mut Vec<Instruction>,
416        task_payload: Word,
417    ) -> Result<Instruction, Error> {
418        // OpEmitMeshTasksEXT must be called right before exiting (after setting other
419        // output variables if there are any)
420
421        // Extract the vec3<u32> into 3 u32's
422        let values = [self.id_gen.next(), self.id_gen.next(), self.id_gen.next()];
423        for (i, &value) in values.iter().enumerate() {
424            let instruction = Instruction::composite_extract(
425                self.get_u32_type_id(),
426                value,
427                value_id,
428                &[i as Word],
429            );
430            body.push(instruction);
431        }
432        let mut instruction = Instruction::new(spirv::Op::EmitMeshTasksEXT);
433        for id in values {
434            instruction.add_operand(id);
435        }
436        // We have to include the task payload in our call
437        instruction.add_operand(task_payload);
438        Ok(instruction)
439    }
440
441    /// This writes the actual loop
442    #[allow(clippy::too_many_arguments)]
443    fn write_mesh_copy_loop(
444        &mut self,
445        body: &mut Vec<Instruction>,
446        mut loop_body_block: Vec<Instruction>,
447        loop_header: u32,
448        loop_merge: u32,
449        count_id: u32,
450        index_var: u32,
451        return_info: &MeshReturnInfo,
452    ) {
453        let u32_id = self.get_u32_type_id();
454        let condition_check = self.id_gen.next();
455        let loop_continue = self.id_gen.next();
456        let loop_body = self.id_gen.next();
457
458        // Loop header
459        {
460            body.push(Instruction::label(loop_header));
461            body.push(Instruction::loop_merge(
462                loop_merge,
463                loop_continue,
464                spirv::SelectionControl::empty(),
465            ));
466            body.push(Instruction::branch(condition_check));
467        }
468        // Condition check - check if i is less than num vertices to copy
469        {
470            body.push(Instruction::label(condition_check));
471
472            let val_i = self.id_gen.next();
473            body.push(Instruction::load(u32_id, val_i, index_var, None));
474
475            let cond = self.id_gen.next();
476            body.push(Instruction::binary(
477                spirv::Op::ULessThan,
478                self.get_bool_type_id(),
479                cond,
480                val_i,
481                count_id,
482            ));
483            body.push(Instruction::branch_conditional(cond, loop_body, loop_merge));
484        }
485        // Loop body
486        {
487            body.push(Instruction::label(loop_body));
488            body.append(&mut loop_body_block);
489            body.push(Instruction::branch(loop_continue));
490        }
491        // Loop continue - increment i
492        {
493            body.push(Instruction::label(loop_continue));
494
495            let prev_val_i = self.id_gen.next();
496            body.push(Instruction::load(u32_id, prev_val_i, index_var, None));
497            let new_val_i = self.id_gen.next();
498            body.push(Instruction::binary(
499                spirv::Op::IAdd,
500                u32_id,
501                new_val_i,
502                prev_val_i,
503                return_info.workgroup_size,
504            ));
505            body.push(Instruction::store(index_var, new_val_i, None));
506
507            body.push(Instruction::branch(loop_header));
508        }
509    }
510
511    /// This generates the instructions used to copy all parts of a single output vertex/primitive
512    /// to their individual output locations
513    fn write_mesh_copy_body(
514        &mut self,
515        is_primitive: bool,
516        return_info: &MeshReturnInfo,
517        index_var: u32,
518        vert_array_ptr: u32,
519        prim_array_ptr: u32,
520    ) -> Vec<Instruction> {
521        let u32_type_id = self.get_u32_type_id();
522        let mut body = Vec::new();
523        // Current index to copy
524        let val_i = self.id_gen.next();
525        body.push(Instruction::load(u32_type_id, val_i, index_var, None));
526
527        let info = if is_primitive {
528            &return_info.primitive_info
529        } else {
530            &return_info.vertex_info
531        };
532        let array_ptr = if is_primitive {
533            prim_array_ptr
534        } else {
535            vert_array_ptr
536        };
537
538        let mut builtin_index = 0;
539        let mut binding_index = 0;
540        // Write individual members of the vertex
541        for (member_id, member) in info.struct_members.iter().enumerate() {
542            let val_to_copy_ptr = self.id_gen.next();
543            body.push(Instruction::access_chain(
544                self.get_pointer_type_id(member.ty_id, spirv::StorageClass::Workgroup),
545                val_to_copy_ptr,
546                array_ptr,
547                &[
548                    val_i,
549                    self.get_constant_scalar(crate::Literal::U32(member_id as u32)),
550                ],
551            ));
552            let val_to_copy = self.id_gen.next();
553            body.push(Instruction::load(
554                member.ty_id,
555                val_to_copy,
556                val_to_copy_ptr,
557                None,
558            ));
559            let mut needs_y_flip = false;
560            let ptr_to_copy_to = self.id_gen.next();
561            // Get a pointer to the struct member to copy
562            match member.binding {
563                crate::Binding::BuiltIn(
564                    crate::BuiltIn::PointIndex
565                    | crate::BuiltIn::LineIndices
566                    | crate::BuiltIn::TriangleIndices,
567                ) => {
568                    body.push(Instruction::access_chain(
569                        self.get_pointer_type_id(member.ty_id, spirv::StorageClass::Output),
570                        ptr_to_copy_to,
571                        return_info.primitive_indices.unwrap(),
572                        &[val_i],
573                    ));
574                }
575                crate::Binding::BuiltIn(bi) => {
576                    body.push(Instruction::access_chain(
577                        self.get_pointer_type_id(member.ty_id, spirv::StorageClass::Output),
578                        ptr_to_copy_to,
579                        info.builtin_block.unwrap(),
580                        &[
581                            val_i,
582                            self.get_constant_scalar(crate::Literal::U32(builtin_index)),
583                        ],
584                    ));
585                    needs_y_flip = matches!(bi, crate::BuiltIn::Position { .. })
586                        && self.flags.contains(WriterFlags::ADJUST_COORDINATE_SPACE);
587                    builtin_index += 1;
588                }
589                crate::Binding::Location { .. } => {
590                    body.push(Instruction::access_chain(
591                        self.get_pointer_type_id(member.ty_id, spirv::StorageClass::Output),
592                        ptr_to_copy_to,
593                        info.bindings[binding_index],
594                        &[val_i],
595                    ));
596                    binding_index += 1;
597                }
598            }
599            body.push(Instruction::store(ptr_to_copy_to, val_to_copy, None));
600            // Flip the vertex position y coordinate in some cases
601            // Can't use epilogue flip because can't read from this storage class
602            if needs_y_flip {
603                let prev_y = self.id_gen.next();
604                body.push(Instruction::composite_extract(
605                    self.get_f32_type_id(),
606                    prev_y,
607                    val_to_copy,
608                    &[1],
609                ));
610                let new_y = self.id_gen.next();
611                body.push(Instruction::unary(
612                    spirv::Op::FNegate,
613                    self.get_f32_type_id(),
614                    new_y,
615                    prev_y,
616                ));
617                let new_ptr_to_copy_to = self.id_gen.next();
618                body.push(Instruction::access_chain(
619                    self.get_f32_pointer_type_id(spirv::StorageClass::Output),
620                    new_ptr_to_copy_to,
621                    ptr_to_copy_to,
622                    &[self.get_constant_scalar(crate::Literal::U32(1))],
623                ));
624                body.push(Instruction::store(new_ptr_to_copy_to, new_y, None));
625            }
626        }
627        body
628    }
629
630    /// Writes the return call for a mesh shader, which involves copying previously
631    /// written vertices/primitives into the actual output location.
632    pub(super) fn write_mesh_shader_return(
633        &mut self,
634        return_info: &MeshReturnInfo,
635        block: &mut Block,
636        loop_counter_vertices: u32,
637        loop_counter_primitives: u32,
638        local_invocation_index_id: Word,
639    ) -> Result<(), Error> {
640        let u32_id = self.get_u32_type_id();
641
642        // Load the actual vertex and primitive counts
643        let mut load_u32_by_member_index =
644            |members: &[MeshReturnMember], bi: crate::BuiltIn, max: u32| {
645                let member_index = members
646                    .iter()
647                    .position(|a| a.binding == crate::Binding::BuiltIn(bi))
648                    .unwrap() as u32;
649                let ptr_id = self.id_gen.next();
650                block.body.push(Instruction::access_chain(
651                    self.get_pointer_type_id(u32_id, spirv::StorageClass::Workgroup),
652                    ptr_id,
653                    return_info.out_variable_id,
654                    &[self.get_constant_scalar(crate::Literal::U32(member_index))],
655                ));
656                let before_min_id = self.id_gen.next();
657                block
658                    .body
659                    .push(Instruction::load(u32_id, before_min_id, ptr_id, None));
660
661                // Clamp the values
662                let id = self.id_gen.next();
663                block.body.push(Instruction::ext_inst_gl_op(
664                    self.gl450_ext_inst_id,
665                    spirv::GLOp::UMin,
666                    u32_id,
667                    id,
668                    &[before_min_id, max],
669                ));
670                id
671            };
672        let vert_count_id = load_u32_by_member_index(
673            &return_info.out_members,
674            crate::BuiltIn::VertexCount,
675            return_info.vertex_info.max_length_constant,
676        );
677        let prim_count_id = load_u32_by_member_index(
678            &return_info.out_members,
679            crate::BuiltIn::PrimitiveCount,
680            return_info.primitive_info.max_length_constant,
681        );
682
683        // Get pointers to the arrays of data to extract
684        let mut get_array_ptr = |bi: crate::BuiltIn, array_type_id: u32| {
685            let id = self.id_gen.next();
686            block.body.push(Instruction::access_chain(
687                self.get_pointer_type_id(array_type_id, spirv::StorageClass::Workgroup),
688                id,
689                return_info.out_variable_id,
690                &[self.get_constant_scalar(crate::Literal::U32(
691                    return_info
692                        .out_members
693                        .iter()
694                        .position(|a| a.binding == crate::Binding::BuiltIn(bi))
695                        .unwrap() as u32,
696                ))],
697            ));
698            id
699        };
700        let vert_array_ptr = get_array_ptr(
701            crate::BuiltIn::Vertices,
702            return_info.vertex_info.array_type_id,
703        );
704        let prim_array_ptr = get_array_ptr(
705            crate::BuiltIn::Primitives,
706            return_info.primitive_info.array_type_id,
707        );
708
709        // This must be called exactly once before any other mesh outputs are written
710        {
711            let mut ins = Instruction::new(spirv::Op::SetMeshOutputsEXT);
712            ins.add_operand(vert_count_id);
713            ins.add_operand(prim_count_id);
714            block.body.push(ins);
715        }
716
717        // This is iterating over every returned vertex and splitting
718        // it out into the multiple per-output arrays.
719        let vertex_loop_header = self.id_gen.next();
720        let prim_loop_header = self.id_gen.next();
721        let in_between_loops = self.id_gen.next();
722        let func_end = self.id_gen.next();
723
724        block.body.push(Instruction::store(
725            loop_counter_vertices,
726            local_invocation_index_id,
727            None,
728        ));
729        block.body.push(Instruction::branch(vertex_loop_header));
730
731        let vertex_copy_body = self.write_mesh_copy_body(
732            false,
733            return_info,
734            loop_counter_vertices,
735            vert_array_ptr,
736            prim_array_ptr,
737        );
738        // Write vertex copy loop
739        self.write_mesh_copy_loop(
740            &mut block.body,
741            vertex_copy_body,
742            vertex_loop_header,
743            in_between_loops,
744            vert_count_id,
745            loop_counter_vertices,
746            return_info,
747        );
748
749        // In between loops, reset the initial index
750        {
751            block.body.push(Instruction::label(in_between_loops));
752
753            block.body.push(Instruction::store(
754                loop_counter_primitives,
755                local_invocation_index_id,
756                None,
757            ));
758
759            block.body.push(Instruction::branch(prim_loop_header));
760        }
761        let primitive_copy_body = self.write_mesh_copy_body(
762            true,
763            return_info,
764            loop_counter_primitives,
765            vert_array_ptr,
766            prim_array_ptr,
767        );
768        // Write primitive copy loop
769        self.write_mesh_copy_loop(
770            &mut block.body,
771            primitive_copy_body,
772            prim_loop_header,
773            func_end,
774            prim_count_id,
775            loop_counter_primitives,
776            return_info,
777        );
778
779        block.body.push(Instruction::label(func_end));
780        Ok(())
781    }
782
783    pub(super) fn write_mesh_shader_wrapper(
784        &mut self,
785        return_info: &MeshReturnInfo,
786        inner_id: u32,
787    ) -> Result<u32, Error> {
788        let out_id = self.id_gen.next();
789        let mut function = super::Function::default();
790        let lookup_function_type = super::LookupFunctionType {
791            parameter_type_ids: alloc::vec![],
792            return_type_id: self.void_type,
793        };
794        let function_type = self.get_function_type(lookup_function_type);
795        function.signature = Some(Instruction::function(
796            self.void_type,
797            out_id,
798            spirv::FunctionControl::empty(),
799            function_type,
800        ));
801        let u32_id = self.get_u32_type_id();
802        {
803            let mut block = Block::new(self.id_gen.next());
804            // A general function variable that we guarantee to allow in the final return. It must be
805            // declared at the top of the function. Currently it is used in the memcpy part to keep
806            // track of the current index to copy.
807            let loop_counter_vertices = self.id_gen.next();
808            let loop_counter_primitives = self.id_gen.next();
809            block.body.insert(
810                0,
811                Instruction::variable(
812                    self.get_pointer_type_id(u32_id, spirv::StorageClass::Function),
813                    loop_counter_vertices,
814                    spirv::StorageClass::Function,
815                    None,
816                ),
817            );
818            block.body.insert(
819                1,
820                Instruction::variable(
821                    self.get_pointer_type_id(u32_id, spirv::StorageClass::Function),
822                    loop_counter_primitives,
823                    spirv::StorageClass::Function,
824                    None,
825                ),
826            );
827            let local_invocation_index_id = self.id_gen.next();
828            block.body.push(Instruction::load(
829                u32_id,
830                local_invocation_index_id,
831                return_info.local_invocation_index_var_id,
832                None,
833            ));
834            block.body.push(Instruction::function_call(
835                self.void_type,
836                self.id_gen.next(),
837                inner_id,
838                &[],
839            ));
840            self.write_control_barrier(crate::Barrier::WORK_GROUP, &mut block.body);
841            self.write_mesh_shader_return(
842                return_info,
843                &mut block,
844                loop_counter_vertices,
845                loop_counter_primitives,
846                local_invocation_index_id,
847            )?;
848            function.consume(block, Instruction::return_void());
849        }
850        function.to_words(&mut self.logical_layout.function_definitions);
851        Ok(out_id)
852    }
853
854    pub(super) fn write_task_shader_wrapper(
855        &mut self,
856        task_payload: Word,
857        inner_id: u32,
858    ) -> Result<u32, Error> {
859        let out_id = self.id_gen.next();
860        let mut function = super::Function::default();
861        let lookup_function_type = super::LookupFunctionType {
862            parameter_type_ids: alloc::vec![],
863            return_type_id: self.void_type,
864        };
865        let function_type = self.get_function_type(lookup_function_type);
866        function.signature = Some(Instruction::function(
867            self.void_type,
868            out_id,
869            spirv::FunctionControl::empty(),
870            function_type,
871        ));
872
873        {
874            let mut block = Block::new(self.id_gen.next());
875            let result = self.id_gen.next();
876            block.body.push(Instruction::function_call(
877                self.get_vec3u_type_id(),
878                result,
879                inner_id,
880                &[],
881            ));
882            self.write_control_barrier(crate::Barrier::WORK_GROUP, &mut block.body);
883            let final_value = if let Some(task_limits) = self.task_dispatch_limits {
884                let zero_u32 = self.get_constant_scalar(crate::Literal::U32(0));
885                let max_per_dim = self.get_constant_scalar(crate::Literal::U32(
886                    task_limits.max_mesh_workgroups_per_dim,
887                ));
888                let max_total = self.get_constant_scalar(crate::Literal::U32(
889                    task_limits.max_mesh_workgroups_total,
890                ));
891                let combined_struct_type = self.get_tuple_of_u32s_ty_id();
892                let values = [self.id_gen.next(), self.id_gen.next(), self.id_gen.next()];
893                for (i, value) in values.into_iter().enumerate() {
894                    block.body.push(Instruction::composite_extract(
895                        self.get_u32_type_id(),
896                        value,
897                        result,
898                        &[i as u32],
899                    ));
900                }
901                let prod_1 = self.id_gen.next();
902                let overflows = [self.id_gen.next(), self.id_gen.next()];
903                {
904                    let struct_out = self.id_gen.next();
905                    block.body.push(Instruction::binary(
906                        spirv::Op::UMulExtended,
907                        combined_struct_type,
908                        struct_out,
909                        values[0],
910                        values[1],
911                    ));
912                    block.body.push(Instruction::composite_extract(
913                        self.get_u32_type_id(),
914                        prod_1,
915                        struct_out,
916                        &[0],
917                    ));
918                    block.body.push(Instruction::composite_extract(
919                        self.get_u32_type_id(),
920                        overflows[0],
921                        struct_out,
922                        &[1],
923                    ));
924                }
925                let prod_final = self.id_gen.next();
926                {
927                    let struct_out = self.id_gen.next();
928                    block.body.push(Instruction::binary(
929                        spirv::Op::UMulExtended,
930                        combined_struct_type,
931                        struct_out,
932                        prod_1,
933                        values[2],
934                    ));
935                    block.body.push(Instruction::composite_extract(
936                        self.get_u32_type_id(),
937                        prod_final,
938                        struct_out,
939                        &[0],
940                    ));
941                    block.body.push(Instruction::composite_extract(
942                        self.get_u32_type_id(),
943                        overflows[1],
944                        struct_out,
945                        &[1],
946                    ));
947                }
948                let total_too_large = self.id_gen.next();
949                block.body.push(Instruction::binary(
950                    spirv::Op::UGreaterThan,
951                    self.get_bool_type_id(),
952                    total_too_large,
953                    prod_final,
954                    max_total,
955                ));
956
957                let too_large = [self.id_gen.next(), self.id_gen.next(), self.id_gen.next()];
958                for (i, value) in values.into_iter().enumerate() {
959                    block.body.push(Instruction::binary(
960                        spirv::Op::UGreaterThan,
961                        self.get_bool_type_id(),
962                        too_large[i],
963                        value,
964                        max_per_dim,
965                    ));
966                }
967                let overflow_happens = [self.id_gen.next(), self.id_gen.next()];
968                for (i, value) in overflows.into_iter().enumerate() {
969                    block.body.push(Instruction::binary(
970                        spirv::Op::INotEqual,
971                        self.get_bool_type_id(),
972                        overflow_happens[i],
973                        value,
974                        zero_u32,
975                    ));
976                }
977                let mut current_violates_limits = total_too_large;
978                for is_too_large in too_large {
979                    let new = self.id_gen.next();
980                    block.body.push(Instruction::binary(
981                        spirv::Op::LogicalOr,
982                        self.get_bool_type_id(),
983                        new,
984                        current_violates_limits,
985                        is_too_large,
986                    ));
987                    current_violates_limits = new;
988                }
989                for overflow_happens in overflow_happens {
990                    let new = self.id_gen.next();
991                    block.body.push(Instruction::binary(
992                        spirv::Op::LogicalOr,
993                        self.get_bool_type_id(),
994                        new,
995                        current_violates_limits,
996                        overflow_happens,
997                    ));
998                    current_violates_limits = new;
999                }
1000                let zero_vec3 = self.id_gen.next();
1001                block.body.push(Instruction::composite_construct(
1002                    self.get_vec3u_type_id(),
1003                    zero_vec3,
1004                    &[zero_u32, zero_u32, zero_u32],
1005                ));
1006                let final_result = self.id_gen.next();
1007                block.body.push(Instruction::select(
1008                    self.get_vec3u_type_id(),
1009                    final_result,
1010                    current_violates_limits,
1011                    zero_vec3,
1012                    result,
1013                ));
1014                final_result
1015            } else {
1016                result
1017            };
1018            let ins =
1019                self.write_entry_point_task_return(final_value, &mut block.body, task_payload)?;
1020            function.consume(block, ins);
1021        }
1022        function.to_words(&mut self.logical_layout.function_definitions);
1023        Ok(out_id)
1024    }
1025}