1use alloc::vec::Vec;
2use spirv::Word;
3
4use crate::{
5 back::spv::{
6 helpers::BindingDecorations, writer::FunctionInterface, Block, EntryPointContext, Error,
7 Instruction, WriterFlags,
8 },
9 non_max_u32::NonMaxU32,
10 Handle,
11};
12
13#[derive(Clone, Debug)]
14pub struct MeshReturnMember {
15 pub ty_id: u32,
16 pub binding: crate::Binding,
17}
18
19#[derive(Debug)]
20struct PerOutputTypeMeshReturnInfo {
21 max_length_constant: Word,
22 array_type_id: Word,
23 struct_members: Vec<MeshReturnMember>,
24
25 builtin_block: Option<Word>,
31 bindings: Vec<Word>,
32}
33
34#[derive(Debug)]
35pub struct MeshReturnInfo {
36 out_variable_id: Word,
38 out_members: Vec<MeshReturnMember>,
40 local_invocation_index_var_id: Word,
42 workgroup_size: u32,
44
45 vertex_info: PerOutputTypeMeshReturnInfo,
47 primitive_info: PerOutputTypeMeshReturnInfo,
49 primitive_indices: Option<Word>,
51}
52
53impl super::Writer {
54 pub(super) fn write_mesh_return_global_variable(
56 &mut self,
57 ty: u32,
58 array_size_id: u32,
59 ) -> Result<Word, Error> {
60 let array_ty = self.id_gen.next();
61 Instruction::type_array(array_ty, ty, array_size_id)
62 .to_words(&mut self.logical_layout.declarations);
63 let ptr_ty = self.get_pointer_type_id(array_ty, spirv::StorageClass::Output);
64 let var_id = self.id_gen.next();
65 Instruction::variable(ptr_ty, var_id, spirv::StorageClass::Output, None)
66 .to_words(&mut self.logical_layout.declarations);
67 Ok(var_id)
68 }
69
70 pub(super) fn write_entry_point_mesh_shader_info(
73 &mut self,
74 iface: &mut FunctionInterface,
75 local_invocation_index_id: Option<Word>,
76 ir_module: &crate::Module,
77 ep_context: &mut EntryPointContext,
78 ) -> Result<(), Error> {
79 let Some(ref mesh_info) = iface.mesh_info else {
80 return Ok(());
81 };
82 let out_members: Vec<MeshReturnMember> =
84 match &ir_module.types[ir_module.global_variables[mesh_info.output_variable].ty] {
85 &crate::Type {
86 inner: crate::TypeInner::Struct { ref members, .. },
87 ..
88 } => members
89 .iter()
90 .map(|a| MeshReturnMember {
91 ty_id: self.get_handle_type_id(a.ty),
92 binding: a.binding.clone().unwrap(),
93 })
94 .collect(),
95 _ => unreachable!(),
96 };
97 let vertex_array_type_id = out_members
98 .iter()
99 .find(|a| a.binding == crate::Binding::BuiltIn(crate::BuiltIn::Vertices))
100 .unwrap()
101 .ty_id;
102 let primitive_array_type_id = out_members
103 .iter()
104 .find(|a| a.binding == crate::Binding::BuiltIn(crate::BuiltIn::Primitives))
105 .unwrap()
106 .ty_id;
107 let vertex_members = match &ir_module.types[mesh_info.vertex_output_type] {
108 &crate::Type {
109 inner: crate::TypeInner::Struct { ref members, .. },
110 ..
111 } => members
112 .iter()
113 .map(|a| MeshReturnMember {
114 ty_id: self.get_handle_type_id(a.ty),
115 binding: a.binding.clone().unwrap(),
116 })
117 .collect(),
118 _ => unreachable!(),
119 };
120 let primitive_members = match &ir_module.types[mesh_info.primitive_output_type] {
121 &crate::Type {
122 inner: crate::TypeInner::Struct { ref members, .. },
123 ..
124 } => members
125 .iter()
126 .map(|a| MeshReturnMember {
127 ty_id: self.get_handle_type_id(a.ty),
128 binding: a.binding.clone().unwrap(),
129 })
130 .collect(),
131 _ => unreachable!(),
132 };
133 let local_invocation_index_var_id = match local_invocation_index_id {
135 Some(a) => a,
136 None => {
137 let u32_id = self.get_u32_type_id();
138 let var = self.id_gen.next();
139 Instruction::variable(
140 self.get_pointer_type_id(u32_id, spirv::StorageClass::Input),
141 var,
142 spirv::StorageClass::Input,
143 None,
144 )
145 .to_words(&mut self.logical_layout.declarations);
146 Instruction::decorate(
147 var,
148 spirv::Decoration::BuiltIn,
149 &[spirv::BuiltIn::LocalInvocationIndex as u32],
150 )
151 .to_words(&mut self.logical_layout.annotations);
152 iface.varying_ids.push(var);
153
154 var
155 }
156 };
157 let mut mesh_return_info = MeshReturnInfo {
160 out_variable_id: self.global_variables[mesh_info.output_variable].var_id,
161 out_members,
162 local_invocation_index_var_id,
163 workgroup_size: self
164 .get_constant_scalar(crate::Literal::U32(iface.workgroup_size.iter().product())),
165
166 vertex_info: PerOutputTypeMeshReturnInfo {
167 array_type_id: vertex_array_type_id,
168 struct_members: vertex_members,
169 max_length_constant: self
170 .get_constant_scalar(crate::Literal::U32(mesh_info.max_vertices)),
171 bindings: Vec::new(),
172 builtin_block: None,
173 },
174 primitive_info: PerOutputTypeMeshReturnInfo {
175 array_type_id: primitive_array_type_id,
176 struct_members: primitive_members,
177 max_length_constant: self
178 .get_constant_scalar(crate::Literal::U32(mesh_info.max_primitives)),
179 bindings: Vec::new(),
180 builtin_block: None,
181 },
182 primitive_indices: None,
183 };
184 let vert_array_size_id =
185 self.get_constant_scalar(crate::Literal::U32(mesh_info.max_vertices));
186 let prim_array_size_id =
187 self.get_constant_scalar(crate::Literal::U32(mesh_info.max_primitives));
188
189 if mesh_return_info
198 .vertex_info
199 .struct_members
200 .iter()
201 .any(|a| matches!(a.binding, crate::Binding::BuiltIn(..)))
202 {
203 let builtin_block_ty_id = self.id_gen.next();
204 let mut ins = Instruction::type_struct(builtin_block_ty_id, &[]);
205 let mut bi_index = 0;
206 let mut decorations = Vec::new();
207 for member in &mesh_return_info.vertex_info.struct_members {
208 if let crate::Binding::BuiltIn(_) = member.binding {
209 ins.add_operand(member.ty_id);
210 let binding = self.map_binding(
211 ir_module,
212 iface.stage,
213 spirv::StorageClass::Output,
214 Handle::new(NonMaxU32::new(0).unwrap()),
216 &member.binding,
217 )?;
218 match binding {
219 BindingDecorations::BuiltIn(bi, others) => {
220 decorations.push(Instruction::member_decorate(
221 builtin_block_ty_id,
222 bi_index,
223 spirv::Decoration::BuiltIn,
224 &[bi as Word],
225 ));
226 for other in others {
227 decorations.push(Instruction::member_decorate(
228 builtin_block_ty_id,
229 bi_index,
230 other,
231 &[],
232 ));
233 }
234 }
235 _ => unreachable!(),
236 }
237 bi_index += 1;
238 }
239 }
240 ins.to_words(&mut self.logical_layout.declarations);
241 decorations.push(Instruction::decorate(
242 builtin_block_ty_id,
243 spirv::Decoration::Block,
244 &[],
245 ));
246 for dec in decorations {
247 dec.to_words(&mut self.logical_layout.annotations);
248 }
249 let v =
250 self.write_mesh_return_global_variable(builtin_block_ty_id, vert_array_size_id)?;
251 iface.varying_ids.push(v);
252 if self.flags.contains(WriterFlags::DEBUG) {
253 self.debugs
254 .push(Instruction::name(v, "naga_vertex_builtin_outputs"));
255 }
256 mesh_return_info.vertex_info.builtin_block = Some(v);
257 }
258 if mesh_return_info
260 .primitive_info
261 .struct_members
262 .iter()
263 .any(|a| {
264 !matches!(
265 a.binding,
266 crate::Binding::BuiltIn(
267 crate::BuiltIn::PointIndex
268 | crate::BuiltIn::LineIndices
269 | crate::BuiltIn::TriangleIndices
270 ) | crate::Binding::Location { .. }
271 )
272 })
273 {
274 let builtin_block_ty_id = self.id_gen.next();
275 let mut ins = Instruction::type_struct(builtin_block_ty_id, &[]);
276 let mut bi_index = 0;
277 let mut decorations = Vec::new();
278 for member in &mesh_return_info.primitive_info.struct_members {
279 if let crate::Binding::BuiltIn(bi) = member.binding {
280 if matches!(
282 bi,
283 crate::BuiltIn::PointIndex
284 | crate::BuiltIn::LineIndices
285 | crate::BuiltIn::TriangleIndices,
286 ) {
287 continue;
288 }
289 ins.add_operand(member.ty_id);
290 let binding = self.map_binding(
291 ir_module,
292 iface.stage,
293 spirv::StorageClass::Output,
294 Handle::new(NonMaxU32::new(0).unwrap()),
296 &member.binding,
297 )?;
298 match binding {
299 BindingDecorations::BuiltIn(bi, others) => {
300 decorations.push(Instruction::member_decorate(
301 builtin_block_ty_id,
302 bi_index,
303 spirv::Decoration::BuiltIn,
304 &[bi as Word],
305 ));
306 for other in others {
307 decorations.push(Instruction::member_decorate(
308 builtin_block_ty_id,
309 bi_index,
310 other,
311 &[],
312 ));
313 }
314 }
315 _ => unreachable!(),
316 }
317 bi_index += 1;
318 }
319 }
320 ins.to_words(&mut self.logical_layout.declarations);
321 decorations.push(Instruction::decorate(
322 builtin_block_ty_id,
323 spirv::Decoration::Block,
324 &[],
325 ));
326 for dec in decorations {
327 dec.to_words(&mut self.logical_layout.annotations);
328 }
329 let v =
330 self.write_mesh_return_global_variable(builtin_block_ty_id, prim_array_size_id)?;
331 Instruction::decorate(v, spirv::Decoration::PerPrimitiveEXT, &[])
332 .to_words(&mut self.logical_layout.annotations);
333 iface.varying_ids.push(v);
334 if self.flags.contains(WriterFlags::DEBUG) {
335 self.debugs
336 .push(Instruction::name(v, "naga_primitive_builtin_outputs"));
337 }
338 mesh_return_info.primitive_info.builtin_block = Some(v);
339 }
340
341 for member in &mesh_return_info.vertex_info.struct_members {
343 match member.binding {
344 crate::Binding::Location { location, .. } => {
345 let v =
347 self.write_mesh_return_global_variable(member.ty_id, vert_array_size_id)?;
348 Instruction::decorate(v, spirv::Decoration::Location, &[location])
350 .to_words(&mut self.logical_layout.annotations);
351 iface.varying_ids.push(v);
352 mesh_return_info.vertex_info.bindings.push(v);
353 }
354 crate::Binding::BuiltIn(_) => (),
355 }
356 }
357 for member in &mesh_return_info.primitive_info.struct_members {
360 match member.binding {
361 crate::Binding::BuiltIn(
362 crate::BuiltIn::PointIndex
363 | crate::BuiltIn::LineIndices
364 | crate::BuiltIn::TriangleIndices,
365 ) => {
366 let v =
368 self.write_mesh_return_global_variable(member.ty_id, prim_array_size_id)?;
369 Instruction::decorate(
371 v,
372 spirv::Decoration::BuiltIn,
373 &[match member.binding.to_built_in().unwrap() {
374 crate::BuiltIn::PointIndex => spirv::BuiltIn::PrimitivePointIndicesEXT,
375 crate::BuiltIn::LineIndices => spirv::BuiltIn::PrimitiveLineIndicesEXT,
376 crate::BuiltIn::TriangleIndices => {
377 spirv::BuiltIn::PrimitiveTriangleIndicesEXT
378 }
379 _ => unreachable!(),
380 } as Word],
381 )
382 .to_words(&mut self.logical_layout.annotations);
383 iface.varying_ids.push(v);
384 if self.flags.contains(WriterFlags::DEBUG) {
385 self.debugs
386 .push(Instruction::name(v, "naga_primitive_indices_outputs"));
387 }
388 mesh_return_info.primitive_indices = Some(v);
389 }
390 crate::Binding::Location { location, .. } => {
391 let v =
393 self.write_mesh_return_global_variable(member.ty_id, prim_array_size_id)?;
394 Instruction::decorate(v, spirv::Decoration::Location, &[location])
396 .to_words(&mut self.logical_layout.annotations);
397 Instruction::decorate(v, spirv::Decoration::PerPrimitiveEXT, &[])
399 .to_words(&mut self.logical_layout.annotations);
400 iface.varying_ids.push(v);
401
402 mesh_return_info.primitive_info.bindings.push(v);
403 }
404 crate::Binding::BuiltIn(_) => (),
405 }
406 }
407
408 ep_context.mesh_state = Some(mesh_return_info);
410
411 Ok(())
412 }
413
414 pub(super) fn write_entry_point_task_return(
415 &mut self,
416 value_id: Word,
417 body: &mut Vec<Instruction>,
418 task_payload: Word,
419 ) -> Result<Instruction, Error> {
420 let values = [self.id_gen.next(), self.id_gen.next(), self.id_gen.next()];
425 for (i, &value) in values.iter().enumerate() {
426 let instruction = Instruction::composite_extract(
427 self.get_u32_type_id(),
428 value,
429 value_id,
430 &[i as Word],
431 );
432 body.push(instruction);
433 }
434 let mut instruction = Instruction::new(spirv::Op::EmitMeshTasksEXT);
435 for id in values {
436 instruction.add_operand(id);
437 }
438 instruction.add_operand(task_payload);
440 Ok(instruction)
441 }
442
443 #[allow(clippy::too_many_arguments)]
445 fn write_mesh_copy_loop(
446 &mut self,
447 body: &mut Vec<Instruction>,
448 mut loop_body_block: Vec<Instruction>,
449 loop_header: u32,
450 loop_merge: u32,
451 count_id: u32,
452 index_var: u32,
453 return_info: &MeshReturnInfo,
454 ) {
455 let u32_id = self.get_u32_type_id();
456 let condition_check = self.id_gen.next();
457 let loop_continue = self.id_gen.next();
458 let loop_body = self.id_gen.next();
459
460 {
462 body.push(Instruction::label(loop_header));
463 body.push(Instruction::loop_merge(
464 loop_merge,
465 loop_continue,
466 spirv::SelectionControl::empty(),
467 ));
468 body.push(Instruction::branch(condition_check));
469 }
470 {
472 body.push(Instruction::label(condition_check));
473
474 let val_i = self.id_gen.next();
475 body.push(Instruction::load(u32_id, val_i, index_var, None));
476
477 let cond = self.id_gen.next();
478 body.push(Instruction::binary(
479 spirv::Op::ULessThan,
480 self.get_bool_type_id(),
481 cond,
482 val_i,
483 count_id,
484 ));
485 body.push(Instruction::branch_conditional(cond, loop_body, loop_merge));
486 }
487 {
489 body.push(Instruction::label(loop_body));
490 body.append(&mut loop_body_block);
491 body.push(Instruction::branch(loop_continue));
492 }
493 {
495 body.push(Instruction::label(loop_continue));
496
497 let prev_val_i = self.id_gen.next();
498 body.push(Instruction::load(u32_id, prev_val_i, index_var, None));
499 let new_val_i = self.id_gen.next();
500 body.push(Instruction::binary(
501 spirv::Op::IAdd,
502 u32_id,
503 new_val_i,
504 prev_val_i,
505 return_info.workgroup_size,
506 ));
507 body.push(Instruction::store(index_var, new_val_i, None));
508
509 body.push(Instruction::branch(loop_header));
510 }
511 }
512
513 fn write_mesh_copy_body(
516 &mut self,
517 is_primitive: bool,
518 return_info: &MeshReturnInfo,
519 index_var: u32,
520 vert_array_ptr: u32,
521 prim_array_ptr: u32,
522 ) -> Vec<Instruction> {
523 let u32_type_id = self.get_u32_type_id();
524 let mut body = Vec::new();
525 let val_i = self.id_gen.next();
527 body.push(Instruction::load(u32_type_id, val_i, index_var, None));
528
529 let info = if is_primitive {
530 &return_info.primitive_info
531 } else {
532 &return_info.vertex_info
533 };
534 let array_ptr = if is_primitive {
535 prim_array_ptr
536 } else {
537 vert_array_ptr
538 };
539
540 let mut builtin_index = 0;
541 let mut binding_index = 0;
542 for (member_id, member) in info.struct_members.iter().enumerate() {
544 let val_to_copy_ptr = self.id_gen.next();
545 body.push(Instruction::access_chain(
546 self.get_pointer_type_id(member.ty_id, spirv::StorageClass::Workgroup),
547 val_to_copy_ptr,
548 array_ptr,
549 &[
550 val_i,
551 self.get_constant_scalar(crate::Literal::U32(member_id as u32)),
552 ],
553 ));
554 let val_to_copy = self.id_gen.next();
555 body.push(Instruction::load(
556 member.ty_id,
557 val_to_copy,
558 val_to_copy_ptr,
559 None,
560 ));
561 let mut needs_y_flip = false;
562 let ptr_to_copy_to = self.id_gen.next();
563 match member.binding {
565 crate::Binding::BuiltIn(
566 crate::BuiltIn::PointIndex
567 | crate::BuiltIn::LineIndices
568 | crate::BuiltIn::TriangleIndices,
569 ) => {
570 body.push(Instruction::access_chain(
571 self.get_pointer_type_id(member.ty_id, spirv::StorageClass::Output),
572 ptr_to_copy_to,
573 return_info.primitive_indices.unwrap(),
574 &[val_i],
575 ));
576 }
577 crate::Binding::BuiltIn(bi) => {
578 body.push(Instruction::access_chain(
579 self.get_pointer_type_id(member.ty_id, spirv::StorageClass::Output),
580 ptr_to_copy_to,
581 info.builtin_block.unwrap(),
582 &[
583 val_i,
584 self.get_constant_scalar(crate::Literal::U32(builtin_index)),
585 ],
586 ));
587 needs_y_flip = matches!(bi, crate::BuiltIn::Position { .. })
588 && self.flags.contains(WriterFlags::ADJUST_COORDINATE_SPACE);
589 builtin_index += 1;
590 }
591 crate::Binding::Location { .. } => {
592 body.push(Instruction::access_chain(
593 self.get_pointer_type_id(member.ty_id, spirv::StorageClass::Output),
594 ptr_to_copy_to,
595 info.bindings[binding_index],
596 &[val_i],
597 ));
598 binding_index += 1;
599 }
600 }
601 body.push(Instruction::store(ptr_to_copy_to, val_to_copy, None));
602 if needs_y_flip {
605 let prev_y = self.id_gen.next();
606 body.push(Instruction::composite_extract(
607 self.get_f32_type_id(),
608 prev_y,
609 val_to_copy,
610 &[1],
611 ));
612 let new_y = self.id_gen.next();
613 body.push(Instruction::unary(
614 spirv::Op::FNegate,
615 self.get_f32_type_id(),
616 new_y,
617 prev_y,
618 ));
619 let new_ptr_to_copy_to = self.id_gen.next();
620 body.push(Instruction::access_chain(
621 self.get_f32_pointer_type_id(spirv::StorageClass::Output),
622 new_ptr_to_copy_to,
623 ptr_to_copy_to,
624 &[self.get_constant_scalar(crate::Literal::U32(1))],
625 ));
626 body.push(Instruction::store(new_ptr_to_copy_to, new_y, None));
627 }
628 }
629 body
630 }
631
632 pub(super) fn write_mesh_shader_return(
635 &mut self,
636 return_info: &MeshReturnInfo,
637 block: &mut Block,
638 loop_counter_vertices: u32,
639 loop_counter_primitives: u32,
640 local_invocation_index_id: Word,
641 ) -> Result<(), Error> {
642 let u32_id = self.get_u32_type_id();
643
644 let mut load_u32_by_member_index =
646 |members: &[MeshReturnMember], bi: crate::BuiltIn, max: u32| {
647 let member_index = members
648 .iter()
649 .position(|a| a.binding == crate::Binding::BuiltIn(bi))
650 .unwrap() as u32;
651 let ptr_id = self.id_gen.next();
652 block.body.push(Instruction::access_chain(
653 self.get_pointer_type_id(u32_id, spirv::StorageClass::Workgroup),
654 ptr_id,
655 return_info.out_variable_id,
656 &[self.get_constant_scalar(crate::Literal::U32(member_index))],
657 ));
658 let before_min_id = self.id_gen.next();
659 block
660 .body
661 .push(Instruction::load(u32_id, before_min_id, ptr_id, None));
662
663 let id = self.id_gen.next();
665 block.body.push(Instruction::ext_inst_gl_op(
666 self.gl450_ext_inst_id,
667 spirv::GlslStd450Op::UMin,
668 u32_id,
669 id,
670 &[before_min_id, max],
671 ));
672 id
673 };
674 let vert_count_id = load_u32_by_member_index(
675 &return_info.out_members,
676 crate::BuiltIn::VertexCount,
677 return_info.vertex_info.max_length_constant,
678 );
679 let prim_count_id = load_u32_by_member_index(
680 &return_info.out_members,
681 crate::BuiltIn::PrimitiveCount,
682 return_info.primitive_info.max_length_constant,
683 );
684
685 let mut get_array_ptr = |bi: crate::BuiltIn, array_type_id: u32| {
687 let id = self.id_gen.next();
688 block.body.push(Instruction::access_chain(
689 self.get_pointer_type_id(array_type_id, spirv::StorageClass::Workgroup),
690 id,
691 return_info.out_variable_id,
692 &[self.get_constant_scalar(crate::Literal::U32(
693 return_info
694 .out_members
695 .iter()
696 .position(|a| a.binding == crate::Binding::BuiltIn(bi))
697 .unwrap() as u32,
698 ))],
699 ));
700 id
701 };
702 let vert_array_ptr = get_array_ptr(
703 crate::BuiltIn::Vertices,
704 return_info.vertex_info.array_type_id,
705 );
706 let prim_array_ptr = get_array_ptr(
707 crate::BuiltIn::Primitives,
708 return_info.primitive_info.array_type_id,
709 );
710
711 {
713 let mut ins = Instruction::new(spirv::Op::SetMeshOutputsEXT);
714 ins.add_operand(vert_count_id);
715 ins.add_operand(prim_count_id);
716 block.body.push(ins);
717 }
718
719 let vertex_loop_header = self.id_gen.next();
722 let prim_loop_header = self.id_gen.next();
723 let in_between_loops = self.id_gen.next();
724 let func_end = self.id_gen.next();
725
726 block.body.push(Instruction::store(
727 loop_counter_vertices,
728 local_invocation_index_id,
729 None,
730 ));
731 block.body.push(Instruction::branch(vertex_loop_header));
732
733 let vertex_copy_body = self.write_mesh_copy_body(
734 false,
735 return_info,
736 loop_counter_vertices,
737 vert_array_ptr,
738 prim_array_ptr,
739 );
740 self.write_mesh_copy_loop(
742 &mut block.body,
743 vertex_copy_body,
744 vertex_loop_header,
745 in_between_loops,
746 vert_count_id,
747 loop_counter_vertices,
748 return_info,
749 );
750
751 {
753 block.body.push(Instruction::label(in_between_loops));
754
755 block.body.push(Instruction::store(
756 loop_counter_primitives,
757 local_invocation_index_id,
758 None,
759 ));
760
761 block.body.push(Instruction::branch(prim_loop_header));
762 }
763 let primitive_copy_body = self.write_mesh_copy_body(
764 true,
765 return_info,
766 loop_counter_primitives,
767 vert_array_ptr,
768 prim_array_ptr,
769 );
770 self.write_mesh_copy_loop(
772 &mut block.body,
773 primitive_copy_body,
774 prim_loop_header,
775 func_end,
776 prim_count_id,
777 loop_counter_primitives,
778 return_info,
779 );
780
781 block.body.push(Instruction::label(func_end));
782 Ok(())
783 }
784
785 pub(super) fn write_mesh_shader_wrapper(
786 &mut self,
787 return_info: &MeshReturnInfo,
788 inner_id: u32,
789 ) -> Result<u32, Error> {
790 let out_id = self.id_gen.next();
791 let mut function = super::Function::default();
792 let lookup_function_type = super::LookupFunctionType {
793 parameter_type_ids: alloc::vec![],
794 return_type_id: self.void_type,
795 };
796 let function_type = self.get_function_type(lookup_function_type);
797 function.signature = Some(Instruction::function(
798 self.void_type,
799 out_id,
800 spirv::FunctionControl::empty(),
801 function_type,
802 ));
803 let u32_id = self.get_u32_type_id();
804 {
805 let mut block = Block::new(self.id_gen.next());
806 let loop_counter_vertices = self.id_gen.next();
810 let loop_counter_primitives = self.id_gen.next();
811 block.body.insert(
812 0,
813 Instruction::variable(
814 self.get_pointer_type_id(u32_id, spirv::StorageClass::Function),
815 loop_counter_vertices,
816 spirv::StorageClass::Function,
817 None,
818 ),
819 );
820 block.body.insert(
821 1,
822 Instruction::variable(
823 self.get_pointer_type_id(u32_id, spirv::StorageClass::Function),
824 loop_counter_primitives,
825 spirv::StorageClass::Function,
826 None,
827 ),
828 );
829 let local_invocation_index_id = self.id_gen.next();
830 block.body.push(Instruction::load(
831 u32_id,
832 local_invocation_index_id,
833 return_info.local_invocation_index_var_id,
834 None,
835 ));
836 block.body.push(Instruction::function_call(
837 self.void_type,
838 self.id_gen.next(),
839 inner_id,
840 &[],
841 ));
842 self.write_control_barrier(crate::Barrier::WORK_GROUP, &mut block.body);
843 self.write_mesh_shader_return(
844 return_info,
845 &mut block,
846 loop_counter_vertices,
847 loop_counter_primitives,
848 local_invocation_index_id,
849 )?;
850 function.consume(block, Instruction::return_void());
851 }
852 function.to_words(&mut self.logical_layout.function_definitions);
853 Ok(out_id)
854 }
855
856 pub(super) fn write_task_shader_wrapper(
857 &mut self,
858 task_payload: Word,
859 inner_id: u32,
860 ) -> Result<u32, Error> {
861 let out_id = self.id_gen.next();
862 let mut function = super::Function::default();
863 let lookup_function_type = super::LookupFunctionType {
864 parameter_type_ids: alloc::vec![],
865 return_type_id: self.void_type,
866 };
867 let function_type = self.get_function_type(lookup_function_type);
868 function.signature = Some(Instruction::function(
869 self.void_type,
870 out_id,
871 spirv::FunctionControl::empty(),
872 function_type,
873 ));
874
875 {
876 let mut block = Block::new(self.id_gen.next());
877 let result = self.id_gen.next();
878 block.body.push(Instruction::function_call(
879 self.get_vec3u_type_id(),
880 result,
881 inner_id,
882 &[],
883 ));
884 self.write_control_barrier(crate::Barrier::WORK_GROUP, &mut block.body);
885 let final_value = if let Some(task_limits) = self.task_dispatch_limits {
886 let zero_u32 = self.get_constant_scalar(crate::Literal::U32(0));
887 let max_per_dim = self.get_constant_scalar(crate::Literal::U32(
888 task_limits.max_mesh_workgroups_per_dim,
889 ));
890 let max_total = self.get_constant_scalar(crate::Literal::U32(
891 task_limits.max_mesh_workgroups_total,
892 ));
893 let combined_struct_type = self.get_tuple_of_u32s_ty_id();
894 let values = [self.id_gen.next(), self.id_gen.next(), self.id_gen.next()];
895 for (i, value) in values.into_iter().enumerate() {
896 block.body.push(Instruction::composite_extract(
897 self.get_u32_type_id(),
898 value,
899 result,
900 &[i as u32],
901 ));
902 }
903 let prod_1 = self.id_gen.next();
904 let overflows = [self.id_gen.next(), self.id_gen.next()];
905 {
906 let struct_out = self.id_gen.next();
907 block.body.push(Instruction::binary(
908 spirv::Op::UMulExtended,
909 combined_struct_type,
910 struct_out,
911 values[0],
912 values[1],
913 ));
914 block.body.push(Instruction::composite_extract(
915 self.get_u32_type_id(),
916 prod_1,
917 struct_out,
918 &[0],
919 ));
920 block.body.push(Instruction::composite_extract(
921 self.get_u32_type_id(),
922 overflows[0],
923 struct_out,
924 &[1],
925 ));
926 }
927 let prod_final = self.id_gen.next();
928 {
929 let struct_out = self.id_gen.next();
930 block.body.push(Instruction::binary(
931 spirv::Op::UMulExtended,
932 combined_struct_type,
933 struct_out,
934 prod_1,
935 values[2],
936 ));
937 block.body.push(Instruction::composite_extract(
938 self.get_u32_type_id(),
939 prod_final,
940 struct_out,
941 &[0],
942 ));
943 block.body.push(Instruction::composite_extract(
944 self.get_u32_type_id(),
945 overflows[1],
946 struct_out,
947 &[1],
948 ));
949 }
950 let total_too_large = self.id_gen.next();
951 block.body.push(Instruction::binary(
952 spirv::Op::UGreaterThan,
953 self.get_bool_type_id(),
954 total_too_large,
955 prod_final,
956 max_total,
957 ));
958
959 let too_large = [self.id_gen.next(), self.id_gen.next(), self.id_gen.next()];
960 for (i, value) in values.into_iter().enumerate() {
961 block.body.push(Instruction::binary(
962 spirv::Op::UGreaterThan,
963 self.get_bool_type_id(),
964 too_large[i],
965 value,
966 max_per_dim,
967 ));
968 }
969 let overflow_happens = [self.id_gen.next(), self.id_gen.next()];
970 for (i, value) in overflows.into_iter().enumerate() {
971 block.body.push(Instruction::binary(
972 spirv::Op::INotEqual,
973 self.get_bool_type_id(),
974 overflow_happens[i],
975 value,
976 zero_u32,
977 ));
978 }
979 let mut current_violates_limits = total_too_large;
980 for is_too_large in too_large {
981 let new = self.id_gen.next();
982 block.body.push(Instruction::binary(
983 spirv::Op::LogicalOr,
984 self.get_bool_type_id(),
985 new,
986 current_violates_limits,
987 is_too_large,
988 ));
989 current_violates_limits = new;
990 }
991 for overflow_happens in overflow_happens {
992 let new = self.id_gen.next();
993 block.body.push(Instruction::binary(
994 spirv::Op::LogicalOr,
995 self.get_bool_type_id(),
996 new,
997 current_violates_limits,
998 overflow_happens,
999 ));
1000 current_violates_limits = new;
1001 }
1002 let zero_vec3 = self.id_gen.next();
1003 block.body.push(Instruction::composite_construct(
1004 self.get_vec3u_type_id(),
1005 zero_vec3,
1006 &[zero_u32, zero_u32, zero_u32],
1007 ));
1008 let final_result = self.id_gen.next();
1009 block.body.push(Instruction::select(
1010 self.get_vec3u_type_id(),
1011 final_result,
1012 current_violates_limits,
1013 zero_vec3,
1014 result,
1015 ));
1016 final_result
1017 } else {
1018 result
1019 };
1020 let ins =
1021 self.write_entry_point_task_return(final_value, &mut block.body, task_payload)?;
1022 function.consume(block, ins);
1023 }
1024 function.to_words(&mut self.logical_layout.function_definitions);
1025 Ok(out_id)
1026 }
1027}