1use alloc::{
2 format,
3 string::{String, ToString},
4 vec::Vec,
5};
6
7use crate::{
8 back::{
9 self,
10 msl::{
11 writer::{TypeContext, TypedGlobalVariable},
12 BackendResult, EntryPointArgument, Error, NAMESPACE, WRAPPED_ARRAY_FIELD,
13 },
14 },
15 proc::NameKey,
16};
17
18pub(super) struct MeshOutputInfo {
19 out_vertex_ty_name: String,
20 out_primitive_ty_name: String,
21 out_vertex_member_names: Vec<Option<String>>,
22 out_primitive_member_names: Vec<Option<String>>,
23}
24
25pub(super) struct NestedFunctionInfo<'a> {
26 pub(super) options: &'a super::Options,
27 pub(super) ep: &'a crate::EntryPoint,
28 pub(super) module: &'a crate::Module,
29 pub(super) mod_info: &'a crate::valid::ModuleInfo,
30 pub(super) fun_info: &'a crate::valid::FunctionInfo,
31 pub(super) args: Vec<EntryPointArgument>,
32 pub(super) local_invocation_index: Option<&'a NameKey>,
33 pub(super) nested_name: &'a str,
34 pub(super) outer_name: &'a str,
35 pub(super) out_mesh_info: Option<MeshOutputInfo>,
36}
37
38impl<W: core::fmt::Write> super::Writer<W> {
39 pub(super) fn write_mesh_output_types(
41 &mut self,
42 mesh_info: &crate::MeshStageInfo,
43 fun_name: &str,
44 module: &crate::Module,
45 allow_and_force_point_size: bool,
47 options: &super::Options,
48 ) -> Result<MeshOutputInfo, Error> {
49 let mut vertex_member_names = Vec::new();
50 let mut primitive_member_names = Vec::new();
51 let vertex_out_name = self.namer.call(&format!("{fun_name}VertexOutput"));
52 let primitive_out_name = self.namer.call(&format!("{fun_name}PrimitiveOutput"));
53 let mut existing_names = Vec::new();
54 for (out_name, struct_ty, is_primitive, member_names) in [
55 (
56 &vertex_out_name,
57 mesh_info.vertex_output_type,
58 false,
59 &mut vertex_member_names,
60 ),
61 (
62 &primitive_out_name,
63 mesh_info.primitive_output_type,
64 true,
65 &mut primitive_member_names,
66 ),
67 ] {
68 writeln!(self.out, "struct {out_name} {{")?;
69 let crate::TypeInner::Struct { ref members, .. } = module.types[struct_ty].inner else {
71 unreachable!()
72 };
73 let mut has_point_size = false;
74 for (index, member) in members.iter().enumerate() {
75 member_names.push(None);
76 let ty_name = TypeContext {
77 handle: member.ty,
78 gctx: module.to_ctx(),
79 names: &self.names,
80 access: crate::StorageAccess::empty(),
81 first_time: true,
82 };
83 let binding = member
84 .binding
85 .clone()
86 .ok_or_else(|| Error::GenericValidation("Expected binding, got None".into()))?;
87
88 if let crate::Binding::BuiltIn(crate::BuiltIn::PointSize) = binding {
89 has_point_size = true;
90 if !allow_and_force_point_size {
91 continue;
92 }
93 }
94 if let crate::Binding::BuiltIn(
95 crate::BuiltIn::PointIndex
96 | crate::BuiltIn::LineIndices
97 | crate::BuiltIn::TriangleIndices,
98 ) = binding
99 {
100 continue;
101 }
102
103 let mut name = self.names[&NameKey::StructMember(struct_ty, index as u32)].clone();
106 if existing_names.contains(&name) {
107 name = self.namer.call(&name);
108 } else {
109 let _ = self.namer.call(&name);
111 }
112
113 let array_len = match module.types[member.ty].inner {
114 crate::TypeInner::Array {
115 size: crate::ArraySize::Constant(size),
116 ..
117 } => Some(size),
118 _ => None,
119 };
120 let resolved =
121 options.resolve_local_binding(&binding, back::msl::LocationMode::MeshOutput)?;
122 write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
123 if let Some(array_len) = array_len {
124 write!(self.out, " [{array_len}]")?;
125 }
126 resolved.try_fmt(&mut self.out)?;
127 writeln!(self.out, ";")?;
128 *member_names.last_mut().unwrap() = Some(name.clone());
129 existing_names.push(name);
130 }
131 if allow_and_force_point_size && !has_point_size && !is_primitive {
132 writeln!(
134 self.out,
135 "{}float _point_size [[point_size]];",
136 back::INDENT
137 )?;
138 }
139 writeln!(self.out, "}};")?;
140 }
141 Ok(MeshOutputInfo {
142 out_vertex_ty_name: vertex_out_name,
143 out_primitive_ty_name: primitive_out_name,
144 out_vertex_member_names: vertex_member_names,
145 out_primitive_member_names: primitive_member_names,
146 })
147 }
148
149 pub(super) fn write_wrapper_function(&mut self, info: NestedFunctionInfo<'_>) -> BackendResult {
150 let NestedFunctionInfo {
151 options,
152 ep,
153 module,
154 mod_info,
155 fun_info,
156 args,
157 local_invocation_index: local_invocation_index_key,
158 nested_name,
159 outer_name,
160 out_mesh_info,
161 } = info;
162 let indent = back::INDENT;
163
164 let em_str = match ep.stage {
165 crate::ShaderStage::Mesh => "[[mesh]]",
166 crate::ShaderStage::Task => "[[object]]",
167 _ => unreachable!(),
168 };
169 writeln!(self.out, "{em_str} void {outer_name}(")?;
170
171 let mut mesh_out_name: Option<String> = None;
174 let mut mesh_variable_name = None;
175 let mut task_grid_name = None;
176 if let Some(ref info) = ep.mesh_info {
177 let mesh_out = out_mesh_info.as_ref().unwrap();
178 let mesh_name = self.namer.call("meshOutput");
179 let topology_name = match info.topology {
180 crate::MeshOutputTopology::Points => "point",
181 crate::MeshOutputTopology::Lines => "line",
182 crate::MeshOutputTopology::Triangles => "triangle",
183 };
184 let num_verts = info.max_vertices;
185 let num_prims = info.max_primitives;
186 writeln!(self.out,
187 " {NAMESPACE}::mesh<{}, {}, {num_verts}, {num_prims}, metal::topology::{topology_name}> {mesh_name}",
188 mesh_out.out_vertex_ty_name,
189 mesh_out.out_primitive_ty_name,
190 )?;
191 mesh_out_name = Some(mesh_name);
192 mesh_variable_name = Some(
193 self.names
194 [&NameKey::GlobalVariable(ep.mesh_info.as_ref().unwrap().output_variable)]
195 .clone(),
196 );
197 } else if ep.stage == crate::ShaderStage::Task {
198 let grid_name = self.namer.call("nagaMeshGrid");
199 writeln!(self.out, " {NAMESPACE}::mesh_grid_properties {grid_name}")?;
200 task_grid_name = Some(grid_name);
201 }
202 let local_invocation_index = if let Some(key) = local_invocation_index_key {
203 self.names[key].clone()
204 } else {
205 "__local_invocation_index".to_string()
206 };
207
208 for arg in &args {
209 write!(self.out, ", {} {}{}", arg.ty_name, arg.name, arg.binding)?;
210 if let Some(init) = arg.init {
211 write!(self.out, " = ")?;
212 self.put_const_expression(init, module, mod_info, &module.global_expressions)?;
213 }
214 writeln!(self.out)?;
215 }
216
217 writeln!(self.out, ") {{")?;
218
219 if ep.stage == crate::ShaderStage::Mesh {
221 for (handle, var) in module.global_variables.iter() {
222 if var.space != crate::AddressSpace::WorkGroup || fun_info[handle].is_empty() {
223 continue;
224 }
225 let tyvar = TypedGlobalVariable {
226 module,
227 names: &self.names,
228 handle,
229 usage: crate::valid::GlobalUse::WRITE | crate::valid::GlobalUse::READ,
230 reference: false,
231 };
232 write!(self.out, "{}", back::INDENT)?;
233 tyvar.try_fmt(&mut self.out)?;
234 writeln!(self.out, ";")?;
235 }
236 }
237 write!(self.out, "{indent}")?;
238 let result_name = if ep.stage == crate::ShaderStage::Task {
239 let name = self.namer.call("nagaGridSize");
240 write!(self.out, "uint3 {} = ", name)?;
241 Some(name)
242 } else {
243 None
244 };
245 write!(self.out, "{nested_name}(")?;
246 {
247 let mut is_first = true;
248 for arg in &args {
249 if !is_first {
250 write!(self.out, ", ")?;
251 }
252 is_first = false;
253 write!(self.out, "{}", arg.name)?;
254 }
255 if ep.stage == crate::ShaderStage::Mesh {
256 for (handle, var) in module.global_variables.iter() {
257 if var.space != crate::AddressSpace::WorkGroup || fun_info[handle].is_empty() {
258 continue;
259 }
260 if !is_first {
261 write!(self.out, ", ")?;
262 }
263 let name = &self.names[&NameKey::GlobalVariable(handle)];
264 write!(self.out, "{name}")?;
265 }
266 }
267 }
268 writeln!(self.out, ");")?;
269 self.write_barrier(crate::Barrier::WORK_GROUP, back::Level(1))?;
270
271 if let Some(grid_name) = task_grid_name {
272 let result_name = result_name.unwrap();
273 writeln!(self.out, "{indent}if ({local_invocation_index} == 0u) {{")?;
274 {
275 let level2 = back::Level(2);
276 if let Some(limits) = options.task_dispatch_limits {
277 let level3 = back::Level(3);
278 let max_per_dim = limits.max_mesh_workgroups_per_dim;
279 let max_total = limits.max_mesh_workgroups_total;
280 writeln!(self.out, "{level2}if (")?;
281
282 writeln!(self.out, "{level3}{result_name}.x > {max_per_dim}u ||")?;
283 writeln!(self.out, "{level3}{result_name}.y > {max_per_dim}u ||")?;
284 writeln!(self.out, "{level3}{result_name}.z > {max_per_dim}u ||")?;
285 writeln!(
286 self.out,
287 "{level3}{NAMESPACE}::mulhi({result_name}.x, {result_name}.y) != 0u ||"
288 )?;
289 writeln!(
290 self.out,
291 "{level3}{NAMESPACE}::mulhi({result_name}.x * {result_name}.y, {result_name}.z) != 0u ||"
292 )?;
293 writeln!(self.out, "{level3}({result_name}.x * {result_name}.y * {result_name}.z) > {max_total}u")?;
294
295 writeln!(self.out, "{level2}) {{")?;
296 writeln!(self.out, "{level3}{result_name} = {NAMESPACE}::uint3(0u);")?;
297 writeln!(self.out, "{level2}}}")?;
298 }
299 writeln!(
300 self.out,
301 "{level2}{grid_name}.set_threadgroups_per_grid({result_name});"
302 )?;
303 }
304 writeln!(self.out, "{indent}}}")?;
305 writeln!(self.out, "{indent}return;")?;
306 } else if let Some(ref info) = ep.mesh_info {
307 let mesh_out = out_mesh_info.as_ref().unwrap();
308 let out_ty = module.global_variables[info.output_variable].ty;
309 let mesh_out_name = mesh_out_name.unwrap();
310 let mesh_variable_name = mesh_variable_name.unwrap();
311 let crate::TypeInner::Struct { ref members, .. } = module.types[out_ty].inner else {
313 unreachable!();
314 };
315 let get_out_value = |bi| {
316 let member_idx = members
317 .iter()
318 .position(|a| a.binding == Some(crate::Binding::BuiltIn(bi)))
319 .unwrap() as u32;
320 format!(
321 "{}.{}",
322 mesh_variable_name,
323 self.names[&NameKey::StructMember(out_ty, member_idx)]
324 )
325 };
326 let vert_count = format!(
327 "{NAMESPACE}::min({}, {}u)",
328 get_out_value(crate::BuiltIn::VertexCount),
329 info.max_vertices
330 );
331 let prim_count = format!(
332 "{NAMESPACE}::min({}, {}u)",
333 get_out_value(crate::BuiltIn::PrimitiveCount),
334 info.max_primitives
335 );
336 let workgroup_size: u32 = ep.workgroup_size.iter().product();
337 {
338 let vert_index = self.namer.call("vertexIndex");
339 let in_array = get_out_value(crate::BuiltIn::Vertices);
340 writeln!(
341 self.out,
342 "{indent}for(uint {vert_index} = {local_invocation_index}; {vert_index} < {vert_count}; {vert_index} += {workgroup_size}) {{"
343 )?;
344 let out_vert = self.namer.call("vertex");
345 writeln!(
346 self.out,
347 "{indent}{indent}{} {out_vert};",
348 mesh_out.out_vertex_ty_name,
349 )?;
350 for (member_idx, new_name) in mesh_out.out_vertex_member_names.iter().enumerate() {
351 let in_value = format!(
352 "{in_array}.{WRAPPED_ARRAY_FIELD}[{vert_index}].{}",
353 self.names
354 [&NameKey::StructMember(info.vertex_output_type, member_idx as u32)]
355 );
356 let out_value = format!("{out_vert}.{}", new_name.as_ref().unwrap());
357 writeln!(self.out, "{indent}{indent}{out_value} = {in_value};")?;
358 }
359 writeln!(
360 self.out,
361 "{indent}{indent}{}.set_vertex({vert_index}, {out_vert});",
362 mesh_out_name
363 )?;
364 writeln!(self.out, "{indent}}}")?;
365 }
366 {
367 let prim_index = self.namer.call("primitiveIndex");
368 let in_array = get_out_value(crate::BuiltIn::Primitives);
369 writeln!(
370 self.out,
371 "{indent}for(uint {prim_index} = {local_invocation_index}; {prim_index} < {prim_count}; {prim_index} += {workgroup_size}) {{"
372 )?;
373 let out_prim = self.namer.call("primitive");
374 writeln!(
375 self.out,
376 "{indent}{indent}{} {out_prim};",
377 mesh_out.out_primitive_ty_name
378 )?;
379 for (member_idx, new_name) in mesh_out.out_primitive_member_names.iter().enumerate()
380 {
381 let in_value = format!(
382 "{in_array}.{WRAPPED_ARRAY_FIELD}[{prim_index}].{}",
383 self.names
384 [&NameKey::StructMember(info.primitive_output_type, member_idx as u32)]
385 );
386 if let Some(new_name) = new_name.as_ref() {
387 let out_value = format!("{out_prim}.{new_name}");
388 writeln!(
389 self.out,
390 "{indent}{}{out_value} = {in_value};",
391 back::INDENT
392 )?;
393 } else {
394 let num_indices = match info.topology {
395 crate::MeshOutputTopology::Points => 1,
396 crate::MeshOutputTopology::Lines => 2,
397 crate::MeshOutputTopology::Triangles => 3,
398 };
399 for i in 0..num_indices {
400 let component = if num_indices == 1 {
401 "".to_string()
402 } else {
403 format!(".{}", back::COMPONENTS[i])
404 };
405 writeln!(
406 self.out,
407 "{indent}{}{}.set_index({prim_index} * {num_indices} + {i}, {in_value}{component});",
408 back::INDENT,
409 mesh_out_name,
410 )?;
411 }
412 }
413 }
414 writeln!(
415 self.out,
416 "{indent}{}{}.set_primitive({prim_index}, {out_prim});",
417 back::INDENT,
418 mesh_out_name
419 )?;
420 writeln!(self.out, "{indent}}}")?;
421 }
422
423 writeln!(self.out, "{indent}if ({local_invocation_index} == 0u) {{")?;
424 writeln!(
425 self.out,
426 "{indent}{indent}{}.set_primitive_count({prim_count});",
427 mesh_out_name,
428 )?;
429 writeln!(self.out, "{indent}}}")?;
430 } else {
431 unreachable!()
433 }
434
435 writeln!(self.out, "}}")?;
436 Ok(())
437 }
438}