naga/back/spv/
mesh_shader.rs

1use alloc::vec::Vec;
2use spirv::Word;
3
4use crate::{
5    back::spv::{
6        helpers::BindingDecorations, writer::FunctionInterface, Block, EntryPointContext, Error,
7        Instruction, ResultMember, WriterFlags,
8    },
9    non_max_u32::NonMaxU32,
10    Handle,
11};
12
13#[derive(Clone)]
14pub struct MeshReturnMember {
15    pub ty_id: u32,
16    pub binding: crate::Binding,
17}
18
19struct PerOutputTypeMeshReturnInfo {
20    max_length_constant: Word,
21    array_type_id: Word,
22    struct_members: Vec<MeshReturnMember>,
23
24    // * Most builtins must be in the same block.
25    // * All bindings must be in their own unique block.
26    // * The primitive indices builtin family needs its own block.
27    // * Cull primitive doesn't care about having its own block, but
28    //   some older validation layers didn't respect this.
29    builtin_block: Option<Word>,
30    bindings: Vec<Word>,
31}
32
33pub struct MeshReturnInfo {
34    /// Id of the workgroup variable containing the data to be output
35    out_variable_id: Word,
36    /// All members of the output variable struct type
37    out_members: Vec<MeshReturnMember>,
38    /// Id of the input variable for local invocation id
39    local_invocation_index_id: Word,
40    /// Total workgroup size (product)
41    workgroup_size: u32,
42    /// Variable to be used later when saving the output as a loop index
43    loop_counter_vertices: Word,
44    /// Variable to be used later when saving the output as a loop index
45    loop_counter_primitives: Word,
46    /// The id of the label to jump to when `return` is called
47    pub entry_point_epilogue_id: Word,
48
49    /// Vertex-specific info
50    vertex_info: PerOutputTypeMeshReturnInfo,
51    /// Primitive-specific info
52    primitive_info: PerOutputTypeMeshReturnInfo,
53    /// Array variable for the primitive indices builtin
54    primitive_indices: Option<Word>,
55}
56
57impl super::Writer {
58    pub(super) fn require_mesh_shaders(&mut self) -> Result<(), Error> {
59        self.use_extension("SPV_EXT_mesh_shader");
60        self.require_any("Mesh Shaders", &[spirv::Capability::MeshShadingEXT])?;
61        let lang_version = self.lang_version();
62        if lang_version.0 <= 1 && lang_version.1 < 4 {
63            return Err(Error::SpirvVersionTooLow(1, 4));
64        }
65        Ok(())
66    }
67
68    /// Sets up an output variable that will handle part of the mesh shader output
69    pub(super) fn write_mesh_return_global_variable(
70        &mut self,
71        ty: u32,
72        array_size_id: u32,
73    ) -> Result<Word, Error> {
74        let array_ty = self.id_gen.next();
75        Instruction::type_array(array_ty, ty, array_size_id)
76            .to_words(&mut self.logical_layout.declarations);
77        let ptr_ty = self.get_pointer_type_id(array_ty, spirv::StorageClass::Output);
78        let var_id = self.id_gen.next();
79        Instruction::variable(ptr_ty, var_id, spirv::StorageClass::Output, None)
80            .to_words(&mut self.logical_layout.declarations);
81        Ok(var_id)
82    }
83
84    /// This does various setup things to allow mesh shader entry points
85    /// to be properly written, such as creating the output variables
86    pub(super) fn write_entry_point_mesh_shader_info(
87        &mut self,
88        iface: &mut FunctionInterface,
89        local_invocation_index_id: Option<Word>,
90        ir_module: &crate::Module,
91        prelude: &mut Block,
92        ep_context: &mut EntryPointContext,
93    ) -> Result<(), Error> {
94        let Some(ref mesh_info) = iface.mesh_info else {
95            return Ok(());
96        };
97        // Collect the members in the output structs
98        let out_members: Vec<MeshReturnMember> =
99            match &ir_module.types[ir_module.global_variables[mesh_info.output_variable].ty] {
100                &crate::Type {
101                    inner: crate::TypeInner::Struct { ref members, .. },
102                    ..
103                } => members
104                    .iter()
105                    .map(|a| MeshReturnMember {
106                        ty_id: self.get_handle_type_id(a.ty),
107                        binding: a.binding.clone().unwrap(),
108                    })
109                    .collect(),
110                _ => unreachable!(),
111            };
112        let vertex_array_type_id = out_members
113            .iter()
114            .find(|a| a.binding == crate::Binding::BuiltIn(crate::BuiltIn::Vertices))
115            .unwrap()
116            .ty_id;
117        let primitive_array_type_id = out_members
118            .iter()
119            .find(|a| a.binding == crate::Binding::BuiltIn(crate::BuiltIn::Primitives))
120            .unwrap()
121            .ty_id;
122        let vertex_members = match &ir_module.types[mesh_info.vertex_output_type] {
123            &crate::Type {
124                inner: crate::TypeInner::Struct { ref members, .. },
125                ..
126            } => members
127                .iter()
128                .map(|a| MeshReturnMember {
129                    ty_id: self.get_handle_type_id(a.ty),
130                    binding: a.binding.clone().unwrap(),
131                })
132                .collect(),
133            _ => unreachable!(),
134        };
135        let primitive_members = match &ir_module.types[mesh_info.primitive_output_type] {
136            &crate::Type {
137                inner: crate::TypeInner::Struct { ref members, .. },
138                ..
139            } => members
140                .iter()
141                .map(|a| MeshReturnMember {
142                    ty_id: self.get_handle_type_id(a.ty),
143                    binding: a.binding.clone().unwrap(),
144                })
145                .collect(),
146            _ => unreachable!(),
147        };
148        // In the final return, we do a giant memcpy, for which this is helpful
149        let local_invocation_index_id = match local_invocation_index_id {
150            Some(a) => a,
151            None => {
152                let u32_id = self.get_u32_type_id();
153                let var = self.id_gen.next();
154                Instruction::variable(
155                    self.get_pointer_type_id(u32_id, spirv::StorageClass::Input),
156                    var,
157                    spirv::StorageClass::Input,
158                    None,
159                )
160                .to_words(&mut self.logical_layout.declarations);
161                Instruction::decorate(
162                    var,
163                    spirv::Decoration::BuiltIn,
164                    &[spirv::BuiltIn::LocalInvocationIndex as u32],
165                )
166                .to_words(&mut self.logical_layout.annotations);
167                iface.varying_ids.push(var);
168
169                let loaded_value = self.id_gen.next();
170                prelude
171                    .body
172                    .push(Instruction::load(u32_id, loaded_value, var, None));
173                loaded_value
174            }
175        };
176        let u32_id = self.get_u32_type_id();
177        // A general function variable that we guarantee to allow in the final return. It must be
178        // declared at the top of the function. Currently it is used in the memcpy part to keep
179        // track of the current index to copy.
180        let loop_counter_1 = self.id_gen.next();
181        let loop_counter_2 = self.id_gen.next();
182        prelude.body.insert(
183            0,
184            Instruction::variable(
185                self.get_pointer_type_id(u32_id, spirv::StorageClass::Function),
186                loop_counter_1,
187                spirv::StorageClass::Function,
188                None,
189            ),
190        );
191        prelude.body.insert(
192            1,
193            Instruction::variable(
194                self.get_pointer_type_id(u32_id, spirv::StorageClass::Function),
195                loop_counter_2,
196                spirv::StorageClass::Function,
197                None,
198            ),
199        );
200        // This is the information that is passed to the function writer
201        // so that it can write the final return logic
202        let mut mesh_return_info = MeshReturnInfo {
203            out_variable_id: self.global_variables[mesh_info.output_variable].var_id,
204            out_members,
205            local_invocation_index_id,
206            workgroup_size: self
207                .get_constant_scalar(crate::Literal::U32(iface.workgroup_size.iter().product())),
208            loop_counter_vertices: loop_counter_1,
209            loop_counter_primitives: loop_counter_2,
210            entry_point_epilogue_id: self.id_gen.next(),
211
212            vertex_info: PerOutputTypeMeshReturnInfo {
213                array_type_id: vertex_array_type_id,
214                struct_members: vertex_members,
215                max_length_constant: self
216                    .get_constant_scalar(crate::Literal::U32(mesh_info.max_vertices)),
217                bindings: Vec::new(),
218                builtin_block: None,
219            },
220            primitive_info: PerOutputTypeMeshReturnInfo {
221                array_type_id: primitive_array_type_id,
222                struct_members: primitive_members,
223                max_length_constant: self
224                    .get_constant_scalar(crate::Literal::U32(mesh_info.max_primitives)),
225                bindings: Vec::new(),
226                builtin_block: None,
227            },
228            primitive_indices: None,
229        };
230        let vert_array_size_id =
231            self.get_constant_scalar(crate::Literal::U32(mesh_info.max_vertices));
232        let prim_array_size_id =
233            self.get_constant_scalar(crate::Literal::U32(mesh_info.max_primitives));
234
235        // Create the actual output variables and types.
236        // According to SPIR-V,
237        // * All builtins must be in the same output `Block` (except builtins for different output types like vertex/primitive)
238        // * Each member with `location` must be in its own `Block` decorated `struct`
239        // * Some builtins like CullPrimitiveEXT don't care as much (older validation layers don't know this! Wonderful!)
240        // * Some builtins like the indices ones need to be in their own output variable without a struct wrapper
241
242        // Write vertex builtin block
243        if mesh_return_info
244            .vertex_info
245            .struct_members
246            .iter()
247            .any(|a| matches!(a.binding, crate::Binding::BuiltIn(..)))
248        {
249            let builtin_block_ty_id = self.id_gen.next();
250            let mut ins = Instruction::type_struct(builtin_block_ty_id, &[]);
251            let mut bi_index = 0;
252            let mut decorations = Vec::new();
253            for member in &mesh_return_info.vertex_info.struct_members {
254                if let crate::Binding::BuiltIn(_) = member.binding {
255                    ins.add_operand(member.ty_id);
256                    let binding = self.map_binding(
257                        ir_module,
258                        iface.stage,
259                        spirv::StorageClass::Output,
260                        // Unused except in fragment shaders with other conditions, so we can pass null
261                        Handle::new(NonMaxU32::new(0).unwrap()),
262                        &member.binding,
263                    )?;
264                    match binding {
265                        BindingDecorations::BuiltIn(bi, others) => {
266                            decorations.push(Instruction::member_decorate(
267                                builtin_block_ty_id,
268                                bi_index,
269                                spirv::Decoration::BuiltIn,
270                                &[bi as Word],
271                            ));
272                            for other in others {
273                                decorations.push(Instruction::member_decorate(
274                                    builtin_block_ty_id,
275                                    bi_index,
276                                    other,
277                                    &[],
278                                ));
279                            }
280                        }
281                        _ => unreachable!(),
282                    }
283                    bi_index += 1;
284                }
285            }
286            ins.to_words(&mut self.logical_layout.declarations);
287            decorations.push(Instruction::decorate(
288                builtin_block_ty_id,
289                spirv::Decoration::Block,
290                &[],
291            ));
292            for dec in decorations {
293                dec.to_words(&mut self.logical_layout.annotations);
294            }
295            let v =
296                self.write_mesh_return_global_variable(builtin_block_ty_id, vert_array_size_id)?;
297            iface.varying_ids.push(v);
298            if self.flags.contains(WriterFlags::DEBUG) {
299                self.debugs
300                    .push(Instruction::name(v, "naga_vertex_builtin_outputs"));
301            }
302            mesh_return_info.vertex_info.builtin_block = Some(v);
303        }
304        // Write primitive builtin block
305        if mesh_return_info
306            .primitive_info
307            .struct_members
308            .iter()
309            .any(|a| {
310                !matches!(
311                    a.binding,
312                    crate::Binding::BuiltIn(
313                        crate::BuiltIn::PointIndex
314                            | crate::BuiltIn::LineIndices
315                            | crate::BuiltIn::TriangleIndices
316                    ) | crate::Binding::Location { .. }
317                )
318            })
319        {
320            let builtin_block_ty_id = self.id_gen.next();
321            let mut ins = Instruction::type_struct(builtin_block_ty_id, &[]);
322            let mut bi_index = 0;
323            let mut decorations = Vec::new();
324            for member in &mesh_return_info.primitive_info.struct_members {
325                if let crate::Binding::BuiltIn(bi) = member.binding {
326                    // These need to be in their own block, unlike other builtins
327                    if matches!(
328                        bi,
329                        crate::BuiltIn::PointIndex
330                            | crate::BuiltIn::LineIndices
331                            | crate::BuiltIn::TriangleIndices,
332                    ) {
333                        continue;
334                    }
335                    ins.add_operand(member.ty_id);
336                    let binding = self.map_binding(
337                        ir_module,
338                        iface.stage,
339                        spirv::StorageClass::Output,
340                        // Unused except in fragment shaders with other conditions, so we can pass null
341                        Handle::new(NonMaxU32::new(0).unwrap()),
342                        &member.binding,
343                    )?;
344                    match binding {
345                        BindingDecorations::BuiltIn(bi, others) => {
346                            decorations.push(Instruction::member_decorate(
347                                builtin_block_ty_id,
348                                bi_index,
349                                spirv::Decoration::BuiltIn,
350                                &[bi as Word],
351                            ));
352                            for other in others {
353                                decorations.push(Instruction::member_decorate(
354                                    builtin_block_ty_id,
355                                    bi_index,
356                                    other,
357                                    &[],
358                                ));
359                            }
360                        }
361                        _ => unreachable!(),
362                    }
363                    bi_index += 1;
364                }
365            }
366            ins.to_words(&mut self.logical_layout.declarations);
367            decorations.push(Instruction::decorate(
368                builtin_block_ty_id,
369                spirv::Decoration::Block,
370                &[],
371            ));
372            for dec in decorations {
373                dec.to_words(&mut self.logical_layout.annotations);
374            }
375            let v =
376                self.write_mesh_return_global_variable(builtin_block_ty_id, prim_array_size_id)?;
377            Instruction::decorate(v, spirv::Decoration::PerPrimitiveEXT, &[])
378                .to_words(&mut self.logical_layout.annotations);
379            iface.varying_ids.push(v);
380            if self.flags.contains(WriterFlags::DEBUG) {
381                self.debugs
382                    .push(Instruction::name(v, "naga_primitive_builtin_outputs"));
383            }
384            mesh_return_info.primitive_info.builtin_block = Some(v);
385        }
386
387        // Write vertex binding output blocks (1 array per output struct member)
388        for member in &mesh_return_info.vertex_info.struct_members {
389            match member.binding {
390                crate::Binding::Location { location, .. } => {
391                    // Create variable
392                    let v =
393                        self.write_mesh_return_global_variable(member.ty_id, vert_array_size_id)?;
394                    // Decorate the variable with Location
395                    Instruction::decorate(v, spirv::Decoration::Location, &[location])
396                        .to_words(&mut self.logical_layout.annotations);
397                    iface.varying_ids.push(v);
398                    mesh_return_info.vertex_info.bindings.push(v);
399                }
400                crate::Binding::BuiltIn(_) => (),
401            }
402        }
403        // Write primitive binding output blocks (1 array per output struct member)
404        // Also write indices output block
405        for member in &mesh_return_info.primitive_info.struct_members {
406            match member.binding {
407                crate::Binding::BuiltIn(
408                    crate::BuiltIn::PointIndex
409                    | crate::BuiltIn::LineIndices
410                    | crate::BuiltIn::TriangleIndices,
411                ) => {
412                    // This is written here instead of as part of the builtin block
413                    let v =
414                        self.write_mesh_return_global_variable(member.ty_id, prim_array_size_id)?;
415                    // This shouldn't be marked as PerPrimitiveEXT
416                    Instruction::decorate(
417                        v,
418                        spirv::Decoration::BuiltIn,
419                        &[match member.binding.to_built_in().unwrap() {
420                            crate::BuiltIn::PointIndex => spirv::BuiltIn::PrimitivePointIndicesEXT,
421                            crate::BuiltIn::LineIndices => spirv::BuiltIn::PrimitiveLineIndicesEXT,
422                            crate::BuiltIn::TriangleIndices => {
423                                spirv::BuiltIn::PrimitiveTriangleIndicesEXT
424                            }
425                            _ => unreachable!(),
426                        } as Word],
427                    )
428                    .to_words(&mut self.logical_layout.annotations);
429                    iface.varying_ids.push(v);
430                    if self.flags.contains(WriterFlags::DEBUG) {
431                        self.debugs
432                            .push(Instruction::name(v, "naga_primitive_indices_outputs"));
433                    }
434                    mesh_return_info.primitive_indices = Some(v);
435                }
436                crate::Binding::Location { location, .. } => {
437                    // Create variable
438                    let v =
439                        self.write_mesh_return_global_variable(member.ty_id, prim_array_size_id)?;
440                    // Decorate the variable with Location
441                    Instruction::decorate(v, spirv::Decoration::Location, &[location])
442                        .to_words(&mut self.logical_layout.annotations);
443                    // Decorate it with PerPrimitiveEXT
444                    Instruction::decorate(v, spirv::Decoration::PerPrimitiveEXT, &[])
445                        .to_words(&mut self.logical_layout.annotations);
446                    iface.varying_ids.push(v);
447
448                    mesh_return_info.primitive_info.bindings.push(v);
449                }
450                crate::Binding::BuiltIn(_) => (),
451            }
452        }
453
454        // Store this where it can be read later during function write
455        ep_context.mesh_state = Some(mesh_return_info);
456
457        Ok(())
458    }
459
460    pub(super) fn try_write_entry_point_task_return(
461        &mut self,
462        value_id: Word,
463        ir_result: &crate::FunctionResult,
464        result_members: &[ResultMember],
465        body: &mut Vec<Instruction>,
466        task_payload: Option<Word>,
467    ) -> Result<Instruction, Error> {
468        // OpEmitMeshTasksEXT must be called right before exiting (after setting other
469        // output variables if there are any)
470        for (index, res_member) in result_members.iter().enumerate() {
471            if res_member.built_in == Some(crate::BuiltIn::MeshTaskSize) {
472                self.write_control_barrier(crate::Barrier::WORK_GROUP, body);
473                // If its a function like `fn a() -> @builtin(...) vec3<u32> ...`
474                // then just use the output value. If it's a struct, extract the
475                // value from the struct.
476                let member_value_id = match ir_result.binding {
477                    Some(_) => value_id,
478                    None => {
479                        let member_value_id = self.id_gen.next();
480                        body.push(Instruction::composite_extract(
481                            res_member.type_id,
482                            member_value_id,
483                            value_id,
484                            &[index as Word],
485                        ));
486                        member_value_id
487                    }
488                };
489
490                // Extract the vec3<u32> into 3 u32's
491                let values = [self.id_gen.next(), self.id_gen.next(), self.id_gen.next()];
492                for (i, &value) in values.iter().enumerate() {
493                    let instruction = Instruction::composite_extract(
494                        self.get_u32_type_id(),
495                        value,
496                        member_value_id,
497                        &[i as Word],
498                    );
499                    body.push(instruction);
500                }
501                // TODO: make this guaranteed to be uniform
502                let mut instruction = Instruction::new(spirv::Op::EmitMeshTasksEXT);
503                for id in values {
504                    instruction.add_operand(id);
505                }
506                // We have to include the task payload in our call
507                if let Some(task_payload) = task_payload {
508                    instruction.add_operand(task_payload);
509                }
510                return Ok(instruction);
511            }
512        }
513        Ok(Instruction::return_void())
514    }
515
516    /// This writes the actual loop
517    #[allow(clippy::too_many_arguments)]
518    fn write_mesh_copy_loop(
519        &mut self,
520        body: &mut Vec<Instruction>,
521        mut loop_body_block: Vec<Instruction>,
522        loop_header: u32,
523        loop_merge: u32,
524        count_id: u32,
525        index_var: u32,
526        return_info: &MeshReturnInfo,
527    ) {
528        let u32_id = self.get_u32_type_id();
529        let condition_check = self.id_gen.next();
530        let loop_continue = self.id_gen.next();
531        let loop_body = self.id_gen.next();
532
533        // Loop header
534        {
535            body.push(Instruction::label(loop_header));
536            body.push(Instruction::loop_merge(
537                loop_merge,
538                loop_continue,
539                spirv::SelectionControl::empty(),
540            ));
541            body.push(Instruction::branch(condition_check));
542        }
543        // Condition check - check if i is less than num vertices to copy
544        {
545            body.push(Instruction::label(condition_check));
546
547            let val_i = self.id_gen.next();
548            body.push(Instruction::load(u32_id, val_i, index_var, None));
549
550            let cond = self.id_gen.next();
551            body.push(Instruction::binary(
552                spirv::Op::ULessThan,
553                self.get_bool_type_id(),
554                cond,
555                val_i,
556                count_id,
557            ));
558            body.push(Instruction::branch_conditional(cond, loop_body, loop_merge));
559        }
560        // Loop body
561        {
562            body.push(Instruction::label(loop_body));
563            body.append(&mut loop_body_block);
564            body.push(Instruction::branch(loop_continue));
565        }
566        // Loop continue - increment i
567        {
568            body.push(Instruction::label(loop_continue));
569
570            let prev_val_i = self.id_gen.next();
571            body.push(Instruction::load(u32_id, prev_val_i, index_var, None));
572            let new_val_i = self.id_gen.next();
573            body.push(Instruction::binary(
574                spirv::Op::IAdd,
575                u32_id,
576                new_val_i,
577                prev_val_i,
578                return_info.workgroup_size,
579            ));
580            body.push(Instruction::store(index_var, new_val_i, None));
581
582            body.push(Instruction::branch(loop_header));
583        }
584    }
585
586    /// This generates the instructions used to copy all parts of a single output vertex/primitive
587    /// to their individual output locations
588    fn write_mesh_copy_body(
589        &mut self,
590        is_primitive: bool,
591        return_info: &MeshReturnInfo,
592        index_var: u32,
593        vert_array_ptr: u32,
594        prim_array_ptr: u32,
595    ) -> Vec<Instruction> {
596        let u32_type_id = self.get_u32_type_id();
597        let mut body = Vec::new();
598        // Current index to copy
599        let val_i = self.id_gen.next();
600        body.push(Instruction::load(u32_type_id, val_i, index_var, None));
601
602        let info = if is_primitive {
603            &return_info.primitive_info
604        } else {
605            &return_info.vertex_info
606        };
607        let array_ptr = if is_primitive {
608            prim_array_ptr
609        } else {
610            vert_array_ptr
611        };
612
613        let mut builtin_index = 0;
614        let mut binding_index = 0;
615        // Write individual members of the vertex
616        for (member_id, member) in info.struct_members.iter().enumerate() {
617            let val_to_copy_ptr = self.id_gen.next();
618            body.push(Instruction::access_chain(
619                self.get_pointer_type_id(member.ty_id, spirv::StorageClass::Workgroup),
620                val_to_copy_ptr,
621                array_ptr,
622                &[
623                    val_i,
624                    self.get_constant_scalar(crate::Literal::U32(member_id as u32)),
625                ],
626            ));
627            let val_to_copy = self.id_gen.next();
628            body.push(Instruction::load(
629                member.ty_id,
630                val_to_copy,
631                val_to_copy_ptr,
632                None,
633            ));
634            let mut needs_y_flip = false;
635            let ptr_to_copy_to = self.id_gen.next();
636            // Get a pointer to the struct member to copy
637            match member.binding {
638                crate::Binding::BuiltIn(
639                    crate::BuiltIn::PointIndex
640                    | crate::BuiltIn::LineIndices
641                    | crate::BuiltIn::TriangleIndices,
642                ) => {
643                    body.push(Instruction::access_chain(
644                        self.get_pointer_type_id(member.ty_id, spirv::StorageClass::Output),
645                        ptr_to_copy_to,
646                        return_info.primitive_indices.unwrap(),
647                        &[val_i],
648                    ));
649                }
650                crate::Binding::BuiltIn(bi) => {
651                    body.push(Instruction::access_chain(
652                        self.get_pointer_type_id(member.ty_id, spirv::StorageClass::Output),
653                        ptr_to_copy_to,
654                        info.builtin_block.unwrap(),
655                        &[
656                            val_i,
657                            self.get_constant_scalar(crate::Literal::U32(builtin_index)),
658                        ],
659                    ));
660                    needs_y_flip = matches!(bi, crate::BuiltIn::Position { .. })
661                        && self.flags.contains(WriterFlags::ADJUST_COORDINATE_SPACE);
662                    builtin_index += 1;
663                }
664                crate::Binding::Location { .. } => {
665                    body.push(Instruction::access_chain(
666                        self.get_pointer_type_id(member.ty_id, spirv::StorageClass::Output),
667                        ptr_to_copy_to,
668                        info.bindings[binding_index],
669                        &[val_i],
670                    ));
671                    binding_index += 1;
672                }
673            }
674            body.push(Instruction::store(ptr_to_copy_to, val_to_copy, None));
675            // Flip the vertex position y coordinate in some cases
676            // Can't use epilogue flip because can't read from this storage class
677            if needs_y_flip {
678                let prev_y = self.id_gen.next();
679                body.push(Instruction::composite_extract(
680                    self.get_f32_type_id(),
681                    prev_y,
682                    val_to_copy,
683                    &[1],
684                ));
685                let new_y = self.id_gen.next();
686                body.push(Instruction::unary(
687                    spirv::Op::FNegate,
688                    self.get_f32_type_id(),
689                    new_y,
690                    prev_y,
691                ));
692                let new_ptr_to_copy_to = self.id_gen.next();
693                body.push(Instruction::access_chain(
694                    self.get_f32_pointer_type_id(spirv::StorageClass::Output),
695                    new_ptr_to_copy_to,
696                    ptr_to_copy_to,
697                    &[self.get_constant_scalar(crate::Literal::U32(1))],
698                ));
699                body.push(Instruction::store(new_ptr_to_copy_to, new_y, None));
700            }
701        }
702        body
703    }
704
705    /// Writes the return call for a mesh shader, which involves copying previously
706    /// written vertices/primitives into the actual output location.
707    pub(super) fn write_mesh_shader_return(
708        &mut self,
709        return_info: &MeshReturnInfo,
710        block: &mut Block,
711    ) -> Result<(), Error> {
712        // Start with a control barrier so that everything that follows is guaranteed to see the same variables
713        self.write_control_barrier(crate::Barrier::WORK_GROUP, &mut block.body);
714        let u32_id = self.get_u32_type_id();
715
716        // Load the actual vertex and primitive counts
717        let mut load_u32_by_member_index =
718            |members: &[MeshReturnMember], bi: crate::BuiltIn, max: u32| {
719                let member_index = members
720                    .iter()
721                    .position(|a| a.binding == crate::Binding::BuiltIn(bi))
722                    .unwrap() as u32;
723                let ptr_id = self.id_gen.next();
724                block.body.push(Instruction::access_chain(
725                    self.get_pointer_type_id(u32_id, spirv::StorageClass::Workgroup),
726                    ptr_id,
727                    return_info.out_variable_id,
728                    &[self.get_constant_scalar(crate::Literal::U32(member_index))],
729                ));
730                let before_min_id = self.id_gen.next();
731                block
732                    .body
733                    .push(Instruction::load(u32_id, before_min_id, ptr_id, None));
734
735                // Clamp the values
736                let id = self.id_gen.next();
737                block.body.push(Instruction::ext_inst_gl_op(
738                    self.gl450_ext_inst_id,
739                    spirv::GLOp::UMin,
740                    u32_id,
741                    id,
742                    &[before_min_id, max],
743                ));
744                id
745            };
746        let vert_count_id = load_u32_by_member_index(
747            &return_info.out_members,
748            crate::BuiltIn::VertexCount,
749            return_info.vertex_info.max_length_constant,
750        );
751        let prim_count_id = load_u32_by_member_index(
752            &return_info.out_members,
753            crate::BuiltIn::PrimitiveCount,
754            return_info.primitive_info.max_length_constant,
755        );
756
757        // Get pointers to the arrays of data to extract
758        let mut get_array_ptr = |bi: crate::BuiltIn, array_type_id: u32| {
759            let id = self.id_gen.next();
760            block.body.push(Instruction::access_chain(
761                self.get_pointer_type_id(array_type_id, spirv::StorageClass::Workgroup),
762                id,
763                return_info.out_variable_id,
764                &[self.get_constant_scalar(crate::Literal::U32(
765                    return_info
766                        .out_members
767                        .iter()
768                        .position(|a| a.binding == crate::Binding::BuiltIn(bi))
769                        .unwrap() as u32,
770                ))],
771            ));
772            id
773        };
774        let vert_array_ptr = get_array_ptr(
775            crate::BuiltIn::Vertices,
776            return_info.vertex_info.array_type_id,
777        );
778        let prim_array_ptr = get_array_ptr(
779            crate::BuiltIn::Primitives,
780            return_info.primitive_info.array_type_id,
781        );
782
783        self.write_control_barrier(crate::Barrier::WORK_GROUP, &mut block.body);
784
785        // This must be called exactly once before any other mesh outputs are written
786        {
787            let mut ins = Instruction::new(spirv::Op::SetMeshOutputsEXT);
788            ins.add_operand(vert_count_id);
789            ins.add_operand(prim_count_id);
790            block.body.push(ins);
791        }
792
793        // This is iterating over every returned vertex and splitting
794        // it out into the multiple per-output arrays.
795        let vertex_loop_header = self.id_gen.next();
796        let prim_loop_header = self.id_gen.next();
797        let in_between_loops = self.id_gen.next();
798        let func_end = self.id_gen.next();
799
800        block.body.push(Instruction::store(
801            return_info.loop_counter_vertices,
802            return_info.local_invocation_index_id,
803            None,
804        ));
805        block.body.push(Instruction::branch(vertex_loop_header));
806
807        let vertex_copy_body = self.write_mesh_copy_body(
808            false,
809            return_info,
810            return_info.loop_counter_vertices,
811            vert_array_ptr,
812            prim_array_ptr,
813        );
814        // Write vertex copy loop
815        self.write_mesh_copy_loop(
816            &mut block.body,
817            vertex_copy_body,
818            vertex_loop_header,
819            in_between_loops,
820            vert_count_id,
821            return_info.loop_counter_vertices,
822            return_info,
823        );
824
825        // In between loops, reset the initial index
826        {
827            block.body.push(Instruction::label(in_between_loops));
828
829            block.body.push(Instruction::store(
830                return_info.loop_counter_primitives,
831                return_info.local_invocation_index_id,
832                None,
833            ));
834
835            block.body.push(Instruction::branch(prim_loop_header));
836        }
837        let primitive_copy_body = self.write_mesh_copy_body(
838            true,
839            return_info,
840            return_info.loop_counter_primitives,
841            vert_array_ptr,
842            prim_array_ptr,
843        );
844        // Write primitive copy loop
845        self.write_mesh_copy_loop(
846            &mut block.body,
847            primitive_copy_body,
848            prim_loop_header,
849            func_end,
850            prim_count_id,
851            return_info.loop_counter_primitives,
852            return_info,
853        );
854
855        block.body.push(Instruction::label(func_end));
856        Ok(())
857    }
858}