1use core::fmt;
2
3use alloc::{
4 format,
5 string::{String, ToString},
6 vec::Vec,
7};
8
9use crate::{
10 back::{
11 self,
12 hlsl::{
13 writer::{EntryPointBinding, EpStructMember, Io, NestedEntryPointArgs},
14 BackendResult, Error,
15 },
16 },
17 proc::NameKey,
18 Handle, Module, ShaderStage, TypeInner,
19};
20
21impl NestedEntryPointArgs {
22 pub fn write_call_args(&self, out: &mut impl fmt::Write) -> fmt::Result {
23 let all_args = self
24 .user_args
25 .iter()
26 .map(String::as_str)
27 .chain(self.task_payload.as_deref())
28 .chain(core::iter::once(self.local_invocation_index.as_str()));
29 for (i, arg) in all_args.enumerate() {
30 if i != 0 {
31 write!(out, ", ")?;
32 }
33 write!(out, "{arg}")?;
34 }
35 Ok(())
36 }
37}
38
39impl<W: fmt::Write> super::Writer<'_, W> {
40 #[expect(clippy::too_many_arguments)]
41 fn write_mesh_shader_wrapper(
42 &mut self,
43 module: &Module,
44 func_ctx: &back::FunctionCtx,
45 need_workgroup_variables_initialization: bool,
46 nested_name: &str,
47 entry_point: &crate::EntryPoint,
48 args: NestedEntryPointArgs,
49 mut separator_if_needed: impl FnMut() -> &'static str,
50 ) -> BackendResult {
51 let Some(ref mesh_info) = entry_point.mesh_info else {
52 unreachable!()
53 };
54 let back::FunctionType::EntryPoint(ep_index) = func_ctx.ty else {
55 unreachable!()
56 };
57 let mesh_interface = self.entry_point_io.get(&(ep_index as usize)).unwrap();
59 let vert_info = mesh_interface.mesh_vertices.as_ref().unwrap();
60 let prim_info = mesh_interface.mesh_primitives.as_ref().unwrap();
61 let indices_info = mesh_interface.mesh_indices.as_ref().unwrap();
62 write!(
64 self.out,
65 "{}out indices {} {}[{}]",
66 separator_if_needed(),
67 indices_info.ty_name,
68 indices_info.arg_name,
69 mesh_info.max_primitives
70 )?;
71 write!(
73 self.out,
74 ", out vertices {} {}[{}]",
75 vert_info.ty_name, vert_info.arg_name, mesh_info.max_vertices
76 )?;
77 write!(
79 self.out,
80 ", out primitives {} {}[{}]",
81 prim_info.ty_name, prim_info.arg_name, mesh_info.max_primitives
82 )?;
83 if let Some(task_payload) = entry_point.task_payload {
84 write!(self.out, ", in payload ")?;
88 let var = &module.global_variables[task_payload];
89 self.write_type(module, var.ty)?;
90 let name = &self.names[&NameKey::GlobalVariable(task_payload)];
91 write!(self.out, " {name}")?;
92 if let TypeInner::Array { base, size, .. } = module.types[var.ty].inner {
93 self.write_array_size(module, base, size)?;
94 }
95 }
96 writeln!(self.out, ") {{")?;
97 if need_workgroup_variables_initialization {
98 writeln!(
99 self.out,
100 "{}if ({} == 0) {{",
101 back::INDENT,
102 args.local_invocation_index,
103 )?;
104 self.write_workgroup_variables_initialization(
105 func_ctx,
106 module,
107 module.entry_points[ep_index as usize].stage,
108 )?;
109 writeln!(self.out, "{}}}", back::INDENT)?;
110 self.write_control_barrier(crate::Barrier::WORK_GROUP, back::Level(1))?;
111 }
112 write!(self.out, "{}{nested_name}(", back::INDENT)?;
113 args.write_call_args(&mut self.out)?;
114 writeln!(self.out, ");")?;
115 writeln!(
116 self.out,
117 "{}GroupMemoryBarrierWithGroupSync();",
118 back::INDENT
119 )?;
120
121 let ep = &module.entry_points[ep_index as usize];
122 let mesh_info = ep.mesh_info.as_ref().unwrap();
123 let io = self.entry_point_io.get(&(ep_index as usize)).unwrap();
124
125 let var_name = &self.names[&NameKey::GlobalVariable(mesh_info.output_variable)];
126 let var_type = module.global_variables[mesh_info.output_variable].ty;
127 let wg_size: u32 = ep.workgroup_size.iter().product();
128
129 let get_var_member_name = |bi, var_type| {
130 let TypeInner::Struct { ref members, .. } = module.types[var_type].inner else {
132 unreachable!()
133 };
134 let idx = members
135 .iter()
136 .position(|f| f.binding == Some(crate::Binding::BuiltIn(bi)))
137 .unwrap();
138 self.names[&NameKey::StructMember(var_type, idx as u32)].clone()
139 };
140
141 let vert_count = format!(
142 "{var_name}.{}",
143 get_var_member_name(crate::BuiltIn::VertexCount, var_type),
144 );
145 let prim_count = format!(
146 "{var_name}.{}",
147 get_var_member_name(crate::BuiltIn::PrimitiveCount, var_type),
148 );
149
150 let level = back::Level(1);
151
152 writeln!(
153 self.out,
154 "{level}SetMeshOutputCounts({vert_count}, {prim_count});"
155 )?;
156
157 struct OutputArray<'a> {
159 array_bi: crate::BuiltIn,
160 count: String,
161 io_interface: &'a EntryPointBinding,
162 is_primitive: bool,
163 index_name: &'static str,
164 ty: Handle<crate::Type>,
165 }
166 let output_arrays = [
167 OutputArray {
168 array_bi: crate::BuiltIn::Vertices,
169 count: vert_count,
170 io_interface: io.mesh_vertices.as_ref().unwrap(),
171 is_primitive: false,
172 index_name: "vertIndex",
173 ty: mesh_info.vertex_output_type,
174 },
175 OutputArray {
176 array_bi: crate::BuiltIn::Primitives,
177 count: prim_count,
178 io_interface: io.mesh_primitives.as_ref().unwrap(),
179 is_primitive: true,
180 index_name: "primIndex",
181 ty: mesh_info.primitive_output_type,
182 },
183 ];
184
185 for output in output_arrays {
186 let OutputArray {
187 array_bi,
188 count,
189 io_interface,
190 is_primitive,
191 index_name,
192 ty,
193 } = output;
194 let out_var_name = &io_interface.arg_name;
195 let index_name = self.namer.call(index_name);
196 let array_name = get_var_member_name(array_bi, var_type);
197 let item_name = format!("{var_name}.{array_name}[{index_name}]");
198 writeln!(
199 self.out,
200 "{level}for (int {index_name} = {}; {index_name} < {count}; {index_name} += {}) {{",
201 args.local_invocation_index, wg_size
202 )?;
203
204 {
206 let level = level.next();
207 for member in &io_interface.members {
208 let out_member_name = &member.name;
209 let in_member_name = &self.names[&NameKey::StructMember(ty, member.index)];
210 writeln!(self.out, "{level}{out_var_name}[{index_name}].{out_member_name} = {item_name}.{in_member_name};",)?;
211 }
212 if is_primitive {
213 let indices_member_name = get_var_member_name(
214 mesh_info.topology.to_builtin(),
215 mesh_info.primitive_output_type,
216 );
217 let indices_var_name = &io.mesh_indices.as_ref().unwrap().arg_name;
218 writeln!(
219 self.out,
220 "{level}{indices_var_name}[{index_name}] = {item_name}.{indices_member_name};",
221 )?;
222 }
223 }
224
225 writeln!(self.out, "{level}}}")?;
226 }
227 Ok(())
228 }
229
230 fn write_task_shader_wrapper(
231 &mut self,
232 module: &Module,
233 func_ctx: &back::FunctionCtx,
234 need_workgroup_variables_initialization: bool,
235 nested_name: &str,
236 entry_point: &crate::EntryPoint,
237 args: NestedEntryPointArgs,
238 ) -> BackendResult {
239 let back::FunctionType::EntryPoint(ep_index) = func_ctx.ty else {
240 unreachable!()
241 };
242 writeln!(self.out, ") {{")?;
244 if need_workgroup_variables_initialization {
245 writeln!(
246 self.out,
247 "{}if ({} == 0) {{",
248 back::INDENT,
249 args.local_invocation_index,
250 )?;
251 self.write_workgroup_variables_initialization(
252 func_ctx,
253 module,
254 module.entry_points[ep_index as usize].stage,
255 )?;
256 writeln!(self.out, "{}}}", back::INDENT)?;
257 self.write_control_barrier(crate::Barrier::WORK_GROUP, back::Level(1))?;
258 }
259 let grid_size = self.namer.call("gridSize");
260 write!(
261 self.out,
262 "{}uint3 {grid_size} = {nested_name}(",
263 back::INDENT
264 )?;
265 args.write_call_args(&mut self.out)?;
266 writeln!(self.out, ");")?;
267 writeln!(
268 self.out,
269 "{}GroupMemoryBarrierWithGroupSync();",
270 back::INDENT
271 )?;
272 if let Some(limits) = self.options.task_dispatch_limits {
273 let level = back::Level(2);
274 writeln!(self.out, "{}if (", back::INDENT)?;
275
276 let max_per_dim = limits.max_mesh_workgroups_per_dim.min(2 << 21);
277 let max_total = limits.max_mesh_workgroups_total;
278 for i in 0..3 {
279 writeln!(
280 self.out,
281 "{level}{grid_size}.{} > {max_per_dim} ||",
282 back::COMPONENTS[i],
283 )?;
284 }
285 writeln!(
286 self.out,
287 "{level}((uint64_t){grid_size}.x) * ((uint64_t){grid_size}.y) > 0xffffffffull ||"
288 )?;
289 writeln!(
290 self.out,
291 "{level}((uint64_t){grid_size}.x) * ((uint64_t){grid_size}.y) * ((uint64_t){grid_size}.z) > {max_total}",
292 )?;
293
294 writeln!(self.out, "{}) {{", back::INDENT)?;
295 writeln!(self.out, "{level}{grid_size} = uint3(0, 0, 0);")?;
296 writeln!(self.out, "{}}}", back::INDENT)?;
297 }
298 writeln!(
299 self.out,
300 "{}DispatchMesh({grid_size}.x, {grid_size}.y, {grid_size}.z, {});",
301 back::INDENT,
302 self.names[&NameKey::GlobalVariable(entry_point.task_payload.unwrap())]
303 )?;
304 Ok(())
305 }
306 #[expect(clippy::too_many_arguments)]
310 pub(super) fn write_nested_function_outer(
311 &mut self,
312 module: &Module,
313 func_ctx: &back::FunctionCtx,
314 header: &str,
315 name: &str,
316 need_workgroup_variables_initialization: bool,
317 nested_name: &str,
318 entry_point: &crate::EntryPoint,
319 args: NestedEntryPointArgs,
322 ) -> BackendResult {
323 let mut any_args_written = false;
324 let mut separator_if_needed = || {
325 if any_args_written {
326 ", "
327 } else {
328 any_args_written = true;
329 ""
330 }
331 };
332
333 let back::FunctionType::EntryPoint(ep_index) = func_ctx.ty else {
334 unreachable!();
335 };
336 let stage = module.entry_points[ep_index as usize].stage;
337 write!(self.out, "{header}")?;
338 write!(self.out, "void {name}(")?;
339 if let Some(ref ep_input) = self.entry_point_io.get(&(ep_index as usize)).unwrap().input {
343 write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name)?;
344 } else {
345 for (index, arg) in entry_point.function.arguments.iter().enumerate() {
346 write!(self.out, "{}", separator_if_needed())?;
347 self.write_type(module, arg.ty)?;
348
349 let argument_name =
350 &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)];
351
352 write!(self.out, " {argument_name}")?;
353 if let TypeInner::Array { base, size, .. } = module.types[arg.ty].inner {
354 self.write_array_size(module, base, size)?;
355 }
356
357 self.write_semantic(&arg.binding, Some((stage, Io::Input)))?;
358 }
359 }
360 if need_workgroup_variables_initialization || stage == ShaderStage::Mesh {
361 write!(
362 self.out,
363 "{}uint {} : SV_GroupIndex",
364 separator_if_needed(),
365 args.local_invocation_index,
366 )?;
367 }
368 if entry_point.stage == ShaderStage::Mesh {
369 self.write_mesh_shader_wrapper(
370 module,
371 func_ctx,
372 need_workgroup_variables_initialization,
373 nested_name,
374 entry_point,
375 args,
376 separator_if_needed,
377 )?;
378 } else {
379 self.write_task_shader_wrapper(
380 module,
381 func_ctx,
382 need_workgroup_variables_initialization,
383 nested_name,
384 entry_point,
385 args,
386 )?;
387 }
388
389 writeln!(self.out, "}}")?;
390 Ok(())
391 }
392
393 pub(super) fn write_ep_mesh_output_struct(
394 &mut self,
395 module: &Module,
396 entry_point_name: &str,
397 is_primitive: bool,
398 mesh_info: &crate::MeshStageInfo,
399 ) -> Result<EntryPointBinding, Error> {
400 let (in_type, io, var_prefix, arg_name) = if is_primitive {
401 (
402 mesh_info.primitive_output_type,
403 Io::MeshPrimitives,
404 "Primitive",
405 "primitives",
406 )
407 } else {
408 (
409 mesh_info.vertex_output_type,
410 Io::MeshVertices,
411 "Vertex",
412 "vertices",
413 )
414 };
415 let struct_name = format!("Mesh{var_prefix}Output_{entry_point_name}",);
416
417 let members = match module.types[in_type].inner {
419 TypeInner::Struct { ref members, .. } => members,
420 _ => unreachable!(),
421 };
422 let mut out_members = Vec::new();
423 for (index, member) in members.iter().enumerate() {
424 if matches!(
425 member.binding,
426 Some(crate::Binding::BuiltIn(
427 crate::BuiltIn::PointIndex
428 | crate::BuiltIn::LineIndices
429 | crate::BuiltIn::TriangleIndices
430 ))
431 ) {
432 continue;
433 }
434 let member_name = self.namer.call_or(&member.name, "member");
435 out_members.push(EpStructMember {
436 name: member_name,
437 ty: member.ty,
438 binding: member.binding.clone(),
439 index: index as u32,
440 })
441 }
442 self.write_interface_struct(
443 module,
444 (ShaderStage::Mesh, io),
445 struct_name,
446 Some(arg_name),
447 out_members,
448 )
449 }
450
451 pub(super) fn write_ep_mesh_output_indices(
452 &mut self,
453 topology: crate::MeshOutputTopology,
454 ) -> Result<EntryPointBinding, Error> {
455 let (indices_name, indices_type) = match topology {
456 crate::MeshOutputTopology::Points => unreachable!(),
458 crate::MeshOutputTopology::Lines => (self.namer.call("lineIndices"), "uint2"),
459 crate::MeshOutputTopology::Triangles => (self.namer.call("triangleIndices"), "uint3"),
460 };
461 Ok(EntryPointBinding {
462 ty_name: indices_type.to_string(),
463 arg_name: indices_name,
464 members: Vec::new(),
465 local_invocation_index_name: None,
466 })
467 }
468}