1use alloc::{
2 format,
3 string::{String, ToString},
4 vec::Vec,
5};
6use core::{
7 fmt::{self, Write as _},
8 mem,
9};
10
11use super::{
12 help,
13 help::{
14 WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess,
15 WrappedZeroValue,
16 },
17 storage::StoreValue,
18 BackendResult, Error, FragmentEntryPoint, Options, PipelineOptions, ShaderModel,
19};
20use crate::{
21 back::{self, get_entry_points, Baked},
22 common,
23 proc::{self, index, ExternalTextureNameKey, NameKey},
24 valid, Handle, Module, RayQueryFunction, Scalar, ScalarKind, ShaderStage, TypeInner,
25};
26
27const LOCATION_SEMANTIC: &str = "LOC";
28const SPECIAL_CBUF_TYPE: &str = "NagaConstants";
29const SPECIAL_CBUF_VAR: &str = "_NagaConstants";
30const SPECIAL_FIRST_VERTEX: &str = "first_vertex";
31const SPECIAL_FIRST_INSTANCE: &str = "first_instance";
32const SPECIAL_OTHER: &str = "other";
33
34pub(crate) const MODF_FUNCTION: &str = "naga_modf";
35pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
36pub(crate) const EXTRACT_BITS_FUNCTION: &str = "naga_extractBits";
37pub(crate) const INSERT_BITS_FUNCTION: &str = "naga_insertBits";
38pub(crate) const SAMPLER_HEAP_VAR: &str = "nagaSamplerHeap";
39pub(crate) const COMPARISON_SAMPLER_HEAP_VAR: &str = "nagaComparisonSamplerHeap";
40pub(crate) const SAMPLE_EXTERNAL_TEXTURE_FUNCTION: &str = "nagaSampleExternalTexture";
41pub(crate) const ABS_FUNCTION: &str = "naga_abs";
42pub(crate) const DIV_FUNCTION: &str = "naga_div";
43pub(crate) const MOD_FUNCTION: &str = "naga_mod";
44pub(crate) const NEG_FUNCTION: &str = "naga_neg";
45pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
46pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
47pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
48pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
49pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
50 "nagaTextureSampleBaseClampToEdge";
51pub(crate) const IMAGE_LOAD_EXTERNAL_FUNCTION: &str = "nagaTextureLoadExternal";
52pub(crate) const RAY_QUERY_TRACKER_VARIABLE_PREFIX: &str = "naga_query_init_tracker_for_";
53pub(crate) const INTERNAL_PREFIX: &str = "naga_";
55
56enum Index {
57 Expression(Handle<crate::Expression>),
58 Static(u32),
59}
60
61pub(super) struct EpStructMember {
62 pub(super) name: String,
63 pub(super) ty: Handle<crate::Type>,
64 pub(super) binding: Option<crate::Binding>,
67 pub(super) index: u32,
68}
69
70pub(super) struct EntryPointBinding {
73 pub(super) arg_name: String,
76 pub(super) ty_name: String,
78 pub(super) members: Vec<EpStructMember>,
80 pub(super) local_invocation_index_name: Option<String>,
81}
82
83pub(super) struct EntryPointInterface {
84 pub(crate) input: Option<EntryPointBinding>,
89 pub(crate) output: Option<EntryPointBinding>,
93 pub(crate) mesh_vertices: Option<EntryPointBinding>,
94 pub(crate) mesh_primitives: Option<EntryPointBinding>,
95 pub(crate) mesh_indices: Option<EntryPointBinding>,
96}
97
98#[derive(Clone, Eq, PartialEq, PartialOrd, Ord)]
99enum InterfaceKey {
100 Location(u32),
101 BuiltIn(crate::BuiltIn),
102 Other,
103}
104
105impl InterfaceKey {
106 const fn new(binding: Option<&crate::Binding>) -> Self {
107 match binding {
108 Some(&crate::Binding::Location { location, .. }) => Self::Location(location),
109 Some(&crate::Binding::BuiltIn(built_in)) => Self::BuiltIn(built_in),
110 None => Self::Other,
111 }
112 }
113}
114
115#[derive(Copy, Clone, PartialEq)]
116pub(super) enum Io {
117 Input,
118 Output,
119 MeshVertices,
120 MeshPrimitives,
121}
122
123pub(super) struct NestedEntryPointArgs {
125 pub user_args: Vec<String>,
127 pub task_payload: Option<String>,
128 pub local_invocation_index: String,
129}
130
131const fn is_subgroup_builtin_binding(binding: &Option<crate::Binding>) -> bool {
132 let &Some(crate::Binding::BuiltIn(builtin)) = binding else {
133 return false;
134 };
135 matches!(
136 builtin,
137 crate::BuiltIn::SubgroupSize
138 | crate::BuiltIn::SubgroupInvocationId
139 | crate::BuiltIn::NumSubgroups
140 | crate::BuiltIn::SubgroupId
141 )
142}
143
144struct BindingArraySamplerInfo {
146 sampler_heap_name: &'static str,
148 sampler_index_buffer_name: String,
150 binding_array_base_index_name: String,
152}
153
154impl<'a, W: fmt::Write> super::Writer<'a, W> {
155 pub fn new(out: W, options: &'a Options, pipeline_options: &'a PipelineOptions) -> Self {
156 Self {
157 out,
158 names: crate::FastHashMap::default(),
159 namer: proc::Namer::default(),
160 options,
161 pipeline_options,
162 entry_point_io: crate::FastHashMap::default(),
163 named_expressions: crate::NamedExpressions::default(),
164 wrapped: super::Wrapped::default(),
165 written_committed_intersection: false,
166 written_candidate_intersection: false,
167 continue_ctx: back::continue_forward::ContinueCtx::default(),
168 temp_access_chain: Vec::new(),
169 need_bake_expressions: Default::default(),
170 function_task_payload_var: Default::default(),
171 }
172 }
173
174 fn reset(&mut self, module: &Module) {
175 self.names.clear();
176 self.namer.reset(
177 module,
178 &super::keywords::RESERVED_SET,
179 proc::KeywordSet::empty(),
180 &super::keywords::RESERVED_CASE_INSENSITIVE_SET,
181 super::keywords::RESERVED_PREFIXES,
182 &mut self.names,
183 );
184 self.entry_point_io.clear();
185 self.named_expressions.clear();
186 self.wrapped.clear();
187 self.written_committed_intersection = false;
188 self.written_candidate_intersection = false;
189 self.continue_ctx.clear();
190 self.need_bake_expressions.clear();
191 self.function_task_payload_var.clear();
192 }
193
194 fn gen_force_bounded_loop_statements(
202 &mut self,
203 level: back::Level,
204 ) -> Option<(String, String)> {
205 if !self.options.force_loop_bounding {
206 return None;
207 }
208
209 let loop_bound_name = self.namer.call("loop_bound");
210 let max = u32::MAX;
211 let decl = format!("{level}uint2 {loop_bound_name} = uint2({max}u, {max}u);");
214 let level = level.next();
215 let break_and_inc = format!(
216 "{level}if (all({loop_bound_name} == uint2(0u, 0u))) {{ break; }}
217{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
218 );
219
220 Some((decl, break_and_inc))
221 }
222
223 fn update_expressions_to_bake(
228 &mut self,
229 module: &Module,
230 func: &crate::Function,
231 info: &valid::FunctionInfo,
232 ) {
233 use crate::Expression;
234 self.need_bake_expressions.clear();
235 for (exp_handle, expr) in func.expressions.iter() {
236 let expr_info = &info[exp_handle];
237 let min_ref_count = func.expressions[exp_handle].bake_ref_count();
238 if min_ref_count <= expr_info.ref_count {
239 self.need_bake_expressions.insert(exp_handle);
240 }
241
242 if let Expression::Math { fun, arg, arg1, .. } = *expr {
243 match fun {
244 crate::MathFunction::Asinh
245 | crate::MathFunction::Acosh
246 | crate::MathFunction::Atanh
247 | crate::MathFunction::Unpack2x16float
248 | crate::MathFunction::Unpack2x16snorm
249 | crate::MathFunction::Unpack2x16unorm
250 | crate::MathFunction::Unpack4x8snorm
251 | crate::MathFunction::Unpack4x8unorm
252 | crate::MathFunction::Unpack4xI8
253 | crate::MathFunction::Unpack4xU8
254 | crate::MathFunction::Pack2x16float
255 | crate::MathFunction::Pack2x16snorm
256 | crate::MathFunction::Pack2x16unorm
257 | crate::MathFunction::Pack4x8snorm
258 | crate::MathFunction::Pack4x8unorm
259 | crate::MathFunction::Pack4xI8
260 | crate::MathFunction::Pack4xU8
261 | crate::MathFunction::Pack4xI8Clamp
262 | crate::MathFunction::Pack4xU8Clamp => {
263 self.need_bake_expressions.insert(arg);
264 }
265 crate::MathFunction::CountLeadingZeros => {
266 let inner = info[exp_handle].ty.inner_with(&module.types);
267 if let Some(ScalarKind::Sint) = inner.scalar_kind() {
268 self.need_bake_expressions.insert(arg);
269 }
270 }
271 crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
272 self.need_bake_expressions.insert(arg);
273 self.need_bake_expressions.insert(arg1.unwrap());
274 }
275 _ => {}
276 }
277 }
278
279 if let Expression::Derivative { axis, ctrl, expr } = *expr {
280 use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
281 if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
282 self.need_bake_expressions.insert(expr);
283 }
284 }
285
286 if let Expression::GlobalVariable(_) = *expr {
287 let inner = info[exp_handle].ty.inner_with(&module.types);
288
289 if let TypeInner::Sampler { .. } = *inner {
290 self.need_bake_expressions.insert(exp_handle);
291 }
292 }
293 }
294 for statement in func.body.iter() {
295 match *statement {
296 crate::Statement::SubgroupCollectiveOperation {
297 op: _,
298 collective_op: crate::CollectiveOperation::InclusiveScan,
299 argument,
300 result: _,
301 } => {
302 self.need_bake_expressions.insert(argument);
303 }
304 crate::Statement::Atomic {
305 fun: crate::AtomicFunction::Exchange { compare: Some(cmp) },
306 ..
307 } => {
308 self.need_bake_expressions.insert(cmp);
309 }
310 _ => {}
311 }
312 }
313 }
314
315 pub fn write(
316 &mut self,
317 module: &Module,
318 module_info: &valid::ModuleInfo,
319 fragment_entry_point: Option<&FragmentEntryPoint<'_>>,
320 ) -> Result<super::ReflectionInfo, Error> {
321 self.reset(module);
322
323 if module.uses_mesh_shaders() && self.options.shader_model < ShaderModel::V6_5 {
324 return Err(Error::ShaderModelTooLow(
325 "mesh shaders".to_string(),
326 ShaderModel::V6_5,
327 ));
328 }
329
330 if let Some(ref bt) = self.options.special_constants_binding {
332 writeln!(self.out, "struct {SPECIAL_CBUF_TYPE} {{")?;
333 writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_VERTEX)?;
334 writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_INSTANCE)?;
335 writeln!(self.out, "{}uint {};", back::INDENT, SPECIAL_OTHER)?;
336 writeln!(self.out, "}};")?;
337 write!(
338 self.out,
339 "ConstantBuffer<{}> {}: register(b{}",
340 SPECIAL_CBUF_TYPE, SPECIAL_CBUF_VAR, bt.register
341 )?;
342 if bt.space != 0 {
343 write!(self.out, ", space{}", bt.space)?;
344 }
345 writeln!(self.out, ");")?;
346
347 writeln!(self.out)?;
349 }
350
351 for (group, bt) in self.options.dynamic_storage_buffer_offsets_targets.iter() {
352 writeln!(self.out, "struct __dynamic_buffer_offsetsTy{group} {{")?;
353 for i in 0..bt.size {
354 writeln!(self.out, "{}uint _{};", back::INDENT, i)?;
355 }
356 writeln!(self.out, "}};")?;
357 writeln!(
358 self.out,
359 "ConstantBuffer<__dynamic_buffer_offsetsTy{}> __dynamic_buffer_offsets{}: register(b{}, space{});",
360 group, group, bt.register, bt.space
361 )?;
362
363 writeln!(self.out)?;
365 }
366
367 let ep_results = module
369 .entry_points
370 .iter()
371 .map(|ep| (ep.stage, ep.function.result.clone()))
372 .collect::<Vec<(ShaderStage, Option<crate::FunctionResult>)>>();
373
374 self.write_all_mat_cx2_typedefs_and_functions(module)?;
375
376 for (handle, ty) in module.types.iter() {
378 if let TypeInner::Struct { ref members, span } = ty.inner {
379 if module.types[members.last().unwrap().ty]
380 .inner
381 .is_dynamically_sized(&module.types)
382 {
383 continue;
386 }
387
388 let ep_result = ep_results.iter().find(|e| {
389 if let Some(ref result) = e.1 {
390 result.ty == handle
391 } else {
392 false
393 }
394 });
395
396 self.write_struct(
397 module,
398 handle,
399 members,
400 span,
401 ep_result.map(|r| (r.0, Io::Output)),
402 )?;
403 writeln!(self.out)?;
404 }
405 }
406
407 self.write_special_functions(module)?;
408
409 self.write_wrapped_expression_functions(module, &module.global_expressions, None)?;
410 self.write_wrapped_zero_value_functions(module, &module.global_expressions)?;
411
412 let mut constants = module
414 .constants
415 .iter()
416 .filter(|&(_, c)| c.name.is_some())
417 .peekable();
418 while let Some((handle, _)) = constants.next() {
419 self.write_global_constant(module, handle)?;
420 if constants.peek().is_none() {
422 writeln!(self.out)?;
423 }
424 }
425
426 for (global, _) in module.global_variables.iter() {
428 self.write_global(module, global)?;
429 }
430
431 if !module.global_variables.is_empty() {
432 writeln!(self.out)?;
434 }
435
436 let ep_range = get_entry_points(module, self.pipeline_options.entry_point.as_ref())
437 .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
438
439 for index in ep_range.clone() {
441 let ep = &module.entry_points[index];
442 let ep_name = self.names[&NameKey::EntryPoint(index as u16)].clone();
443 let ep_io = self.write_ep_interface(module, ep, &ep_name, fragment_entry_point)?;
444 self.entry_point_io.insert(index, ep_io);
445 }
446
447 for (handle, function) in module.functions.iter() {
449 let info = &module_info[handle];
450
451 if !self.options.fake_missing_bindings {
453 if let Some((var_handle, _)) =
454 module
455 .global_variables
456 .iter()
457 .find(|&(var_handle, var)| match var.binding {
458 Some(ref binding) if !info[var_handle].is_empty() => {
459 self.options.resolve_resource_binding(binding).is_err()
460 && self
461 .options
462 .resolve_external_texture_resource_binding(binding)
463 .is_err()
464 }
465 _ => false,
466 })
467 {
468 log::debug!(
469 "Skipping function {:?} (name {:?}) because global {:?} is inaccessible",
470 handle,
471 function.name,
472 var_handle
473 );
474 continue;
475 }
476 }
477
478 let ctx = back::FunctionCtx {
479 ty: back::FunctionType::Function(handle),
480 info,
481 expressions: &function.expressions,
482 named_expressions: &function.named_expressions,
483 };
484 let name = self.names[&NameKey::Function(handle)].clone();
485
486 self.write_wrapped_functions(module, &ctx)?;
487
488 self.write_function(module, name.as_str(), function, &ctx, info, String::new())?;
489
490 writeln!(self.out)?;
491 }
492
493 let mut translated_ep_names = Vec::with_capacity(ep_range.len());
494
495 for index in ep_range {
497 let ep = &module.entry_points[index];
498 let info = module_info.get_entry_point(index);
499
500 if !self.options.fake_missing_bindings {
501 let mut ep_error = None;
502 for (var_handle, var) in module.global_variables.iter() {
503 match var.binding {
504 Some(ref binding) if !info[var_handle].is_empty() => {
505 if let Err(err) = self.options.resolve_resource_binding(binding) {
506 if self
507 .options
508 .resolve_external_texture_resource_binding(binding)
509 .is_err()
510 {
511 ep_error = Some(err);
512 break;
513 }
514 }
515 }
516 _ => {}
517 }
518 }
519 if let Some(err) = ep_error {
520 translated_ep_names.push(Err(err));
521 continue;
522 }
523 }
524
525 let ctx = back::FunctionCtx {
526 ty: back::FunctionType::EntryPoint(index as u16),
527 info,
528 expressions: &ep.function.expressions,
529 named_expressions: &ep.function.named_expressions,
530 };
531
532 self.write_wrapped_functions(module, &ctx)?;
533
534 let mut attribute_string = String::new();
537 if ep.stage.compute_like() {
538 let num_threads = ep.workgroup_size;
540 writeln!(
541 attribute_string,
542 "[numthreads({}, {}, {})]",
543 num_threads[0], num_threads[1], num_threads[2]
544 )?;
545 }
546 if let Some(ref info) = ep.mesh_info {
547 let topology_str = match info.topology {
548 crate::MeshOutputTopology::Points => unreachable!(),
549 crate::MeshOutputTopology::Lines => "line",
550 crate::MeshOutputTopology::Triangles => "triangle",
551 };
552 writeln!(attribute_string, "[outputtopology(\"{topology_str}\")]")?;
553 }
554
555 let name = self.names[&NameKey::EntryPoint(index as u16)].clone();
556 self.write_function(module, &name, &ep.function, &ctx, info, attribute_string)?;
557
558 if index < module.entry_points.len() - 1 {
559 writeln!(self.out)?;
560 }
561
562 translated_ep_names.push(Ok(name));
563 }
564
565 Ok(super::ReflectionInfo {
566 entry_point_names: translated_ep_names,
567 })
568 }
569
570 fn write_modifier(&mut self, binding: &crate::Binding) -> BackendResult {
571 match *binding {
572 crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }) => {
573 write!(self.out, "precise ")?;
574 }
575 crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { perspective: false }) => {
576 write!(self.out, "noperspective ")?;
577 }
578 crate::Binding::Location {
579 interpolation,
580 sampling,
581 ..
582 } => {
583 if let Some(interpolation) = interpolation {
584 if let Some(string) = interpolation.to_hlsl_str() {
585 write!(self.out, "{string} ")?
586 }
587 }
588
589 if let Some(sampling) = sampling {
590 if let Some(string) = sampling.to_hlsl_str() {
591 write!(self.out, "{string} ")?
592 }
593 }
594 }
595 crate::Binding::BuiltIn(_) => {}
596 }
597
598 Ok(())
599 }
600
601 pub(super) fn write_semantic(
604 &mut self,
605 binding: &Option<crate::Binding>,
606 stage: Option<(ShaderStage, Io)>,
607 ) -> BackendResult {
608 let is_per_primitive = match *binding {
609 Some(crate::Binding::BuiltIn(builtin)) if !is_subgroup_builtin_binding(binding) => {
610 if builtin == crate::BuiltIn::ViewIndex
611 && self.options.shader_model < ShaderModel::V6_1
612 {
613 return Err(Error::ShaderModelTooLow(
614 "used @builtin(view_index) or SV_ViewID".to_string(),
615 ShaderModel::V6_1,
616 ));
617 }
618 if let Some(builtin_str) = builtin.to_hlsl_str()? {
619 write!(self.out, " : {builtin_str}")?;
620 }
621 false
622 }
623 Some(crate::Binding::Location {
624 blend_src: Some(1),
625 per_primitive,
626 ..
627 }) => {
628 write!(self.out, " : SV_Target1")?;
629 per_primitive
630 }
631 Some(crate::Binding::Location {
632 location,
633 per_primitive,
634 ..
635 }) => {
636 if stage == Some((ShaderStage::Fragment, Io::Output)) {
637 write!(self.out, " : SV_Target{location}")?;
638 } else {
639 write!(self.out, " : {LOCATION_SEMANTIC}{location}")?;
640 }
641 per_primitive
642 }
643 _ => false,
644 };
645 if is_per_primitive {
646 write!(self.out, " : primitive")?;
647 }
648
649 Ok(())
650 }
651
652 pub(super) fn write_interface_struct(
653 &mut self,
654 module: &Module,
655 shader_stage: (ShaderStage, Io),
656 struct_name: String,
657 var_name: Option<&str>,
658 mut members: Vec<EpStructMember>,
659 ) -> Result<EntryPointBinding, Error> {
660 let struct_name = self.namer.call(&struct_name);
661 members.sort_by_key(|m| InterfaceKey::new(m.binding.as_ref()));
665
666 write!(self.out, "struct {struct_name}")?;
667 writeln!(self.out, " {{")?;
668 let mut local_invocation_index_name = None;
669 let mut subgroup_id_used = false;
670 for m in members.iter() {
671 debug_assert!(m.binding.is_some());
674
675 match m.binding {
676 Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => {
677 subgroup_id_used = true;
678 }
679 Some(crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationIndex)) => {
680 local_invocation_index_name = Some(m.name.clone());
681 }
682 _ => (),
683 }
684
685 if is_subgroup_builtin_binding(&m.binding) {
686 continue;
687 }
688 write!(self.out, "{}", back::INDENT)?;
689 if let Some(ref binding) = m.binding {
690 self.write_modifier(binding)?;
691 }
692 self.write_type(module, m.ty)?;
693 write!(self.out, " {}", &m.name)?;
694 self.write_semantic(&m.binding, Some(shader_stage))?;
695 writeln!(self.out, ";")?;
696 }
697 if subgroup_id_used && local_invocation_index_name.is_none() {
698 let name = self.namer.call("local_invocation_index");
699 writeln!(self.out, "{}uint {name} : SV_GroupIndex;", back::INDENT)?;
700 local_invocation_index_name = Some(name);
701 }
702 writeln!(self.out, "}};")?;
703 writeln!(self.out)?;
704
705 match shader_stage.1 {
707 Io::Input => {
708 members.sort_by_key(|m| m.index);
710 }
711 Io::Output | Io::MeshVertices | Io::MeshPrimitives => {
712 }
714 }
715
716 Ok(EntryPointBinding {
717 arg_name: self
718 .namer
719 .call(var_name.unwrap_or(struct_name.to_lowercase().as_str())),
720 ty_name: struct_name,
721 members,
722 local_invocation_index_name,
723 })
724 }
725
726 fn write_ep_input_struct(
730 &mut self,
731 module: &Module,
732 func: &crate::Function,
733 stage: ShaderStage,
734 entry_point_name: &str,
735 ) -> Result<EntryPointBinding, Error> {
736 let struct_name = format!("{stage:?}Input_{entry_point_name}");
737
738 let mut fake_members = Vec::new();
739 for arg in func.arguments.iter() {
740 match module.types[arg.ty].inner {
745 TypeInner::Struct { ref members, .. } => {
746 for member in members.iter() {
747 let name = self.namer.call_or(&member.name, "member");
748 let index = fake_members.len() as u32;
749 fake_members.push(EpStructMember {
750 name,
751 ty: member.ty,
752 binding: member.binding.clone(),
753 index,
754 });
755 }
756 }
757 _ => {
758 let member_name = self.namer.call_or(&arg.name, "member");
759 let index = fake_members.len() as u32;
760 fake_members.push(EpStructMember {
761 name: member_name,
762 ty: arg.ty,
763 binding: arg.binding.clone(),
764 index,
765 });
766 }
767 }
768 }
769
770 self.write_interface_struct(module, (stage, Io::Input), struct_name, None, fake_members)
771 }
772
773 fn write_ep_output_struct(
777 &mut self,
778 module: &Module,
779 result: &crate::FunctionResult,
780 stage: ShaderStage,
781 entry_point_name: &str,
782 frag_ep: Option<&FragmentEntryPoint<'_>>,
783 ) -> Result<EntryPointBinding, Error> {
784 let struct_name = format!("{stage:?}Output_{entry_point_name}");
785
786 let empty = [];
787 let members = match module.types[result.ty].inner {
788 TypeInner::Struct { ref members, .. } => members,
789 ref other => {
790 log::error!("Unexpected {other:?} output type without a binding");
791 &empty[..]
792 }
793 };
794
795 let fs_input_locs = if let (Some(frag_ep), ShaderStage::Vertex) = (frag_ep, stage) {
800 let mut fs_input_locs = Vec::new();
801 for arg in frag_ep.func.arguments.iter() {
802 let mut push_if_location = |binding: &Option<crate::Binding>| match *binding {
803 Some(crate::Binding::Location { location, .. }) => fs_input_locs.push(location),
804 Some(crate::Binding::BuiltIn(_)) | None => {}
805 };
806
807 match frag_ep.module.types[arg.ty].inner {
810 TypeInner::Struct { ref members, .. } => {
811 for member in members.iter() {
812 push_if_location(&member.binding);
813 }
814 }
815 _ => push_if_location(&arg.binding),
816 }
817 }
818 fs_input_locs.sort();
819 Some(fs_input_locs)
820 } else {
821 None
822 };
823
824 let mut fake_members = Vec::new();
825 for (index, member) in members.iter().enumerate() {
826 if let Some(ref fs_input_locs) = fs_input_locs {
827 match member.binding {
828 Some(crate::Binding::Location { location, .. }) => {
829 if fs_input_locs.binary_search(&location).is_err() {
830 continue;
831 }
832 }
833 Some(crate::Binding::BuiltIn(_)) | None => {}
834 }
835 }
836
837 let member_name = self.namer.call_or(&member.name, "member");
838 fake_members.push(EpStructMember {
839 name: member_name,
840 ty: member.ty,
841 binding: member.binding.clone(),
842 index: index as u32,
843 });
844 }
845
846 self.write_interface_struct(module, (stage, Io::Output), struct_name, None, fake_members)
847 }
848
849 fn write_ep_interface(
853 &mut self,
854 module: &Module,
855 ep: &crate::EntryPoint,
856 ep_name: &str,
857 frag_ep: Option<&FragmentEntryPoint<'_>>,
858 ) -> Result<EntryPointInterface, Error> {
859 let func = &ep.function;
860 let stage = ep.stage;
861 Ok(EntryPointInterface {
862 input: if !func.arguments.is_empty()
863 && (stage == ShaderStage::Fragment
864 || func
865 .arguments
866 .iter()
867 .any(|arg| is_subgroup_builtin_binding(&arg.binding)))
868 {
869 Some(self.write_ep_input_struct(module, func, stage, ep_name)?)
870 } else {
871 None
872 },
873 output: match func.result {
874 Some(ref fr) if fr.binding.is_none() && stage == ShaderStage::Vertex => {
875 Some(self.write_ep_output_struct(module, fr, stage, ep_name, frag_ep)?)
876 }
877 _ => None,
878 },
879 mesh_vertices: if let Some(ref info) = ep.mesh_info {
880 Some(self.write_ep_mesh_output_struct(module, ep_name, false, info)?)
881 } else {
882 None
883 },
884 mesh_primitives: if let Some(ref info) = ep.mesh_info {
885 Some(self.write_ep_mesh_output_struct(module, ep_name, true, info)?)
886 } else {
887 None
888 },
889 mesh_indices: if let Some(ref info) = ep.mesh_info {
890 Some(self.write_ep_mesh_output_indices(info.topology)?)
891 } else {
892 None
893 },
894 })
895 }
896
897 fn write_ep_argument_initialization(
898 &mut self,
899 ep: &crate::EntryPoint,
900 ep_input: &EntryPointBinding,
901 fake_member: &EpStructMember,
902 ) -> BackendResult {
903 match fake_member.binding {
904 Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupSize)) => {
905 write!(self.out, "WaveGetLaneCount()")?
906 }
907 Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupInvocationId)) => {
908 write!(self.out, "WaveGetLaneIndex()")?
909 }
910 Some(crate::Binding::BuiltIn(crate::BuiltIn::NumSubgroups)) => write!(
911 self.out,
912 "({}u + WaveGetLaneCount() - 1u) / WaveGetLaneCount()",
913 ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2]
914 )?,
915 Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => {
916 write!(
917 self.out,
918 "{}.{} / WaveGetLaneCount()",
919 ep_input.arg_name,
920 ep_input.local_invocation_index_name.as_ref().unwrap()
922 )?;
923 }
924 _ => {
925 write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?;
926 }
927 }
928 Ok(())
929 }
930
931 fn write_ep_arguments_initialization(
933 &mut self,
934 module: &Module,
935 func: &crate::Function,
936 ep_index: u16,
937 ) -> BackendResult {
938 let ep = &module.entry_points[ep_index as usize];
939 let ep_input = match self
940 .entry_point_io
941 .get_mut(&(ep_index as usize))
942 .unwrap()
943 .input
944 .take()
945 {
946 Some(ep_input) => ep_input,
947 None => return Ok(()),
948 };
949 let mut fake_iter = ep_input.members.iter();
950 for (arg_index, arg) in func.arguments.iter().enumerate() {
951 write!(self.out, "{}", back::INDENT)?;
952 self.write_type(module, arg.ty)?;
953 let arg_name = &self.names[&NameKey::EntryPointArgument(ep_index, arg_index as u32)];
954 write!(self.out, " {arg_name}")?;
955 match module.types[arg.ty].inner {
956 TypeInner::Array { base, size, .. } => {
957 self.write_array_size(module, base, size)?;
958 write!(self.out, " = ")?;
959 self.write_ep_argument_initialization(
960 ep,
961 &ep_input,
962 fake_iter.next().unwrap(),
963 )?;
964 writeln!(self.out, ";")?;
965 }
966 TypeInner::Struct { ref members, .. } => {
967 write!(self.out, " = {{ ")?;
968 for index in 0..members.len() {
969 if index != 0 {
970 write!(self.out, ", ")?;
971 }
972 self.write_ep_argument_initialization(
973 ep,
974 &ep_input,
975 fake_iter.next().unwrap(),
976 )?;
977 }
978 writeln!(self.out, " }};")?;
979 }
980 _ => {
981 write!(self.out, " = ")?;
982 self.write_ep_argument_initialization(
983 ep,
984 &ep_input,
985 fake_iter.next().unwrap(),
986 )?;
987 writeln!(self.out, ";")?;
988 }
989 }
990 }
991 assert!(fake_iter.next().is_none());
992 Ok(())
993 }
994
995 fn write_global(
999 &mut self,
1000 module: &Module,
1001 handle: Handle<crate::GlobalVariable>,
1002 ) -> BackendResult {
1003 let global = &module.global_variables[handle];
1004 let inner = &module.types[global.ty].inner;
1005
1006 let handle_ty = match *inner {
1007 TypeInner::BindingArray { ref base, .. } => &module.types[*base].inner,
1008 _ => inner,
1009 };
1010
1011 let is_external_texture = matches!(
1015 *handle_ty,
1016 TypeInner::Image {
1017 class: crate::ImageClass::External,
1018 ..
1019 }
1020 );
1021 if is_external_texture {
1022 return self.write_global_external_texture(module, handle, global);
1023 }
1024
1025 if let Some(ref binding) = global.binding {
1026 if let Err(err) = self.options.resolve_resource_binding(binding) {
1027 log::debug!(
1028 "Skipping global {:?} (name {:?}) for being inaccessible: {}",
1029 handle,
1030 global.name,
1031 err,
1032 );
1033 return Ok(());
1034 }
1035 }
1036
1037 let is_sampler = matches!(*handle_ty, TypeInner::Sampler { .. });
1039
1040 if is_sampler {
1041 return self.write_global_sampler(module, handle, global);
1042 }
1043
1044 let register_ty = match global.space {
1046 crate::AddressSpace::Function => unreachable!("Function address space"),
1047 crate::AddressSpace::Private => {
1048 write!(self.out, "static ")?;
1049 self.write_type(module, global.ty)?;
1050 ""
1051 }
1052 crate::AddressSpace::WorkGroup | crate::AddressSpace::TaskPayload => {
1053 write!(self.out, "groupshared ")?;
1054 self.write_type(module, global.ty)?;
1055 ""
1056 }
1057 crate::AddressSpace::Uniform => {
1058 write!(self.out, "cbuffer")?;
1061 "b"
1062 }
1063 crate::AddressSpace::Storage { access } => {
1064 if global
1065 .memory_decorations
1066 .contains(crate::MemoryDecorations::COHERENT)
1067 {
1068 write!(self.out, "globallycoherent ")?;
1069 }
1070 let (prefix, register) = if access.contains(crate::StorageAccess::STORE) {
1071 ("RW", "u")
1072 } else {
1073 ("", "t")
1074 };
1075 write!(self.out, "{prefix}ByteAddressBuffer")?;
1076 register
1077 }
1078 crate::AddressSpace::Handle => {
1079 let register = match *handle_ty {
1080 TypeInner::Image {
1082 class: crate::ImageClass::Storage { .. },
1083 ..
1084 } => "u",
1085 _ => "t",
1086 };
1087 self.write_type(module, global.ty)?;
1088 register
1089 }
1090 crate::AddressSpace::Immediate => {
1091 write!(self.out, "ConstantBuffer<")?;
1093 "b"
1094 }
1095 crate::AddressSpace::RayPayload | crate::AddressSpace::IncomingRayPayload => {
1096 unimplemented!()
1097 }
1098 };
1099
1100 if global.space == crate::AddressSpace::Immediate {
1103 self.write_global_type(module, global.ty)?;
1104
1105 if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1107 self.write_array_size(module, base, size)?;
1108 }
1109
1110 write!(self.out, ">")?;
1112 }
1113
1114 let name = &self.names[&NameKey::GlobalVariable(handle)];
1115 write!(self.out, " {name}")?;
1116
1117 if global.space == crate::AddressSpace::Immediate {
1120 match module.types[global.ty].inner {
1121 TypeInner::Struct { .. } => {}
1122 _ => {
1123 return Err(Error::Unimplemented(format!(
1124 "push-constant '{name}' has non-struct type; tracked by: https://github.com/gfx-rs/wgpu/issues/5683"
1125 )));
1126 }
1127 }
1128
1129 let target = self
1130 .options
1131 .immediates_target
1132 .as_ref()
1133 .expect("No bind target was defined for the immediates block");
1134 write!(self.out, ": register(b{}", target.register)?;
1135 if target.space != 0 {
1136 write!(self.out, ", space{}", target.space)?;
1137 }
1138 write!(self.out, ")")?;
1139 }
1140
1141 if let Some(ref binding) = global.binding {
1142 let bt = self.options.resolve_resource_binding(binding).unwrap();
1144
1145 if let TypeInner::BindingArray { base, size, .. } = module.types[global.ty].inner {
1147 if let Some(overridden_size) = bt.binding_array_size {
1148 write!(self.out, "[{overridden_size}]")?;
1149 } else {
1150 self.write_array_size(module, base, size)?;
1151 }
1152 }
1153
1154 write!(self.out, " : register({}{}", register_ty, bt.register)?;
1155 if bt.space != 0 {
1156 write!(self.out, ", space{}", bt.space)?;
1157 }
1158 write!(self.out, ")")?;
1159 } else {
1160 if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1162 self.write_array_size(module, base, size)?;
1163 }
1164 if global.space == crate::AddressSpace::Private {
1165 write!(self.out, " = ")?;
1166 if let Some(init) = global.init {
1167 self.write_const_expression(module, init, &module.global_expressions)?;
1168 } else {
1169 self.write_default_init(module, global.ty)?;
1170 }
1171 }
1172 }
1173
1174 if global.space == crate::AddressSpace::Uniform {
1175 write!(self.out, " {{ ")?;
1176
1177 self.write_global_type(module, global.ty)?;
1178
1179 write!(
1180 self.out,
1181 " {}",
1182 &self.names[&NameKey::GlobalVariable(handle)]
1183 )?;
1184
1185 if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1187 self.write_array_size(module, base, size)?;
1188 }
1189
1190 writeln!(self.out, "; }}")?;
1191 } else {
1192 writeln!(self.out, ";")?;
1193 }
1194
1195 Ok(())
1196 }
1197
1198 fn write_global_sampler(
1199 &mut self,
1200 module: &Module,
1201 handle: Handle<crate::GlobalVariable>,
1202 global: &crate::GlobalVariable,
1203 ) -> BackendResult {
1204 let binding = *global.binding.as_ref().unwrap();
1205
1206 let key = super::SamplerIndexBufferKey {
1207 group: binding.group,
1208 };
1209 self.write_wrapped_sampler_buffer(key)?;
1210
1211 let bt = self.options.resolve_resource_binding(&binding).unwrap();
1213
1214 match module.types[global.ty].inner {
1215 TypeInner::Sampler { comparison } => {
1216 write!(self.out, "static const ")?;
1223 self.write_type(module, global.ty)?;
1224
1225 let heap_var = if comparison {
1226 COMPARISON_SAMPLER_HEAP_VAR
1227 } else {
1228 SAMPLER_HEAP_VAR
1229 };
1230
1231 let index_buffer_name = &self.wrapped.sampler_index_buffers[&key];
1232 let name = &self.names[&NameKey::GlobalVariable(handle)];
1233 writeln!(
1234 self.out,
1235 " {name} = {heap_var}[{index_buffer_name}[{register}]];",
1236 register = bt.register
1237 )?;
1238 }
1239 TypeInner::BindingArray { .. } => {
1240 let name = &self.names[&NameKey::GlobalVariable(handle)];
1246 writeln!(
1247 self.out,
1248 "static const uint {name} = {register};",
1249 register = bt.register
1250 )?;
1251 }
1252 _ => unreachable!(),
1253 };
1254
1255 Ok(())
1256 }
1257
1258 fn write_global_external_texture(
1262 &mut self,
1263 module: &Module,
1264 handle: Handle<crate::GlobalVariable>,
1265 global: &crate::GlobalVariable,
1266 ) -> BackendResult {
1267 let res_binding = global
1268 .binding
1269 .as_ref()
1270 .expect("External texture global variables must have a resource binding");
1271 let ext_tex_bindings = match self
1272 .options
1273 .resolve_external_texture_resource_binding(res_binding)
1274 {
1275 Ok(bindings) => bindings,
1276 Err(err) => {
1277 log::debug!(
1278 "Skipping global {:?} (name {:?}) for being inaccessible: {}",
1279 handle,
1280 global.name,
1281 err,
1282 );
1283 return Ok(());
1284 }
1285 };
1286
1287 let mut write_plane = |bt: &super::BindTarget, name| -> BackendResult {
1288 write!(
1289 self.out,
1290 "Texture2D<float4> {}: register(t{}",
1291 name, bt.register
1292 )?;
1293 if bt.space != 0 {
1294 write!(self.out, ", space{}", bt.space)?;
1295 }
1296 writeln!(self.out, ");")?;
1297 Ok(())
1298 };
1299 for (i, bt) in ext_tex_bindings.planes.iter().enumerate() {
1300 let plane_name = &self.names
1301 [&NameKey::ExternalTextureGlobalVariable(handle, ExternalTextureNameKey::Plane(i))];
1302 write_plane(bt, plane_name)?;
1303 }
1304
1305 let params_name = &self.names
1306 [&NameKey::ExternalTextureGlobalVariable(handle, ExternalTextureNameKey::Params)];
1307 let params_ty_name =
1308 &self.names[&NameKey::Type(module.special_types.external_texture_params.unwrap())];
1309 write!(
1310 self.out,
1311 "cbuffer {}: register(b{}",
1312 params_name, ext_tex_bindings.params.register
1313 )?;
1314 if ext_tex_bindings.params.space != 0 {
1315 write!(self.out, ", space{}", ext_tex_bindings.params.space)?;
1316 }
1317 writeln!(self.out, ") {{ {params_ty_name} {params_name}; }};")?;
1318
1319 Ok(())
1320 }
1321
1322 fn write_global_constant(
1327 &mut self,
1328 module: &Module,
1329 handle: Handle<crate::Constant>,
1330 ) -> BackendResult {
1331 write!(self.out, "static const ")?;
1332 let constant = &module.constants[handle];
1333 self.write_type(module, constant.ty)?;
1334 let name = &self.names[&NameKey::Constant(handle)];
1335 write!(self.out, " {name}")?;
1336 if let TypeInner::Array { base, size, .. } = module.types[constant.ty].inner {
1338 self.write_array_size(module, base, size)?;
1339 }
1340 write!(self.out, " = ")?;
1341 self.write_const_expression(module, constant.init, &module.global_expressions)?;
1342 writeln!(self.out, ";")?;
1343 Ok(())
1344 }
1345
1346 pub(super) fn write_array_size(
1347 &mut self,
1348 module: &Module,
1349 base: Handle<crate::Type>,
1350 size: crate::ArraySize,
1351 ) -> BackendResult {
1352 write!(self.out, "[")?;
1353
1354 match size.resolve(module.to_ctx())? {
1355 proc::IndexableLength::Known(size) => {
1356 write!(self.out, "{size}")?;
1357 }
1358 proc::IndexableLength::Dynamic => unreachable!(),
1359 }
1360
1361 write!(self.out, "]")?;
1362
1363 if let TypeInner::Array {
1364 base: next_base,
1365 size: next_size,
1366 ..
1367 } = module.types[base].inner
1368 {
1369 self.write_array_size(module, next_base, next_size)?;
1370 }
1371
1372 Ok(())
1373 }
1374
1375 fn write_struct(
1380 &mut self,
1381 module: &Module,
1382 handle: Handle<crate::Type>,
1383 members: &[crate::StructMember],
1384 span: u32,
1385 shader_stage: Option<(ShaderStage, Io)>,
1386 ) -> BackendResult {
1387 let struct_name = &self.names[&NameKey::Type(handle)];
1389 writeln!(self.out, "struct {struct_name} {{")?;
1390
1391 let mut last_offset = 0;
1392 for (index, member) in members.iter().enumerate() {
1393 if member.binding.is_none() && member.offset > last_offset {
1394 let padding = (member.offset - last_offset) / 4;
1398 for i in 0..padding {
1399 writeln!(self.out, "{}int _pad{}_{};", back::INDENT, index, i)?;
1400 }
1401 }
1402 let ty_inner = &module.types[member.ty].inner;
1403 last_offset = member.offset + ty_inner.size_hlsl(module.to_ctx())?;
1404
1405 write!(self.out, "{}", back::INDENT)?;
1407
1408 match module.types[member.ty].inner {
1409 TypeInner::Array { base, size, .. } => {
1410 self.write_global_type(module, member.ty)?;
1413
1414 write!(
1416 self.out,
1417 " {}",
1418 &self.names[&NameKey::StructMember(handle, index as u32)]
1419 )?;
1420 self.write_array_size(module, base, size)?;
1422 }
1423 TypeInner::Matrix {
1426 rows,
1427 columns,
1428 scalar,
1429 } if member.binding.is_none() && rows == crate::VectorSize::Bi => {
1430 let vec_ty = TypeInner::Vector { size: rows, scalar };
1431 let field_name_key = NameKey::StructMember(handle, index as u32);
1432
1433 for i in 0..columns as u8 {
1434 if i != 0 {
1435 write!(self.out, "; ")?;
1436 }
1437 self.write_value_type(module, &vec_ty)?;
1438 write!(self.out, " {}_{}", &self.names[&field_name_key], i)?;
1439 }
1440 }
1441 _ => {
1442 if let Some(ref binding) = member.binding {
1444 self.write_modifier(binding)?;
1445 }
1446
1447 if let TypeInner::Matrix { .. } = module.types[member.ty].inner {
1451 write!(self.out, "row_major ")?;
1452 }
1453
1454 self.write_type(module, member.ty)?;
1456 write!(
1457 self.out,
1458 " {}",
1459 &self.names[&NameKey::StructMember(handle, index as u32)]
1460 )?;
1461 }
1462 }
1463
1464 self.write_semantic(&member.binding, shader_stage)?;
1465 writeln!(self.out, ";")?;
1466 }
1467
1468 if members.last().unwrap().binding.is_none() && span > last_offset {
1470 let padding = (span - last_offset) / 4;
1471 for i in 0..padding {
1472 writeln!(self.out, "{}int _end_pad_{};", back::INDENT, i)?;
1473 }
1474 }
1475
1476 writeln!(self.out, "}};")?;
1477 Ok(())
1478 }
1479
1480 pub(super) fn write_global_type(
1485 &mut self,
1486 module: &Module,
1487 ty: Handle<crate::Type>,
1488 ) -> BackendResult {
1489 let matrix_data = get_inner_matrix_data(module, ty);
1490
1491 if let Some(MatrixType {
1494 columns,
1495 rows: crate::VectorSize::Bi,
1496 width: 4,
1497 }) = matrix_data
1498 {
1499 write!(self.out, "__mat{}x2", columns as u8)?;
1500 } else {
1501 if matrix_data.is_some() {
1505 write!(self.out, "row_major ")?;
1506 }
1507
1508 self.write_type(module, ty)?;
1509 }
1510
1511 Ok(())
1512 }
1513
1514 pub(super) fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
1519 let inner = &module.types[ty].inner;
1520 match *inner {
1521 TypeInner::Struct { .. } => write!(self.out, "{}", self.names[&NameKey::Type(ty)])?,
1522 TypeInner::Array { base, .. } | TypeInner::BindingArray { base, .. } => {
1524 self.write_type(module, base)?
1525 }
1526 ref other => self.write_value_type(module, other)?,
1527 }
1528
1529 Ok(())
1530 }
1531
1532 pub(super) fn write_value_type(&mut self, module: &Module, inner: &TypeInner) -> BackendResult {
1537 match *inner {
1538 TypeInner::Scalar(scalar) | TypeInner::Atomic(scalar) => {
1539 write!(self.out, "{}", scalar.to_hlsl_str()?)?;
1540 }
1541 TypeInner::Vector { size, scalar } => {
1542 write!(
1543 self.out,
1544 "{}{}",
1545 scalar.to_hlsl_str()?,
1546 common::vector_size_str(size)
1547 )?;
1548 }
1549 TypeInner::Matrix {
1550 columns,
1551 rows,
1552 scalar,
1553 } => {
1554 write!(
1559 self.out,
1560 "{}{}x{}",
1561 scalar.to_hlsl_str()?,
1562 common::vector_size_str(columns),
1563 common::vector_size_str(rows),
1564 )?;
1565 }
1566 TypeInner::Image {
1567 dim,
1568 arrayed,
1569 class,
1570 } => {
1571 self.write_image_type(dim, arrayed, class)?;
1572 }
1573 TypeInner::Sampler { comparison } => {
1574 let sampler = if comparison {
1575 "SamplerComparisonState"
1576 } else {
1577 "SamplerState"
1578 };
1579 write!(self.out, "{sampler}")?;
1580 }
1581 TypeInner::Array { base, size, .. } | TypeInner::BindingArray { base, size } => {
1585 self.write_array_size(module, base, size)?;
1586 }
1587 TypeInner::AccelerationStructure { .. } => {
1588 write!(self.out, "RaytracingAccelerationStructure")?;
1589 }
1590 TypeInner::RayQuery { .. } => {
1591 write!(self.out, "RayQuery<RAY_FLAG_NONE>")?;
1593 }
1594 _ => return Err(Error::Unimplemented(format!("write_value_type {inner:?}"))),
1595 }
1596
1597 Ok(())
1598 }
1599
1600 fn write_function(
1604 &mut self,
1605 module: &Module,
1606 name: &str,
1607 func: &crate::Function,
1608 func_ctx: &back::FunctionCtx<'_>,
1609 info: &valid::FunctionInfo,
1610 header: String,
1611 ) -> BackendResult {
1612 self.update_expressions_to_bake(module, func, info);
1615 let ep = match func_ctx.ty {
1616 back::FunctionType::EntryPoint(idx) => Some(&module.entry_points[idx as usize]),
1617 back::FunctionType::Function(_) => None,
1618 };
1619
1620 let nested = matches!(
1621 ep,
1622 Some(crate::EntryPoint {
1623 stage: ShaderStage::Task | ShaderStage::Mesh,
1624 ..
1625 })
1626 );
1627 if !nested {
1628 write!(self.out, "{header}")?;
1629 }
1630
1631 if let Some(ref result) = func.result {
1632 let array_return_type = match module.types[result.ty].inner {
1634 TypeInner::Array { base, size, .. } => {
1635 let array_return_type = self.namer.call(&format!("ret_{name}"));
1636 write!(self.out, "typedef ")?;
1637 self.write_type(module, result.ty)?;
1638 write!(self.out, " {array_return_type}")?;
1639 self.write_array_size(module, base, size)?;
1640 writeln!(self.out, ";")?;
1641 Some(array_return_type)
1642 }
1643 _ => None,
1644 };
1645
1646 if let Some(
1648 ref binding @ crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }),
1649 ) = result.binding
1650 {
1651 self.write_modifier(binding)?;
1652 }
1653
1654 match func_ctx.ty {
1656 back::FunctionType::Function(_) => {
1657 if let Some(array_return_type) = array_return_type {
1658 write!(self.out, "{array_return_type}")?;
1659 } else {
1660 self.write_type(module, result.ty)?;
1661 }
1662 }
1663 back::FunctionType::EntryPoint(index) => {
1664 if let Some(ref ep_output) =
1665 self.entry_point_io.get(&(index as usize)).unwrap().output
1666 {
1667 write!(self.out, "{}", ep_output.ty_name)?;
1668 } else {
1669 self.write_type(module, result.ty)?;
1670 }
1671 }
1672 }
1673 } else {
1674 write!(self.out, "void")?;
1675 }
1676
1677 let nested_name = if nested {
1678 self.namer.call(&format!("_{name}"))
1679 } else {
1680 name.to_string()
1681 };
1682
1683 write!(self.out, " {nested_name}(")?;
1685
1686 let need_workgroup_variables_initialization =
1687 self.need_workgroup_variables_initialization(func_ctx, module);
1688
1689 let mut any_args_written = false;
1690 let mut separator = || {
1691 if any_args_written {
1692 ", "
1693 } else {
1694 any_args_written = true;
1695 ""
1696 }
1697 };
1698
1699 let needs_local_invocation_index_name = need_workgroup_variables_initialization || nested;
1700 let mut local_invocation_index_name = None;
1701 let mut nested_wgsl_args: Vec<String> = Vec::new();
1704 let mut nested_task_payload_name: Option<String> = None;
1705 match func_ctx.ty {
1707 back::FunctionType::Function(handle) => {
1708 for (index, arg) in func.arguments.iter().enumerate() {
1709 write!(self.out, "{}", separator())?;
1710 self.write_function_argument(module, handle, arg, index)?;
1711 }
1712 for (var_handle, var) in module.global_variables.iter() {
1714 let uses = info[var_handle];
1715 if uses.contains(valid::GlobalUse::READ)
1716 && !uses.contains(valid::GlobalUse::WRITE)
1717 && var.space == crate::AddressSpace::TaskPayload
1718 {
1719 self.function_task_payload_var.insert(handle, var_handle);
1720 write!(self.out, "{}in ", separator())?;
1721
1722 self.write_type(module, var.ty)?;
1723 let name = &self.names[&NameKey::GlobalVariable(var_handle)];
1724 write!(self.out, " {name}")?;
1725 break;
1726 }
1727 }
1728 }
1729 back::FunctionType::EntryPoint(ep_index) => {
1730 let ep = &module.entry_points[ep_index as usize];
1731 if let Some(ref ep_input) =
1732 self.entry_point_io.get(&(ep_index as usize)).unwrap().input
1733 {
1734 write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name)?;
1735 separator();
1736 nested_wgsl_args.push(ep_input.arg_name.clone());
1737 } else {
1738 let stage = ep.stage;
1739 for (index, arg) in func.arguments.iter().enumerate() {
1740 write!(self.out, "{}", separator())?;
1741 self.write_type(module, arg.ty)?;
1742
1743 let argument_name =
1744 &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)];
1745
1746 if arg.binding
1747 == Some(crate::Binding::BuiltIn(
1748 crate::BuiltIn::LocalInvocationIndex,
1749 ))
1750 {
1751 local_invocation_index_name = Some(argument_name.clone());
1752 }
1753
1754 nested_wgsl_args.push(argument_name.clone());
1755 write!(self.out, " {argument_name}")?;
1756 if let TypeInner::Array { base, size, .. } = module.types[arg.ty].inner {
1757 self.write_array_size(module, base, size)?;
1758 }
1759
1760 self.write_semantic(&arg.binding, Some((stage, Io::Input)))?;
1761 }
1762 }
1763 if ep.stage == ShaderStage::Mesh {
1764 if let Some(var_handle) = ep.task_payload {
1765 let var = &module.global_variables[var_handle];
1766 write!(self.out, "{}in ", separator())?;
1767 self.write_type(module, var.ty)?;
1768 let arg_name = &self.names[&NameKey::GlobalVariable(var_handle)];
1769 write!(self.out, " {arg_name}")?;
1770 nested_task_payload_name = Some(arg_name.clone());
1771 if let TypeInner::Array { base, size, .. } = module.types[var.ty].inner {
1772 self.write_array_size(module, base, size)?;
1773 }
1774 }
1775 }
1776 if needs_local_invocation_index_name && local_invocation_index_name.is_none() {
1777 let name = self.namer.call("local_invocation_index");
1778 write!(self.out, "{}uint {name}", separator())?;
1779 write!(self.out, " : SV_GroupIndex")?;
1780 local_invocation_index_name = Some(name);
1781 }
1782 }
1783 }
1784 write!(self.out, ")")?;
1786
1787 if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1789 let stage = module.entry_points[index as usize].stage;
1790 if let Some(crate::FunctionResult { ref binding, .. }) = func.result {
1791 self.write_semantic(binding, Some((stage, Io::Output)))?;
1792 }
1793 }
1794
1795 writeln!(self.out)?;
1797 writeln!(self.out, "{{")?;
1798
1799 if need_workgroup_variables_initialization && !nested {
1800 let back::FunctionType::EntryPoint(index) = func_ctx.ty else {
1801 unreachable!();
1802 };
1803 writeln!(
1804 self.out,
1805 "{}if ({} == 0) {{",
1806 back::INDENT,
1807 local_invocation_index_name.as_ref().unwrap(),
1810 )?;
1811 self.write_workgroup_variables_initialization(
1812 func_ctx,
1813 module,
1814 module.entry_points[index as usize].stage,
1815 )?;
1816
1817 writeln!(self.out, "{}}}", back::INDENT)?;
1818 self.write_control_barrier(crate::Barrier::WORK_GROUP, back::Level(1))?;
1819 }
1820
1821 if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1822 self.write_ep_arguments_initialization(module, func, index)?;
1823 }
1824
1825 for (handle, local) in func.local_variables.iter() {
1827 write!(self.out, "{}", back::INDENT)?;
1829
1830 self.write_type(module, local.ty)?;
1833 write!(self.out, " {}", self.names[&func_ctx.name_key(handle)])?;
1834 if let TypeInner::Array { base, size, .. } = module.types[local.ty].inner {
1836 self.write_array_size(module, base, size)?;
1837 }
1838
1839 let is_ray_query = match module.types[local.ty].inner {
1840 TypeInner::RayQuery { .. } => true,
1842 _ => {
1843 write!(self.out, " = ")?;
1844 if let Some(init) = local.init {
1846 self.write_expr(module, init, func_ctx)?;
1847 } else {
1848 self.write_default_init(module, local.ty)?;
1850 }
1851 false
1852 }
1853 };
1854 writeln!(self.out, ";")?;
1856 if is_ray_query {
1858 write!(self.out, "{}", back::INDENT)?;
1859 self.write_value_type(module, &TypeInner::Scalar(Scalar::U32))?;
1860 writeln!(
1861 self.out,
1862 " {RAY_QUERY_TRACKER_VARIABLE_PREFIX}{} = 0;",
1863 self.names[&func_ctx.name_key(handle)]
1864 )?;
1865 }
1866 }
1867
1868 if !func.local_variables.is_empty() {
1869 writeln!(self.out)?;
1870 }
1871
1872 for sta in func.body.iter() {
1874 self.write_stmt(module, sta, func_ctx, back::Level(1))?;
1876 }
1877
1878 writeln!(self.out, "}}")?;
1879
1880 if nested {
1881 self.write_nested_function_outer(
1882 module,
1883 func_ctx,
1884 &header,
1885 name,
1886 need_workgroup_variables_initialization,
1887 &nested_name,
1888 ep.unwrap(),
1889 NestedEntryPointArgs {
1890 user_args: nested_wgsl_args,
1891 task_payload: nested_task_payload_name,
1892 local_invocation_index: local_invocation_index_name.unwrap(),
1894 },
1895 )?;
1896 }
1897
1898 self.named_expressions.clear();
1899
1900 Ok(())
1901 }
1902
1903 fn write_function_argument(
1904 &mut self,
1905 module: &Module,
1906 handle: Handle<crate::Function>,
1907 arg: &crate::FunctionArgument,
1908 index: usize,
1909 ) -> BackendResult {
1910 if let TypeInner::Image {
1913 class: crate::ImageClass::External,
1914 ..
1915 } = module.types[arg.ty].inner
1916 {
1917 return self.write_function_external_texture_argument(module, handle, index);
1918 }
1919
1920 let arg_ty = match module.types[arg.ty].inner {
1922 TypeInner::Pointer { base, .. } => {
1924 write!(self.out, "inout ")?;
1926 base
1927 }
1928 _ => arg.ty,
1929 };
1930 self.write_type(module, arg_ty)?;
1931
1932 let argument_name = &self.names[&NameKey::FunctionArgument(handle, index as u32)];
1933
1934 write!(self.out, " {argument_name}")?;
1936 if let TypeInner::Array { base, size, .. } = module.types[arg_ty].inner {
1937 self.write_array_size(module, base, size)?;
1938 }
1939
1940 Ok(())
1941 }
1942
1943 fn write_function_external_texture_argument(
1944 &mut self,
1945 module: &Module,
1946 handle: Handle<crate::Function>,
1947 index: usize,
1948 ) -> BackendResult {
1949 let plane_names = [0, 1, 2].map(|i| {
1950 &self.names[&NameKey::ExternalTextureFunctionArgument(
1951 handle,
1952 index as u32,
1953 ExternalTextureNameKey::Plane(i),
1954 )]
1955 });
1956 let params_name = &self.names[&NameKey::ExternalTextureFunctionArgument(
1957 handle,
1958 index as u32,
1959 ExternalTextureNameKey::Params,
1960 )];
1961 let params_ty_name =
1962 &self.names[&NameKey::Type(module.special_types.external_texture_params.unwrap())];
1963 write!(
1964 self.out,
1965 "Texture2D<float4> {}, Texture2D<float4> {}, Texture2D<float4> {}, {params_ty_name} {params_name}",
1966 plane_names[0], plane_names[1], plane_names[2],
1967 )?;
1968 Ok(())
1969 }
1970
1971 fn need_workgroup_variables_initialization(
1972 &mut self,
1973 func_ctx: &back::FunctionCtx,
1974 module: &Module,
1975 ) -> bool {
1976 self.options.zero_initialize_workgroup_memory
1977 && func_ctx.ty.is_compute_like_entry_point(module)
1978 && module.global_variables.iter().any(|(handle, var)| {
1979 !func_ctx.info[handle].is_empty() && var.space.is_workgroup_like()
1980 })
1981 }
1982
1983 pub(super) fn write_workgroup_variables_initialization(
1984 &mut self,
1985 func_ctx: &back::FunctionCtx,
1986 module: &Module,
1987 stage: ShaderStage,
1988 ) -> BackendResult {
1989 let vars = module.global_variables.iter().filter(|&(handle, var)| {
1990 let task_needs_zero =
1992 (var.space == crate::AddressSpace::TaskPayload) && stage == ShaderStage::Task;
1993 !func_ctx.info[handle].is_empty()
1994 && (var.space == crate::AddressSpace::WorkGroup || task_needs_zero)
1995 });
1996
1997 for (handle, var) in vars {
1998 let name = &self.names[&NameKey::GlobalVariable(handle)];
1999 write!(self.out, "{}{} = ", back::Level(2), name)?;
2000 self.write_default_init(module, var.ty)?;
2001 writeln!(self.out, ";")?;
2002 }
2003 Ok(())
2004 }
2005
2006 fn write_switch(
2008 &mut self,
2009 module: &Module,
2010 func_ctx: &back::FunctionCtx<'_>,
2011 level: back::Level,
2012 selector: Handle<crate::Expression>,
2013 cases: &[crate::SwitchCase],
2014 ) -> BackendResult {
2015 let indent_level_1 = level.next();
2017 let indent_level_2 = indent_level_1.next();
2018
2019 if let Some(variable) = self.continue_ctx.enter_switch(&mut self.namer) {
2021 writeln!(self.out, "{level}bool {variable} = false;",)?;
2022 };
2023
2024 let one_body = cases
2029 .iter()
2030 .rev()
2031 .skip(1)
2032 .all(|case| case.fall_through && case.body.is_empty());
2033 if one_body {
2034 writeln!(self.out, "{level}do {{")?;
2036 if let Some(case) = cases.last() {
2040 for sta in case.body.iter() {
2041 self.write_stmt(module, sta, func_ctx, indent_level_1)?;
2042 }
2043 }
2044 writeln!(self.out, "{level}}} while(false);")?;
2046 } else {
2047 write!(self.out, "{level}")?;
2049 write!(self.out, "switch(")?;
2050 self.write_expr(module, selector, func_ctx)?;
2051 writeln!(self.out, ") {{")?;
2052
2053 for (i, case) in cases.iter().enumerate() {
2054 match case.value {
2055 crate::SwitchValue::I32(value) => {
2056 write!(self.out, "{indent_level_1}case {value}:")?
2057 }
2058 crate::SwitchValue::U32(value) => {
2059 write!(self.out, "{indent_level_1}case {value}u:")?
2060 }
2061 crate::SwitchValue::Default => write!(self.out, "{indent_level_1}default:")?,
2062 }
2063
2064 let write_block_braces = !(case.fall_through && case.body.is_empty());
2071 if write_block_braces {
2072 writeln!(self.out, " {{")?;
2073 } else {
2074 writeln!(self.out)?;
2075 }
2076
2077 if case.fall_through && !case.body.is_empty() {
2095 let curr_len = i + 1;
2096 let end_case_idx = curr_len
2097 + cases
2098 .iter()
2099 .skip(curr_len)
2100 .position(|case| !case.fall_through)
2101 .unwrap();
2102 let indent_level_3 = indent_level_2.next();
2103 for case in &cases[i..=end_case_idx] {
2104 writeln!(self.out, "{indent_level_2}{{")?;
2105 let prev_len = self.named_expressions.len();
2106 for sta in case.body.iter() {
2107 self.write_stmt(module, sta, func_ctx, indent_level_3)?;
2108 }
2109 self.named_expressions.truncate(prev_len);
2111 writeln!(self.out, "{indent_level_2}}}")?;
2112 }
2113
2114 let last_case = &cases[end_case_idx];
2115 if last_case.body.last().is_none_or(|s| !s.is_terminator()) {
2116 writeln!(self.out, "{indent_level_2}break;")?;
2117 }
2118 } else {
2119 for sta in case.body.iter() {
2120 self.write_stmt(module, sta, func_ctx, indent_level_2)?;
2121 }
2122 if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator()) {
2123 writeln!(self.out, "{indent_level_2}break;")?;
2124 }
2125 }
2126
2127 if write_block_braces {
2128 writeln!(self.out, "{indent_level_1}}}")?;
2129 }
2130 }
2131
2132 writeln!(self.out, "{level}}}")?;
2133 }
2134
2135 use back::continue_forward::ExitControlFlow;
2137 let op = match self.continue_ctx.exit_switch() {
2138 ExitControlFlow::None => None,
2139 ExitControlFlow::Continue { variable } => Some(("continue", variable)),
2140 ExitControlFlow::Break { variable } => Some(("break", variable)),
2141 };
2142 if let Some((control_flow, variable)) = op {
2143 writeln!(self.out, "{level}if ({variable}) {{")?;
2144 writeln!(self.out, "{indent_level_1}{control_flow};")?;
2145 writeln!(self.out, "{level}}}")?;
2146 }
2147
2148 Ok(())
2149 }
2150
2151 fn write_index(
2152 &mut self,
2153 module: &Module,
2154 index: Index,
2155 func_ctx: &back::FunctionCtx<'_>,
2156 ) -> BackendResult {
2157 match index {
2158 Index::Static(index) => {
2159 write!(self.out, "{index}")?;
2160 }
2161 Index::Expression(index) => {
2162 self.write_expr(module, index, func_ctx)?;
2163 }
2164 }
2165 Ok(())
2166 }
2167
2168 fn write_stmt(
2173 &mut self,
2174 module: &Module,
2175 stmt: &crate::Statement,
2176 func_ctx: &back::FunctionCtx<'_>,
2177 level: back::Level,
2178 ) -> BackendResult {
2179 use crate::Statement;
2180
2181 match *stmt {
2182 Statement::Emit(ref range) => {
2183 for handle in range.clone() {
2184 let ptr_class = func_ctx.resolve_type(handle, &module.types).pointer_space();
2185 let expr_name = if ptr_class.is_some() {
2186 None
2190 } else if let Some(name) = func_ctx.named_expressions.get(&handle) {
2191 Some(self.namer.call(name))
2196 } else if self.need_bake_expressions.contains(&handle) {
2197 Some(Baked(handle).to_string())
2198 } else {
2199 None
2200 };
2201
2202 if let Some(name) = expr_name {
2203 write!(self.out, "{level}")?;
2204 self.write_named_expr(module, handle, name, handle, func_ctx)?;
2205 }
2206 }
2207 }
2208 Statement::Block(ref block) => {
2210 write!(self.out, "{level}")?;
2211 writeln!(self.out, "{{")?;
2212 for sta in block.iter() {
2213 self.write_stmt(module, sta, func_ctx, level.next())?
2215 }
2216 writeln!(self.out, "{level}}}")?
2217 }
2218 Statement::If {
2220 condition,
2221 ref accept,
2222 ref reject,
2223 } => {
2224 write!(self.out, "{level}")?;
2225 write!(self.out, "if (")?;
2226 self.write_expr(module, condition, func_ctx)?;
2227 writeln!(self.out, ") {{")?;
2228
2229 let l2 = level.next();
2230 for sta in accept {
2231 self.write_stmt(module, sta, func_ctx, l2)?;
2233 }
2234
2235 if !reject.is_empty() {
2238 writeln!(self.out, "{level}}} else {{")?;
2239
2240 for sta in reject {
2241 self.write_stmt(module, sta, func_ctx, l2)?;
2243 }
2244 }
2245
2246 writeln!(self.out, "{level}}}")?
2247 }
2248 Statement::Kill => writeln!(self.out, "{level}discard;")?,
2250 Statement::Return { value: None } => {
2251 writeln!(self.out, "{level}return;")?;
2252 }
2253 Statement::Return { value: Some(expr) } => {
2254 let base_ty_res = &func_ctx.info[expr].ty;
2255 let mut resolved = base_ty_res.inner_with(&module.types);
2256 if let TypeInner::Pointer { base, space: _ } = *resolved {
2257 resolved = &module.types[base].inner;
2258 }
2259
2260 if let TypeInner::Struct { .. } = *resolved {
2261 let ty = base_ty_res.handle().unwrap();
2263 let struct_name = &self.names[&NameKey::Type(ty)];
2264 let variable_name = self.namer.call(&struct_name.to_lowercase());
2265 write!(self.out, "{level}const {struct_name} {variable_name} = ",)?;
2266 self.write_expr(module, expr, func_ctx)?;
2267 writeln!(self.out, ";")?;
2268
2269 let ep_output = match func_ctx.ty {
2271 back::FunctionType::Function(_) => None,
2272 back::FunctionType::EntryPoint(index) => self
2273 .entry_point_io
2274 .get(&(index as usize))
2275 .unwrap()
2276 .output
2277 .as_ref(),
2278 };
2279 let final_name = match ep_output {
2280 Some(ep_output) => {
2281 let final_name = self.namer.call(&variable_name);
2282 write!(
2283 self.out,
2284 "{}const {} {} = {{ ",
2285 level, ep_output.ty_name, final_name,
2286 )?;
2287 for (index, m) in ep_output.members.iter().enumerate() {
2288 if index != 0 {
2289 write!(self.out, ", ")?;
2290 }
2291 let member_name = &self.names[&NameKey::StructMember(ty, m.index)];
2292 write!(self.out, "{variable_name}.{member_name}")?;
2293 }
2294 writeln!(self.out, " }};")?;
2295 final_name
2296 }
2297 None => variable_name,
2298 };
2299 writeln!(self.out, "{level}return {final_name};")?;
2300 } else {
2301 write!(self.out, "{level}return ")?;
2302 self.write_expr(module, expr, func_ctx)?;
2303 writeln!(self.out, ";")?
2304 }
2305 }
2306 Statement::Store { pointer, value } => {
2307 let ty_inner = func_ctx.resolve_type(pointer, &module.types);
2308 if let Some(crate::AddressSpace::Storage { .. }) = ty_inner.pointer_space() {
2309 let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2310 self.write_storage_store(
2311 module,
2312 var_handle,
2313 StoreValue::Expression(value),
2314 func_ctx,
2315 level,
2316 None,
2317 )?;
2318 } else {
2319 enum MatrixAccess {
2325 Direct {
2326 base: Handle<crate::Expression>,
2327 index: u32,
2328 },
2329 Struct {
2330 columns: crate::VectorSize,
2331 base: Handle<crate::Expression>,
2332 },
2333 }
2334
2335 let get_members = |expr: Handle<crate::Expression>| {
2336 let resolved = func_ctx.resolve_type(expr, &module.types);
2337 match *resolved {
2338 TypeInner::Pointer { base, .. } => match module.types[base].inner {
2339 TypeInner::Struct { ref members, .. } => Some(members),
2340 _ => None,
2341 },
2342 _ => None,
2343 }
2344 };
2345
2346 write!(self.out, "{level}")?;
2347
2348 let matrix_access_on_lhs =
2349 find_matrix_in_access_chain(module, pointer, func_ctx).and_then(
2350 |(matrix_expr, vector, scalar)| match (
2351 func_ctx.resolve_type(matrix_expr, &module.types),
2352 &func_ctx.expressions[matrix_expr],
2353 ) {
2354 (
2355 &TypeInner::Pointer { base: ty, .. },
2356 &crate::Expression::AccessIndex { base, index },
2357 ) if matches!(
2358 module.types[ty].inner,
2359 TypeInner::Matrix {
2360 rows: crate::VectorSize::Bi,
2361 ..
2362 }
2363 ) && get_members(base)
2364 .map(|members| members[index as usize].binding.is_none())
2365 == Some(true) =>
2366 {
2367 Some((MatrixAccess::Direct { base, index }, vector, scalar))
2368 }
2369 _ => {
2370 if let Some(MatrixType {
2371 columns,
2372 rows: crate::VectorSize::Bi,
2373 width: 4,
2374 }) = get_inner_matrix_of_struct_array_member(
2375 module,
2376 matrix_expr,
2377 func_ctx,
2378 true,
2379 ) {
2380 Some((
2381 MatrixAccess::Struct {
2382 columns,
2383 base: matrix_expr,
2384 },
2385 vector,
2386 scalar,
2387 ))
2388 } else {
2389 None
2390 }
2391 }
2392 },
2393 );
2394
2395 match matrix_access_on_lhs {
2396 Some((MatrixAccess::Direct { index, base }, vector, scalar)) => {
2397 let base_ty_res = &func_ctx.info[base].ty;
2398 let resolved = base_ty_res.inner_with(&module.types);
2399 let ty = match *resolved {
2400 TypeInner::Pointer { base, .. } => base,
2401 _ => base_ty_res.handle().unwrap(),
2402 };
2403
2404 if let Some(Index::Static(vec_index)) = vector {
2405 self.write_expr(module, base, func_ctx)?;
2406 write!(
2407 self.out,
2408 ".{}_{}",
2409 &self.names[&NameKey::StructMember(ty, index)],
2410 vec_index
2411 )?;
2412
2413 if let Some(scalar_index) = scalar {
2414 write!(self.out, "[")?;
2415 self.write_index(module, scalar_index, func_ctx)?;
2416 write!(self.out, "]")?;
2417 }
2418
2419 write!(self.out, " = ")?;
2420 self.write_expr(module, value, func_ctx)?;
2421 writeln!(self.out, ";")?;
2422 } else {
2423 let access = WrappedStructMatrixAccess { ty, index };
2424 match (&vector, &scalar) {
2425 (&Some(_), &Some(_)) => {
2426 self.write_wrapped_struct_matrix_set_scalar_function_name(
2427 access,
2428 )?;
2429 }
2430 (&Some(_), &None) => {
2431 self.write_wrapped_struct_matrix_set_vec_function_name(
2432 access,
2433 )?;
2434 }
2435 (&None, _) => {
2436 self.write_wrapped_struct_matrix_set_function_name(access)?;
2437 }
2438 }
2439
2440 write!(self.out, "(")?;
2441 self.write_expr(module, base, func_ctx)?;
2442 write!(self.out, ", ")?;
2443 self.write_expr(module, value, func_ctx)?;
2444
2445 if let Some(Index::Expression(vec_index)) = vector {
2446 write!(self.out, ", ")?;
2447 self.write_expr(module, vec_index, func_ctx)?;
2448
2449 if let Some(scalar_index) = scalar {
2450 write!(self.out, ", ")?;
2451 self.write_index(module, scalar_index, func_ctx)?;
2452 }
2453 }
2454 writeln!(self.out, ");")?;
2455 }
2456 }
2457 Some((
2458 MatrixAccess::Struct { columns, base },
2459 Some(Index::Expression(vec_index)),
2460 scalar,
2461 )) => {
2462 if scalar.is_some() {
2466 write!(self.out, "__set_el_of_mat{}x2", columns as u8)?;
2467 } else {
2468 write!(self.out, "__set_col_of_mat{}x2", columns as u8)?;
2469 }
2470 write!(self.out, "(")?;
2471 self.write_expr(module, base, func_ctx)?;
2472 write!(self.out, ", ")?;
2473 self.write_expr(module, vec_index, func_ctx)?;
2474
2475 if let Some(scalar_index) = scalar {
2476 write!(self.out, ", ")?;
2477 self.write_index(module, scalar_index, func_ctx)?;
2478 }
2479
2480 write!(self.out, ", ")?;
2481 self.write_expr(module, value, func_ctx)?;
2482
2483 writeln!(self.out, ");")?;
2484 }
2485 Some((MatrixAccess::Struct { .. }, Some(Index::Static(_)), _))
2486 | Some((MatrixAccess::Struct { .. }, None, _))
2487 | None => {
2488 self.write_expr(module, pointer, func_ctx)?;
2489 write!(self.out, " = ")?;
2490
2491 if let Some(MatrixType {
2496 columns,
2497 rows: crate::VectorSize::Bi,
2498 width: 4,
2499 }) = get_inner_matrix_of_struct_array_member(
2500 module, pointer, func_ctx, false,
2501 ) {
2502 let mut resolved = func_ctx.resolve_type(pointer, &module.types);
2503 if let TypeInner::Pointer { base, .. } = *resolved {
2504 resolved = &module.types[base].inner;
2505 }
2506
2507 write!(self.out, "(__mat{}x2", columns as u8)?;
2508 if let TypeInner::Array { base, size, .. } = *resolved {
2509 self.write_array_size(module, base, size)?;
2510 }
2511 write!(self.out, ")")?;
2512 }
2513
2514 self.write_expr(module, value, func_ctx)?;
2515 writeln!(self.out, ";")?
2516 }
2517 }
2518 }
2519 }
2520 Statement::Loop {
2521 ref body,
2522 ref continuing,
2523 break_if,
2524 } => {
2525 let force_loop_bound_statements = self.gen_force_bounded_loop_statements(level);
2526 let gate_name = (!continuing.is_empty() || break_if.is_some())
2527 .then(|| self.namer.call("loop_init"));
2528
2529 if let Some((ref decl, _)) = force_loop_bound_statements {
2530 writeln!(self.out, "{decl}")?;
2531 }
2532 if let Some(ref gate_name) = gate_name {
2533 writeln!(self.out, "{level}bool {gate_name} = true;")?;
2534 }
2535
2536 self.continue_ctx.enter_loop();
2537 writeln!(self.out, "{level}while(true) {{")?;
2538 if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
2539 writeln!(self.out, "{break_and_inc}")?;
2540 }
2541 let l2 = level.next();
2542 if let Some(gate_name) = gate_name {
2543 writeln!(self.out, "{l2}if (!{gate_name}) {{")?;
2544 let l3 = l2.next();
2545 for sta in continuing.iter() {
2546 self.write_stmt(module, sta, func_ctx, l3)?;
2547 }
2548 if let Some(condition) = break_if {
2549 write!(self.out, "{l3}if (")?;
2550 self.write_expr(module, condition, func_ctx)?;
2551 writeln!(self.out, ") {{")?;
2552 writeln!(self.out, "{}break;", l3.next())?;
2553 writeln!(self.out, "{l3}}}")?;
2554 }
2555 writeln!(self.out, "{l2}}}")?;
2556 writeln!(self.out, "{l2}{gate_name} = false;")?;
2557 }
2558
2559 for sta in body.iter() {
2560 self.write_stmt(module, sta, func_ctx, l2)?;
2561 }
2562
2563 writeln!(self.out, "{level}}}")?;
2564 self.continue_ctx.exit_loop();
2565 }
2566 Statement::Break => writeln!(self.out, "{level}break;")?,
2567 Statement::Continue => {
2568 if let Some(variable) = self.continue_ctx.continue_encountered() {
2569 writeln!(self.out, "{level}{variable} = true;")?;
2570 writeln!(self.out, "{level}break;")?
2571 } else {
2572 writeln!(self.out, "{level}continue;")?
2573 }
2574 }
2575 Statement::ControlBarrier(barrier) => {
2576 self.write_control_barrier(barrier, level)?;
2577 }
2578 Statement::MemoryBarrier(barrier) => {
2579 self.write_memory_barrier(barrier, level)?;
2580 }
2581 Statement::ImageStore {
2582 image,
2583 coordinate,
2584 array_index,
2585 value,
2586 } => {
2587 write!(self.out, "{level}")?;
2588 self.write_expr(module, image, func_ctx)?;
2589
2590 write!(self.out, "[")?;
2591 if let Some(index) = array_index {
2592 write!(self.out, "int3(")?;
2594 self.write_expr(module, coordinate, func_ctx)?;
2595 write!(self.out, ", ")?;
2596 self.write_expr(module, index, func_ctx)?;
2597 write!(self.out, ")")?;
2598 } else {
2599 self.write_expr(module, coordinate, func_ctx)?;
2600 }
2601 write!(self.out, "]")?;
2602
2603 write!(self.out, " = ")?;
2604 self.write_expr(module, value, func_ctx)?;
2605 writeln!(self.out, ";")?;
2606 }
2607 Statement::Call {
2608 function,
2609 ref arguments,
2610 result,
2611 } => {
2612 write!(self.out, "{level}")?;
2613
2614 if let Some(expr) = result {
2615 write!(self.out, "const ")?;
2616 let name = Baked(expr).to_string();
2617 let expr_ty = &func_ctx.info[expr].ty;
2618 let ty_inner = match *expr_ty {
2619 proc::TypeResolution::Handle(handle) => {
2620 self.write_type(module, handle)?;
2621 &module.types[handle].inner
2622 }
2623 proc::TypeResolution::Value(ref value) => {
2624 self.write_value_type(module, value)?;
2625 value
2626 }
2627 };
2628 write!(self.out, " {name}")?;
2629 if let TypeInner::Array { base, size, .. } = *ty_inner {
2630 self.write_array_size(module, base, size)?;
2631 }
2632 write!(self.out, " = ")?;
2633 self.named_expressions.insert(expr, name);
2634 }
2635 let func_name = &self.names[&NameKey::Function(function)];
2636 write!(self.out, "{func_name}(")?;
2637 let mut any_args_written = false;
2638 let mut separator = || {
2639 if any_args_written {
2640 ", "
2641 } else {
2642 any_args_written = true;
2643 ""
2644 }
2645 };
2646 for argument in arguments {
2647 write!(self.out, "{}", separator())?;
2648 self.write_expr(module, *argument, func_ctx)?;
2649 }
2650 if let Some(&var) = self.function_task_payload_var.get(&function) {
2651 let name = &self.names[&NameKey::GlobalVariable(var)];
2652 write!(self.out, "{}{name}", separator())?;
2654 }
2655 writeln!(self.out, ");")?;
2656 }
2657 Statement::Atomic {
2658 pointer,
2659 ref fun,
2660 value,
2661 result,
2662 } => {
2663 write!(self.out, "{level}")?;
2664 let res_var_info = if let Some(res_handle) = result {
2665 let name = Baked(res_handle).to_string();
2666 match func_ctx.info[res_handle].ty {
2667 proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2668 proc::TypeResolution::Value(ref value) => {
2669 self.write_value_type(module, value)?
2670 }
2671 };
2672 write!(self.out, " {name}; ")?;
2673 self.named_expressions.insert(res_handle, name.clone());
2674 Some((res_handle, name))
2675 } else {
2676 None
2677 };
2678 let pointer_space = func_ctx
2679 .resolve_type(pointer, &module.types)
2680 .pointer_space()
2681 .unwrap();
2682 let fun_str = fun.to_hlsl_suffix();
2683 let compare_expr = match *fun {
2684 crate::AtomicFunction::Exchange { compare: Some(cmp) } => Some(cmp),
2685 _ => None,
2686 };
2687 match pointer_space {
2688 crate::AddressSpace::WorkGroup => {
2689 write!(self.out, "Interlocked{fun_str}(")?;
2690 self.write_expr(module, pointer, func_ctx)?;
2691 self.emit_hlsl_atomic_tail(
2692 module,
2693 func_ctx,
2694 fun,
2695 compare_expr,
2696 value,
2697 &res_var_info,
2698 )?;
2699 }
2700 crate::AddressSpace::Storage { .. } => {
2701 let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2702 let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
2703 let width = match func_ctx.resolve_type(value, &module.types) {
2704 &TypeInner::Scalar(Scalar { width: 8, .. }) => "64",
2705 _ => "",
2706 };
2707 write!(self.out, "{var_name}.Interlocked{fun_str}{width}(")?;
2708 let chain = mem::take(&mut self.temp_access_chain);
2709 self.write_storage_address(module, &chain, func_ctx)?;
2710 self.temp_access_chain = chain;
2711 self.emit_hlsl_atomic_tail(
2712 module,
2713 func_ctx,
2714 fun,
2715 compare_expr,
2716 value,
2717 &res_var_info,
2718 )?;
2719 }
2720 ref other => {
2721 return Err(Error::Custom(format!(
2722 "invalid address space {other:?} for atomic statement"
2723 )))
2724 }
2725 }
2726 if let Some(cmp) = compare_expr {
2727 if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
2728 write!(
2729 self.out,
2730 "{level}{res_name}.exchanged = ({res_name}.old_value == "
2731 )?;
2732 self.write_expr(module, cmp, func_ctx)?;
2733 writeln!(self.out, ");")?;
2734 }
2735 }
2736 }
2737 Statement::ImageAtomic {
2738 image,
2739 coordinate,
2740 array_index,
2741 fun,
2742 value,
2743 } => {
2744 write!(self.out, "{level}")?;
2745
2746 let fun_str = fun.to_hlsl_suffix();
2747 write!(self.out, "Interlocked{fun_str}(")?;
2748 self.write_expr(module, image, func_ctx)?;
2749 write!(self.out, "[")?;
2750 self.write_texture_coordinates(
2751 "int",
2752 coordinate,
2753 array_index,
2754 None,
2755 module,
2756 func_ctx,
2757 )?;
2758 write!(self.out, "],")?;
2759
2760 self.write_expr(module, value, func_ctx)?;
2761 writeln!(self.out, ");")?;
2762 }
2763 Statement::WorkGroupUniformLoad { pointer, result } => {
2764 self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
2765 write!(self.out, "{level}")?;
2766 let name = Baked(result).to_string();
2767 self.write_named_expr(module, pointer, name, result, func_ctx)?;
2768
2769 self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
2770 }
2771 Statement::Switch {
2772 selector,
2773 ref cases,
2774 } => {
2775 self.write_switch(module, func_ctx, level, selector, cases)?;
2776 }
2777 Statement::RayQuery { query, ref fun } => {
2778 let crate::Expression::LocalVariable(query_var) = func_ctx.expressions[query]
2790 else {
2791 unreachable!()
2792 };
2793
2794 let tracker_expr_name = format!(
2795 "{RAY_QUERY_TRACKER_VARIABLE_PREFIX}{}",
2796 self.names[&func_ctx.name_key(query_var)]
2797 );
2798
2799 match *fun {
2800 RayQueryFunction::Initialize {
2801 acceleration_structure,
2802 descriptor,
2803 } => {
2804 self.write_initialize_function(
2805 module,
2806 level,
2807 query,
2808 acceleration_structure,
2809 descriptor,
2810 &tracker_expr_name,
2811 func_ctx,
2812 )?;
2813 }
2814 RayQueryFunction::Proceed { result } => {
2815 self.write_proceed(
2816 module,
2817 level,
2818 query,
2819 result,
2820 &tracker_expr_name,
2821 func_ctx,
2822 )?;
2823 }
2824 RayQueryFunction::GenerateIntersection { hit_t } => {
2825 self.write_generate_intersection(
2826 module,
2827 level,
2828 query,
2829 hit_t,
2830 &tracker_expr_name,
2831 func_ctx,
2832 )?;
2833 }
2834 RayQueryFunction::ConfirmIntersection => {
2835 self.write_confirm_intersection(
2836 module,
2837 level,
2838 query,
2839 &tracker_expr_name,
2840 func_ctx,
2841 )?;
2842 }
2843 RayQueryFunction::Terminate => {
2844 self.write_terminate(module, level, query, &tracker_expr_name, func_ctx)?;
2845 }
2846 }
2847 }
2848 Statement::SubgroupBallot { result, predicate } => {
2849 write!(self.out, "{level}")?;
2850 let name = Baked(result).to_string();
2851 write!(self.out, "const uint4 {name} = ")?;
2852 self.named_expressions.insert(result, name);
2853
2854 write!(self.out, "WaveActiveBallot(")?;
2855 match predicate {
2856 Some(predicate) => self.write_expr(module, predicate, func_ctx)?,
2857 None => write!(self.out, "true")?,
2858 }
2859 writeln!(self.out, ");")?;
2860 }
2861 Statement::SubgroupCollectiveOperation {
2862 op,
2863 collective_op,
2864 argument,
2865 result,
2866 } => {
2867 write!(self.out, "{level}")?;
2868 write!(self.out, "const ")?;
2869 let name = Baked(result).to_string();
2870 match func_ctx.info[result].ty {
2871 proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2872 proc::TypeResolution::Value(ref value) => {
2873 self.write_value_type(module, value)?
2874 }
2875 };
2876 write!(self.out, " {name} = ")?;
2877 self.named_expressions.insert(result, name);
2878
2879 match (collective_op, op) {
2880 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
2881 write!(self.out, "WaveActiveAllTrue(")?
2882 }
2883 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
2884 write!(self.out, "WaveActiveAnyTrue(")?
2885 }
2886 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
2887 write!(self.out, "WaveActiveSum(")?
2888 }
2889 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
2890 write!(self.out, "WaveActiveProduct(")?
2891 }
2892 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
2893 write!(self.out, "WaveActiveMax(")?
2894 }
2895 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
2896 write!(self.out, "WaveActiveMin(")?
2897 }
2898 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
2899 write!(self.out, "WaveActiveBitAnd(")?
2900 }
2901 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
2902 write!(self.out, "WaveActiveBitOr(")?
2903 }
2904 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
2905 write!(self.out, "WaveActiveBitXor(")?
2906 }
2907 (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
2908 write!(self.out, "WavePrefixSum(")?
2909 }
2910 (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
2911 write!(self.out, "WavePrefixProduct(")?
2912 }
2913 (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
2914 self.write_expr(module, argument, func_ctx)?;
2915 write!(self.out, " + WavePrefixSum(")?;
2916 }
2917 (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
2918 self.write_expr(module, argument, func_ctx)?;
2919 write!(self.out, " * WavePrefixProduct(")?;
2920 }
2921 _ => unimplemented!(),
2922 }
2923 self.write_expr(module, argument, func_ctx)?;
2924 writeln!(self.out, ");")?;
2925 }
2926 Statement::SubgroupGather {
2927 mode,
2928 argument,
2929 result,
2930 } => {
2931 write!(self.out, "{level}")?;
2932 write!(self.out, "const ")?;
2933 let name = Baked(result).to_string();
2934 match func_ctx.info[result].ty {
2935 proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2936 proc::TypeResolution::Value(ref value) => {
2937 self.write_value_type(module, value)?
2938 }
2939 };
2940 write!(self.out, " {name} = ")?;
2941 self.named_expressions.insert(result, name);
2942 match mode {
2943 crate::GatherMode::BroadcastFirst => {
2944 write!(self.out, "WaveReadLaneFirst(")?;
2945 self.write_expr(module, argument, func_ctx)?;
2946 }
2947 crate::GatherMode::QuadBroadcast(index) => {
2948 write!(self.out, "QuadReadLaneAt(")?;
2949 self.write_expr(module, argument, func_ctx)?;
2950 write!(self.out, ", ")?;
2951 self.write_expr(module, index, func_ctx)?;
2952 }
2953 crate::GatherMode::QuadSwap(direction) => {
2954 match direction {
2955 crate::Direction::X => {
2956 write!(self.out, "QuadReadAcrossX(")?;
2957 }
2958 crate::Direction::Y => {
2959 write!(self.out, "QuadReadAcrossY(")?;
2960 }
2961 crate::Direction::Diagonal => {
2962 write!(self.out, "QuadReadAcrossDiagonal(")?;
2963 }
2964 }
2965 self.write_expr(module, argument, func_ctx)?;
2966 }
2967 _ => {
2968 write!(self.out, "WaveReadLaneAt(")?;
2969 self.write_expr(module, argument, func_ctx)?;
2970 write!(self.out, ", ")?;
2971 match mode {
2972 crate::GatherMode::BroadcastFirst => unreachable!(),
2973 crate::GatherMode::Broadcast(index)
2974 | crate::GatherMode::Shuffle(index) => {
2975 self.write_expr(module, index, func_ctx)?;
2976 }
2977 crate::GatherMode::ShuffleDown(index) => {
2978 write!(self.out, "WaveGetLaneIndex() + ")?;
2979 self.write_expr(module, index, func_ctx)?;
2980 }
2981 crate::GatherMode::ShuffleUp(index) => {
2982 write!(self.out, "WaveGetLaneIndex() - ")?;
2983 self.write_expr(module, index, func_ctx)?;
2984 }
2985 crate::GatherMode::ShuffleXor(index) => {
2986 write!(self.out, "WaveGetLaneIndex() ^ ")?;
2987 self.write_expr(module, index, func_ctx)?;
2988 }
2989 crate::GatherMode::QuadBroadcast(_) => unreachable!(),
2990 crate::GatherMode::QuadSwap(_) => unreachable!(),
2991 }
2992 }
2993 }
2994 writeln!(self.out, ");")?;
2995 }
2996 Statement::CooperativeStore { .. } => unimplemented!(),
2997 Statement::RayPipelineFunction(_) => unreachable!(),
2998 }
2999
3000 Ok(())
3001 }
3002
3003 fn write_const_expression(
3004 &mut self,
3005 module: &Module,
3006 expr: Handle<crate::Expression>,
3007 arena: &crate::Arena<crate::Expression>,
3008 ) -> BackendResult {
3009 self.write_possibly_const_expression(module, expr, arena, |writer, expr| {
3010 writer.write_const_expression(module, expr, arena)
3011 })
3012 }
3013
3014 pub(super) fn write_literal(&mut self, literal: crate::Literal) -> BackendResult {
3015 match literal {
3016 crate::Literal::F64(value) => write!(self.out, "{value:?}L")?,
3017 crate::Literal::F32(value) => write!(self.out, "{value:?}")?,
3018 crate::Literal::F16(value) => write!(self.out, "{value:?}h")?,
3019 crate::Literal::U32(value) => write!(self.out, "{value}u")?,
3020 crate::Literal::I32(value) if value == i32::MIN => {
3026 write!(self.out, "int({} - 1)", value + 1)?
3027 }
3028 crate::Literal::I32(value) => write!(self.out, "int({value})")?,
3032 crate::Literal::U64(value) => write!(self.out, "{value}uL")?,
3033 crate::Literal::I64(value) if value == i64::MIN => {
3035 write!(self.out, "({}L - 1L)", value + 1)?;
3036 }
3037 crate::Literal::I64(value) => write!(self.out, "{value}L")?,
3038 crate::Literal::Bool(value) => write!(self.out, "{value}")?,
3039 crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
3040 return Err(Error::Custom(
3041 "Abstract types should not appear in IR presented to backends".into(),
3042 ));
3043 }
3044 }
3045 Ok(())
3046 }
3047
3048 fn write_possibly_const_expression<E>(
3049 &mut self,
3050 module: &Module,
3051 expr: Handle<crate::Expression>,
3052 expressions: &crate::Arena<crate::Expression>,
3053 write_expression: E,
3054 ) -> BackendResult
3055 where
3056 E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult,
3057 {
3058 use crate::Expression;
3059
3060 match expressions[expr] {
3061 Expression::Literal(literal) => {
3062 self.write_literal(literal)?;
3063 }
3064 Expression::Constant(handle) => {
3065 let constant = &module.constants[handle];
3066 if constant.name.is_some() {
3067 write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
3068 } else {
3069 self.write_const_expression(module, constant.init, &module.global_expressions)?;
3070 }
3071 }
3072 Expression::ZeroValue(ty) => {
3073 self.write_wrapped_zero_value_function_name(module, WrappedZeroValue { ty })?;
3074 write!(self.out, "()")?;
3075 }
3076 Expression::Compose { ty, ref components } => {
3077 match module.types[ty].inner {
3078 TypeInner::Struct { .. } | TypeInner::Array { .. } => {
3079 self.write_wrapped_constructor_function_name(
3080 module,
3081 WrappedConstructor { ty },
3082 )?;
3083 }
3084 _ => {
3085 self.write_type(module, ty)?;
3086 }
3087 };
3088 write!(self.out, "(")?;
3089 for (index, component) in components.iter().enumerate() {
3090 if index != 0 {
3091 write!(self.out, ", ")?;
3092 }
3093 write_expression(self, *component)?;
3094 }
3095 write!(self.out, ")")?;
3096 }
3097 Expression::Splat { size, value } => {
3098 let number_of_components = match size {
3102 crate::VectorSize::Bi => "xx",
3103 crate::VectorSize::Tri => "xxx",
3104 crate::VectorSize::Quad => "xxxx",
3105 };
3106 write!(self.out, "(")?;
3107 write_expression(self, value)?;
3108 write!(self.out, ").{number_of_components}")?
3109 }
3110 _ => {
3111 return Err(Error::Override);
3112 }
3113 }
3114
3115 Ok(())
3116 }
3117
3118 pub(super) fn write_expr(
3123 &mut self,
3124 module: &Module,
3125 expr: Handle<crate::Expression>,
3126 func_ctx: &back::FunctionCtx<'_>,
3127 ) -> BackendResult {
3128 use crate::Expression;
3129
3130 let ff_input = if self.options.special_constants_binding.is_some() {
3132 func_ctx.is_fixed_function_input(expr, module)
3133 } else {
3134 None
3135 };
3136 let closing_bracket = match ff_input {
3137 Some(crate::BuiltIn::VertexIndex) => {
3138 write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX} + ")?;
3139 ")"
3140 }
3141 Some(crate::BuiltIn::InstanceIndex) => {
3142 write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE} + ",)?;
3143 ")"
3144 }
3145 Some(crate::BuiltIn::NumWorkGroups) => {
3146 write!(
3150 self.out,
3151 "uint3({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX}, {SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE}, {SPECIAL_CBUF_VAR}.{SPECIAL_OTHER})",
3152 )?;
3153 return Ok(());
3154 }
3155 _ => "",
3156 };
3157
3158 if let Some(name) = self.named_expressions.get(&expr) {
3159 write!(self.out, "{name}{closing_bracket}")?;
3160 return Ok(());
3161 }
3162
3163 let expression = &func_ctx.expressions[expr];
3164
3165 match *expression {
3166 Expression::Literal(_)
3167 | Expression::Constant(_)
3168 | Expression::ZeroValue(_)
3169 | Expression::Compose { .. }
3170 | Expression::Splat { .. } => {
3171 self.write_possibly_const_expression(
3172 module,
3173 expr,
3174 func_ctx.expressions,
3175 |writer, expr| writer.write_expr(module, expr, func_ctx),
3176 )?;
3177 }
3178 Expression::Override(_) => return Err(Error::Override),
3179 Expression::Binary {
3186 op:
3187 op @ crate::BinaryOperator::Add
3188 | op @ crate::BinaryOperator::Subtract
3189 | op @ crate::BinaryOperator::Multiply,
3190 left,
3191 right,
3192 } if matches!(
3193 func_ctx.resolve_type(expr, &module.types).scalar(),
3194 Some(Scalar::I32)
3195 ) =>
3196 {
3197 write!(self.out, "asint(asuint(",)?;
3198 self.write_expr(module, left, func_ctx)?;
3199 write!(self.out, ") {} asuint(", back::binary_operation_str(op))?;
3200 self.write_expr(module, right, func_ctx)?;
3201 write!(self.out, "))")?;
3202 }
3203 Expression::Binary {
3206 op: crate::BinaryOperator::Multiply,
3207 left,
3208 right,
3209 } if func_ctx.resolve_type(left, &module.types).is_matrix()
3210 || func_ctx.resolve_type(right, &module.types).is_matrix() =>
3211 {
3212 write!(self.out, "mul(")?;
3214 self.write_expr(module, right, func_ctx)?;
3215 write!(self.out, ", ")?;
3216 self.write_expr(module, left, func_ctx)?;
3217 write!(self.out, ")")?;
3218 }
3219
3220 Expression::Binary {
3232 op: crate::BinaryOperator::Divide,
3233 left,
3234 right,
3235 } if matches!(
3236 func_ctx.resolve_type(expr, &module.types).scalar_kind(),
3237 Some(ScalarKind::Sint | ScalarKind::Uint)
3238 ) =>
3239 {
3240 write!(self.out, "{DIV_FUNCTION}(")?;
3241 self.write_expr(module, left, func_ctx)?;
3242 write!(self.out, ", ")?;
3243 self.write_expr(module, right, func_ctx)?;
3244 write!(self.out, ")")?;
3245 }
3246
3247 Expression::Binary {
3248 op: crate::BinaryOperator::Modulo,
3249 left,
3250 right,
3251 } if matches!(
3252 func_ctx.resolve_type(expr, &module.types).scalar_kind(),
3253 Some(ScalarKind::Sint | ScalarKind::Uint | ScalarKind::Float)
3254 ) =>
3255 {
3256 write!(self.out, "{MOD_FUNCTION}(")?;
3257 self.write_expr(module, left, func_ctx)?;
3258 write!(self.out, ", ")?;
3259 self.write_expr(module, right, func_ctx)?;
3260 write!(self.out, ")")?;
3261 }
3262
3263 Expression::Binary { op, left, right } => {
3264 write!(self.out, "(")?;
3265 self.write_expr(module, left, func_ctx)?;
3266 write!(self.out, " {} ", back::binary_operation_str(op))?;
3267 self.write_expr(module, right, func_ctx)?;
3268 write!(self.out, ")")?;
3269 }
3270 Expression::Access { base, index } => {
3271 if let Some(crate::AddressSpace::Storage { .. }) =
3272 func_ctx.resolve_type(expr, &module.types).pointer_space()
3273 {
3274 } else {
3276 if let Some(MatrixType {
3283 columns,
3284 rows: crate::VectorSize::Bi,
3285 width: 4,
3286 }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
3287 .or_else(|| get_global_uniform_matrix(module, base, func_ctx))
3288 {
3289 write!(self.out, "__get_col_of_mat{}x2(", columns as u8)?;
3290 self.write_expr(module, base, func_ctx)?;
3291 write!(self.out, ", ")?;
3292 self.write_expr(module, index, func_ctx)?;
3293 write!(self.out, ")")?;
3294 return Ok(());
3295 }
3296
3297 let resolved = func_ctx.resolve_type(base, &module.types);
3298
3299 let (indexing_binding_array, non_uniform_qualifier) = match *resolved {
3300 TypeInner::BindingArray { .. } => {
3301 let uniformity = &func_ctx.info[index].uniformity;
3302
3303 (true, uniformity.non_uniform_result.is_some())
3304 }
3305 _ => (false, false),
3306 };
3307
3308 self.write_expr(module, base, func_ctx)?;
3309
3310 let array_sampler_info = self.sampler_binding_array_info_from_expression(
3311 module, func_ctx, base, resolved,
3312 );
3313
3314 if let Some(ref info) = array_sampler_info {
3315 write!(self.out, "{}[", info.sampler_heap_name)?;
3316 } else {
3317 write!(self.out, "[")?;
3318 }
3319
3320 let needs_bound_check = self.options.restrict_indexing
3321 && !indexing_binding_array
3322 && match resolved.pointer_space() {
3323 Some(
3324 crate::AddressSpace::Function
3325 | crate::AddressSpace::Private
3326 | crate::AddressSpace::WorkGroup
3327 | crate::AddressSpace::Immediate
3328 | crate::AddressSpace::TaskPayload
3329 | crate::AddressSpace::RayPayload
3330 | crate::AddressSpace::IncomingRayPayload,
3331 )
3332 | None => true,
3333 Some(crate::AddressSpace::Uniform) => {
3334 let var_handle = self.fill_access_chain(module, base, func_ctx)?;
3336 let bind_target = self
3337 .options
3338 .resolve_resource_binding(
3339 module.global_variables[var_handle]
3340 .binding
3341 .as_ref()
3342 .unwrap(),
3343 )
3344 .unwrap();
3345 bind_target.restrict_indexing
3346 }
3347 Some(
3348 crate::AddressSpace::Handle | crate::AddressSpace::Storage { .. },
3349 ) => unreachable!(),
3350 };
3351 let restriction_needed = if needs_bound_check {
3353 index::access_needs_check(
3354 base,
3355 index::GuardedIndex::Expression(index),
3356 module,
3357 func_ctx.expressions,
3358 func_ctx.info,
3359 )
3360 } else {
3361 None
3362 };
3363 if let Some(limit) = restriction_needed {
3364 write!(self.out, "min(uint(")?;
3365 self.write_expr(module, index, func_ctx)?;
3366 write!(self.out, "), ")?;
3367 match limit {
3368 index::IndexableLength::Known(limit) => {
3369 write!(self.out, "{}u", limit - 1)?;
3370 }
3371 index::IndexableLength::Dynamic => unreachable!(),
3372 }
3373 write!(self.out, ")")?;
3374 } else {
3375 if non_uniform_qualifier {
3376 write!(self.out, "NonUniformResourceIndex(")?;
3377 }
3378 if let Some(ref info) = array_sampler_info {
3379 write!(
3380 self.out,
3381 "{}[{} + ",
3382 info.sampler_index_buffer_name, info.binding_array_base_index_name,
3383 )?;
3384 }
3385 self.write_expr(module, index, func_ctx)?;
3386 if array_sampler_info.is_some() {
3387 write!(self.out, "]")?;
3388 }
3389 if non_uniform_qualifier {
3390 write!(self.out, ")")?;
3391 }
3392 }
3393
3394 write!(self.out, "]")?;
3395 }
3396 }
3397 Expression::AccessIndex { base, index } => {
3398 if let Some(crate::AddressSpace::Storage { .. }) =
3399 func_ctx.resolve_type(expr, &module.types).pointer_space()
3400 {
3401 } else {
3403 if let Some(MatrixType {
3407 rows: crate::VectorSize::Bi,
3408 width: 4,
3409 ..
3410 }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
3411 .or_else(|| get_global_uniform_matrix(module, base, func_ctx))
3412 {
3413 self.write_expr(module, base, func_ctx)?;
3414 write!(self.out, "._{index}")?;
3415 return Ok(());
3416 }
3417
3418 let base_ty_res = &func_ctx.info[base].ty;
3419 let mut resolved = base_ty_res.inner_with(&module.types);
3420 let base_ty_handle = match *resolved {
3421 TypeInner::Pointer { base, .. } => {
3422 resolved = &module.types[base].inner;
3423 Some(base)
3424 }
3425 _ => base_ty_res.handle(),
3426 };
3427
3428 if let TypeInner::Struct { ref members, .. } = *resolved {
3434 let member = &members[index as usize];
3435
3436 match module.types[member.ty].inner {
3437 TypeInner::Matrix {
3438 rows: crate::VectorSize::Bi,
3439 ..
3440 } if member.binding.is_none() => {
3441 let ty = base_ty_handle.unwrap();
3442 self.write_wrapped_struct_matrix_get_function_name(
3443 WrappedStructMatrixAccess { ty, index },
3444 )?;
3445 write!(self.out, "(")?;
3446 self.write_expr(module, base, func_ctx)?;
3447 write!(self.out, ")")?;
3448 return Ok(());
3449 }
3450 _ => {}
3451 }
3452 }
3453
3454 let array_sampler_info = self.sampler_binding_array_info_from_expression(
3455 module, func_ctx, base, resolved,
3456 );
3457
3458 if let Some(ref info) = array_sampler_info {
3459 write!(
3460 self.out,
3461 "{}[{}",
3462 info.sampler_heap_name, info.sampler_index_buffer_name
3463 )?;
3464 }
3465
3466 self.write_expr(module, base, func_ctx)?;
3467
3468 match *resolved {
3469 TypeInner::Vector { .. } | TypeInner::ValuePointer { .. } => {
3475 write!(self.out, ".{}", back::COMPONENTS[index as usize])?
3477 }
3478 TypeInner::Matrix { .. }
3479 | TypeInner::Array { .. }
3480 | TypeInner::BindingArray { .. } => {
3481 if let Some(ref info) = array_sampler_info {
3482 write!(
3483 self.out,
3484 "[{} + {index}]",
3485 info.binding_array_base_index_name
3486 )?;
3487 } else {
3488 write!(self.out, "[{index}]")?;
3489 }
3490 }
3491 TypeInner::Struct { .. } => {
3492 let ty = base_ty_handle.unwrap();
3495
3496 write!(
3497 self.out,
3498 ".{}",
3499 &self.names[&NameKey::StructMember(ty, index)]
3500 )?
3501 }
3502 ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))),
3503 }
3504
3505 if array_sampler_info.is_some() {
3506 write!(self.out, "]")?;
3507 }
3508 }
3509 }
3510 Expression::FunctionArgument(pos) => {
3511 let ty = func_ctx.resolve_type(expr, &module.types);
3512
3513 if let TypeInner::Image {
3519 class: crate::ImageClass::External,
3520 ..
3521 } = *ty
3522 {
3523 let plane_names = [0, 1, 2].map(|i| {
3524 &self.names[&func_ctx
3525 .external_texture_argument_key(pos, ExternalTextureNameKey::Plane(i))]
3526 });
3527 let params_name = &self.names[&func_ctx
3528 .external_texture_argument_key(pos, ExternalTextureNameKey::Params)];
3529 write!(
3530 self.out,
3531 "{}, {}, {}, {}",
3532 plane_names[0], plane_names[1], plane_names[2], params_name
3533 )?;
3534 } else {
3535 let key = func_ctx.argument_key(pos);
3536 let name = &self.names[&key];
3537 write!(self.out, "{name}")?;
3538 }
3539 }
3540 Expression::ImageSample {
3541 coordinate,
3542 image,
3543 sampler,
3544 clamp_to_edge: true,
3545 gather: None,
3546 array_index: None,
3547 offset: None,
3548 level: crate::SampleLevel::Zero,
3549 depth_ref: None,
3550 } => {
3551 write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?;
3552 self.write_expr(module, image, func_ctx)?;
3553 write!(self.out, ", ")?;
3554 self.write_expr(module, sampler, func_ctx)?;
3555 write!(self.out, ", ")?;
3556 self.write_expr(module, coordinate, func_ctx)?;
3557 write!(self.out, ")")?;
3558 }
3559 Expression::ImageSample {
3560 image,
3561 sampler,
3562 gather,
3563 coordinate,
3564 array_index,
3565 offset,
3566 level,
3567 depth_ref,
3568 clamp_to_edge,
3569 } => {
3570 if clamp_to_edge {
3571 return Err(Error::Custom(
3572 "ImageSample::clamp_to_edge should have been validated out".to_string(),
3573 ));
3574 }
3575
3576 use crate::SampleLevel as Sl;
3577 const COMPONENTS: [&str; 4] = ["", "Green", "Blue", "Alpha"];
3578
3579 let (base_str, component_str) = match gather {
3580 Some(component) => ("Gather", COMPONENTS[component as usize]),
3581 None => ("Sample", ""),
3582 };
3583 let cmp_str = match depth_ref {
3584 Some(_) => "Cmp",
3585 None => "",
3586 };
3587 let level_str = match level {
3588 Sl::Zero if gather.is_none() => "LevelZero",
3589 Sl::Auto | Sl::Zero => "",
3590 Sl::Exact(_) => "Level",
3591 Sl::Bias(_) => "Bias",
3592 Sl::Gradient { .. } => "Grad",
3593 };
3594
3595 self.write_expr(module, image, func_ctx)?;
3596 write!(self.out, ".{base_str}{cmp_str}{component_str}{level_str}(")?;
3597 self.write_expr(module, sampler, func_ctx)?;
3598 write!(self.out, ", ")?;
3599 self.write_texture_coordinates(
3600 "float",
3601 coordinate,
3602 array_index,
3603 None,
3604 module,
3605 func_ctx,
3606 )?;
3607
3608 if let Some(depth_ref) = depth_ref {
3609 write!(self.out, ", ")?;
3610 self.write_expr(module, depth_ref, func_ctx)?;
3611 }
3612
3613 match level {
3614 Sl::Auto | Sl::Zero => {}
3615 Sl::Exact(expr) => {
3616 write!(self.out, ", ")?;
3617 self.write_expr(module, expr, func_ctx)?;
3618 }
3619 Sl::Bias(expr) => {
3620 write!(self.out, ", ")?;
3621 self.write_expr(module, expr, func_ctx)?;
3622 }
3623 Sl::Gradient { x, y } => {
3624 write!(self.out, ", ")?;
3625 self.write_expr(module, x, func_ctx)?;
3626 write!(self.out, ", ")?;
3627 self.write_expr(module, y, func_ctx)?;
3628 }
3629 }
3630
3631 if let Some(offset) = offset {
3632 write!(self.out, ", ")?;
3633 write!(self.out, "int2(")?; self.write_const_expression(module, offset, func_ctx.expressions)?;
3635 write!(self.out, ")")?;
3636 }
3637
3638 write!(self.out, ")")?;
3639 }
3640 Expression::ImageQuery { image, query } => {
3641 if let TypeInner::Image {
3643 dim,
3644 arrayed,
3645 class,
3646 } = *func_ctx.resolve_type(image, &module.types)
3647 {
3648 let wrapped_image_query = WrappedImageQuery {
3649 dim,
3650 arrayed,
3651 class,
3652 query: query.into(),
3653 };
3654
3655 self.write_wrapped_image_query_function_name(wrapped_image_query)?;
3656 write!(self.out, "(")?;
3657 self.write_expr(module, image, func_ctx)?;
3659 if let crate::ImageQuery::Size { level: Some(level) } = query {
3660 write!(self.out, ", ")?;
3661 self.write_expr(module, level, func_ctx)?;
3662 }
3663 write!(self.out, ")")?;
3664 }
3665 }
3666 Expression::ImageLoad {
3667 image,
3668 coordinate,
3669 array_index,
3670 sample,
3671 level,
3672 } => self.write_image_load(
3673 &module,
3674 expr,
3675 func_ctx,
3676 image,
3677 coordinate,
3678 array_index,
3679 sample,
3680 level,
3681 )?,
3682 Expression::GlobalVariable(handle) => {
3683 let global_variable = &module.global_variables[handle];
3684 let ty = &module.types[global_variable.ty].inner;
3685
3686 let is_binding_array_of_samplers = match *ty {
3691 TypeInner::BindingArray { base, .. } => {
3692 let base_ty = &module.types[base].inner;
3693 matches!(*base_ty, TypeInner::Sampler { .. })
3694 }
3695 _ => false,
3696 };
3697
3698 let is_storage_space =
3699 matches!(global_variable.space, crate::AddressSpace::Storage { .. });
3700
3701 if let TypeInner::Image {
3709 class: crate::ImageClass::External,
3710 ..
3711 } = *ty
3712 {
3713 let plane_names = [0, 1, 2].map(|i| {
3714 &self.names[&NameKey::ExternalTextureGlobalVariable(
3715 handle,
3716 ExternalTextureNameKey::Plane(i),
3717 )]
3718 });
3719 let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
3720 handle,
3721 ExternalTextureNameKey::Params,
3722 )];
3723 write!(
3724 self.out,
3725 "{}, {}, {}, {}",
3726 plane_names[0], plane_names[1], plane_names[2], params_name
3727 )?;
3728 } else if !is_binding_array_of_samplers && !is_storage_space {
3729 let name = &self.names[&NameKey::GlobalVariable(handle)];
3730 write!(self.out, "{name}")?;
3731 }
3732 }
3733 Expression::LocalVariable(handle) => {
3734 write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])?
3735 }
3736 Expression::Load { pointer } => {
3737 match func_ctx
3738 .resolve_type(pointer, &module.types)
3739 .pointer_space()
3740 {
3741 Some(crate::AddressSpace::Storage { .. }) => {
3742 let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
3743 let result_ty = func_ctx.info[expr].ty.clone();
3744 self.write_storage_load(module, var_handle, result_ty, func_ctx)?;
3745 }
3746 _ => {
3747 let mut close_paren = false;
3748
3749 if let Some(MatrixType {
3754 rows: crate::VectorSize::Bi,
3755 width: 4,
3756 ..
3757 }) = get_inner_matrix_of_struct_array_member(
3758 module, pointer, func_ctx, false,
3759 )
3760 .or_else(|| get_inner_matrix_of_global_uniform(module, pointer, func_ctx))
3761 {
3762 let mut resolved = func_ctx.resolve_type(pointer, &module.types);
3763 let ptr_tr = resolved.pointer_base_type();
3764 if let Some(ptr_ty) =
3765 ptr_tr.as_ref().map(|tr| tr.inner_with(&module.types))
3766 {
3767 resolved = ptr_ty;
3768 }
3769
3770 write!(self.out, "((")?;
3771 if let TypeInner::Array { base, size, .. } = *resolved {
3772 self.write_type(module, base)?;
3773 self.write_array_size(module, base, size)?;
3774 } else {
3775 self.write_value_type(module, resolved)?;
3776 }
3777 write!(self.out, ")")?;
3778 close_paren = true;
3779 }
3780
3781 self.write_expr(module, pointer, func_ctx)?;
3782
3783 if close_paren {
3784 write!(self.out, ")")?;
3785 }
3786 }
3787 }
3788 }
3789 Expression::Unary { op, expr } => {
3790 let op_str = match op {
3792 crate::UnaryOperator::Negate => {
3793 match func_ctx.resolve_type(expr, &module.types).scalar() {
3794 Some(Scalar::I32) => NEG_FUNCTION,
3795 _ => "-",
3796 }
3797 }
3798 crate::UnaryOperator::LogicalNot => "!",
3799 crate::UnaryOperator::BitwiseNot => "~",
3800 };
3801 write!(self.out, "{op_str}(")?;
3802 self.write_expr(module, expr, func_ctx)?;
3803 write!(self.out, ")")?;
3804 }
3805 Expression::As {
3806 expr,
3807 kind,
3808 convert,
3809 } => {
3810 let inner = func_ctx.resolve_type(expr, &module.types);
3811 if inner.scalar_kind() == Some(ScalarKind::Float)
3812 && (kind == ScalarKind::Sint || kind == ScalarKind::Uint)
3813 && convert.is_some()
3814 {
3815 let fun_name = match (kind, convert) {
3819 (ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
3820 (ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
3821 (ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
3822 (ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
3823 _ => unreachable!(),
3824 };
3825 write!(self.out, "{fun_name}(")?;
3826 self.write_expr(module, expr, func_ctx)?;
3827 write!(self.out, ")")?;
3828 } else {
3829 let close_paren = match convert {
3830 Some(dst_width) => {
3831 let scalar = Scalar {
3832 kind,
3833 width: dst_width,
3834 };
3835 match *inner {
3836 TypeInner::Vector { size, .. } => {
3837 write!(
3838 self.out,
3839 "{}{}(",
3840 scalar.to_hlsl_str()?,
3841 common::vector_size_str(size)
3842 )?;
3843 }
3844 TypeInner::Scalar(_) => {
3845 write!(self.out, "{}(", scalar.to_hlsl_str()?,)?;
3846 }
3847 TypeInner::Matrix { columns, rows, .. } => {
3848 write!(
3849 self.out,
3850 "{}{}x{}(",
3851 scalar.to_hlsl_str()?,
3852 common::vector_size_str(columns),
3853 common::vector_size_str(rows)
3854 )?;
3855 }
3856 _ => {
3857 return Err(Error::Unimplemented(format!(
3858 "write_expr expression::as {inner:?}"
3859 )));
3860 }
3861 };
3862 true
3863 }
3864 None => {
3865 if inner.scalar_width() == Some(8) {
3866 false
3867 } else {
3868 write!(self.out, "{}(", kind.to_hlsl_cast(),)?;
3869 true
3870 }
3871 }
3872 };
3873 self.write_expr(module, expr, func_ctx)?;
3874 if close_paren {
3875 write!(self.out, ")")?;
3876 }
3877 }
3878 }
3879 Expression::Math {
3880 fun,
3881 arg,
3882 arg1,
3883 arg2,
3884 arg3,
3885 } => {
3886 use crate::MathFunction as Mf;
3887
3888 enum Function {
3889 Asincosh { is_sin: bool },
3890 Atanh,
3891 Pack2x16float,
3892 Pack2x16snorm,
3893 Pack2x16unorm,
3894 Pack4x8snorm,
3895 Pack4x8unorm,
3896 Pack4xI8,
3897 Pack4xU8,
3898 Pack4xI8Clamp,
3899 Pack4xU8Clamp,
3900 Unpack2x16float,
3901 Unpack2x16snorm,
3902 Unpack2x16unorm,
3903 Unpack4x8snorm,
3904 Unpack4x8unorm,
3905 Unpack4xI8,
3906 Unpack4xU8,
3907 Dot4I8Packed,
3908 Dot4U8Packed,
3909 QuantizeToF16,
3910 Regular(&'static str),
3911 MissingIntOverload(&'static str),
3912 MissingIntReturnType(&'static str),
3913 CountTrailingZeros,
3914 CountLeadingZeros,
3915 }
3916
3917 let fun = match fun {
3918 Mf::Abs => match func_ctx.resolve_type(arg, &module.types).scalar() {
3920 Some(Scalar::I32) => Function::Regular(ABS_FUNCTION),
3921 _ => Function::Regular("abs"),
3922 },
3923 Mf::Min => Function::Regular("min"),
3924 Mf::Max => Function::Regular("max"),
3925 Mf::Clamp => Function::Regular("clamp"),
3926 Mf::Saturate => Function::Regular("saturate"),
3927 Mf::Cos => Function::Regular("cos"),
3929 Mf::Cosh => Function::Regular("cosh"),
3930 Mf::Sin => Function::Regular("sin"),
3931 Mf::Sinh => Function::Regular("sinh"),
3932 Mf::Tan => Function::Regular("tan"),
3933 Mf::Tanh => Function::Regular("tanh"),
3934 Mf::Acos => Function::Regular("acos"),
3935 Mf::Asin => Function::Regular("asin"),
3936 Mf::Atan => Function::Regular("atan"),
3937 Mf::Atan2 => Function::Regular("atan2"),
3938 Mf::Asinh => Function::Asincosh { is_sin: true },
3939 Mf::Acosh => Function::Asincosh { is_sin: false },
3940 Mf::Atanh => Function::Atanh,
3941 Mf::Radians => Function::Regular("radians"),
3942 Mf::Degrees => Function::Regular("degrees"),
3943 Mf::Ceil => Function::Regular("ceil"),
3945 Mf::Floor => Function::Regular("floor"),
3946 Mf::Round => Function::Regular("round"),
3947 Mf::Fract => Function::Regular("frac"),
3948 Mf::Trunc => Function::Regular("trunc"),
3949 Mf::Modf => Function::Regular(MODF_FUNCTION),
3950 Mf::Frexp => Function::Regular(FREXP_FUNCTION),
3951 Mf::Ldexp => Function::Regular("ldexp"),
3952 Mf::Exp => Function::Regular("exp"),
3954 Mf::Exp2 => Function::Regular("exp2"),
3955 Mf::Log => Function::Regular("log"),
3956 Mf::Log2 => Function::Regular("log2"),
3957 Mf::Pow => Function::Regular("pow"),
3958 Mf::Dot => Function::Regular("dot"),
3960 Mf::Dot4I8Packed => Function::Dot4I8Packed,
3961 Mf::Dot4U8Packed => Function::Dot4U8Packed,
3962 Mf::Cross => Function::Regular("cross"),
3964 Mf::Distance => Function::Regular("distance"),
3965 Mf::Length => Function::Regular("length"),
3966 Mf::Normalize => Function::Regular("normalize"),
3967 Mf::FaceForward => Function::Regular("faceforward"),
3968 Mf::Reflect => Function::Regular("reflect"),
3969 Mf::Refract => Function::Regular("refract"),
3970 Mf::Sign => Function::Regular("sign"),
3972 Mf::Fma => Function::Regular("mad"),
3973 Mf::Mix => Function::Regular("lerp"),
3974 Mf::Step => Function::Regular("step"),
3975 Mf::SmoothStep => Function::Regular("smoothstep"),
3976 Mf::Sqrt => Function::Regular("sqrt"),
3977 Mf::InverseSqrt => Function::Regular("rsqrt"),
3978 Mf::Transpose => Function::Regular("transpose"),
3980 Mf::Determinant => Function::Regular("determinant"),
3981 Mf::QuantizeToF16 => Function::QuantizeToF16,
3982 Mf::CountTrailingZeros => Function::CountTrailingZeros,
3984 Mf::CountLeadingZeros => Function::CountLeadingZeros,
3985 Mf::CountOneBits => Function::MissingIntOverload("countbits"),
3986 Mf::ReverseBits => Function::MissingIntOverload("reversebits"),
3987 Mf::FirstTrailingBit => Function::MissingIntReturnType("firstbitlow"),
3988 Mf::FirstLeadingBit => Function::MissingIntReturnType("firstbithigh"),
3989 Mf::ExtractBits => Function::Regular(EXTRACT_BITS_FUNCTION),
3990 Mf::InsertBits => Function::Regular(INSERT_BITS_FUNCTION),
3991 Mf::Pack2x16float => Function::Pack2x16float,
3993 Mf::Pack2x16snorm => Function::Pack2x16snorm,
3994 Mf::Pack2x16unorm => Function::Pack2x16unorm,
3995 Mf::Pack4x8snorm => Function::Pack4x8snorm,
3996 Mf::Pack4x8unorm => Function::Pack4x8unorm,
3997 Mf::Pack4xI8 => Function::Pack4xI8,
3998 Mf::Pack4xU8 => Function::Pack4xU8,
3999 Mf::Pack4xI8Clamp => Function::Pack4xI8Clamp,
4000 Mf::Pack4xU8Clamp => Function::Pack4xU8Clamp,
4001 Mf::Unpack2x16float => Function::Unpack2x16float,
4003 Mf::Unpack2x16snorm => Function::Unpack2x16snorm,
4004 Mf::Unpack2x16unorm => Function::Unpack2x16unorm,
4005 Mf::Unpack4x8snorm => Function::Unpack4x8snorm,
4006 Mf::Unpack4x8unorm => Function::Unpack4x8unorm,
4007 Mf::Unpack4xI8 => Function::Unpack4xI8,
4008 Mf::Unpack4xU8 => Function::Unpack4xU8,
4009 _ => return Err(Error::Unimplemented(format!("write_expr_math {fun:?}"))),
4010 };
4011
4012 match fun {
4013 Function::Asincosh { is_sin } => {
4014 write!(self.out, "log(")?;
4015 self.write_expr(module, arg, func_ctx)?;
4016 write!(self.out, " + sqrt(")?;
4017 self.write_expr(module, arg, func_ctx)?;
4018 write!(self.out, " * ")?;
4019 self.write_expr(module, arg, func_ctx)?;
4020 match is_sin {
4021 true => write!(self.out, " + 1.0))")?,
4022 false => write!(self.out, " - 1.0))")?,
4023 }
4024 }
4025 Function::Atanh => {
4026 write!(self.out, "0.5 * log((1.0 + ")?;
4027 self.write_expr(module, arg, func_ctx)?;
4028 write!(self.out, ") / (1.0 - ")?;
4029 self.write_expr(module, arg, func_ctx)?;
4030 write!(self.out, "))")?;
4031 }
4032 Function::Pack2x16float => {
4033 write!(self.out, "(f32tof16(")?;
4034 self.write_expr(module, arg, func_ctx)?;
4035 write!(self.out, "[0]) | f32tof16(")?;
4036 self.write_expr(module, arg, func_ctx)?;
4037 write!(self.out, "[1]) << 16)")?;
4038 }
4039 Function::Pack2x16snorm => {
4040 let scale = 32767;
4041
4042 write!(self.out, "uint((int(round(clamp(")?;
4043 self.write_expr(module, arg, func_ctx)?;
4044 write!(
4045 self.out,
4046 "[0], -1.0, 1.0) * {scale}.0)) & 0xFFFF) | ((int(round(clamp("
4047 )?;
4048 self.write_expr(module, arg, func_ctx)?;
4049 write!(self.out, "[1], -1.0, 1.0) * {scale}.0)) & 0xFFFF) << 16))",)?;
4050 }
4051 Function::Pack2x16unorm => {
4052 let scale = 65535;
4053
4054 write!(self.out, "(uint(round(clamp(")?;
4055 self.write_expr(module, arg, func_ctx)?;
4056 write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
4057 self.write_expr(module, arg, func_ctx)?;
4058 write!(self.out, "[1], 0.0, 1.0) * {scale}.0)) << 16)")?;
4059 }
4060 Function::Pack4x8snorm => {
4061 let scale = 127;
4062
4063 write!(self.out, "uint((int(round(clamp(")?;
4064 self.write_expr(module, arg, func_ctx)?;
4065 write!(
4066 self.out,
4067 "[0], -1.0, 1.0) * {scale}.0)) & 0xFF) | ((int(round(clamp("
4068 )?;
4069 self.write_expr(module, arg, func_ctx)?;
4070 write!(
4071 self.out,
4072 "[1], -1.0, 1.0) * {scale}.0)) & 0xFF) << 8) | ((int(round(clamp("
4073 )?;
4074 self.write_expr(module, arg, func_ctx)?;
4075 write!(
4076 self.out,
4077 "[2], -1.0, 1.0) * {scale}.0)) & 0xFF) << 16) | ((int(round(clamp("
4078 )?;
4079 self.write_expr(module, arg, func_ctx)?;
4080 write!(self.out, "[3], -1.0, 1.0) * {scale}.0)) & 0xFF) << 24))",)?;
4081 }
4082 Function::Pack4x8unorm => {
4083 let scale = 255;
4084
4085 write!(self.out, "(uint(round(clamp(")?;
4086 self.write_expr(module, arg, func_ctx)?;
4087 write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
4088 self.write_expr(module, arg, func_ctx)?;
4089 write!(
4090 self.out,
4091 "[1], 0.0, 1.0) * {scale}.0)) << 8 | uint(round(clamp("
4092 )?;
4093 self.write_expr(module, arg, func_ctx)?;
4094 write!(
4095 self.out,
4096 "[2], 0.0, 1.0) * {scale}.0)) << 16 | uint(round(clamp("
4097 )?;
4098 self.write_expr(module, arg, func_ctx)?;
4099 write!(self.out, "[3], 0.0, 1.0) * {scale}.0)) << 24)")?;
4100 }
4101 fun @ (Function::Pack4xI8
4102 | Function::Pack4xU8
4103 | Function::Pack4xI8Clamp
4104 | Function::Pack4xU8Clamp) => {
4105 let was_signed =
4106 matches!(fun, Function::Pack4xI8 | Function::Pack4xI8Clamp);
4107 let clamp_bounds = match fun {
4108 Function::Pack4xI8Clamp => Some(("-128", "127")),
4109 Function::Pack4xU8Clamp => Some(("0", "255")),
4110 _ => None,
4111 };
4112 if was_signed {
4113 write!(self.out, "uint(")?;
4114 }
4115 let write_arg = |this: &mut Self| -> BackendResult {
4116 if let Some((min, max)) = clamp_bounds {
4117 write!(this.out, "clamp(")?;
4118 this.write_expr(module, arg, func_ctx)?;
4119 write!(this.out, ", {min}, {max})")?;
4120 } else {
4121 this.write_expr(module, arg, func_ctx)?;
4122 }
4123 Ok(())
4124 };
4125 write!(self.out, "(")?;
4126 write_arg(self)?;
4127 write!(self.out, "[0] & 0xFF) | ((")?;
4128 write_arg(self)?;
4129 write!(self.out, "[1] & 0xFF) << 8) | ((")?;
4130 write_arg(self)?;
4131 write!(self.out, "[2] & 0xFF) << 16) | ((")?;
4132 write_arg(self)?;
4133 write!(self.out, "[3] & 0xFF) << 24)")?;
4134 if was_signed {
4135 write!(self.out, ")")?;
4136 }
4137 }
4138
4139 Function::Unpack2x16float => {
4140 write!(self.out, "float2(f16tof32(")?;
4141 self.write_expr(module, arg, func_ctx)?;
4142 write!(self.out, "), f16tof32((")?;
4143 self.write_expr(module, arg, func_ctx)?;
4144 write!(self.out, ") >> 16))")?;
4145 }
4146 Function::Unpack2x16snorm => {
4147 let scale = 32767;
4148
4149 write!(self.out, "(float2(int2(")?;
4150 self.write_expr(module, arg, func_ctx)?;
4151 write!(self.out, " << 16, ")?;
4152 self.write_expr(module, arg, func_ctx)?;
4153 write!(self.out, ") >> 16) / {scale}.0)")?;
4154 }
4155 Function::Unpack2x16unorm => {
4156 let scale = 65535;
4157
4158 write!(self.out, "(float2(")?;
4159 self.write_expr(module, arg, func_ctx)?;
4160 write!(self.out, " & 0xFFFF, ")?;
4161 self.write_expr(module, arg, func_ctx)?;
4162 write!(self.out, " >> 16) / {scale}.0)")?;
4163 }
4164 Function::Unpack4x8snorm => {
4165 let scale = 127;
4166
4167 write!(self.out, "(float4(int4(")?;
4168 self.write_expr(module, arg, func_ctx)?;
4169 write!(self.out, " << 24, ")?;
4170 self.write_expr(module, arg, func_ctx)?;
4171 write!(self.out, " << 16, ")?;
4172 self.write_expr(module, arg, func_ctx)?;
4173 write!(self.out, " << 8, ")?;
4174 self.write_expr(module, arg, func_ctx)?;
4175 write!(self.out, ") >> 24) / {scale}.0)")?;
4176 }
4177 Function::Unpack4x8unorm => {
4178 let scale = 255;
4179
4180 write!(self.out, "(float4(")?;
4181 self.write_expr(module, arg, func_ctx)?;
4182 write!(self.out, " & 0xFF, ")?;
4183 self.write_expr(module, arg, func_ctx)?;
4184 write!(self.out, " >> 8 & 0xFF, ")?;
4185 self.write_expr(module, arg, func_ctx)?;
4186 write!(self.out, " >> 16 & 0xFF, ")?;
4187 self.write_expr(module, arg, func_ctx)?;
4188 write!(self.out, " >> 24) / {scale}.0)")?;
4189 }
4190 fun @ (Function::Unpack4xI8 | Function::Unpack4xU8) => {
4191 write!(self.out, "(")?;
4192 if matches!(fun, Function::Unpack4xU8) {
4193 write!(self.out, "u")?;
4194 }
4195 write!(self.out, "int4(")?;
4196 self.write_expr(module, arg, func_ctx)?;
4197 write!(self.out, ", ")?;
4198 self.write_expr(module, arg, func_ctx)?;
4199 write!(self.out, " >> 8, ")?;
4200 self.write_expr(module, arg, func_ctx)?;
4201 write!(self.out, " >> 16, ")?;
4202 self.write_expr(module, arg, func_ctx)?;
4203 write!(self.out, " >> 24) << 24 >> 24)")?;
4204 }
4205 fun @ (Function::Dot4I8Packed | Function::Dot4U8Packed) => {
4206 let arg1 = arg1.unwrap();
4207
4208 if self.options.shader_model >= ShaderModel::V6_4 {
4209 let function_name = match fun {
4211 Function::Dot4I8Packed => "dot4add_i8packed",
4212 Function::Dot4U8Packed => "dot4add_u8packed",
4213 _ => unreachable!(),
4214 };
4215 write!(self.out, "{function_name}(")?;
4216 self.write_expr(module, arg, func_ctx)?;
4217 write!(self.out, ", ")?;
4218 self.write_expr(module, arg1, func_ctx)?;
4219 write!(self.out, ", 0)")?;
4220 } else {
4221 write!(self.out, "dot(")?;
4223
4224 if matches!(fun, Function::Dot4U8Packed) {
4225 write!(self.out, "u")?;
4226 }
4227 write!(self.out, "int4(")?;
4228 self.write_expr(module, arg, func_ctx)?;
4229 write!(self.out, ", ")?;
4230 self.write_expr(module, arg, func_ctx)?;
4231 write!(self.out, " >> 8, ")?;
4232 self.write_expr(module, arg, func_ctx)?;
4233 write!(self.out, " >> 16, ")?;
4234 self.write_expr(module, arg, func_ctx)?;
4235 write!(self.out, " >> 24) << 24 >> 24, ")?;
4236
4237 if matches!(fun, Function::Dot4U8Packed) {
4238 write!(self.out, "u")?;
4239 }
4240 write!(self.out, "int4(")?;
4241 self.write_expr(module, arg1, func_ctx)?;
4242 write!(self.out, ", ")?;
4243 self.write_expr(module, arg1, func_ctx)?;
4244 write!(self.out, " >> 8, ")?;
4245 self.write_expr(module, arg1, func_ctx)?;
4246 write!(self.out, " >> 16, ")?;
4247 self.write_expr(module, arg1, func_ctx)?;
4248 write!(self.out, " >> 24) << 24 >> 24)")?;
4249 }
4250 }
4251 Function::QuantizeToF16 => {
4252 write!(self.out, "f16tof32(f32tof16(")?;
4253 self.write_expr(module, arg, func_ctx)?;
4254 write!(self.out, "))")?;
4255 }
4256 Function::Regular(fun_name) => {
4257 write!(self.out, "{fun_name}(")?;
4258 self.write_expr(module, arg, func_ctx)?;
4259 if let Some(arg) = arg1 {
4260 write!(self.out, ", ")?;
4261 self.write_expr(module, arg, func_ctx)?;
4262 }
4263 if let Some(arg) = arg2 {
4264 write!(self.out, ", ")?;
4265 self.write_expr(module, arg, func_ctx)?;
4266 }
4267 if let Some(arg) = arg3 {
4268 write!(self.out, ", ")?;
4269 self.write_expr(module, arg, func_ctx)?;
4270 }
4271 write!(self.out, ")")?
4272 }
4273 Function::MissingIntOverload(fun_name) => {
4276 let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
4277 if let Some(Scalar::I32) = scalar_kind {
4278 write!(self.out, "asint({fun_name}(asuint(")?;
4279 self.write_expr(module, arg, func_ctx)?;
4280 write!(self.out, ")))")?;
4281 } else {
4282 write!(self.out, "{fun_name}(")?;
4283 self.write_expr(module, arg, func_ctx)?;
4284 write!(self.out, ")")?;
4285 }
4286 }
4287 Function::MissingIntReturnType(fun_name) => {
4290 let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
4291 if let Some(Scalar::I32) = scalar_kind {
4292 write!(self.out, "asint({fun_name}(")?;
4293 self.write_expr(module, arg, func_ctx)?;
4294 write!(self.out, "))")?;
4295 } else {
4296 write!(self.out, "{fun_name}(")?;
4297 self.write_expr(module, arg, func_ctx)?;
4298 write!(self.out, ")")?;
4299 }
4300 }
4301 Function::CountTrailingZeros => {
4302 match *func_ctx.resolve_type(arg, &module.types) {
4303 TypeInner::Vector { size, scalar } => {
4304 let s = match size {
4305 crate::VectorSize::Bi => ".xx",
4306 crate::VectorSize::Tri => ".xxx",
4307 crate::VectorSize::Quad => ".xxxx",
4308 };
4309
4310 let scalar_width_bits = scalar.width * 8;
4311
4312 if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
4313 write!(
4314 self.out,
4315 "min(({scalar_width_bits}u){s}, firstbitlow("
4316 )?;
4317 self.write_expr(module, arg, func_ctx)?;
4318 write!(self.out, "))")?;
4319 } else {
4320 write!(
4322 self.out,
4323 "asint(min(({scalar_width_bits}u){s}, firstbitlow("
4324 )?;
4325 self.write_expr(module, arg, func_ctx)?;
4326 write!(self.out, ")))")?;
4327 }
4328 }
4329 TypeInner::Scalar(scalar) => {
4330 let scalar_width_bits = scalar.width * 8;
4331
4332 if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
4333 write!(self.out, "min({scalar_width_bits}u, firstbitlow(")?;
4334 self.write_expr(module, arg, func_ctx)?;
4335 write!(self.out, "))")?;
4336 } else {
4337 write!(
4339 self.out,
4340 "asint(min({scalar_width_bits}u, firstbitlow("
4341 )?;
4342 self.write_expr(module, arg, func_ctx)?;
4343 write!(self.out, ")))")?;
4344 }
4345 }
4346 _ => unreachable!(),
4347 }
4348
4349 return Ok(());
4350 }
4351 Function::CountLeadingZeros => {
4352 match *func_ctx.resolve_type(arg, &module.types) {
4353 TypeInner::Vector { size, scalar } => {
4354 let s = match size {
4355 crate::VectorSize::Bi => ".xx",
4356 crate::VectorSize::Tri => ".xxx",
4357 crate::VectorSize::Quad => ".xxxx",
4358 };
4359
4360 let constant = scalar.width * 8 - 1;
4362
4363 if scalar.kind == ScalarKind::Uint {
4364 write!(self.out, "(({constant}u){s} - firstbithigh(")?;
4365 self.write_expr(module, arg, func_ctx)?;
4366 write!(self.out, "))")?;
4367 } else {
4368 let conversion_func = match scalar.width {
4369 4 => "asint",
4370 _ => "",
4371 };
4372 write!(self.out, "(")?;
4373 self.write_expr(module, arg, func_ctx)?;
4374 write!(
4375 self.out,
4376 " < (0){s} ? (0){s} : ({constant}){s} - {conversion_func}(firstbithigh("
4377 )?;
4378 self.write_expr(module, arg, func_ctx)?;
4379 write!(self.out, ")))")?;
4380 }
4381 }
4382 TypeInner::Scalar(scalar) => {
4383 let constant = scalar.width * 8 - 1;
4385
4386 if let ScalarKind::Uint = scalar.kind {
4387 write!(self.out, "({constant}u - firstbithigh(")?;
4388 self.write_expr(module, arg, func_ctx)?;
4389 write!(self.out, "))")?;
4390 } else {
4391 let conversion_func = match scalar.width {
4392 4 => "asint",
4393 _ => "",
4394 };
4395 write!(self.out, "(")?;
4396 self.write_expr(module, arg, func_ctx)?;
4397 write!(
4398 self.out,
4399 " < 0 ? 0 : {constant} - {conversion_func}(firstbithigh("
4400 )?;
4401 self.write_expr(module, arg, func_ctx)?;
4402 write!(self.out, ")))")?;
4403 }
4404 }
4405 _ => unreachable!(),
4406 }
4407
4408 return Ok(());
4409 }
4410 }
4411 }
4412 Expression::Swizzle {
4413 size,
4414 vector,
4415 pattern,
4416 } => {
4417 self.write_expr(module, vector, func_ctx)?;
4418 write!(self.out, ".")?;
4419 for &sc in pattern[..size as usize].iter() {
4420 self.out.write_char(back::COMPONENTS[sc as usize])?;
4421 }
4422 }
4423 Expression::ArrayLength(expr) => {
4424 let var_handle = match func_ctx.expressions[expr] {
4425 Expression::AccessIndex { base, index: _ } => {
4426 match func_ctx.expressions[base] {
4427 Expression::GlobalVariable(handle) => handle,
4428 _ => unreachable!(),
4429 }
4430 }
4431 Expression::GlobalVariable(handle) => handle,
4432 _ => unreachable!(),
4433 };
4434
4435 let var = &module.global_variables[var_handle];
4436 let (offset, stride) = match module.types[var.ty].inner {
4437 TypeInner::Array { stride, .. } => (0, stride),
4438 TypeInner::Struct { ref members, .. } => {
4439 let last = members.last().unwrap();
4440 let stride = match module.types[last.ty].inner {
4441 TypeInner::Array { stride, .. } => stride,
4442 _ => unreachable!(),
4443 };
4444 (last.offset, stride)
4445 }
4446 _ => unreachable!(),
4447 };
4448
4449 let storage_access = match var.space {
4450 crate::AddressSpace::Storage { access } => access,
4451 _ => crate::StorageAccess::default(),
4452 };
4453 let wrapped_array_length = WrappedArrayLength {
4454 writable: storage_access.contains(crate::StorageAccess::STORE),
4455 };
4456
4457 write!(self.out, "((")?;
4458 self.write_wrapped_array_length_function_name(wrapped_array_length)?;
4459 let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
4460 write!(self.out, "({var_name}) - {offset}) / {stride})")?
4461 }
4462 Expression::Derivative { axis, ctrl, expr } => {
4463 use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
4464 if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
4465 let tail = match ctrl {
4466 Ctrl::Coarse => "coarse",
4467 Ctrl::Fine => "fine",
4468 Ctrl::None => unreachable!(),
4469 };
4470 write!(self.out, "abs(ddx_{tail}(")?;
4471 self.write_expr(module, expr, func_ctx)?;
4472 write!(self.out, ")) + abs(ddy_{tail}(")?;
4473 self.write_expr(module, expr, func_ctx)?;
4474 write!(self.out, "))")?
4475 } else {
4476 let fun_str = match (axis, ctrl) {
4477 (Axis::X, Ctrl::Coarse) => "ddx_coarse",
4478 (Axis::X, Ctrl::Fine) => "ddx_fine",
4479 (Axis::X, Ctrl::None) => "ddx",
4480 (Axis::Y, Ctrl::Coarse) => "ddy_coarse",
4481 (Axis::Y, Ctrl::Fine) => "ddy_fine",
4482 (Axis::Y, Ctrl::None) => "ddy",
4483 (Axis::Width, Ctrl::Coarse | Ctrl::Fine) => unreachable!(),
4484 (Axis::Width, Ctrl::None) => "fwidth",
4485 };
4486 write!(self.out, "{fun_str}(")?;
4487 self.write_expr(module, expr, func_ctx)?;
4488 write!(self.out, ")")?
4489 }
4490 }
4491 Expression::Relational { fun, argument } => {
4492 use crate::RelationalFunction as Rf;
4493
4494 let fun_str = match fun {
4495 Rf::All => "all",
4496 Rf::Any => "any",
4497 Rf::IsNan => "isnan",
4498 Rf::IsInf => "isinf",
4499 };
4500 write!(self.out, "{fun_str}(")?;
4501 self.write_expr(module, argument, func_ctx)?;
4502 write!(self.out, ")")?
4503 }
4504 Expression::Select {
4505 condition,
4506 accept,
4507 reject,
4508 } => {
4509 write!(self.out, "(")?;
4510 self.write_expr(module, condition, func_ctx)?;
4511 write!(self.out, " ? ")?;
4512 self.write_expr(module, accept, func_ctx)?;
4513 write!(self.out, " : ")?;
4514 self.write_expr(module, reject, func_ctx)?;
4515 write!(self.out, ")")?
4516 }
4517 Expression::RayQueryGetIntersection { query, committed } => {
4518 let Expression::LocalVariable(query_var) = func_ctx.expressions[query] else {
4520 unreachable!()
4521 };
4522
4523 let tracker_expr_name = format!(
4524 "{RAY_QUERY_TRACKER_VARIABLE_PREFIX}{}",
4525 self.names[&func_ctx.name_key(query_var)]
4526 );
4527
4528 if committed {
4529 write!(self.out, "GetCommittedIntersection(")?;
4530 self.write_expr(module, query, func_ctx)?;
4531 write!(self.out, ", {tracker_expr_name})")?;
4532 } else {
4533 write!(self.out, "GetCandidateIntersection(")?;
4534 self.write_expr(module, query, func_ctx)?;
4535 write!(self.out, ", {tracker_expr_name})")?;
4536 }
4537 }
4538 Expression::RayQueryVertexPositions { .. }
4540 | Expression::CooperativeLoad { .. }
4541 | Expression::CooperativeMultiplyAdd { .. } => {
4542 unreachable!()
4543 }
4544 Expression::CallResult(_)
4546 | Expression::AtomicResult { .. }
4547 | Expression::WorkGroupUniformLoadResult { .. }
4548 | Expression::RayQueryProceedResult
4549 | Expression::SubgroupBallotResult
4550 | Expression::SubgroupOperationResult { .. } => {}
4551 }
4552
4553 if !closing_bracket.is_empty() {
4554 write!(self.out, "{closing_bracket}")?;
4555 }
4556 Ok(())
4557 }
4558
4559 #[allow(clippy::too_many_arguments)]
4560 fn write_image_load(
4561 &mut self,
4562 module: &&Module,
4563 expr: Handle<crate::Expression>,
4564 func_ctx: &back::FunctionCtx,
4565 image: Handle<crate::Expression>,
4566 coordinate: Handle<crate::Expression>,
4567 array_index: Option<Handle<crate::Expression>>,
4568 sample: Option<Handle<crate::Expression>>,
4569 level: Option<Handle<crate::Expression>>,
4570 ) -> Result<(), Error> {
4571 let mut wrapping_type = None;
4572 match *func_ctx.resolve_type(image, &module.types) {
4573 TypeInner::Image {
4574 class: crate::ImageClass::External,
4575 ..
4576 } => {
4577 write!(self.out, "{IMAGE_LOAD_EXTERNAL_FUNCTION}(")?;
4578 self.write_expr(module, image, func_ctx)?;
4579 write!(self.out, ", ")?;
4580 self.write_expr(module, coordinate, func_ctx)?;
4581 write!(self.out, ")")?;
4582 return Ok(());
4583 }
4584 TypeInner::Image {
4585 class: crate::ImageClass::Storage { format, .. },
4586 ..
4587 } => {
4588 if format.single_component() {
4589 wrapping_type = Some(Scalar::from(format));
4590 }
4591 }
4592 _ => {}
4593 }
4594 if let Some(scalar) = wrapping_type {
4595 write!(
4596 self.out,
4597 "{}{}(",
4598 help::IMAGE_STORAGE_LOAD_SCALAR_WRAPPER,
4599 scalar.to_hlsl_str()?
4600 )?;
4601 }
4602 self.write_expr(module, image, func_ctx)?;
4604 write!(self.out, ".Load(")?;
4605
4606 self.write_texture_coordinates("int", coordinate, array_index, level, module, func_ctx)?;
4607
4608 if let Some(sample) = sample {
4609 write!(self.out, ", ")?;
4610 self.write_expr(module, sample, func_ctx)?;
4611 }
4612
4613 write!(self.out, ")")?;
4615
4616 if wrapping_type.is_some() {
4617 write!(self.out, ")")?;
4618 }
4619
4620 if let TypeInner::Scalar(_) = *func_ctx.resolve_type(expr, &module.types) {
4622 write!(self.out, ".x")?;
4623 }
4624 Ok(())
4625 }
4626
4627 fn sampler_binding_array_info_from_expression(
4630 &mut self,
4631 module: &Module,
4632 func_ctx: &back::FunctionCtx<'_>,
4633 base: Handle<crate::Expression>,
4634 resolved: &TypeInner,
4635 ) -> Option<BindingArraySamplerInfo> {
4636 if let TypeInner::BindingArray {
4637 base: base_ty_handle,
4638 ..
4639 } = *resolved
4640 {
4641 let base_ty = &module.types[base_ty_handle].inner;
4642 if let TypeInner::Sampler { comparison, .. } = *base_ty {
4643 let base = &func_ctx.expressions[base];
4644
4645 if let crate::Expression::GlobalVariable(handle) = *base {
4646 let variable = &module.global_variables[handle];
4647
4648 let sampler_heap_name = match comparison {
4649 true => COMPARISON_SAMPLER_HEAP_VAR,
4650 false => SAMPLER_HEAP_VAR,
4651 };
4652
4653 return Some(BindingArraySamplerInfo {
4654 sampler_heap_name,
4655 sampler_index_buffer_name: self
4656 .wrapped
4657 .sampler_index_buffers
4658 .get(&super::SamplerIndexBufferKey {
4659 group: variable.binding.unwrap().group,
4660 })
4661 .unwrap()
4662 .clone(),
4663 binding_array_base_index_name: self.names[&NameKey::GlobalVariable(handle)]
4664 .clone(),
4665 });
4666 }
4667 }
4668 }
4669
4670 None
4671 }
4672
4673 fn write_named_expr(
4674 &mut self,
4675 module: &Module,
4676 handle: Handle<crate::Expression>,
4677 name: String,
4678 named: Handle<crate::Expression>,
4681 ctx: &back::FunctionCtx,
4682 ) -> BackendResult {
4683 match ctx.info[named].ty {
4684 proc::TypeResolution::Handle(ty_handle) => match module.types[ty_handle].inner {
4685 TypeInner::Struct { .. } => {
4686 let ty_name = &self.names[&NameKey::Type(ty_handle)];
4687 write!(self.out, "{ty_name}")?;
4688 }
4689 _ => {
4690 self.write_type(module, ty_handle)?;
4691 }
4692 },
4693 proc::TypeResolution::Value(ref inner) => {
4694 self.write_value_type(module, inner)?;
4695 }
4696 }
4697
4698 let resolved = ctx.resolve_type(named, &module.types);
4699
4700 write!(self.out, " {name}")?;
4701 if let TypeInner::Array { base, size, .. } = *resolved {
4703 self.write_array_size(module, base, size)?;
4704 }
4705 write!(self.out, " = ")?;
4706 self.write_expr(module, handle, ctx)?;
4707 writeln!(self.out, ";")?;
4708 self.named_expressions.insert(named, name);
4709
4710 Ok(())
4711 }
4712
4713 pub(super) fn write_default_init(
4715 &mut self,
4716 module: &Module,
4717 ty: Handle<crate::Type>,
4718 ) -> BackendResult {
4719 write!(self.out, "(")?;
4720 self.write_type(module, ty)?;
4721 if let TypeInner::Array { base, size, .. } = module.types[ty].inner {
4722 self.write_array_size(module, base, size)?;
4723 }
4724 write!(self.out, ")0")?;
4725 Ok(())
4726 }
4727
4728 pub(super) fn write_control_barrier(
4729 &mut self,
4730 barrier: crate::Barrier,
4731 level: back::Level,
4732 ) -> BackendResult {
4733 if barrier.contains(crate::Barrier::STORAGE) {
4734 writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
4735 }
4736 if barrier.contains(crate::Barrier::WORK_GROUP) {
4737 writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?;
4738 }
4739 if barrier.contains(crate::Barrier::SUB_GROUP) {
4740 }
4742 if barrier.contains(crate::Barrier::TEXTURE) {
4743 writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
4744 }
4745 Ok(())
4746 }
4747
4748 fn write_memory_barrier(
4749 &mut self,
4750 barrier: crate::Barrier,
4751 level: back::Level,
4752 ) -> BackendResult {
4753 if barrier.contains(crate::Barrier::STORAGE) {
4754 writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
4755 }
4756 if barrier.contains(crate::Barrier::WORK_GROUP) {
4757 writeln!(self.out, "{level}GroupMemoryBarrier();")?;
4758 }
4759 if barrier.contains(crate::Barrier::SUB_GROUP) {
4760 }
4762 if barrier.contains(crate::Barrier::TEXTURE) {
4763 writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
4764 }
4765 Ok(())
4766 }
4767
4768 fn emit_hlsl_atomic_tail(
4770 &mut self,
4771 module: &Module,
4772 func_ctx: &back::FunctionCtx<'_>,
4773 fun: &crate::AtomicFunction,
4774 compare_expr: Option<Handle<crate::Expression>>,
4775 value: Handle<crate::Expression>,
4776 res_var_info: &Option<(Handle<crate::Expression>, String)>,
4777 ) -> BackendResult {
4778 if let Some(cmp) = compare_expr {
4779 write!(self.out, ", ")?;
4780 self.write_expr(module, cmp, func_ctx)?;
4781 }
4782 write!(self.out, ", ")?;
4783 if let crate::AtomicFunction::Subtract = *fun {
4784 write!(self.out, "-")?;
4786 }
4787 self.write_expr(module, value, func_ctx)?;
4788 if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
4789 write!(self.out, ", ")?;
4790 if compare_expr.is_some() {
4791 write!(self.out, "{res_name}.old_value")?;
4792 } else {
4793 write!(self.out, "{res_name}")?;
4794 }
4795 }
4796 writeln!(self.out, ");")?;
4797 Ok(())
4798 }
4799}
4800
4801pub(super) struct MatrixType {
4802 pub(super) columns: crate::VectorSize,
4803 pub(super) rows: crate::VectorSize,
4804 pub(super) width: crate::Bytes,
4805}
4806
4807pub(super) fn get_inner_matrix_data(
4808 module: &Module,
4809 handle: Handle<crate::Type>,
4810) -> Option<MatrixType> {
4811 match module.types[handle].inner {
4812 TypeInner::Matrix {
4813 columns,
4814 rows,
4815 scalar,
4816 } => Some(MatrixType {
4817 columns,
4818 rows,
4819 width: scalar.width,
4820 }),
4821 TypeInner::Array { base, .. } => get_inner_matrix_data(module, base),
4822 _ => None,
4823 }
4824}
4825
4826fn find_matrix_in_access_chain(
4830 module: &Module,
4831 base: Handle<crate::Expression>,
4832 func_ctx: &back::FunctionCtx<'_>,
4833) -> Option<(Handle<crate::Expression>, Option<Index>, Option<Index>)> {
4834 let mut current_base = base;
4835 let mut vector = None;
4836 let mut scalar = None;
4837 loop {
4838 let resolved_tr = func_ctx
4839 .resolve_type(current_base, &module.types)
4840 .pointer_base_type();
4841 let resolved = resolved_tr.as_ref()?.inner_with(&module.types);
4842
4843 match *resolved {
4844 TypeInner::Matrix { .. } => return Some((current_base, vector, scalar)),
4845 TypeInner::Scalar(_) | TypeInner::Vector { .. } => {}
4846 _ => return None,
4847 }
4848
4849 let index;
4850 (current_base, index) = match func_ctx.expressions[current_base] {
4851 crate::Expression::Access { base, index } => (base, Index::Expression(index)),
4852 crate::Expression::AccessIndex { base, index } => (base, Index::Static(index)),
4853 _ => return None,
4854 };
4855
4856 match *resolved {
4857 TypeInner::Scalar(_) => scalar = Some(index),
4858 TypeInner::Vector { .. } => vector = Some(index),
4859 _ => unreachable!(),
4860 }
4861 }
4862}
4863
4864pub(super) fn get_inner_matrix_of_struct_array_member(
4869 module: &Module,
4870 base: Handle<crate::Expression>,
4871 func_ctx: &back::FunctionCtx<'_>,
4872 direct: bool,
4873) -> Option<MatrixType> {
4874 let mut mat_data = None;
4875 let mut array_base = None;
4876
4877 let mut current_base = base;
4878 loop {
4879 let mut resolved = func_ctx.resolve_type(current_base, &module.types);
4880 if let TypeInner::Pointer { base, .. } = *resolved {
4881 resolved = &module.types[base].inner;
4882 };
4883
4884 match *resolved {
4885 TypeInner::Matrix {
4886 columns,
4887 rows,
4888 scalar,
4889 } => {
4890 mat_data = Some(MatrixType {
4891 columns,
4892 rows,
4893 width: scalar.width,
4894 })
4895 }
4896 TypeInner::Array { base, .. } => {
4897 array_base = Some(base);
4898 }
4899 TypeInner::Struct { .. } => {
4900 if let Some(array_base) = array_base {
4901 if direct {
4902 return mat_data;
4903 } else {
4904 return get_inner_matrix_data(module, array_base);
4905 }
4906 }
4907
4908 break;
4909 }
4910 _ => break,
4911 }
4912
4913 current_base = match func_ctx.expressions[current_base] {
4914 crate::Expression::Access { base, .. } => base,
4915 crate::Expression::AccessIndex { base, .. } => base,
4916 _ => break,
4917 };
4918 }
4919 None
4920}
4921
4922fn get_global_uniform_matrix(
4925 module: &Module,
4926 base: Handle<crate::Expression>,
4927 func_ctx: &back::FunctionCtx<'_>,
4928) -> Option<MatrixType> {
4929 let base_tr = func_ctx
4930 .resolve_type(base, &module.types)
4931 .pointer_base_type();
4932 let base_ty = base_tr.as_ref().map(|tr| tr.inner_with(&module.types));
4933 match (&func_ctx.expressions[base], base_ty) {
4934 (
4935 &crate::Expression::GlobalVariable(handle),
4936 Some(&TypeInner::Matrix {
4937 columns,
4938 rows,
4939 scalar,
4940 }),
4941 ) if module.global_variables[handle].space == crate::AddressSpace::Uniform => {
4942 Some(MatrixType {
4943 columns,
4944 rows,
4945 width: scalar.width,
4946 })
4947 }
4948 _ => None,
4949 }
4950}
4951
4952fn get_inner_matrix_of_global_uniform(
4957 module: &Module,
4958 base: Handle<crate::Expression>,
4959 func_ctx: &back::FunctionCtx<'_>,
4960) -> Option<MatrixType> {
4961 let mut mat_data = None;
4962 let mut array_base = None;
4963
4964 let mut current_base = base;
4965 loop {
4966 let mut resolved = func_ctx.resolve_type(current_base, &module.types);
4967 if let TypeInner::Pointer { base, .. } = *resolved {
4968 resolved = &module.types[base].inner;
4969 };
4970
4971 match *resolved {
4972 TypeInner::Matrix {
4973 columns,
4974 rows,
4975 scalar,
4976 } => {
4977 mat_data = Some(MatrixType {
4978 columns,
4979 rows,
4980 width: scalar.width,
4981 })
4982 }
4983 TypeInner::Array { base, .. } => {
4984 array_base = Some(base);
4985 }
4986 _ => break,
4987 }
4988
4989 current_base = match func_ctx.expressions[current_base] {
4990 crate::Expression::Access { base, .. } => base,
4991 crate::Expression::AccessIndex { base, .. } => base,
4992 crate::Expression::GlobalVariable(handle)
4993 if module.global_variables[handle].space == crate::AddressSpace::Uniform =>
4994 {
4995 return mat_data.or_else(|| {
4996 array_base.and_then(|array_base| get_inner_matrix_data(module, array_base))
4997 })
4998 }
4999 _ => break,
5000 };
5001 }
5002 None
5003}