1use alloc::vec::Vec;
2use spirv::Word;
3
4use crate::{
5 back::spv::{
6 helpers::BindingDecorations, writer::FunctionInterface, Block, EntryPointContext, Error,
7 Instruction, WriterFlags,
8 },
9 non_max_u32::NonMaxU32,
10 Handle,
11};
12
13#[derive(Clone)]
14pub struct MeshReturnMember {
15 pub ty_id: u32,
16 pub binding: crate::Binding,
17}
18
19struct PerOutputTypeMeshReturnInfo {
20 max_length_constant: Word,
21 array_type_id: Word,
22 struct_members: Vec<MeshReturnMember>,
23
24 builtin_block: Option<Word>,
30 bindings: Vec<Word>,
31}
32
33pub struct MeshReturnInfo {
34 out_variable_id: Word,
36 out_members: Vec<MeshReturnMember>,
38 local_invocation_index_var_id: Word,
40 workgroup_size: u32,
42
43 vertex_info: PerOutputTypeMeshReturnInfo,
45 primitive_info: PerOutputTypeMeshReturnInfo,
47 primitive_indices: Option<Word>,
49}
50
51impl super::Writer {
52 pub(super) fn write_mesh_return_global_variable(
54 &mut self,
55 ty: u32,
56 array_size_id: u32,
57 ) -> Result<Word, Error> {
58 let array_ty = self.id_gen.next();
59 Instruction::type_array(array_ty, ty, array_size_id)
60 .to_words(&mut self.logical_layout.declarations);
61 let ptr_ty = self.get_pointer_type_id(array_ty, spirv::StorageClass::Output);
62 let var_id = self.id_gen.next();
63 Instruction::variable(ptr_ty, var_id, spirv::StorageClass::Output, None)
64 .to_words(&mut self.logical_layout.declarations);
65 Ok(var_id)
66 }
67
68 pub(super) fn write_entry_point_mesh_shader_info(
71 &mut self,
72 iface: &mut FunctionInterface,
73 local_invocation_index_id: Option<Word>,
74 ir_module: &crate::Module,
75 ep_context: &mut EntryPointContext,
76 ) -> Result<(), Error> {
77 let Some(ref mesh_info) = iface.mesh_info else {
78 return Ok(());
79 };
80 let out_members: Vec<MeshReturnMember> =
82 match &ir_module.types[ir_module.global_variables[mesh_info.output_variable].ty] {
83 &crate::Type {
84 inner: crate::TypeInner::Struct { ref members, .. },
85 ..
86 } => members
87 .iter()
88 .map(|a| MeshReturnMember {
89 ty_id: self.get_handle_type_id(a.ty),
90 binding: a.binding.clone().unwrap(),
91 })
92 .collect(),
93 _ => unreachable!(),
94 };
95 let vertex_array_type_id = out_members
96 .iter()
97 .find(|a| a.binding == crate::Binding::BuiltIn(crate::BuiltIn::Vertices))
98 .unwrap()
99 .ty_id;
100 let primitive_array_type_id = out_members
101 .iter()
102 .find(|a| a.binding == crate::Binding::BuiltIn(crate::BuiltIn::Primitives))
103 .unwrap()
104 .ty_id;
105 let vertex_members = match &ir_module.types[mesh_info.vertex_output_type] {
106 &crate::Type {
107 inner: crate::TypeInner::Struct { ref members, .. },
108 ..
109 } => members
110 .iter()
111 .map(|a| MeshReturnMember {
112 ty_id: self.get_handle_type_id(a.ty),
113 binding: a.binding.clone().unwrap(),
114 })
115 .collect(),
116 _ => unreachable!(),
117 };
118 let primitive_members = match &ir_module.types[mesh_info.primitive_output_type] {
119 &crate::Type {
120 inner: crate::TypeInner::Struct { ref members, .. },
121 ..
122 } => members
123 .iter()
124 .map(|a| MeshReturnMember {
125 ty_id: self.get_handle_type_id(a.ty),
126 binding: a.binding.clone().unwrap(),
127 })
128 .collect(),
129 _ => unreachable!(),
130 };
131 let local_invocation_index_var_id = match local_invocation_index_id {
133 Some(a) => a,
134 None => {
135 let u32_id = self.get_u32_type_id();
136 let var = self.id_gen.next();
137 Instruction::variable(
138 self.get_pointer_type_id(u32_id, spirv::StorageClass::Input),
139 var,
140 spirv::StorageClass::Input,
141 None,
142 )
143 .to_words(&mut self.logical_layout.declarations);
144 Instruction::decorate(
145 var,
146 spirv::Decoration::BuiltIn,
147 &[spirv::BuiltIn::LocalInvocationIndex as u32],
148 )
149 .to_words(&mut self.logical_layout.annotations);
150 iface.varying_ids.push(var);
151
152 var
153 }
154 };
155 let mut mesh_return_info = MeshReturnInfo {
158 out_variable_id: self.global_variables[mesh_info.output_variable].var_id,
159 out_members,
160 local_invocation_index_var_id,
161 workgroup_size: self
162 .get_constant_scalar(crate::Literal::U32(iface.workgroup_size.iter().product())),
163
164 vertex_info: PerOutputTypeMeshReturnInfo {
165 array_type_id: vertex_array_type_id,
166 struct_members: vertex_members,
167 max_length_constant: self
168 .get_constant_scalar(crate::Literal::U32(mesh_info.max_vertices)),
169 bindings: Vec::new(),
170 builtin_block: None,
171 },
172 primitive_info: PerOutputTypeMeshReturnInfo {
173 array_type_id: primitive_array_type_id,
174 struct_members: primitive_members,
175 max_length_constant: self
176 .get_constant_scalar(crate::Literal::U32(mesh_info.max_primitives)),
177 bindings: Vec::new(),
178 builtin_block: None,
179 },
180 primitive_indices: None,
181 };
182 let vert_array_size_id =
183 self.get_constant_scalar(crate::Literal::U32(mesh_info.max_vertices));
184 let prim_array_size_id =
185 self.get_constant_scalar(crate::Literal::U32(mesh_info.max_primitives));
186
187 if mesh_return_info
196 .vertex_info
197 .struct_members
198 .iter()
199 .any(|a| matches!(a.binding, crate::Binding::BuiltIn(..)))
200 {
201 let builtin_block_ty_id = self.id_gen.next();
202 let mut ins = Instruction::type_struct(builtin_block_ty_id, &[]);
203 let mut bi_index = 0;
204 let mut decorations = Vec::new();
205 for member in &mesh_return_info.vertex_info.struct_members {
206 if let crate::Binding::BuiltIn(_) = member.binding {
207 ins.add_operand(member.ty_id);
208 let binding = self.map_binding(
209 ir_module,
210 iface.stage,
211 spirv::StorageClass::Output,
212 Handle::new(NonMaxU32::new(0).unwrap()),
214 &member.binding,
215 )?;
216 match binding {
217 BindingDecorations::BuiltIn(bi, others) => {
218 decorations.push(Instruction::member_decorate(
219 builtin_block_ty_id,
220 bi_index,
221 spirv::Decoration::BuiltIn,
222 &[bi as Word],
223 ));
224 for other in others {
225 decorations.push(Instruction::member_decorate(
226 builtin_block_ty_id,
227 bi_index,
228 other,
229 &[],
230 ));
231 }
232 }
233 _ => unreachable!(),
234 }
235 bi_index += 1;
236 }
237 }
238 ins.to_words(&mut self.logical_layout.declarations);
239 decorations.push(Instruction::decorate(
240 builtin_block_ty_id,
241 spirv::Decoration::Block,
242 &[],
243 ));
244 for dec in decorations {
245 dec.to_words(&mut self.logical_layout.annotations);
246 }
247 let v =
248 self.write_mesh_return_global_variable(builtin_block_ty_id, vert_array_size_id)?;
249 iface.varying_ids.push(v);
250 if self.flags.contains(WriterFlags::DEBUG) {
251 self.debugs
252 .push(Instruction::name(v, "naga_vertex_builtin_outputs"));
253 }
254 mesh_return_info.vertex_info.builtin_block = Some(v);
255 }
256 if mesh_return_info
258 .primitive_info
259 .struct_members
260 .iter()
261 .any(|a| {
262 !matches!(
263 a.binding,
264 crate::Binding::BuiltIn(
265 crate::BuiltIn::PointIndex
266 | crate::BuiltIn::LineIndices
267 | crate::BuiltIn::TriangleIndices
268 ) | crate::Binding::Location { .. }
269 )
270 })
271 {
272 let builtin_block_ty_id = self.id_gen.next();
273 let mut ins = Instruction::type_struct(builtin_block_ty_id, &[]);
274 let mut bi_index = 0;
275 let mut decorations = Vec::new();
276 for member in &mesh_return_info.primitive_info.struct_members {
277 if let crate::Binding::BuiltIn(bi) = member.binding {
278 if matches!(
280 bi,
281 crate::BuiltIn::PointIndex
282 | crate::BuiltIn::LineIndices
283 | crate::BuiltIn::TriangleIndices,
284 ) {
285 continue;
286 }
287 ins.add_operand(member.ty_id);
288 let binding = self.map_binding(
289 ir_module,
290 iface.stage,
291 spirv::StorageClass::Output,
292 Handle::new(NonMaxU32::new(0).unwrap()),
294 &member.binding,
295 )?;
296 match binding {
297 BindingDecorations::BuiltIn(bi, others) => {
298 decorations.push(Instruction::member_decorate(
299 builtin_block_ty_id,
300 bi_index,
301 spirv::Decoration::BuiltIn,
302 &[bi as Word],
303 ));
304 for other in others {
305 decorations.push(Instruction::member_decorate(
306 builtin_block_ty_id,
307 bi_index,
308 other,
309 &[],
310 ));
311 }
312 }
313 _ => unreachable!(),
314 }
315 bi_index += 1;
316 }
317 }
318 ins.to_words(&mut self.logical_layout.declarations);
319 decorations.push(Instruction::decorate(
320 builtin_block_ty_id,
321 spirv::Decoration::Block,
322 &[],
323 ));
324 for dec in decorations {
325 dec.to_words(&mut self.logical_layout.annotations);
326 }
327 let v =
328 self.write_mesh_return_global_variable(builtin_block_ty_id, prim_array_size_id)?;
329 Instruction::decorate(v, spirv::Decoration::PerPrimitiveEXT, &[])
330 .to_words(&mut self.logical_layout.annotations);
331 iface.varying_ids.push(v);
332 if self.flags.contains(WriterFlags::DEBUG) {
333 self.debugs
334 .push(Instruction::name(v, "naga_primitive_builtin_outputs"));
335 }
336 mesh_return_info.primitive_info.builtin_block = Some(v);
337 }
338
339 for member in &mesh_return_info.vertex_info.struct_members {
341 match member.binding {
342 crate::Binding::Location { location, .. } => {
343 let v =
345 self.write_mesh_return_global_variable(member.ty_id, vert_array_size_id)?;
346 Instruction::decorate(v, spirv::Decoration::Location, &[location])
348 .to_words(&mut self.logical_layout.annotations);
349 iface.varying_ids.push(v);
350 mesh_return_info.vertex_info.bindings.push(v);
351 }
352 crate::Binding::BuiltIn(_) => (),
353 }
354 }
355 for member in &mesh_return_info.primitive_info.struct_members {
358 match member.binding {
359 crate::Binding::BuiltIn(
360 crate::BuiltIn::PointIndex
361 | crate::BuiltIn::LineIndices
362 | crate::BuiltIn::TriangleIndices,
363 ) => {
364 let v =
366 self.write_mesh_return_global_variable(member.ty_id, prim_array_size_id)?;
367 Instruction::decorate(
369 v,
370 spirv::Decoration::BuiltIn,
371 &[match member.binding.to_built_in().unwrap() {
372 crate::BuiltIn::PointIndex => spirv::BuiltIn::PrimitivePointIndicesEXT,
373 crate::BuiltIn::LineIndices => spirv::BuiltIn::PrimitiveLineIndicesEXT,
374 crate::BuiltIn::TriangleIndices => {
375 spirv::BuiltIn::PrimitiveTriangleIndicesEXT
376 }
377 _ => unreachable!(),
378 } as Word],
379 )
380 .to_words(&mut self.logical_layout.annotations);
381 iface.varying_ids.push(v);
382 if self.flags.contains(WriterFlags::DEBUG) {
383 self.debugs
384 .push(Instruction::name(v, "naga_primitive_indices_outputs"));
385 }
386 mesh_return_info.primitive_indices = Some(v);
387 }
388 crate::Binding::Location { location, .. } => {
389 let v =
391 self.write_mesh_return_global_variable(member.ty_id, prim_array_size_id)?;
392 Instruction::decorate(v, spirv::Decoration::Location, &[location])
394 .to_words(&mut self.logical_layout.annotations);
395 Instruction::decorate(v, spirv::Decoration::PerPrimitiveEXT, &[])
397 .to_words(&mut self.logical_layout.annotations);
398 iface.varying_ids.push(v);
399
400 mesh_return_info.primitive_info.bindings.push(v);
401 }
402 crate::Binding::BuiltIn(_) => (),
403 }
404 }
405
406 ep_context.mesh_state = Some(mesh_return_info);
408
409 Ok(())
410 }
411
412 pub(super) fn write_entry_point_task_return(
413 &mut self,
414 value_id: Word,
415 body: &mut Vec<Instruction>,
416 task_payload: Word,
417 ) -> Result<Instruction, Error> {
418 let values = [self.id_gen.next(), self.id_gen.next(), self.id_gen.next()];
423 for (i, &value) in values.iter().enumerate() {
424 let instruction = Instruction::composite_extract(
425 self.get_u32_type_id(),
426 value,
427 value_id,
428 &[i as Word],
429 );
430 body.push(instruction);
431 }
432 let mut instruction = Instruction::new(spirv::Op::EmitMeshTasksEXT);
433 for id in values {
434 instruction.add_operand(id);
435 }
436 instruction.add_operand(task_payload);
438 Ok(instruction)
439 }
440
441 #[allow(clippy::too_many_arguments)]
443 fn write_mesh_copy_loop(
444 &mut self,
445 body: &mut Vec<Instruction>,
446 mut loop_body_block: Vec<Instruction>,
447 loop_header: u32,
448 loop_merge: u32,
449 count_id: u32,
450 index_var: u32,
451 return_info: &MeshReturnInfo,
452 ) {
453 let u32_id = self.get_u32_type_id();
454 let condition_check = self.id_gen.next();
455 let loop_continue = self.id_gen.next();
456 let loop_body = self.id_gen.next();
457
458 {
460 body.push(Instruction::label(loop_header));
461 body.push(Instruction::loop_merge(
462 loop_merge,
463 loop_continue,
464 spirv::SelectionControl::empty(),
465 ));
466 body.push(Instruction::branch(condition_check));
467 }
468 {
470 body.push(Instruction::label(condition_check));
471
472 let val_i = self.id_gen.next();
473 body.push(Instruction::load(u32_id, val_i, index_var, None));
474
475 let cond = self.id_gen.next();
476 body.push(Instruction::binary(
477 spirv::Op::ULessThan,
478 self.get_bool_type_id(),
479 cond,
480 val_i,
481 count_id,
482 ));
483 body.push(Instruction::branch_conditional(cond, loop_body, loop_merge));
484 }
485 {
487 body.push(Instruction::label(loop_body));
488 body.append(&mut loop_body_block);
489 body.push(Instruction::branch(loop_continue));
490 }
491 {
493 body.push(Instruction::label(loop_continue));
494
495 let prev_val_i = self.id_gen.next();
496 body.push(Instruction::load(u32_id, prev_val_i, index_var, None));
497 let new_val_i = self.id_gen.next();
498 body.push(Instruction::binary(
499 spirv::Op::IAdd,
500 u32_id,
501 new_val_i,
502 prev_val_i,
503 return_info.workgroup_size,
504 ));
505 body.push(Instruction::store(index_var, new_val_i, None));
506
507 body.push(Instruction::branch(loop_header));
508 }
509 }
510
511 fn write_mesh_copy_body(
514 &mut self,
515 is_primitive: bool,
516 return_info: &MeshReturnInfo,
517 index_var: u32,
518 vert_array_ptr: u32,
519 prim_array_ptr: u32,
520 ) -> Vec<Instruction> {
521 let u32_type_id = self.get_u32_type_id();
522 let mut body = Vec::new();
523 let val_i = self.id_gen.next();
525 body.push(Instruction::load(u32_type_id, val_i, index_var, None));
526
527 let info = if is_primitive {
528 &return_info.primitive_info
529 } else {
530 &return_info.vertex_info
531 };
532 let array_ptr = if is_primitive {
533 prim_array_ptr
534 } else {
535 vert_array_ptr
536 };
537
538 let mut builtin_index = 0;
539 let mut binding_index = 0;
540 for (member_id, member) in info.struct_members.iter().enumerate() {
542 let val_to_copy_ptr = self.id_gen.next();
543 body.push(Instruction::access_chain(
544 self.get_pointer_type_id(member.ty_id, spirv::StorageClass::Workgroup),
545 val_to_copy_ptr,
546 array_ptr,
547 &[
548 val_i,
549 self.get_constant_scalar(crate::Literal::U32(member_id as u32)),
550 ],
551 ));
552 let val_to_copy = self.id_gen.next();
553 body.push(Instruction::load(
554 member.ty_id,
555 val_to_copy,
556 val_to_copy_ptr,
557 None,
558 ));
559 let mut needs_y_flip = false;
560 let ptr_to_copy_to = self.id_gen.next();
561 match member.binding {
563 crate::Binding::BuiltIn(
564 crate::BuiltIn::PointIndex
565 | crate::BuiltIn::LineIndices
566 | crate::BuiltIn::TriangleIndices,
567 ) => {
568 body.push(Instruction::access_chain(
569 self.get_pointer_type_id(member.ty_id, spirv::StorageClass::Output),
570 ptr_to_copy_to,
571 return_info.primitive_indices.unwrap(),
572 &[val_i],
573 ));
574 }
575 crate::Binding::BuiltIn(bi) => {
576 body.push(Instruction::access_chain(
577 self.get_pointer_type_id(member.ty_id, spirv::StorageClass::Output),
578 ptr_to_copy_to,
579 info.builtin_block.unwrap(),
580 &[
581 val_i,
582 self.get_constant_scalar(crate::Literal::U32(builtin_index)),
583 ],
584 ));
585 needs_y_flip = matches!(bi, crate::BuiltIn::Position { .. })
586 && self.flags.contains(WriterFlags::ADJUST_COORDINATE_SPACE);
587 builtin_index += 1;
588 }
589 crate::Binding::Location { .. } => {
590 body.push(Instruction::access_chain(
591 self.get_pointer_type_id(member.ty_id, spirv::StorageClass::Output),
592 ptr_to_copy_to,
593 info.bindings[binding_index],
594 &[val_i],
595 ));
596 binding_index += 1;
597 }
598 }
599 body.push(Instruction::store(ptr_to_copy_to, val_to_copy, None));
600 if needs_y_flip {
603 let prev_y = self.id_gen.next();
604 body.push(Instruction::composite_extract(
605 self.get_f32_type_id(),
606 prev_y,
607 val_to_copy,
608 &[1],
609 ));
610 let new_y = self.id_gen.next();
611 body.push(Instruction::unary(
612 spirv::Op::FNegate,
613 self.get_f32_type_id(),
614 new_y,
615 prev_y,
616 ));
617 let new_ptr_to_copy_to = self.id_gen.next();
618 body.push(Instruction::access_chain(
619 self.get_f32_pointer_type_id(spirv::StorageClass::Output),
620 new_ptr_to_copy_to,
621 ptr_to_copy_to,
622 &[self.get_constant_scalar(crate::Literal::U32(1))],
623 ));
624 body.push(Instruction::store(new_ptr_to_copy_to, new_y, None));
625 }
626 }
627 body
628 }
629
630 pub(super) fn write_mesh_shader_return(
633 &mut self,
634 return_info: &MeshReturnInfo,
635 block: &mut Block,
636 loop_counter_vertices: u32,
637 loop_counter_primitives: u32,
638 local_invocation_index_id: Word,
639 ) -> Result<(), Error> {
640 let u32_id = self.get_u32_type_id();
641
642 let mut load_u32_by_member_index =
644 |members: &[MeshReturnMember], bi: crate::BuiltIn, max: u32| {
645 let member_index = members
646 .iter()
647 .position(|a| a.binding == crate::Binding::BuiltIn(bi))
648 .unwrap() as u32;
649 let ptr_id = self.id_gen.next();
650 block.body.push(Instruction::access_chain(
651 self.get_pointer_type_id(u32_id, spirv::StorageClass::Workgroup),
652 ptr_id,
653 return_info.out_variable_id,
654 &[self.get_constant_scalar(crate::Literal::U32(member_index))],
655 ));
656 let before_min_id = self.id_gen.next();
657 block
658 .body
659 .push(Instruction::load(u32_id, before_min_id, ptr_id, None));
660
661 let id = self.id_gen.next();
663 block.body.push(Instruction::ext_inst_gl_op(
664 self.gl450_ext_inst_id,
665 spirv::GLOp::UMin,
666 u32_id,
667 id,
668 &[before_min_id, max],
669 ));
670 id
671 };
672 let vert_count_id = load_u32_by_member_index(
673 &return_info.out_members,
674 crate::BuiltIn::VertexCount,
675 return_info.vertex_info.max_length_constant,
676 );
677 let prim_count_id = load_u32_by_member_index(
678 &return_info.out_members,
679 crate::BuiltIn::PrimitiveCount,
680 return_info.primitive_info.max_length_constant,
681 );
682
683 let mut get_array_ptr = |bi: crate::BuiltIn, array_type_id: u32| {
685 let id = self.id_gen.next();
686 block.body.push(Instruction::access_chain(
687 self.get_pointer_type_id(array_type_id, spirv::StorageClass::Workgroup),
688 id,
689 return_info.out_variable_id,
690 &[self.get_constant_scalar(crate::Literal::U32(
691 return_info
692 .out_members
693 .iter()
694 .position(|a| a.binding == crate::Binding::BuiltIn(bi))
695 .unwrap() as u32,
696 ))],
697 ));
698 id
699 };
700 let vert_array_ptr = get_array_ptr(
701 crate::BuiltIn::Vertices,
702 return_info.vertex_info.array_type_id,
703 );
704 let prim_array_ptr = get_array_ptr(
705 crate::BuiltIn::Primitives,
706 return_info.primitive_info.array_type_id,
707 );
708
709 {
711 let mut ins = Instruction::new(spirv::Op::SetMeshOutputsEXT);
712 ins.add_operand(vert_count_id);
713 ins.add_operand(prim_count_id);
714 block.body.push(ins);
715 }
716
717 let vertex_loop_header = self.id_gen.next();
720 let prim_loop_header = self.id_gen.next();
721 let in_between_loops = self.id_gen.next();
722 let func_end = self.id_gen.next();
723
724 block.body.push(Instruction::store(
725 loop_counter_vertices,
726 local_invocation_index_id,
727 None,
728 ));
729 block.body.push(Instruction::branch(vertex_loop_header));
730
731 let vertex_copy_body = self.write_mesh_copy_body(
732 false,
733 return_info,
734 loop_counter_vertices,
735 vert_array_ptr,
736 prim_array_ptr,
737 );
738 self.write_mesh_copy_loop(
740 &mut block.body,
741 vertex_copy_body,
742 vertex_loop_header,
743 in_between_loops,
744 vert_count_id,
745 loop_counter_vertices,
746 return_info,
747 );
748
749 {
751 block.body.push(Instruction::label(in_between_loops));
752
753 block.body.push(Instruction::store(
754 loop_counter_primitives,
755 local_invocation_index_id,
756 None,
757 ));
758
759 block.body.push(Instruction::branch(prim_loop_header));
760 }
761 let primitive_copy_body = self.write_mesh_copy_body(
762 true,
763 return_info,
764 loop_counter_primitives,
765 vert_array_ptr,
766 prim_array_ptr,
767 );
768 self.write_mesh_copy_loop(
770 &mut block.body,
771 primitive_copy_body,
772 prim_loop_header,
773 func_end,
774 prim_count_id,
775 loop_counter_primitives,
776 return_info,
777 );
778
779 block.body.push(Instruction::label(func_end));
780 Ok(())
781 }
782
783 pub(super) fn write_mesh_shader_wrapper(
784 &mut self,
785 return_info: &MeshReturnInfo,
786 inner_id: u32,
787 ) -> Result<u32, Error> {
788 let out_id = self.id_gen.next();
789 let mut function = super::Function::default();
790 let lookup_function_type = super::LookupFunctionType {
791 parameter_type_ids: alloc::vec![],
792 return_type_id: self.void_type,
793 };
794 let function_type = self.get_function_type(lookup_function_type);
795 function.signature = Some(Instruction::function(
796 self.void_type,
797 out_id,
798 spirv::FunctionControl::empty(),
799 function_type,
800 ));
801 let u32_id = self.get_u32_type_id();
802 {
803 let mut block = Block::new(self.id_gen.next());
804 let loop_counter_vertices = self.id_gen.next();
808 let loop_counter_primitives = self.id_gen.next();
809 block.body.insert(
810 0,
811 Instruction::variable(
812 self.get_pointer_type_id(u32_id, spirv::StorageClass::Function),
813 loop_counter_vertices,
814 spirv::StorageClass::Function,
815 None,
816 ),
817 );
818 block.body.insert(
819 1,
820 Instruction::variable(
821 self.get_pointer_type_id(u32_id, spirv::StorageClass::Function),
822 loop_counter_primitives,
823 spirv::StorageClass::Function,
824 None,
825 ),
826 );
827 let local_invocation_index_id = self.id_gen.next();
828 block.body.push(Instruction::load(
829 u32_id,
830 local_invocation_index_id,
831 return_info.local_invocation_index_var_id,
832 None,
833 ));
834 block.body.push(Instruction::function_call(
835 self.void_type,
836 self.id_gen.next(),
837 inner_id,
838 &[],
839 ));
840 self.write_control_barrier(crate::Barrier::WORK_GROUP, &mut block.body);
841 self.write_mesh_shader_return(
842 return_info,
843 &mut block,
844 loop_counter_vertices,
845 loop_counter_primitives,
846 local_invocation_index_id,
847 )?;
848 function.consume(block, Instruction::return_void());
849 }
850 function.to_words(&mut self.logical_layout.function_definitions);
851 Ok(out_id)
852 }
853
854 pub(super) fn write_task_shader_wrapper(
855 &mut self,
856 task_payload: Word,
857 inner_id: u32,
858 ) -> Result<u32, Error> {
859 let out_id = self.id_gen.next();
860 let mut function = super::Function::default();
861 let lookup_function_type = super::LookupFunctionType {
862 parameter_type_ids: alloc::vec![],
863 return_type_id: self.void_type,
864 };
865 let function_type = self.get_function_type(lookup_function_type);
866 function.signature = Some(Instruction::function(
867 self.void_type,
868 out_id,
869 spirv::FunctionControl::empty(),
870 function_type,
871 ));
872
873 {
874 let mut block = Block::new(self.id_gen.next());
875 let result = self.id_gen.next();
876 block.body.push(Instruction::function_call(
877 self.get_vec3u_type_id(),
878 result,
879 inner_id,
880 &[],
881 ));
882 self.write_control_barrier(crate::Barrier::WORK_GROUP, &mut block.body);
883 let final_value = if let Some(task_limits) = self.task_dispatch_limits {
884 let zero_u32 = self.get_constant_scalar(crate::Literal::U32(0));
885 let max_per_dim = self.get_constant_scalar(crate::Literal::U32(
886 task_limits.max_mesh_workgroups_per_dim,
887 ));
888 let max_total = self.get_constant_scalar(crate::Literal::U32(
889 task_limits.max_mesh_workgroups_total,
890 ));
891 let combined_struct_type = self.get_tuple_of_u32s_ty_id();
892 let values = [self.id_gen.next(), self.id_gen.next(), self.id_gen.next()];
893 for (i, value) in values.into_iter().enumerate() {
894 block.body.push(Instruction::composite_extract(
895 self.get_u32_type_id(),
896 value,
897 result,
898 &[i as u32],
899 ));
900 }
901 let prod_1 = self.id_gen.next();
902 let overflows = [self.id_gen.next(), self.id_gen.next()];
903 {
904 let struct_out = self.id_gen.next();
905 block.body.push(Instruction::binary(
906 spirv::Op::UMulExtended,
907 combined_struct_type,
908 struct_out,
909 values[0],
910 values[1],
911 ));
912 block.body.push(Instruction::composite_extract(
913 self.get_u32_type_id(),
914 prod_1,
915 struct_out,
916 &[0],
917 ));
918 block.body.push(Instruction::composite_extract(
919 self.get_u32_type_id(),
920 overflows[0],
921 struct_out,
922 &[1],
923 ));
924 }
925 let prod_final = self.id_gen.next();
926 {
927 let struct_out = self.id_gen.next();
928 block.body.push(Instruction::binary(
929 spirv::Op::UMulExtended,
930 combined_struct_type,
931 struct_out,
932 prod_1,
933 values[2],
934 ));
935 block.body.push(Instruction::composite_extract(
936 self.get_u32_type_id(),
937 prod_final,
938 struct_out,
939 &[0],
940 ));
941 block.body.push(Instruction::composite_extract(
942 self.get_u32_type_id(),
943 overflows[1],
944 struct_out,
945 &[1],
946 ));
947 }
948 let total_too_large = self.id_gen.next();
949 block.body.push(Instruction::binary(
950 spirv::Op::UGreaterThan,
951 self.get_bool_type_id(),
952 total_too_large,
953 prod_final,
954 max_total,
955 ));
956
957 let too_large = [self.id_gen.next(), self.id_gen.next(), self.id_gen.next()];
958 for (i, value) in values.into_iter().enumerate() {
959 block.body.push(Instruction::binary(
960 spirv::Op::UGreaterThan,
961 self.get_bool_type_id(),
962 too_large[i],
963 value,
964 max_per_dim,
965 ));
966 }
967 let overflow_happens = [self.id_gen.next(), self.id_gen.next()];
968 for (i, value) in overflows.into_iter().enumerate() {
969 block.body.push(Instruction::binary(
970 spirv::Op::INotEqual,
971 self.get_bool_type_id(),
972 overflow_happens[i],
973 value,
974 zero_u32,
975 ));
976 }
977 let mut current_violates_limits = total_too_large;
978 for is_too_large in too_large {
979 let new = self.id_gen.next();
980 block.body.push(Instruction::binary(
981 spirv::Op::LogicalOr,
982 self.get_bool_type_id(),
983 new,
984 current_violates_limits,
985 is_too_large,
986 ));
987 current_violates_limits = new;
988 }
989 for overflow_happens in overflow_happens {
990 let new = self.id_gen.next();
991 block.body.push(Instruction::binary(
992 spirv::Op::LogicalOr,
993 self.get_bool_type_id(),
994 new,
995 current_violates_limits,
996 overflow_happens,
997 ));
998 current_violates_limits = new;
999 }
1000 let zero_vec3 = self.id_gen.next();
1001 block.body.push(Instruction::composite_construct(
1002 self.get_vec3u_type_id(),
1003 zero_vec3,
1004 &[zero_u32, zero_u32, zero_u32],
1005 ));
1006 let final_result = self.id_gen.next();
1007 block.body.push(Instruction::select(
1008 self.get_vec3u_type_id(),
1009 final_result,
1010 current_violates_limits,
1011 zero_vec3,
1012 result,
1013 ));
1014 final_result
1015 } else {
1016 result
1017 };
1018 let ins =
1019 self.write_entry_point_task_return(final_value, &mut block.body, task_payload)?;
1020 function.consume(block, ins);
1021 }
1022 function.to_words(&mut self.logical_layout.function_definitions);
1023 Ok(out_id)
1024 }
1025}