naga/back/msl/
mesh_shader.rs

1use alloc::{
2    format,
3    string::{String, ToString},
4    vec::Vec,
5};
6
7use crate::{
8    back::{
9        self,
10        msl::{
11            writer::{TypeContext, TypedGlobalVariable},
12            BackendResult, EntryPointArgument, Error, NAMESPACE, WRAPPED_ARRAY_FIELD,
13        },
14    },
15    proc::NameKey,
16};
17
18pub(super) struct MeshOutputInfo {
19    out_vertex_ty_name: String,
20    out_primitive_ty_name: String,
21    out_vertex_member_names: Vec<Option<String>>,
22    out_primitive_member_names: Vec<Option<String>>,
23}
24
25pub(super) struct NestedFunctionInfo<'a> {
26    pub(super) options: &'a super::Options,
27    pub(super) ep: &'a crate::EntryPoint,
28    pub(super) module: &'a crate::Module,
29    pub(super) mod_info: &'a crate::valid::ModuleInfo,
30    pub(super) fun_info: &'a crate::valid::FunctionInfo,
31    pub(super) args: Vec<EntryPointArgument>,
32    pub(super) local_invocation_index: Option<&'a NameKey>,
33    pub(super) nested_name: &'a str,
34    pub(super) outer_name: &'a str,
35    pub(super) out_mesh_info: Option<MeshOutputInfo>,
36}
37
38impl<W: core::fmt::Write> super::Writer<W> {
39    /// This writes the output vertex and primitive structs given the reflection information about them.
40    pub(super) fn write_mesh_output_types(
41        &mut self,
42        mesh_info: &crate::MeshStageInfo,
43        fun_name: &str,
44        module: &crate::Module,
45        // See `PipelineOptions::allow_and_force_point_size`
46        allow_and_force_point_size: bool,
47        options: &super::Options,
48    ) -> Result<MeshOutputInfo, Error> {
49        let mut vertex_member_names = Vec::new();
50        let mut primitive_member_names = Vec::new();
51        let vertex_out_name = self.namer.call(&format!("{fun_name}VertexOutput"));
52        let primitive_out_name = self.namer.call(&format!("{fun_name}PrimitiveOutput"));
53        let mut existing_names = Vec::new();
54        for (out_name, struct_ty, is_primitive, member_names) in [
55            (
56                &vertex_out_name,
57                mesh_info.vertex_output_type,
58                false,
59                &mut vertex_member_names,
60            ),
61            (
62                &primitive_out_name,
63                mesh_info.primitive_output_type,
64                true,
65                &mut primitive_member_names,
66            ),
67        ] {
68            writeln!(self.out, "struct {out_name} {{")?;
69            // Mesh output types are guaranteed to be user defined structs. This is validated by naga.
70            let crate::TypeInner::Struct { ref members, .. } = module.types[struct_ty].inner else {
71                unreachable!()
72            };
73            let mut has_point_size = false;
74            for (index, member) in members.iter().enumerate() {
75                member_names.push(None);
76                let ty_name = TypeContext {
77                    handle: member.ty,
78                    gctx: module.to_ctx(),
79                    names: &self.names,
80                    access: crate::StorageAccess::empty(),
81                    first_time: true,
82                };
83                let binding = member
84                    .binding
85                    .clone()
86                    .ok_or_else(|| Error::GenericValidation("Expected binding, got None".into()))?;
87
88                if let crate::Binding::BuiltIn(crate::BuiltIn::PointSize) = binding {
89                    has_point_size = true;
90                    if !allow_and_force_point_size {
91                        continue;
92                    }
93                }
94                if let crate::Binding::BuiltIn(
95                    crate::BuiltIn::PointIndex
96                    | crate::BuiltIn::LineIndices
97                    | crate::BuiltIn::TriangleIndices,
98                ) = binding
99                {
100                    continue;
101                }
102
103                // Names of struct members must be unique across vertex and primitive output.
104                // Therefore, when writing the primitive output struct, we might need to rename some fields.
105                let mut name = self.names[&NameKey::StructMember(struct_ty, index as u32)].clone();
106                if existing_names.contains(&name) {
107                    name = self.namer.call(&name);
108                } else {
109                    // Let the namer know this is illegal to use again
110                    let _ = self.namer.call(&name);
111                }
112
113                let array_len = match module.types[member.ty].inner {
114                    crate::TypeInner::Array {
115                        size: crate::ArraySize::Constant(size),
116                        ..
117                    } => Some(size),
118                    _ => None,
119                };
120                let resolved =
121                    options.resolve_local_binding(&binding, back::msl::LocationMode::MeshOutput)?;
122                write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
123                if let Some(array_len) = array_len {
124                    write!(self.out, " [{array_len}]")?;
125                }
126                resolved.try_fmt(&mut self.out)?;
127                writeln!(self.out, ";")?;
128                *member_names.last_mut().unwrap() = Some(name.clone());
129                existing_names.push(name);
130            }
131            if allow_and_force_point_size && !has_point_size && !is_primitive {
132                // inject the point size output last
133                writeln!(
134                    self.out,
135                    "{}float _point_size [[point_size]];",
136                    back::INDENT
137                )?;
138            }
139            writeln!(self.out, "}};")?;
140        }
141        Ok(MeshOutputInfo {
142            out_vertex_ty_name: vertex_out_name,
143            out_primitive_ty_name: primitive_out_name,
144            out_vertex_member_names: vertex_member_names,
145            out_primitive_member_names: primitive_member_names,
146        })
147    }
148
149    pub(super) fn write_wrapper_function(&mut self, info: NestedFunctionInfo<'_>) -> BackendResult {
150        let NestedFunctionInfo {
151            options,
152            ep,
153            module,
154            mod_info,
155            fun_info,
156            args,
157            local_invocation_index: local_invocation_index_key,
158            nested_name,
159            outer_name,
160            out_mesh_info,
161        } = info;
162        let indent = back::INDENT;
163
164        let em_str = match ep.stage {
165            crate::ShaderStage::Mesh => "[[mesh]]",
166            crate::ShaderStage::Task => "[[object]]",
167            _ => unreachable!(),
168        };
169        writeln!(self.out, "{em_str} void {outer_name}(")?;
170
171        // Arguments
172
173        let mut mesh_out_name: Option<String> = None;
174        let mut mesh_variable_name = None;
175        let mut task_grid_name = None;
176        if let Some(ref info) = ep.mesh_info {
177            let mesh_out = out_mesh_info.as_ref().unwrap();
178            let mesh_name = self.namer.call("meshOutput");
179            let topology_name = match info.topology {
180                crate::MeshOutputTopology::Points => "point",
181                crate::MeshOutputTopology::Lines => "line",
182                crate::MeshOutputTopology::Triangles => "triangle",
183            };
184            let num_verts = info.max_vertices;
185            let num_prims = info.max_primitives;
186            writeln!(self.out,
187                "  {NAMESPACE}::mesh<{}, {}, {num_verts}, {num_prims}, metal::topology::{topology_name}> {mesh_name}",
188                mesh_out.out_vertex_ty_name,
189                mesh_out.out_primitive_ty_name,
190            )?;
191            mesh_out_name = Some(mesh_name);
192            mesh_variable_name = Some(
193                self.names
194                    [&NameKey::GlobalVariable(ep.mesh_info.as_ref().unwrap().output_variable)]
195                    .clone(),
196            );
197        } else if ep.stage == crate::ShaderStage::Task {
198            let grid_name = self.namer.call("nagaMeshGrid");
199            writeln!(self.out, "  {NAMESPACE}::mesh_grid_properties {grid_name}")?;
200            task_grid_name = Some(grid_name);
201        }
202        let local_invocation_index = if let Some(key) = local_invocation_index_key {
203            self.names[key].clone()
204        } else {
205            "__local_invocation_index".to_string()
206        };
207
208        for arg in &args {
209            write!(self.out, ", {} {}{}", arg.ty_name, arg.name, arg.binding)?;
210            if let Some(init) = arg.init {
211                write!(self.out, " = ")?;
212                self.put_const_expression(init, module, mod_info, &module.global_expressions)?;
213            }
214            writeln!(self.out)?;
215        }
216
217        writeln!(self.out, ") {{")?;
218
219        // Function body
220        if ep.stage == crate::ShaderStage::Mesh {
221            for (handle, var) in module.global_variables.iter() {
222                if var.space != crate::AddressSpace::WorkGroup || fun_info[handle].is_empty() {
223                    continue;
224                }
225                let tyvar = TypedGlobalVariable {
226                    module,
227                    names: &self.names,
228                    handle,
229                    usage: crate::valid::GlobalUse::WRITE | crate::valid::GlobalUse::READ,
230                    reference: false,
231                };
232                write!(self.out, "{}", back::INDENT)?;
233                tyvar.try_fmt(&mut self.out)?;
234                writeln!(self.out, ";")?;
235            }
236        }
237        write!(self.out, "{indent}")?;
238        let result_name = if ep.stage == crate::ShaderStage::Task {
239            let name = self.namer.call("nagaGridSize");
240            write!(self.out, "uint3 {} = ", name)?;
241            Some(name)
242        } else {
243            None
244        };
245        write!(self.out, "{nested_name}(")?;
246        {
247            let mut is_first = true;
248            for arg in &args {
249                if !is_first {
250                    write!(self.out, ", ")?;
251                }
252                is_first = false;
253                write!(self.out, "{}", arg.name)?;
254            }
255            if ep.stage == crate::ShaderStage::Mesh {
256                for (handle, var) in module.global_variables.iter() {
257                    if var.space != crate::AddressSpace::WorkGroup || fun_info[handle].is_empty() {
258                        continue;
259                    }
260                    if !is_first {
261                        write!(self.out, ", ")?;
262                    }
263                    let name = &self.names[&NameKey::GlobalVariable(handle)];
264                    write!(self.out, "{name}")?;
265                }
266            }
267        }
268        writeln!(self.out, ");")?;
269        self.write_barrier(crate::Barrier::WORK_GROUP, back::Level(1))?;
270
271        if let Some(grid_name) = task_grid_name {
272            let result_name = result_name.unwrap();
273            writeln!(self.out, "{indent}if ({local_invocation_index} == 0u) {{")?;
274            {
275                let level2 = back::Level(2);
276                if let Some(limits) = options.task_dispatch_limits {
277                    let level3 = back::Level(3);
278                    let max_per_dim = limits.max_mesh_workgroups_per_dim;
279                    let max_total = limits.max_mesh_workgroups_total;
280                    writeln!(self.out, "{level2}if (")?;
281
282                    writeln!(self.out, "{level3}{result_name}.x > {max_per_dim}u ||")?;
283                    writeln!(self.out, "{level3}{result_name}.y > {max_per_dim}u ||")?;
284                    writeln!(self.out, "{level3}{result_name}.z > {max_per_dim}u ||")?;
285                    writeln!(
286                        self.out,
287                        "{level3}{NAMESPACE}::mulhi({result_name}.x, {result_name}.y) != 0u ||"
288                    )?;
289                    writeln!(
290                        self.out,
291                        "{level3}{NAMESPACE}::mulhi({result_name}.x * {result_name}.y, {result_name}.z) != 0u ||"
292                    )?;
293                    writeln!(self.out, "{level3}({result_name}.x * {result_name}.y * {result_name}.z) > {max_total}u")?;
294
295                    writeln!(self.out, "{level2}) {{")?;
296                    writeln!(self.out, "{level3}{result_name} = {NAMESPACE}::uint3(0u);")?;
297                    writeln!(self.out, "{level2}}}")?;
298                }
299                writeln!(
300                    self.out,
301                    "{level2}{grid_name}.set_threadgroups_per_grid({result_name});"
302                )?;
303            }
304            writeln!(self.out, "{indent}}}")?;
305            writeln!(self.out, "{indent}return;")?;
306        } else if let Some(ref info) = ep.mesh_info {
307            let mesh_out = out_mesh_info.as_ref().unwrap();
308            let out_ty = module.global_variables[info.output_variable].ty;
309            let mesh_out_name = mesh_out_name.unwrap();
310            let mesh_variable_name = mesh_variable_name.unwrap();
311            // The output type is guaranteed to be a struct with exactly 4 members
312            let crate::TypeInner::Struct { ref members, .. } = module.types[out_ty].inner else {
313                unreachable!();
314            };
315            let get_out_value = |bi| {
316                let member_idx = members
317                    .iter()
318                    .position(|a| a.binding == Some(crate::Binding::BuiltIn(bi)))
319                    .unwrap() as u32;
320                format!(
321                    "{}.{}",
322                    mesh_variable_name,
323                    self.names[&NameKey::StructMember(out_ty, member_idx)]
324                )
325            };
326            let vert_count = format!(
327                "{NAMESPACE}::min({}, {}u)",
328                get_out_value(crate::BuiltIn::VertexCount),
329                info.max_vertices
330            );
331            let prim_count = format!(
332                "{NAMESPACE}::min({}, {}u)",
333                get_out_value(crate::BuiltIn::PrimitiveCount),
334                info.max_primitives
335            );
336            let workgroup_size: u32 = ep.workgroup_size.iter().product();
337            {
338                let vert_index = self.namer.call("vertexIndex");
339                let in_array = get_out_value(crate::BuiltIn::Vertices);
340                writeln!(
341                    self.out,
342                    "{indent}for(uint {vert_index} = {local_invocation_index}; {vert_index} < {vert_count}; {vert_index} += {workgroup_size}) {{"
343                )?;
344                let out_vert = self.namer.call("vertex");
345                writeln!(
346                    self.out,
347                    "{indent}{indent}{} {out_vert};",
348                    mesh_out.out_vertex_ty_name,
349                )?;
350                for (member_idx, new_name) in mesh_out.out_vertex_member_names.iter().enumerate() {
351                    let in_value = format!(
352                        "{in_array}.{WRAPPED_ARRAY_FIELD}[{vert_index}].{}",
353                        self.names
354                            [&NameKey::StructMember(info.vertex_output_type, member_idx as u32)]
355                    );
356                    let out_value = format!("{out_vert}.{}", new_name.as_ref().unwrap());
357                    writeln!(self.out, "{indent}{indent}{out_value} = {in_value};")?;
358                }
359                writeln!(
360                    self.out,
361                    "{indent}{indent}{}.set_vertex({vert_index}, {out_vert});",
362                    mesh_out_name
363                )?;
364                writeln!(self.out, "{indent}}}")?;
365            }
366            {
367                let prim_index = self.namer.call("primitiveIndex");
368                let in_array = get_out_value(crate::BuiltIn::Primitives);
369                writeln!(
370                    self.out,
371                    "{indent}for(uint {prim_index} = {local_invocation_index}; {prim_index} < {prim_count}; {prim_index} += {workgroup_size}) {{"
372                )?;
373                let out_prim = self.namer.call("primitive");
374                writeln!(
375                    self.out,
376                    "{indent}{indent}{} {out_prim};",
377                    mesh_out.out_primitive_ty_name
378                )?;
379                for (member_idx, new_name) in mesh_out.out_primitive_member_names.iter().enumerate()
380                {
381                    let in_value = format!(
382                        "{in_array}.{WRAPPED_ARRAY_FIELD}[{prim_index}].{}",
383                        self.names
384                            [&NameKey::StructMember(info.primitive_output_type, member_idx as u32)]
385                    );
386                    if let Some(new_name) = new_name.as_ref() {
387                        let out_value = format!("{out_prim}.{new_name}");
388                        writeln!(
389                            self.out,
390                            "{indent}{}{out_value} = {in_value};",
391                            back::INDENT
392                        )?;
393                    } else {
394                        let num_indices = match info.topology {
395                            crate::MeshOutputTopology::Points => 1,
396                            crate::MeshOutputTopology::Lines => 2,
397                            crate::MeshOutputTopology::Triangles => 3,
398                        };
399                        for i in 0..num_indices {
400                            let component = if num_indices == 1 {
401                                "".to_string()
402                            } else {
403                                format!(".{}", back::COMPONENTS[i])
404                            };
405                            writeln!(
406                                self.out,
407                                "{indent}{}{}.set_index({prim_index} * {num_indices} + {i}, {in_value}{component});",
408                                back::INDENT,
409                                mesh_out_name,
410                            )?;
411                        }
412                    }
413                }
414                writeln!(
415                    self.out,
416                    "{indent}{}{}.set_primitive({prim_index}, {out_prim});",
417                    back::INDENT,
418                    mesh_out_name
419                )?;
420                writeln!(self.out, "{indent}}}")?;
421            }
422
423            writeln!(self.out, "{indent}if ({local_invocation_index} == 0u) {{")?;
424            writeln!(
425                self.out,
426                "{indent}{indent}{}.set_primitive_count({prim_count});",
427                mesh_out_name,
428            )?;
429            writeln!(self.out, "{indent}}}")?;
430        } else {
431            // Must either have task output grid (task shader) or mesh output info (mesh shader)
432            unreachable!()
433        }
434
435        writeln!(self.out, "}}")?;
436        Ok(())
437    }
438}