naga/back/hlsl/
mesh_shader.rs

1use core::fmt;
2
3use alloc::{
4    format,
5    string::{String, ToString},
6    vec::Vec,
7};
8
9use crate::{
10    back::{
11        self,
12        hlsl::{
13            writer::{EntryPointBinding, EpStructMember, Io, NestedEntryPointArgs},
14            BackendResult, Error,
15        },
16    },
17    proc::NameKey,
18    Handle, Module, ShaderStage, TypeInner,
19};
20
21impl NestedEntryPointArgs {
22    pub fn write_call_args(&self, out: &mut impl fmt::Write) -> fmt::Result {
23        let all_args = self
24            .user_args
25            .iter()
26            .map(String::as_str)
27            .chain(self.task_payload.as_deref())
28            .chain(core::iter::once(self.local_invocation_index.as_str()));
29        for (i, arg) in all_args.enumerate() {
30            if i != 0 {
31                write!(out, ", ")?;
32            }
33            write!(out, "{arg}")?;
34        }
35        Ok(())
36    }
37}
38
39impl<W: fmt::Write> super::Writer<'_, W> {
40    #[expect(clippy::too_many_arguments)]
41    fn write_mesh_shader_wrapper(
42        &mut self,
43        module: &Module,
44        func_ctx: &back::FunctionCtx,
45        need_workgroup_variables_initialization: bool,
46        nested_name: &str,
47        entry_point: &crate::EntryPoint,
48        args: NestedEntryPointArgs,
49        mut separator_if_needed: impl FnMut() -> &'static str,
50    ) -> BackendResult {
51        let Some(ref mesh_info) = entry_point.mesh_info else {
52            unreachable!()
53        };
54        let back::FunctionType::EntryPoint(ep_index) = func_ctx.ty else {
55            unreachable!()
56        };
57        // Mesh shader wrapper
58        let mesh_interface = self.entry_point_io.get(&(ep_index as usize)).unwrap();
59        let vert_info = mesh_interface.mesh_vertices.as_ref().unwrap();
60        let prim_info = mesh_interface.mesh_primitives.as_ref().unwrap();
61        let indices_info = mesh_interface.mesh_indices.as_ref().unwrap();
62        // Write something of the form `out indices uint3 indices_var[num_primitives]`
63        write!(
64            self.out,
65            "{}out indices {} {}[{}]",
66            separator_if_needed(),
67            indices_info.ty_name,
68            indices_info.arg_name,
69            mesh_info.max_primitives
70        )?;
71        // Write something of the form `out vertices VertexType vertices_var[num_vertices]`
72        write!(
73            self.out,
74            ", out vertices {} {}[{}]",
75            vert_info.ty_name, vert_info.arg_name, mesh_info.max_vertices
76        )?;
77        // Write something of the form `out primitives PrimitiveType} primitives_var[num_primitives]`
78        write!(
79            self.out,
80            ", out primitives {} {}[{}]",
81            prim_info.ty_name, prim_info.arg_name, mesh_info.max_primitives
82        )?;
83        if let Some(task_payload) = entry_point.task_payload {
84            // Write the outer-function `in payload` arg.  The name is already in
85            // args.task_payload, having been collected when the inner function
86            // signature was written in write_function (writer.rs).
87            write!(self.out, ", in payload ")?;
88            let var = &module.global_variables[task_payload];
89            self.write_type(module, var.ty)?;
90            let name = &self.names[&NameKey::GlobalVariable(task_payload)];
91            write!(self.out, " {name}")?;
92            if let TypeInner::Array { base, size, .. } = module.types[var.ty].inner {
93                self.write_array_size(module, base, size)?;
94            }
95        }
96        writeln!(self.out, ") {{")?;
97        if need_workgroup_variables_initialization {
98            writeln!(
99                self.out,
100                "{}if ({} == 0) {{",
101                back::INDENT,
102                args.local_invocation_index,
103            )?;
104            self.write_workgroup_variables_initialization(
105                func_ctx,
106                module,
107                module.entry_points[ep_index as usize].stage,
108            )?;
109            writeln!(self.out, "{}}}", back::INDENT)?;
110            self.write_control_barrier(crate::Barrier::WORK_GROUP, back::Level(1))?;
111        }
112        write!(self.out, "{}{nested_name}(", back::INDENT)?;
113        args.write_call_args(&mut self.out)?;
114        writeln!(self.out, ");")?;
115        writeln!(
116            self.out,
117            "{}GroupMemoryBarrierWithGroupSync();",
118            back::INDENT
119        )?;
120
121        let ep = &module.entry_points[ep_index as usize];
122        let mesh_info = ep.mesh_info.as_ref().unwrap();
123        let io = self.entry_point_io.get(&(ep_index as usize)).unwrap();
124
125        let var_name = &self.names[&NameKey::GlobalVariable(mesh_info.output_variable)];
126        let var_type = module.global_variables[mesh_info.output_variable].ty;
127        let wg_size: u32 = ep.workgroup_size.iter().product();
128
129        let get_var_member_name = |bi, var_type| {
130            // The mesh shader output type must be a struct with exactly 4 members.
131            let TypeInner::Struct { ref members, .. } = module.types[var_type].inner else {
132                unreachable!()
133            };
134            let idx = members
135                .iter()
136                .position(|f| f.binding == Some(crate::Binding::BuiltIn(bi)))
137                .unwrap();
138            self.names[&NameKey::StructMember(var_type, idx as u32)].clone()
139        };
140
141        let vert_count = format!(
142            "{var_name}.{}",
143            get_var_member_name(crate::BuiltIn::VertexCount, var_type),
144        );
145        let prim_count = format!(
146            "{var_name}.{}",
147            get_var_member_name(crate::BuiltIn::PrimitiveCount, var_type),
148        );
149
150        let level = back::Level(1);
151
152        writeln!(
153            self.out,
154            "{level}SetMeshOutputCounts({vert_count}, {prim_count});"
155        )?;
156
157        // We need separate loops for vertices and primitives writing
158        struct OutputArray<'a> {
159            array_bi: crate::BuiltIn,
160            count: String,
161            io_interface: &'a EntryPointBinding,
162            is_primitive: bool,
163            index_name: &'static str,
164            ty: Handle<crate::Type>,
165        }
166        let output_arrays = [
167            OutputArray {
168                array_bi: crate::BuiltIn::Vertices,
169                count: vert_count,
170                io_interface: io.mesh_vertices.as_ref().unwrap(),
171                is_primitive: false,
172                index_name: "vertIndex",
173                ty: mesh_info.vertex_output_type,
174            },
175            OutputArray {
176                array_bi: crate::BuiltIn::Primitives,
177                count: prim_count,
178                io_interface: io.mesh_primitives.as_ref().unwrap(),
179                is_primitive: true,
180                index_name: "primIndex",
181                ty: mesh_info.primitive_output_type,
182            },
183        ];
184
185        for output in output_arrays {
186            let OutputArray {
187                array_bi,
188                count,
189                io_interface,
190                is_primitive,
191                index_name,
192                ty,
193            } = output;
194            let out_var_name = &io_interface.arg_name;
195            let index_name = self.namer.call(index_name);
196            let array_name = get_var_member_name(array_bi, var_type);
197            let item_name = format!("{var_name}.{array_name}[{index_name}]");
198            writeln!(
199                self.out,
200                "{level}for (int {index_name} = {}; {index_name} < {count}; {index_name} += {}) {{",
201                args.local_invocation_index, wg_size
202            )?;
203
204            // Loop body, uses more indentation
205            {
206                let level = level.next();
207                for member in &io_interface.members {
208                    let out_member_name = &member.name;
209                    let in_member_name = &self.names[&NameKey::StructMember(ty, member.index)];
210                    writeln!(self.out, "{level}{out_var_name}[{index_name}].{out_member_name} = {item_name}.{in_member_name};",)?;
211                }
212                if is_primitive {
213                    let indices_member_name = get_var_member_name(
214                        mesh_info.topology.to_builtin(),
215                        mesh_info.primitive_output_type,
216                    );
217                    let indices_var_name = &io.mesh_indices.as_ref().unwrap().arg_name;
218                    writeln!(
219                                self.out,
220                                "{level}{indices_var_name}[{index_name}] = {item_name}.{indices_member_name};",
221                            )?;
222                }
223            }
224
225            writeln!(self.out, "{level}}}")?;
226        }
227        Ok(())
228    }
229
230    fn write_task_shader_wrapper(
231        &mut self,
232        module: &Module,
233        func_ctx: &back::FunctionCtx,
234        need_workgroup_variables_initialization: bool,
235        nested_name: &str,
236        entry_point: &crate::EntryPoint,
237        args: NestedEntryPointArgs,
238    ) -> BackendResult {
239        let back::FunctionType::EntryPoint(ep_index) = func_ctx.ty else {
240            unreachable!()
241        };
242        // Task shader wrapper
243        writeln!(self.out, ") {{")?;
244        if need_workgroup_variables_initialization {
245            writeln!(
246                self.out,
247                "{}if ({} == 0) {{",
248                back::INDENT,
249                args.local_invocation_index,
250            )?;
251            self.write_workgroup_variables_initialization(
252                func_ctx,
253                module,
254                module.entry_points[ep_index as usize].stage,
255            )?;
256            writeln!(self.out, "{}}}", back::INDENT)?;
257            self.write_control_barrier(crate::Barrier::WORK_GROUP, back::Level(1))?;
258        }
259        let grid_size = self.namer.call("gridSize");
260        write!(
261            self.out,
262            "{}uint3 {grid_size} = {nested_name}(",
263            back::INDENT
264        )?;
265        args.write_call_args(&mut self.out)?;
266        writeln!(self.out, ");")?;
267        writeln!(
268            self.out,
269            "{}GroupMemoryBarrierWithGroupSync();",
270            back::INDENT
271        )?;
272        if let Some(limits) = self.options.task_dispatch_limits {
273            let level = back::Level(2);
274            writeln!(self.out, "{}if (", back::INDENT)?;
275
276            let max_per_dim = limits.max_mesh_workgroups_per_dim.min(2 << 21);
277            let max_total = limits.max_mesh_workgroups_total;
278            for i in 0..3 {
279                writeln!(
280                    self.out,
281                    "{level}{grid_size}.{} > {max_per_dim} ||",
282                    back::COMPONENTS[i],
283                )?;
284            }
285            writeln!(
286                self.out,
287                "{level}((uint64_t){grid_size}.x) * ((uint64_t){grid_size}.y) > 0xffffffffull ||"
288            )?;
289            writeln!(
290                    self.out,
291                    "{level}((uint64_t){grid_size}.x) * ((uint64_t){grid_size}.y) * ((uint64_t){grid_size}.z) > {max_total}",
292                )?;
293
294            writeln!(self.out, "{}) {{", back::INDENT)?;
295            writeln!(self.out, "{level}{grid_size} = uint3(0, 0, 0);")?;
296            writeln!(self.out, "{}}}", back::INDENT)?;
297        }
298        writeln!(
299            self.out,
300            "{}DispatchMesh({grid_size}.x, {grid_size}.y, {grid_size}.z, {});",
301            back::INDENT,
302            self.names[&NameKey::GlobalVariable(entry_point.task_payload.unwrap())]
303        )?;
304        Ok(())
305    }
306    /// Mesh and task entry points must all return at the same `return` statement,
307    /// so we have a nested function that can return wherever. This writes the caller,
308    /// or the actual entry point.
309    #[expect(clippy::too_many_arguments)]
310    pub(super) fn write_nested_function_outer(
311        &mut self,
312        module: &Module,
313        func_ctx: &back::FunctionCtx,
314        header: &str,
315        name: &str,
316        need_workgroup_variables_initialization: bool,
317        nested_name: &str,
318        entry_point: &crate::EntryPoint,
319        // Built in write_function alongside the inner function signature, so the
320        // call-site argument order is guaranteed to match the declaration order.
321        args: NestedEntryPointArgs,
322    ) -> BackendResult {
323        let mut any_args_written = false;
324        let mut separator_if_needed = || {
325            if any_args_written {
326                ", "
327            } else {
328                any_args_written = true;
329                ""
330            }
331        };
332
333        let back::FunctionType::EntryPoint(ep_index) = func_ctx.ty else {
334            unreachable!();
335        };
336        let stage = module.entry_points[ep_index as usize].stage;
337        write!(self.out, "{header}")?;
338        write!(self.out, "void {name}(")?;
339        // Write the outer function's argument list with full type annotations and
340        // semantics.  Arg names come from self.names and are the same names that
341        // were collected into `args` when writing the inner function signature.
342        if let Some(ref ep_input) = self.entry_point_io.get(&(ep_index as usize)).unwrap().input {
343            write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name)?;
344        } else {
345            for (index, arg) in entry_point.function.arguments.iter().enumerate() {
346                write!(self.out, "{}", separator_if_needed())?;
347                self.write_type(module, arg.ty)?;
348
349                let argument_name =
350                    &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)];
351
352                write!(self.out, " {argument_name}")?;
353                if let TypeInner::Array { base, size, .. } = module.types[arg.ty].inner {
354                    self.write_array_size(module, base, size)?;
355                }
356
357                self.write_semantic(&arg.binding, Some((stage, Io::Input)))?;
358            }
359        }
360        if need_workgroup_variables_initialization || stage == ShaderStage::Mesh {
361            write!(
362                self.out,
363                "{}uint {} : SV_GroupIndex",
364                separator_if_needed(),
365                args.local_invocation_index,
366            )?;
367        }
368        if entry_point.stage == ShaderStage::Mesh {
369            self.write_mesh_shader_wrapper(
370                module,
371                func_ctx,
372                need_workgroup_variables_initialization,
373                nested_name,
374                entry_point,
375                args,
376                separator_if_needed,
377            )?;
378        } else {
379            self.write_task_shader_wrapper(
380                module,
381                func_ctx,
382                need_workgroup_variables_initialization,
383                nested_name,
384                entry_point,
385                args,
386            )?;
387        }
388
389        writeln!(self.out, "}}")?;
390        Ok(())
391    }
392
393    pub(super) fn write_ep_mesh_output_struct(
394        &mut self,
395        module: &Module,
396        entry_point_name: &str,
397        is_primitive: bool,
398        mesh_info: &crate::MeshStageInfo,
399    ) -> Result<EntryPointBinding, Error> {
400        let (in_type, io, var_prefix, arg_name) = if is_primitive {
401            (
402                mesh_info.primitive_output_type,
403                Io::MeshPrimitives,
404                "Primitive",
405                "primitives",
406            )
407        } else {
408            (
409                mesh_info.vertex_output_type,
410                Io::MeshVertices,
411                "Vertex",
412                "vertices",
413            )
414        };
415        let struct_name = format!("Mesh{var_prefix}Output_{entry_point_name}",);
416
417        // Mesh shader output types must be structs; this is validated by naga
418        let members = match module.types[in_type].inner {
419            TypeInner::Struct { ref members, .. } => members,
420            _ => unreachable!(),
421        };
422        let mut out_members = Vec::new();
423        for (index, member) in members.iter().enumerate() {
424            if matches!(
425                member.binding,
426                Some(crate::Binding::BuiltIn(
427                    crate::BuiltIn::PointIndex
428                        | crate::BuiltIn::LineIndices
429                        | crate::BuiltIn::TriangleIndices
430                ))
431            ) {
432                continue;
433            }
434            let member_name = self.namer.call_or(&member.name, "member");
435            out_members.push(EpStructMember {
436                name: member_name,
437                ty: member.ty,
438                binding: member.binding.clone(),
439                index: index as u32,
440            })
441        }
442        self.write_interface_struct(
443            module,
444            (ShaderStage::Mesh, io),
445            struct_name,
446            Some(arg_name),
447            out_members,
448        )
449    }
450
451    pub(super) fn write_ep_mesh_output_indices(
452        &mut self,
453        topology: crate::MeshOutputTopology,
454    ) -> Result<EntryPointBinding, Error> {
455        let (indices_name, indices_type) = match topology {
456            // Points require a capability that isn't supported in the HLSL writer
457            crate::MeshOutputTopology::Points => unreachable!(),
458            crate::MeshOutputTopology::Lines => (self.namer.call("lineIndices"), "uint2"),
459            crate::MeshOutputTopology::Triangles => (self.namer.call("triangleIndices"), "uint3"),
460        };
461        Ok(EntryPointBinding {
462            ty_name: indices_type.to_string(),
463            arg_name: indices_name,
464            members: Vec::new(),
465            local_invocation_index_name: None,
466        })
467    }
468}