1use alloc::{
2 format,
3 string::{String, ToString},
4 vec::Vec,
5};
6use core::{
7 fmt::{self, Write as _},
8 mem,
9};
10
11use super::{
12 help,
13 help::{
14 WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess,
15 WrappedZeroValue,
16 },
17 storage::StoreValue,
18 BackendResult, Error, FragmentEntryPoint, Options, PipelineOptions, ShaderModel,
19};
20use crate::{
21 back::{self, get_entry_points, Baked},
22 common,
23 proc::{self, index, ExternalTextureNameKey, NameKey},
24 valid, Handle, Module, RayQueryFunction, Scalar, ScalarKind, ShaderStage, TypeInner,
25};
26
27const LOCATION_SEMANTIC: &str = "LOC";
28const SPECIAL_CBUF_TYPE: &str = "NagaConstants";
29const SPECIAL_CBUF_VAR: &str = "_NagaConstants";
30const SPECIAL_FIRST_VERTEX: &str = "first_vertex";
31const SPECIAL_FIRST_INSTANCE: &str = "first_instance";
32const SPECIAL_OTHER: &str = "other";
33
34pub(crate) const MODF_FUNCTION: &str = "naga_modf";
35pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
36pub(crate) const EXTRACT_BITS_FUNCTION: &str = "naga_extractBits";
37pub(crate) const INSERT_BITS_FUNCTION: &str = "naga_insertBits";
38pub(crate) const SAMPLER_HEAP_VAR: &str = "nagaSamplerHeap";
39pub(crate) const COMPARISON_SAMPLER_HEAP_VAR: &str = "nagaComparisonSamplerHeap";
40pub(crate) const SAMPLE_EXTERNAL_TEXTURE_FUNCTION: &str = "nagaSampleExternalTexture";
41pub(crate) const ABS_FUNCTION: &str = "naga_abs";
42pub(crate) const DIV_FUNCTION: &str = "naga_div";
43pub(crate) const MOD_FUNCTION: &str = "naga_mod";
44pub(crate) const NEG_FUNCTION: &str = "naga_neg";
45pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
46pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
47pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
48pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
49pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
50 "nagaTextureSampleBaseClampToEdge";
51pub(crate) const IMAGE_LOAD_EXTERNAL_FUNCTION: &str = "nagaTextureLoadExternal";
52pub(crate) const RAY_QUERY_TRACKER_VARIABLE_PREFIX: &str = "naga_query_init_tracker_for_";
53pub(crate) const INTERNAL_PREFIX: &str = "naga_";
55
56enum Index {
57 Expression(Handle<crate::Expression>),
58 Static(u32),
59}
60
61pub(super) struct EpStructMember {
62 pub(super) name: String,
63 pub(super) ty: Handle<crate::Type>,
64 pub(super) binding: Option<crate::Binding>,
67 pub(super) index: u32,
68}
69
70pub(super) struct EntryPointBinding {
73 pub(super) arg_name: String,
76 pub(super) ty_name: String,
78 pub(super) members: Vec<EpStructMember>,
80 pub(super) local_invocation_index_name: Option<String>,
81}
82
83pub(super) struct EntryPointInterface {
84 pub(crate) input: Option<EntryPointBinding>,
89 pub(crate) output: Option<EntryPointBinding>,
93 pub(crate) mesh_vertices: Option<EntryPointBinding>,
94 pub(crate) mesh_primitives: Option<EntryPointBinding>,
95 pub(crate) mesh_indices: Option<EntryPointBinding>,
96}
97
98#[derive(Clone, Eq, PartialEq, PartialOrd, Ord)]
99enum InterfaceKey {
100 Location(u32),
101 BuiltIn(crate::BuiltIn),
102 Other,
103}
104
105impl InterfaceKey {
106 const fn new(binding: Option<&crate::Binding>) -> Self {
107 match binding {
108 Some(&crate::Binding::Location { location, .. }) => Self::Location(location),
109 Some(&crate::Binding::BuiltIn(built_in)) => Self::BuiltIn(built_in),
110 None => Self::Other,
111 }
112 }
113}
114
115#[derive(Copy, Clone, PartialEq)]
116pub(super) enum Io {
117 Input,
118 Output,
119 MeshVertices,
120 MeshPrimitives,
121}
122
123pub(super) struct NestedEntryPointArgs {
125 pub user_args: Vec<String>,
127 pub task_payload: Option<String>,
128 pub local_invocation_index: String,
129}
130
131const fn is_subgroup_builtin_binding(binding: &Option<crate::Binding>) -> bool {
132 let &Some(crate::Binding::BuiltIn(builtin)) = binding else {
133 return false;
134 };
135 matches!(
136 builtin,
137 crate::BuiltIn::SubgroupSize
138 | crate::BuiltIn::SubgroupInvocationId
139 | crate::BuiltIn::NumSubgroups
140 | crate::BuiltIn::SubgroupId
141 )
142}
143
144struct BindingArraySamplerInfo {
146 sampler_heap_name: &'static str,
148 sampler_index_buffer_name: String,
150 binding_array_base_index_name: String,
152}
153
154impl<'a, W: fmt::Write> super::Writer<'a, W> {
155 pub fn new(out: W, options: &'a Options, pipeline_options: &'a PipelineOptions) -> Self {
156 Self {
157 out,
158 names: crate::FastHashMap::default(),
159 namer: proc::Namer::default(),
160 options,
161 pipeline_options,
162 entry_point_io: crate::FastHashMap::default(),
163 named_expressions: crate::NamedExpressions::default(),
164 wrapped: super::Wrapped::default(),
165 written_committed_intersection: false,
166 written_candidate_intersection: false,
167 continue_ctx: back::continue_forward::ContinueCtx::default(),
168 temp_access_chain: Vec::new(),
169 need_bake_expressions: Default::default(),
170 function_task_payload_var: Default::default(),
171 }
172 }
173
174 fn reset(&mut self, module: &Module) {
175 self.names.clear();
176 self.namer.reset(
177 module,
178 &super::keywords::RESERVED_SET,
179 proc::KeywordSet::empty(),
180 &super::keywords::RESERVED_CASE_INSENSITIVE_SET,
181 super::keywords::RESERVED_PREFIXES,
182 &mut self.names,
183 );
184 self.entry_point_io.clear();
185 self.named_expressions.clear();
186 self.wrapped.clear();
187 self.written_committed_intersection = false;
188 self.written_candidate_intersection = false;
189 self.continue_ctx.clear();
190 self.need_bake_expressions.clear();
191 self.function_task_payload_var.clear();
192 }
193
194 fn gen_force_bounded_loop_statements(
202 &mut self,
203 level: back::Level,
204 ) -> Option<(String, String)> {
205 if !self.options.force_loop_bounding {
206 return None;
207 }
208
209 let loop_bound_name = self.namer.call("loop_bound");
210 let max = u32::MAX;
211 let decl = format!("{level}uint2 {loop_bound_name} = uint2({max}u, {max}u);");
214 let level = level.next();
215 let break_and_inc = format!(
216 "{level}if (all({loop_bound_name} == uint2(0u, 0u))) {{ break; }}
217{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
218 );
219
220 Some((decl, break_and_inc))
221 }
222
223 fn update_expressions_to_bake(
228 &mut self,
229 module: &Module,
230 func: &crate::Function,
231 info: &valid::FunctionInfo,
232 ) {
233 use crate::Expression;
234 self.need_bake_expressions.clear();
235 for (exp_handle, expr) in func.expressions.iter() {
236 let expr_info = &info[exp_handle];
237 let min_ref_count = func.expressions[exp_handle].bake_ref_count();
238 if min_ref_count <= expr_info.ref_count {
239 self.need_bake_expressions.insert(exp_handle);
240 }
241 if let Expression::Load { pointer } = *expr {
242 if info[pointer]
243 .ty
244 .inner_with(&module.types)
245 .is_atomic_pointer(&module.types)
246 {
247 self.need_bake_expressions.insert(exp_handle);
248 }
249 }
250
251 if let Expression::Math { fun, arg, arg1, .. } = *expr {
252 match fun {
253 crate::MathFunction::Asinh
254 | crate::MathFunction::Acosh
255 | crate::MathFunction::Atanh
256 | crate::MathFunction::Unpack2x16float
257 | crate::MathFunction::Unpack2x16snorm
258 | crate::MathFunction::Unpack2x16unorm
259 | crate::MathFunction::Unpack4x8snorm
260 | crate::MathFunction::Unpack4x8unorm
261 | crate::MathFunction::Unpack4xI8
262 | crate::MathFunction::Unpack4xU8
263 | crate::MathFunction::Pack2x16float
264 | crate::MathFunction::Pack2x16snorm
265 | crate::MathFunction::Pack2x16unorm
266 | crate::MathFunction::Pack4x8snorm
267 | crate::MathFunction::Pack4x8unorm
268 | crate::MathFunction::Pack4xI8
269 | crate::MathFunction::Pack4xU8
270 | crate::MathFunction::Pack4xI8Clamp
271 | crate::MathFunction::Pack4xU8Clamp => {
272 self.need_bake_expressions.insert(arg);
273 }
274 crate::MathFunction::CountLeadingZeros => {
275 let inner = info[exp_handle].ty.inner_with(&module.types);
276 if let Some(ScalarKind::Sint) = inner.scalar_kind() {
277 self.need_bake_expressions.insert(arg);
278 }
279 }
280 crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
281 self.need_bake_expressions.insert(arg);
282 self.need_bake_expressions.insert(arg1.unwrap());
283 }
284 _ => {}
285 }
286 }
287
288 if let Expression::Derivative { axis, ctrl, expr } = *expr {
289 use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
290 if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
291 self.need_bake_expressions.insert(expr);
292 }
293 }
294
295 if let Expression::GlobalVariable(_) = *expr {
296 let inner = info[exp_handle].ty.inner_with(&module.types);
297
298 if let TypeInner::Sampler { .. } = *inner {
299 self.need_bake_expressions.insert(exp_handle);
300 }
301 }
302 }
303 for statement in func.body.iter() {
304 match *statement {
305 crate::Statement::SubgroupCollectiveOperation {
306 op: _,
307 collective_op: crate::CollectiveOperation::InclusiveScan,
308 argument,
309 result: _,
310 } => {
311 self.need_bake_expressions.insert(argument);
312 }
313 crate::Statement::Atomic {
314 fun: crate::AtomicFunction::Exchange { compare: Some(cmp) },
315 ..
316 } => {
317 self.need_bake_expressions.insert(cmp);
318 }
319 _ => {}
320 }
321 }
322 }
323
324 pub fn write(
325 &mut self,
326 module: &Module,
327 module_info: &valid::ModuleInfo,
328 fragment_entry_point: Option<&FragmentEntryPoint<'_>>,
329 ) -> Result<super::ReflectionInfo, Error> {
330 self.reset(module);
331
332 if module.uses_mesh_shaders() && self.options.shader_model < ShaderModel::V6_5 {
333 return Err(Error::ShaderModelTooLow(
334 "mesh shaders".to_string(),
335 ShaderModel::V6_5,
336 ));
337 }
338
339 if let Some(ref bt) = self.options.special_constants_binding {
341 writeln!(self.out, "struct {SPECIAL_CBUF_TYPE} {{")?;
342 writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_VERTEX)?;
343 writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_INSTANCE)?;
344 writeln!(self.out, "{}uint {};", back::INDENT, SPECIAL_OTHER)?;
345 writeln!(self.out, "}};")?;
346 write!(
347 self.out,
348 "ConstantBuffer<{}> {}: register(b{}",
349 SPECIAL_CBUF_TYPE, SPECIAL_CBUF_VAR, bt.register
350 )?;
351 if bt.space != 0 {
352 write!(self.out, ", space{}", bt.space)?;
353 }
354 writeln!(self.out, ");")?;
355
356 writeln!(self.out)?;
358 }
359
360 for (group, bt) in self.options.dynamic_storage_buffer_offsets_targets.iter() {
361 writeln!(self.out, "struct __dynamic_buffer_offsetsTy{group} {{")?;
362 for i in 0..bt.size {
363 writeln!(self.out, "{}uint _{};", back::INDENT, i)?;
364 }
365 writeln!(self.out, "}};")?;
366 writeln!(
367 self.out,
368 "ConstantBuffer<__dynamic_buffer_offsetsTy{}> __dynamic_buffer_offsets{}: register(b{}, space{});",
369 group, group, bt.register, bt.space
370 )?;
371
372 writeln!(self.out)?;
374 }
375
376 let ep_results = module
378 .entry_points
379 .iter()
380 .map(|ep| (ep.stage, ep.function.result.clone()))
381 .collect::<Vec<(ShaderStage, Option<crate::FunctionResult>)>>();
382
383 self.write_all_mat_cx2_typedefs_and_functions(module)?;
384
385 for (handle, ty) in module.types.iter() {
387 if let TypeInner::Struct { ref members, span } = ty.inner {
388 if module.types[members.last().unwrap().ty]
389 .inner
390 .is_dynamically_sized(&module.types)
391 {
392 continue;
395 }
396
397 let ep_result = ep_results.iter().find(|e| {
398 if let Some(ref result) = e.1 {
399 result.ty == handle
400 } else {
401 false
402 }
403 });
404
405 self.write_struct(
406 module,
407 handle,
408 members,
409 span,
410 ep_result.map(|r| (r.0, Io::Output)),
411 )?;
412 writeln!(self.out)?;
413 }
414 }
415
416 self.write_special_functions(module)?;
417
418 self.write_wrapped_expression_functions(module, &module.global_expressions, None)?;
419 self.write_wrapped_zero_value_functions(module, &module.global_expressions)?;
420
421 let mut constants = module
423 .constants
424 .iter()
425 .filter(|&(_, c)| c.name.is_some())
426 .peekable();
427 while let Some((handle, _)) = constants.next() {
428 self.write_global_constant(module, handle)?;
429 if constants.peek().is_none() {
431 writeln!(self.out)?;
432 }
433 }
434
435 for (global, _) in module.global_variables.iter() {
437 self.write_global(module, global)?;
438 }
439
440 if !module.global_variables.is_empty() {
441 writeln!(self.out)?;
443 }
444
445 let ep_range = get_entry_points(module, self.pipeline_options.entry_point.as_ref())
446 .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
447
448 for index in ep_range.clone() {
450 let ep = &module.entry_points[index];
451 let ep_name = self.names[&NameKey::EntryPoint(index as u16)].clone();
452 let ep_io = self.write_ep_interface(module, ep, &ep_name, fragment_entry_point)?;
453 self.entry_point_io.insert(index, ep_io);
454 }
455
456 for (handle, function) in module.functions.iter() {
458 let info = &module_info[handle];
459
460 if !self.options.fake_missing_bindings {
462 if let Some((var_handle, _)) =
463 module
464 .global_variables
465 .iter()
466 .find(|&(var_handle, var)| match var.binding {
467 Some(ref binding) if !info[var_handle].is_empty() => {
468 self.options.resolve_resource_binding(binding).is_err()
469 && self
470 .options
471 .resolve_external_texture_resource_binding(binding)
472 .is_err()
473 }
474 _ => false,
475 })
476 {
477 log::debug!(
478 "Skipping function {:?} (name {:?}) because global {:?} is inaccessible",
479 handle,
480 function.name,
481 var_handle
482 );
483 continue;
484 }
485 }
486
487 let ctx = back::FunctionCtx {
488 ty: back::FunctionType::Function(handle),
489 info,
490 expressions: &function.expressions,
491 named_expressions: &function.named_expressions,
492 };
493 let name = self.names[&NameKey::Function(handle)].clone();
494
495 self.write_wrapped_functions(module, &ctx)?;
496
497 self.write_function(module, name.as_str(), function, &ctx, info, String::new())?;
498
499 writeln!(self.out)?;
500 }
501
502 let mut translated_ep_names = Vec::with_capacity(ep_range.len());
503
504 for index in ep_range {
506 let ep = &module.entry_points[index];
507 let info = module_info.get_entry_point(index);
508
509 if !self.options.fake_missing_bindings {
510 let mut ep_error = None;
511 for (var_handle, var) in module.global_variables.iter() {
512 match var.binding {
513 Some(ref binding) if !info[var_handle].is_empty() => {
514 if let Err(err) = self.options.resolve_resource_binding(binding) {
515 if self
516 .options
517 .resolve_external_texture_resource_binding(binding)
518 .is_err()
519 {
520 ep_error = Some(err);
521 break;
522 }
523 }
524 }
525 _ => {}
526 }
527 }
528 if let Some(err) = ep_error {
529 translated_ep_names.push(Err(err));
530 continue;
531 }
532 }
533
534 let ctx = back::FunctionCtx {
535 ty: back::FunctionType::EntryPoint(index as u16),
536 info,
537 expressions: &ep.function.expressions,
538 named_expressions: &ep.function.named_expressions,
539 };
540
541 self.write_wrapped_functions(module, &ctx)?;
542
543 let mut attribute_string = String::new();
546 if ep.stage.compute_like() {
547 let num_threads = ep.workgroup_size;
549 writeln!(
550 attribute_string,
551 "[numthreads({}, {}, {})]",
552 num_threads[0], num_threads[1], num_threads[2]
553 )?;
554 }
555 if let Some(ref info) = ep.mesh_info {
556 let topology_str = match info.topology {
557 crate::MeshOutputTopology::Points => unreachable!(),
558 crate::MeshOutputTopology::Lines => "line",
559 crate::MeshOutputTopology::Triangles => "triangle",
560 };
561 writeln!(attribute_string, "[outputtopology(\"{topology_str}\")]")?;
562 }
563
564 let name = self.names[&NameKey::EntryPoint(index as u16)].clone();
565 self.write_function(module, &name, &ep.function, &ctx, info, attribute_string)?;
566
567 if index < module.entry_points.len() - 1 {
568 writeln!(self.out)?;
569 }
570
571 translated_ep_names.push(Ok(name));
572 }
573
574 Ok(super::ReflectionInfo {
575 entry_point_names: translated_ep_names,
576 })
577 }
578
579 fn write_modifier(&mut self, binding: &crate::Binding) -> BackendResult {
580 match *binding {
581 crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }) => {
582 write!(self.out, "precise ")?;
583 }
584 crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { perspective: false }) => {
585 write!(self.out, "noperspective ")?;
586 }
587 crate::Binding::Location {
588 interpolation,
589 sampling,
590 ..
591 } => {
592 if let Some(interpolation) = interpolation {
593 if let Some(string) = interpolation.to_hlsl_str() {
594 write!(self.out, "{string} ")?
595 }
596 }
597
598 if let Some(sampling) = sampling {
599 if let Some(string) = sampling.to_hlsl_str() {
600 write!(self.out, "{string} ")?
601 }
602 }
603 }
604 crate::Binding::BuiltIn(_) => {}
605 }
606
607 Ok(())
608 }
609
610 pub(super) fn write_semantic(
613 &mut self,
614 binding: &Option<crate::Binding>,
615 stage: Option<(ShaderStage, Io)>,
616 ) -> BackendResult {
617 let is_per_primitive = match *binding {
618 Some(crate::Binding::BuiltIn(builtin)) if !is_subgroup_builtin_binding(binding) => {
619 if builtin == crate::BuiltIn::ViewIndex
620 && self.options.shader_model < ShaderModel::V6_1
621 {
622 return Err(Error::ShaderModelTooLow(
623 "used @builtin(view_index) or SV_ViewID".to_string(),
624 ShaderModel::V6_1,
625 ));
626 }
627 if let Some(builtin_str) = builtin.to_hlsl_str()? {
628 write!(self.out, " : {builtin_str}")?;
629 }
630 false
631 }
632 Some(crate::Binding::Location {
633 blend_src: Some(1),
634 per_primitive,
635 ..
636 }) => {
637 write!(self.out, " : SV_Target1")?;
638 per_primitive
639 }
640 Some(crate::Binding::Location {
641 location,
642 per_primitive,
643 ..
644 }) => {
645 if stage == Some((ShaderStage::Fragment, Io::Output)) {
646 write!(self.out, " : SV_Target{location}")?;
647 } else {
648 write!(self.out, " : {LOCATION_SEMANTIC}{location}")?;
649 }
650 per_primitive
651 }
652 _ => false,
653 };
654 if is_per_primitive {
655 write!(self.out, " : primitive")?;
656 }
657
658 Ok(())
659 }
660
661 pub(super) fn write_interface_struct(
662 &mut self,
663 module: &Module,
664 shader_stage: (ShaderStage, Io),
665 struct_name: String,
666 var_name: Option<&str>,
667 mut members: Vec<EpStructMember>,
668 ) -> Result<EntryPointBinding, Error> {
669 let struct_name = self.namer.call(&struct_name);
670 members.sort_by_key(|m| InterfaceKey::new(m.binding.as_ref()));
674
675 write!(self.out, "struct {struct_name}")?;
676 writeln!(self.out, " {{")?;
677 let mut local_invocation_index_name = None;
678 let mut subgroup_id_used = false;
679 for m in members.iter() {
680 debug_assert!(m.binding.is_some());
683
684 match m.binding {
685 Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => {
686 subgroup_id_used = true;
687 }
688 Some(crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationIndex)) => {
689 local_invocation_index_name = Some(m.name.clone());
690 }
691 _ => (),
692 }
693
694 if is_subgroup_builtin_binding(&m.binding) {
695 continue;
696 }
697 write!(self.out, "{}", back::INDENT)?;
698 if let Some(ref binding) = m.binding {
699 self.write_modifier(binding)?;
700 }
701 self.write_type(module, m.ty)?;
702 write!(self.out, " {}", &m.name)?;
703 self.write_semantic(&m.binding, Some(shader_stage))?;
704 writeln!(self.out, ";")?;
705 }
706 if subgroup_id_used && local_invocation_index_name.is_none() {
707 let name = self.namer.call("local_invocation_index");
708 writeln!(self.out, "{}uint {name} : SV_GroupIndex;", back::INDENT)?;
709 local_invocation_index_name = Some(name);
710 }
711 writeln!(self.out, "}};")?;
712 writeln!(self.out)?;
713
714 match shader_stage.1 {
716 Io::Input => {
717 members.sort_by_key(|m| m.index);
719 }
720 Io::Output | Io::MeshVertices | Io::MeshPrimitives => {
721 }
723 }
724
725 Ok(EntryPointBinding {
726 arg_name: self
727 .namer
728 .call(var_name.unwrap_or(struct_name.to_lowercase().as_str())),
729 ty_name: struct_name,
730 members,
731 local_invocation_index_name,
732 })
733 }
734
735 fn write_ep_input_struct(
739 &mut self,
740 module: &Module,
741 func: &crate::Function,
742 stage: ShaderStage,
743 entry_point_name: &str,
744 ) -> Result<EntryPointBinding, Error> {
745 let struct_name = format!("{stage:?}Input_{entry_point_name}");
746
747 let mut fake_members = Vec::new();
748 for arg in func.arguments.iter() {
749 match module.types[arg.ty].inner {
754 TypeInner::Struct { ref members, .. } => {
755 for member in members.iter() {
756 let name = self.namer.call_or(&member.name, "member");
757 let index = fake_members.len() as u32;
758 fake_members.push(EpStructMember {
759 name,
760 ty: member.ty,
761 binding: member.binding.clone(),
762 index,
763 });
764 }
765 }
766 _ => {
767 let member_name = self.namer.call_or(&arg.name, "member");
768 let index = fake_members.len() as u32;
769 fake_members.push(EpStructMember {
770 name: member_name,
771 ty: arg.ty,
772 binding: arg.binding.clone(),
773 index,
774 });
775 }
776 }
777 }
778
779 self.write_interface_struct(module, (stage, Io::Input), struct_name, None, fake_members)
780 }
781
782 fn write_ep_output_struct(
786 &mut self,
787 module: &Module,
788 result: &crate::FunctionResult,
789 stage: ShaderStage,
790 entry_point_name: &str,
791 frag_ep: Option<&FragmentEntryPoint<'_>>,
792 ) -> Result<EntryPointBinding, Error> {
793 let struct_name = format!("{stage:?}Output_{entry_point_name}");
794
795 let empty = [];
796 let members = match module.types[result.ty].inner {
797 TypeInner::Struct { ref members, .. } => members,
798 ref other => {
799 log::error!("Unexpected {other:?} output type without a binding");
800 &empty[..]
801 }
802 };
803
804 let fs_input_locs = if let (Some(frag_ep), ShaderStage::Vertex) = (frag_ep, stage) {
809 let mut fs_input_locs = Vec::new();
810 for arg in frag_ep.func.arguments.iter() {
811 let mut push_if_location = |binding: &Option<crate::Binding>| match *binding {
812 Some(crate::Binding::Location { location, .. }) => fs_input_locs.push(location),
813 Some(crate::Binding::BuiltIn(_)) | None => {}
814 };
815
816 match frag_ep.module.types[arg.ty].inner {
819 TypeInner::Struct { ref members, .. } => {
820 for member in members.iter() {
821 push_if_location(&member.binding);
822 }
823 }
824 _ => push_if_location(&arg.binding),
825 }
826 }
827 fs_input_locs.sort();
828 Some(fs_input_locs)
829 } else {
830 None
831 };
832
833 let mut fake_members = Vec::new();
834 for (index, member) in members.iter().enumerate() {
835 if let Some(ref fs_input_locs) = fs_input_locs {
836 match member.binding {
837 Some(crate::Binding::Location { location, .. }) => {
838 if fs_input_locs.binary_search(&location).is_err() {
839 continue;
840 }
841 }
842 Some(crate::Binding::BuiltIn(_)) | None => {}
843 }
844 }
845
846 let member_name = self.namer.call_or(&member.name, "member");
847 fake_members.push(EpStructMember {
848 name: member_name,
849 ty: member.ty,
850 binding: member.binding.clone(),
851 index: index as u32,
852 });
853 }
854
855 self.write_interface_struct(module, (stage, Io::Output), struct_name, None, fake_members)
856 }
857
858 fn write_ep_interface(
862 &mut self,
863 module: &Module,
864 ep: &crate::EntryPoint,
865 ep_name: &str,
866 frag_ep: Option<&FragmentEntryPoint<'_>>,
867 ) -> Result<EntryPointInterface, Error> {
868 let func = &ep.function;
869 let stage = ep.stage;
870 Ok(EntryPointInterface {
871 input: if !func.arguments.is_empty()
872 && (stage == ShaderStage::Fragment
873 || func
874 .arguments
875 .iter()
876 .any(|arg| is_subgroup_builtin_binding(&arg.binding)))
877 {
878 Some(self.write_ep_input_struct(module, func, stage, ep_name)?)
879 } else {
880 None
881 },
882 output: match func.result {
883 Some(ref fr) if fr.binding.is_none() && stage == ShaderStage::Vertex => {
884 Some(self.write_ep_output_struct(module, fr, stage, ep_name, frag_ep)?)
885 }
886 _ => None,
887 },
888 mesh_vertices: if let Some(ref info) = ep.mesh_info {
889 Some(self.write_ep_mesh_output_struct(module, ep_name, false, info)?)
890 } else {
891 None
892 },
893 mesh_primitives: if let Some(ref info) = ep.mesh_info {
894 Some(self.write_ep_mesh_output_struct(module, ep_name, true, info)?)
895 } else {
896 None
897 },
898 mesh_indices: if let Some(ref info) = ep.mesh_info {
899 Some(self.write_ep_mesh_output_indices(info.topology)?)
900 } else {
901 None
902 },
903 })
904 }
905
906 fn write_ep_argument_initialization(
907 &mut self,
908 ep: &crate::EntryPoint,
909 ep_input: &EntryPointBinding,
910 fake_member: &EpStructMember,
911 ) -> BackendResult {
912 match fake_member.binding {
913 Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupSize)) => {
914 write!(self.out, "WaveGetLaneCount()")?
915 }
916 Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupInvocationId)) => {
917 write!(self.out, "WaveGetLaneIndex()")?
918 }
919 Some(crate::Binding::BuiltIn(crate::BuiltIn::NumSubgroups)) => write!(
920 self.out,
921 "({}u + WaveGetLaneCount() - 1u) / WaveGetLaneCount()",
922 ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2]
923 )?,
924 Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => {
925 write!(
926 self.out,
927 "{}.{} / WaveGetLaneCount()",
928 ep_input.arg_name,
929 ep_input.local_invocation_index_name.as_ref().unwrap()
931 )?;
932 }
933 Some(crate::Binding::Location {
934 interpolation: Some(crate::Interpolation::PerVertex),
935 ..
936 }) => {
937 if self.options.shader_model < ShaderModel::V6_1 {
938 return Err(Error::ShaderModelTooLow(
939 "per_vertex fragment inputs".to_string(),
940 ShaderModel::V6_1,
941 ));
942 }
943 write!(
944 self.out,
945 "{{ GetAttributeAtVertex({0}.{1}, 0), GetAttributeAtVertex({0}.{1}, 1), GetAttributeAtVertex({0}.{1}, 2) }}",
946 ep_input.arg_name,
947 fake_member.name,
948 )?;
949 }
950 _ => {
951 write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?;
952 }
953 }
954 Ok(())
955 }
956
957 fn write_ep_arguments_initialization(
959 &mut self,
960 module: &Module,
961 func: &crate::Function,
962 ep_index: u16,
963 ) -> BackendResult {
964 let ep = &module.entry_points[ep_index as usize];
965 let ep_input = match self
966 .entry_point_io
967 .get_mut(&(ep_index as usize))
968 .unwrap()
969 .input
970 .take()
971 {
972 Some(ep_input) => ep_input,
973 None => return Ok(()),
974 };
975 let mut fake_iter = ep_input.members.iter();
976 for (arg_index, arg) in func.arguments.iter().enumerate() {
977 write!(self.out, "{}", back::INDENT)?;
978 self.write_type(module, arg.ty)?;
979 let arg_name = &self.names[&NameKey::EntryPointArgument(ep_index, arg_index as u32)];
980 write!(self.out, " {arg_name}")?;
981 match module.types[arg.ty].inner {
982 TypeInner::Array { base, size, .. } => {
983 self.write_array_size(module, base, size)?;
984 write!(self.out, " = ")?;
985 self.write_ep_argument_initialization(
986 ep,
987 &ep_input,
988 fake_iter.next().unwrap(),
989 )?;
990 writeln!(self.out, ";")?;
991 }
992 TypeInner::Struct { ref members, .. } => {
993 write!(self.out, " = {{ ")?;
994 for index in 0..members.len() {
995 if index != 0 {
996 write!(self.out, ", ")?;
997 }
998 self.write_ep_argument_initialization(
999 ep,
1000 &ep_input,
1001 fake_iter.next().unwrap(),
1002 )?;
1003 }
1004 writeln!(self.out, " }};")?;
1005 }
1006 _ => {
1007 write!(self.out, " = ")?;
1008 self.write_ep_argument_initialization(
1009 ep,
1010 &ep_input,
1011 fake_iter.next().unwrap(),
1012 )?;
1013 writeln!(self.out, ";")?;
1014 }
1015 }
1016 }
1017 assert!(fake_iter.next().is_none());
1018 Ok(())
1019 }
1020
1021 fn write_global(
1025 &mut self,
1026 module: &Module,
1027 handle: Handle<crate::GlobalVariable>,
1028 ) -> BackendResult {
1029 let global = &module.global_variables[handle];
1030 let inner = &module.types[global.ty].inner;
1031
1032 let handle_ty = match *inner {
1033 TypeInner::BindingArray { ref base, .. } => &module.types[*base].inner,
1034 _ => inner,
1035 };
1036
1037 let is_external_texture = matches!(
1041 *handle_ty,
1042 TypeInner::Image {
1043 class: crate::ImageClass::External,
1044 ..
1045 }
1046 );
1047 if is_external_texture {
1048 return self.write_global_external_texture(module, handle, global);
1049 }
1050
1051 if let Some(ref binding) = global.binding {
1052 if let Err(err) = self.options.resolve_resource_binding(binding) {
1053 log::debug!(
1054 "Skipping global {:?} (name {:?}) for being inaccessible: {}",
1055 handle,
1056 global.name,
1057 err,
1058 );
1059 return Ok(());
1060 }
1061 }
1062
1063 let is_sampler = matches!(*handle_ty, TypeInner::Sampler { .. });
1065
1066 if is_sampler {
1067 return self.write_global_sampler(module, handle, global);
1068 }
1069
1070 let register_ty = match global.space {
1072 crate::AddressSpace::Function => unreachable!("Function address space"),
1073 crate::AddressSpace::Private => {
1074 write!(self.out, "static ")?;
1075 self.write_type(module, global.ty)?;
1076 ""
1077 }
1078 crate::AddressSpace::WorkGroup | crate::AddressSpace::TaskPayload => {
1079 write!(self.out, "groupshared ")?;
1080 self.write_type(module, global.ty)?;
1081 ""
1082 }
1083 crate::AddressSpace::Uniform => {
1084 write!(self.out, "cbuffer")?;
1087 "b"
1088 }
1089 crate::AddressSpace::Storage { access } => {
1090 if global
1091 .memory_decorations
1092 .contains(crate::MemoryDecorations::COHERENT)
1093 {
1094 write!(self.out, "globallycoherent ")?;
1095 }
1096 let (prefix, register) = if access.contains(crate::StorageAccess::STORE) {
1097 ("RW", "u")
1098 } else {
1099 ("", "t")
1100 };
1101 write!(self.out, "{prefix}ByteAddressBuffer")?;
1102 register
1103 }
1104 crate::AddressSpace::Handle => {
1105 let register = match *handle_ty {
1106 TypeInner::Image {
1108 class: crate::ImageClass::Storage { .. },
1109 ..
1110 } => "u",
1111 _ => "t",
1112 };
1113 self.write_type(module, global.ty)?;
1114 register
1115 }
1116 crate::AddressSpace::Immediate => {
1117 write!(self.out, "ConstantBuffer<")?;
1119 "b"
1120 }
1121 crate::AddressSpace::RayPayload | crate::AddressSpace::IncomingRayPayload => {
1122 unimplemented!()
1123 }
1124 };
1125
1126 if global.space == crate::AddressSpace::Immediate {
1129 self.write_global_type(module, global.ty)?;
1130
1131 if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1133 self.write_array_size(module, base, size)?;
1134 }
1135
1136 write!(self.out, ">")?;
1138 }
1139
1140 let name = &self.names[&NameKey::GlobalVariable(handle)];
1141 write!(self.out, " {name}")?;
1142
1143 if global.space == crate::AddressSpace::Immediate {
1146 match module.types[global.ty].inner {
1147 TypeInner::Struct { .. } => {}
1148 _ => {
1149 return Err(Error::Unimplemented(format!(
1150 "push-constant '{name}' has non-struct type; tracked by: https://github.com/gfx-rs/wgpu/issues/5683"
1151 )));
1152 }
1153 }
1154
1155 let target = self
1156 .options
1157 .immediates_target
1158 .as_ref()
1159 .expect("No bind target was defined for the immediates block");
1160 write!(self.out, ": register(b{}", target.register)?;
1161 if target.space != 0 {
1162 write!(self.out, ", space{}", target.space)?;
1163 }
1164 write!(self.out, ")")?;
1165 }
1166
1167 if let Some(ref binding) = global.binding {
1168 let bt = self.options.resolve_resource_binding(binding).unwrap();
1170
1171 if let TypeInner::BindingArray { base, size, .. } = module.types[global.ty].inner {
1173 if let Some(overridden_size) = bt.binding_array_size {
1174 write!(self.out, "[{overridden_size}]")?;
1175 } else {
1176 self.write_array_size(module, base, size)?;
1177 }
1178 }
1179
1180 write!(self.out, " : register({}{}", register_ty, bt.register)?;
1181 if bt.space != 0 {
1182 write!(self.out, ", space{}", bt.space)?;
1183 }
1184 write!(self.out, ")")?;
1185 } else {
1186 if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1188 self.write_array_size(module, base, size)?;
1189 }
1190 if global.space == crate::AddressSpace::Private {
1191 write!(self.out, " = ")?;
1192 if let Some(init) = global.init {
1193 self.write_const_expression(module, init, &module.global_expressions)?;
1194 } else {
1195 self.write_default_init(module, global.ty)?;
1196 }
1197 }
1198 }
1199
1200 if global.space == crate::AddressSpace::Uniform {
1201 write!(self.out, " {{ ")?;
1202
1203 self.write_global_type(module, global.ty)?;
1204
1205 write!(
1206 self.out,
1207 " {}",
1208 &self.names[&NameKey::GlobalVariable(handle)]
1209 )?;
1210
1211 if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1213 self.write_array_size(module, base, size)?;
1214 }
1215
1216 writeln!(self.out, "; }}")?;
1217 } else {
1218 writeln!(self.out, ";")?;
1219 }
1220
1221 Ok(())
1222 }
1223
1224 fn write_global_sampler(
1225 &mut self,
1226 module: &Module,
1227 handle: Handle<crate::GlobalVariable>,
1228 global: &crate::GlobalVariable,
1229 ) -> BackendResult {
1230 let binding = *global.binding.as_ref().unwrap();
1231
1232 let key = super::SamplerIndexBufferKey {
1233 group: binding.group,
1234 };
1235 self.write_wrapped_sampler_buffer(key)?;
1236
1237 let bt = self.options.resolve_resource_binding(&binding).unwrap();
1239
1240 match module.types[global.ty].inner {
1241 TypeInner::Sampler { comparison } => {
1242 write!(self.out, "static const ")?;
1249 self.write_type(module, global.ty)?;
1250
1251 let heap_var = if comparison {
1252 COMPARISON_SAMPLER_HEAP_VAR
1253 } else {
1254 SAMPLER_HEAP_VAR
1255 };
1256
1257 let index_buffer_name = &self.wrapped.sampler_index_buffers[&key];
1258 let name = &self.names[&NameKey::GlobalVariable(handle)];
1259 writeln!(
1260 self.out,
1261 " {name} = {heap_var}[{index_buffer_name}[{register}]];",
1262 register = bt.register
1263 )?;
1264 }
1265 TypeInner::BindingArray { .. } => {
1266 let name = &self.names[&NameKey::GlobalVariable(handle)];
1272 writeln!(
1273 self.out,
1274 "static const uint {name} = {register};",
1275 register = bt.register
1276 )?;
1277 }
1278 _ => unreachable!(),
1279 };
1280
1281 Ok(())
1282 }
1283
1284 fn write_global_external_texture(
1288 &mut self,
1289 module: &Module,
1290 handle: Handle<crate::GlobalVariable>,
1291 global: &crate::GlobalVariable,
1292 ) -> BackendResult {
1293 let res_binding = global
1294 .binding
1295 .as_ref()
1296 .expect("External texture global variables must have a resource binding");
1297 let ext_tex_bindings = match self
1298 .options
1299 .resolve_external_texture_resource_binding(res_binding)
1300 {
1301 Ok(bindings) => bindings,
1302 Err(err) => {
1303 log::debug!(
1304 "Skipping global {:?} (name {:?}) for being inaccessible: {}",
1305 handle,
1306 global.name,
1307 err,
1308 );
1309 return Ok(());
1310 }
1311 };
1312
1313 let mut write_plane = |bt: &super::BindTarget, name| -> BackendResult {
1314 write!(
1315 self.out,
1316 "Texture2D<float4> {}: register(t{}",
1317 name, bt.register
1318 )?;
1319 if bt.space != 0 {
1320 write!(self.out, ", space{}", bt.space)?;
1321 }
1322 writeln!(self.out, ");")?;
1323 Ok(())
1324 };
1325 for (i, bt) in ext_tex_bindings.planes.iter().enumerate() {
1326 let plane_name = &self.names
1327 [&NameKey::ExternalTextureGlobalVariable(handle, ExternalTextureNameKey::Plane(i))];
1328 write_plane(bt, plane_name)?;
1329 }
1330
1331 let params_name = &self.names
1332 [&NameKey::ExternalTextureGlobalVariable(handle, ExternalTextureNameKey::Params)];
1333 let params_ty_name =
1334 &self.names[&NameKey::Type(module.special_types.external_texture_params.unwrap())];
1335 write!(
1336 self.out,
1337 "cbuffer {}: register(b{}",
1338 params_name, ext_tex_bindings.params.register
1339 )?;
1340 if ext_tex_bindings.params.space != 0 {
1341 write!(self.out, ", space{}", ext_tex_bindings.params.space)?;
1342 }
1343 writeln!(self.out, ") {{ {params_ty_name} {params_name}; }};")?;
1344
1345 Ok(())
1346 }
1347
1348 fn write_global_constant(
1353 &mut self,
1354 module: &Module,
1355 handle: Handle<crate::Constant>,
1356 ) -> BackendResult {
1357 write!(self.out, "static const ")?;
1358 let constant = &module.constants[handle];
1359 self.write_type(module, constant.ty)?;
1360 let name = &self.names[&NameKey::Constant(handle)];
1361 write!(self.out, " {name}")?;
1362 if let TypeInner::Array { base, size, .. } = module.types[constant.ty].inner {
1364 self.write_array_size(module, base, size)?;
1365 }
1366 write!(self.out, " = ")?;
1367 self.write_const_expression(module, constant.init, &module.global_expressions)?;
1368 writeln!(self.out, ";")?;
1369 Ok(())
1370 }
1371
1372 pub(super) fn write_array_size(
1373 &mut self,
1374 module: &Module,
1375 base: Handle<crate::Type>,
1376 size: crate::ArraySize,
1377 ) -> BackendResult {
1378 write!(self.out, "[")?;
1379
1380 match size.resolve(module.to_ctx())? {
1381 proc::IndexableLength::Known(size) => {
1382 write!(self.out, "{size}")?;
1383 }
1384 proc::IndexableLength::Dynamic => unreachable!(),
1385 }
1386
1387 write!(self.out, "]")?;
1388
1389 if let TypeInner::Array {
1390 base: next_base,
1391 size: next_size,
1392 ..
1393 } = module.types[base].inner
1394 {
1395 self.write_array_size(module, next_base, next_size)?;
1396 }
1397
1398 Ok(())
1399 }
1400
1401 fn write_struct(
1406 &mut self,
1407 module: &Module,
1408 handle: Handle<crate::Type>,
1409 members: &[crate::StructMember],
1410 span: u32,
1411 shader_stage: Option<(ShaderStage, Io)>,
1412 ) -> BackendResult {
1413 let struct_name = &self.names[&NameKey::Type(handle)];
1415 writeln!(self.out, "struct {struct_name} {{")?;
1416
1417 let mut last_offset = 0;
1418 for (index, member) in members.iter().enumerate() {
1419 if member.binding.is_none() && member.offset > last_offset {
1420 let padding = (member.offset - last_offset) / 4;
1424 for i in 0..padding {
1425 writeln!(self.out, "{}int _pad{}_{};", back::INDENT, index, i)?;
1426 }
1427 }
1428 let ty_inner = &module.types[member.ty].inner;
1429 last_offset = member.offset + ty_inner.size_hlsl(module.to_ctx())?;
1430
1431 write!(self.out, "{}", back::INDENT)?;
1433
1434 match module.types[member.ty].inner {
1435 TypeInner::Array { base, size, .. } => {
1436 self.write_global_type(module, member.ty)?;
1439
1440 write!(
1442 self.out,
1443 " {}",
1444 &self.names[&NameKey::StructMember(handle, index as u32)]
1445 )?;
1446 self.write_array_size(module, base, size)?;
1448 }
1449 TypeInner::Matrix {
1452 rows,
1453 columns,
1454 scalar,
1455 } if member.binding.is_none() && rows == crate::VectorSize::Bi => {
1456 let vec_ty = TypeInner::Vector { size: rows, scalar };
1457 let field_name_key = NameKey::StructMember(handle, index as u32);
1458
1459 for i in 0..columns as u8 {
1460 if i != 0 {
1461 write!(self.out, "; ")?;
1462 }
1463 self.write_value_type(module, &vec_ty)?;
1464 write!(self.out, " {}_{}", &self.names[&field_name_key], i)?;
1465 }
1466 }
1467 _ => {
1468 if let Some(ref binding) = member.binding {
1470 self.write_modifier(binding)?;
1471 }
1472
1473 if let TypeInner::Matrix { .. } = module.types[member.ty].inner {
1477 write!(self.out, "row_major ")?;
1478 }
1479
1480 self.write_type(module, member.ty)?;
1482 write!(
1483 self.out,
1484 " {}",
1485 &self.names[&NameKey::StructMember(handle, index as u32)]
1486 )?;
1487 }
1488 }
1489
1490 self.write_semantic(&member.binding, shader_stage)?;
1491 writeln!(self.out, ";")?;
1492 }
1493
1494 if members.last().unwrap().binding.is_none() && span > last_offset {
1496 let padding = (span - last_offset) / 4;
1497 for i in 0..padding {
1498 writeln!(self.out, "{}int _end_pad_{};", back::INDENT, i)?;
1499 }
1500 }
1501
1502 writeln!(self.out, "}};")?;
1503 Ok(())
1504 }
1505
1506 pub(super) fn write_global_type(
1511 &mut self,
1512 module: &Module,
1513 ty: Handle<crate::Type>,
1514 ) -> BackendResult {
1515 let matrix_data = get_inner_matrix_data(module, ty);
1516
1517 if let Some(MatrixType {
1520 columns,
1521 rows: crate::VectorSize::Bi,
1522 width,
1523 }) = matrix_data
1524 {
1525 write!(self.out, "__mat{}x2_f{}", columns as u8, width * 8)?;
1526 } else {
1527 if matrix_data.is_some() {
1531 write!(self.out, "row_major ")?;
1532 }
1533
1534 self.write_type(module, ty)?;
1535 }
1536
1537 Ok(())
1538 }
1539
1540 pub(super) fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
1545 let inner = &module.types[ty].inner;
1546 match *inner {
1547 TypeInner::Struct { .. } => write!(self.out, "{}", self.names[&NameKey::Type(ty)])?,
1548 TypeInner::Array { base, .. } | TypeInner::BindingArray { base, .. } => {
1550 self.write_type(module, base)?
1551 }
1552 ref other => self.write_value_type(module, other)?,
1553 }
1554
1555 Ok(())
1556 }
1557
1558 pub(super) fn write_value_type(&mut self, module: &Module, inner: &TypeInner) -> BackendResult {
1563 match *inner {
1564 TypeInner::Scalar(scalar) | TypeInner::Atomic(scalar) => {
1565 write!(self.out, "{}", scalar.to_hlsl_str()?)?;
1566 }
1567 TypeInner::Vector { size, scalar } => {
1568 write!(
1569 self.out,
1570 "{}{}",
1571 scalar.to_hlsl_str()?,
1572 common::vector_size_str(size)
1573 )?;
1574 }
1575 TypeInner::Matrix {
1576 columns,
1577 rows,
1578 scalar,
1579 } => {
1580 write!(
1585 self.out,
1586 "{}{}x{}",
1587 scalar.to_hlsl_str()?,
1588 common::vector_size_str(columns),
1589 common::vector_size_str(rows),
1590 )?;
1591 }
1592 TypeInner::Image {
1593 dim,
1594 arrayed,
1595 class,
1596 } => {
1597 self.write_image_type(dim, arrayed, class)?;
1598 }
1599 TypeInner::Sampler { comparison } => {
1600 let sampler = if comparison {
1601 "SamplerComparisonState"
1602 } else {
1603 "SamplerState"
1604 };
1605 write!(self.out, "{sampler}")?;
1606 }
1607 TypeInner::Array { base, size, .. } | TypeInner::BindingArray { base, size } => {
1611 self.write_array_size(module, base, size)?;
1612 }
1613 TypeInner::AccelerationStructure { .. } => {
1614 write!(self.out, "RaytracingAccelerationStructure")?;
1615 }
1616 TypeInner::RayQuery { .. } => {
1617 write!(self.out, "RayQuery<RAY_FLAG_NONE>")?;
1619 }
1620 _ => return Err(Error::Unimplemented(format!("write_value_type {inner:?}"))),
1621 }
1622
1623 Ok(())
1624 }
1625
1626 fn write_function(
1630 &mut self,
1631 module: &Module,
1632 name: &str,
1633 func: &crate::Function,
1634 func_ctx: &back::FunctionCtx<'_>,
1635 info: &valid::FunctionInfo,
1636 header: String,
1637 ) -> BackendResult {
1638 self.update_expressions_to_bake(module, func, info);
1641 let ep = match func_ctx.ty {
1642 back::FunctionType::EntryPoint(idx) => Some(&module.entry_points[idx as usize]),
1643 back::FunctionType::Function(_) => None,
1644 };
1645
1646 let nested = matches!(
1647 ep,
1648 Some(crate::EntryPoint {
1649 stage: ShaderStage::Task | ShaderStage::Mesh,
1650 ..
1651 })
1652 );
1653 if !nested {
1654 write!(self.out, "{header}")?;
1655 }
1656
1657 if let Some(ref result) = func.result {
1658 let array_return_type = match module.types[result.ty].inner {
1660 TypeInner::Array { base, size, .. } => {
1661 let array_return_type = self.namer.call(&format!("ret_{name}"));
1662 write!(self.out, "typedef ")?;
1663 self.write_type(module, result.ty)?;
1664 write!(self.out, " {array_return_type}")?;
1665 self.write_array_size(module, base, size)?;
1666 writeln!(self.out, ";")?;
1667 Some(array_return_type)
1668 }
1669 _ => None,
1670 };
1671
1672 if let Some(
1674 ref binding @ crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }),
1675 ) = result.binding
1676 {
1677 self.write_modifier(binding)?;
1678 }
1679
1680 match func_ctx.ty {
1682 back::FunctionType::Function(_) => {
1683 if let Some(array_return_type) = array_return_type {
1684 write!(self.out, "{array_return_type}")?;
1685 } else {
1686 self.write_type(module, result.ty)?;
1687 }
1688 }
1689 back::FunctionType::EntryPoint(index) => {
1690 if let Some(ref ep_output) =
1691 self.entry_point_io.get(&(index as usize)).unwrap().output
1692 {
1693 write!(self.out, "{}", ep_output.ty_name)?;
1694 } else {
1695 self.write_type(module, result.ty)?;
1696 }
1697 }
1698 }
1699 } else {
1700 write!(self.out, "void")?;
1701 }
1702
1703 let nested_name = if nested {
1704 self.namer.call(&format!("_{name}"))
1705 } else {
1706 name.to_string()
1707 };
1708
1709 write!(self.out, " {nested_name}(")?;
1711
1712 let need_workgroup_variables_initialization =
1713 self.need_workgroup_variables_initialization(func_ctx, module);
1714
1715 let mut any_args_written = false;
1716 let mut separator = || {
1717 if any_args_written {
1718 ", "
1719 } else {
1720 any_args_written = true;
1721 ""
1722 }
1723 };
1724
1725 let needs_local_invocation_index_name = need_workgroup_variables_initialization || nested;
1726 let mut local_invocation_index_name = None;
1727 let mut nested_wgsl_args: Vec<String> = Vec::new();
1730 let mut nested_task_payload_name: Option<String> = None;
1731 match func_ctx.ty {
1733 back::FunctionType::Function(handle) => {
1734 for (index, arg) in func.arguments.iter().enumerate() {
1735 write!(self.out, "{}", separator())?;
1736 self.write_function_argument(module, handle, arg, index)?;
1737 }
1738 for (var_handle, var) in module.global_variables.iter() {
1740 let uses = info[var_handle];
1741 if uses.contains(valid::GlobalUse::READ)
1742 && !uses.contains(valid::GlobalUse::WRITE)
1743 && var.space == crate::AddressSpace::TaskPayload
1744 {
1745 self.function_task_payload_var.insert(handle, var_handle);
1746 write!(self.out, "{}in ", separator())?;
1747
1748 self.write_type(module, var.ty)?;
1749 let name = &self.names[&NameKey::GlobalVariable(var_handle)];
1750 write!(self.out, " {name}")?;
1751 break;
1752 }
1753 }
1754 }
1755 back::FunctionType::EntryPoint(ep_index) => {
1756 let ep = &module.entry_points[ep_index as usize];
1757 if let Some(ref ep_input) =
1758 self.entry_point_io.get(&(ep_index as usize)).unwrap().input
1759 {
1760 write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name)?;
1761 separator();
1762 nested_wgsl_args.push(ep_input.arg_name.clone());
1763 } else {
1764 let stage = ep.stage;
1765 for (index, arg) in func.arguments.iter().enumerate() {
1766 write!(self.out, "{}", separator())?;
1767 self.write_type(module, arg.ty)?;
1768
1769 let argument_name =
1770 &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)];
1771
1772 if arg.binding
1773 == Some(crate::Binding::BuiltIn(
1774 crate::BuiltIn::LocalInvocationIndex,
1775 ))
1776 {
1777 local_invocation_index_name = Some(argument_name.clone());
1778 }
1779
1780 nested_wgsl_args.push(argument_name.clone());
1781 write!(self.out, " {argument_name}")?;
1782 if let TypeInner::Array { base, size, .. } = module.types[arg.ty].inner {
1783 self.write_array_size(module, base, size)?;
1784 }
1785
1786 self.write_semantic(&arg.binding, Some((stage, Io::Input)))?;
1787 }
1788 }
1789 if ep.stage == ShaderStage::Mesh {
1790 if let Some(var_handle) = ep.task_payload {
1791 let var = &module.global_variables[var_handle];
1792 write!(self.out, "{}in ", separator())?;
1793 self.write_type(module, var.ty)?;
1794 let arg_name = &self.names[&NameKey::GlobalVariable(var_handle)];
1795 write!(self.out, " {arg_name}")?;
1796 nested_task_payload_name = Some(arg_name.clone());
1797 if let TypeInner::Array { base, size, .. } = module.types[var.ty].inner {
1798 self.write_array_size(module, base, size)?;
1799 }
1800 }
1801 }
1802 if needs_local_invocation_index_name && local_invocation_index_name.is_none() {
1803 let name = self.namer.call("local_invocation_index");
1804 write!(self.out, "{}uint {name}", separator())?;
1805 write!(self.out, " : SV_GroupIndex")?;
1806 local_invocation_index_name = Some(name);
1807 }
1808 }
1809 }
1810 write!(self.out, ")")?;
1812
1813 if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1815 let stage = module.entry_points[index as usize].stage;
1816 if let Some(crate::FunctionResult { ref binding, .. }) = func.result {
1817 self.write_semantic(binding, Some((stage, Io::Output)))?;
1818 }
1819 }
1820
1821 writeln!(self.out)?;
1823 writeln!(self.out, "{{")?;
1824
1825 if need_workgroup_variables_initialization && !nested {
1826 let back::FunctionType::EntryPoint(index) = func_ctx.ty else {
1827 unreachable!();
1828 };
1829 writeln!(
1830 self.out,
1831 "{}if ({} == 0) {{",
1832 back::INDENT,
1833 local_invocation_index_name.as_ref().unwrap(),
1836 )?;
1837 self.write_workgroup_variables_initialization(
1838 func_ctx,
1839 module,
1840 module.entry_points[index as usize].stage,
1841 )?;
1842
1843 writeln!(self.out, "{}}}", back::INDENT)?;
1844 self.write_control_barrier(crate::Barrier::WORK_GROUP, back::Level(1))?;
1845 }
1846
1847 if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1848 self.write_ep_arguments_initialization(module, func, index)?;
1849 }
1850
1851 for (handle, local) in func.local_variables.iter() {
1853 write!(self.out, "{}", back::INDENT)?;
1855
1856 self.write_type(module, local.ty)?;
1859 write!(self.out, " {}", self.names[&func_ctx.name_key(handle)])?;
1860 if let TypeInner::Array { base, size, .. } = module.types[local.ty].inner {
1862 self.write_array_size(module, base, size)?;
1863 }
1864
1865 let is_ray_query = match module.types[local.ty].inner {
1866 TypeInner::RayQuery { .. } => true,
1868 _ => {
1869 write!(self.out, " = ")?;
1870 if let Some(init) = local.init {
1872 self.write_expr(module, init, func_ctx)?;
1873 } else {
1874 self.write_default_init(module, local.ty)?;
1876 }
1877 false
1878 }
1879 };
1880 writeln!(self.out, ";")?;
1882 if is_ray_query {
1884 write!(self.out, "{}", back::INDENT)?;
1885 self.write_value_type(module, &TypeInner::Scalar(Scalar::U32))?;
1886 writeln!(
1887 self.out,
1888 " {RAY_QUERY_TRACKER_VARIABLE_PREFIX}{} = 0;",
1889 self.names[&func_ctx.name_key(handle)]
1890 )?;
1891 }
1892 }
1893
1894 if !func.local_variables.is_empty() {
1895 writeln!(self.out)?;
1896 }
1897
1898 for sta in func.body.iter() {
1900 self.write_stmt(module, sta, func_ctx, back::Level(1))?;
1902 }
1903
1904 writeln!(self.out, "}}")?;
1905
1906 if nested {
1907 self.write_nested_function_outer(
1908 module,
1909 func_ctx,
1910 &header,
1911 name,
1912 need_workgroup_variables_initialization,
1913 &nested_name,
1914 ep.unwrap(),
1915 NestedEntryPointArgs {
1916 user_args: nested_wgsl_args,
1917 task_payload: nested_task_payload_name,
1918 local_invocation_index: local_invocation_index_name.unwrap(),
1920 },
1921 )?;
1922 }
1923
1924 self.named_expressions.clear();
1925
1926 Ok(())
1927 }
1928
1929 fn write_function_argument(
1930 &mut self,
1931 module: &Module,
1932 handle: Handle<crate::Function>,
1933 arg: &crate::FunctionArgument,
1934 index: usize,
1935 ) -> BackendResult {
1936 if let TypeInner::Image {
1939 class: crate::ImageClass::External,
1940 ..
1941 } = module.types[arg.ty].inner
1942 {
1943 return self.write_function_external_texture_argument(module, handle, index);
1944 }
1945
1946 let arg_ty = match module.types[arg.ty].inner {
1948 TypeInner::Pointer { base, .. } => {
1950 write!(self.out, "inout ")?;
1952 base
1953 }
1954 _ => arg.ty,
1955 };
1956 self.write_type(module, arg_ty)?;
1957
1958 let argument_name = &self.names[&NameKey::FunctionArgument(handle, index as u32)];
1959
1960 write!(self.out, " {argument_name}")?;
1962 if let TypeInner::Array { base, size, .. } = module.types[arg_ty].inner {
1963 self.write_array_size(module, base, size)?;
1964 }
1965
1966 Ok(())
1967 }
1968
1969 fn write_function_external_texture_argument(
1970 &mut self,
1971 module: &Module,
1972 handle: Handle<crate::Function>,
1973 index: usize,
1974 ) -> BackendResult {
1975 let plane_names = [0, 1, 2].map(|i| {
1976 &self.names[&NameKey::ExternalTextureFunctionArgument(
1977 handle,
1978 index as u32,
1979 ExternalTextureNameKey::Plane(i),
1980 )]
1981 });
1982 let params_name = &self.names[&NameKey::ExternalTextureFunctionArgument(
1983 handle,
1984 index as u32,
1985 ExternalTextureNameKey::Params,
1986 )];
1987 let params_ty_name =
1988 &self.names[&NameKey::Type(module.special_types.external_texture_params.unwrap())];
1989 write!(
1990 self.out,
1991 "Texture2D<float4> {}, Texture2D<float4> {}, Texture2D<float4> {}, {params_ty_name} {params_name}",
1992 plane_names[0], plane_names[1], plane_names[2],
1993 )?;
1994 Ok(())
1995 }
1996
1997 fn need_workgroup_variables_initialization(
1998 &mut self,
1999 func_ctx: &back::FunctionCtx,
2000 module: &Module,
2001 ) -> bool {
2002 self.options.zero_initialize_workgroup_memory
2003 && func_ctx.ty.is_compute_like_entry_point(module)
2004 && module.global_variables.iter().any(|(handle, var)| {
2005 !func_ctx.info[handle].is_empty() && var.space.is_workgroup_like()
2006 })
2007 }
2008
2009 pub(super) fn write_workgroup_variables_initialization(
2010 &mut self,
2011 func_ctx: &back::FunctionCtx,
2012 module: &Module,
2013 stage: ShaderStage,
2014 ) -> BackendResult {
2015 let vars = module.global_variables.iter().filter(|&(handle, var)| {
2016 let task_needs_zero =
2018 (var.space == crate::AddressSpace::TaskPayload) && stage == ShaderStage::Task;
2019 !func_ctx.info[handle].is_empty()
2020 && (var.space == crate::AddressSpace::WorkGroup || task_needs_zero)
2021 });
2022
2023 for (handle, var) in vars {
2024 let name = &self.names[&NameKey::GlobalVariable(handle)];
2025 write!(self.out, "{}{} = ", back::Level(2), name)?;
2026 self.write_default_init(module, var.ty)?;
2027 writeln!(self.out, ";")?;
2028 }
2029 Ok(())
2030 }
2031
2032 fn write_switch(
2034 &mut self,
2035 module: &Module,
2036 func_ctx: &back::FunctionCtx<'_>,
2037 level: back::Level,
2038 selector: Handle<crate::Expression>,
2039 cases: &[crate::SwitchCase],
2040 ) -> BackendResult {
2041 let indent_level_1 = level.next();
2043 let indent_level_2 = indent_level_1.next();
2044
2045 if let Some(variable) = self.continue_ctx.enter_switch(&mut self.namer) {
2047 writeln!(self.out, "{level}bool {variable} = false;",)?;
2048 };
2049
2050 let one_body = cases
2055 .iter()
2056 .rev()
2057 .skip(1)
2058 .all(|case| case.fall_through && case.body.is_empty());
2059 if one_body {
2060 writeln!(self.out, "{level}do {{")?;
2062 if let Some(case) = cases.last() {
2066 for sta in case.body.iter() {
2067 self.write_stmt(module, sta, func_ctx, indent_level_1)?;
2068 }
2069 }
2070 writeln!(self.out, "{level}}} while(false);")?;
2072 } else {
2073 write!(self.out, "{level}")?;
2075 write!(self.out, "switch(")?;
2076 self.write_expr(module, selector, func_ctx)?;
2077 writeln!(self.out, ") {{")?;
2078
2079 for (i, case) in cases.iter().enumerate() {
2080 match case.value {
2081 crate::SwitchValue::I32(value) => {
2082 write!(self.out, "{indent_level_1}case {value}:")?
2083 }
2084 crate::SwitchValue::U32(value) => {
2085 write!(self.out, "{indent_level_1}case {value}u:")?
2086 }
2087 crate::SwitchValue::Default => write!(self.out, "{indent_level_1}default:")?,
2088 }
2089
2090 let write_block_braces = !(case.fall_through && case.body.is_empty());
2097 if write_block_braces {
2098 writeln!(self.out, " {{")?;
2099 } else {
2100 writeln!(self.out)?;
2101 }
2102
2103 if case.fall_through && !case.body.is_empty() {
2121 let curr_len = i + 1;
2122 let end_case_idx = curr_len
2123 + cases
2124 .iter()
2125 .skip(curr_len)
2126 .position(|case| !case.fall_through)
2127 .unwrap();
2128 let indent_level_3 = indent_level_2.next();
2129 for case in &cases[i..=end_case_idx] {
2130 writeln!(self.out, "{indent_level_2}{{")?;
2131 let prev_len = self.named_expressions.len();
2132 for sta in case.body.iter() {
2133 self.write_stmt(module, sta, func_ctx, indent_level_3)?;
2134 }
2135 self.named_expressions.truncate(prev_len);
2137 writeln!(self.out, "{indent_level_2}}}")?;
2138 }
2139
2140 let last_case = &cases[end_case_idx];
2141 if last_case.body.last().is_none_or(|s| !s.is_terminator()) {
2142 writeln!(self.out, "{indent_level_2}break;")?;
2143 }
2144 } else {
2145 for sta in case.body.iter() {
2146 self.write_stmt(module, sta, func_ctx, indent_level_2)?;
2147 }
2148 if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator()) {
2149 writeln!(self.out, "{indent_level_2}break;")?;
2150 }
2151 }
2152
2153 if write_block_braces {
2154 writeln!(self.out, "{indent_level_1}}}")?;
2155 }
2156 }
2157
2158 writeln!(self.out, "{level}}}")?;
2159 }
2160
2161 use back::continue_forward::ExitControlFlow;
2163 let op = match self.continue_ctx.exit_switch() {
2164 ExitControlFlow::None => None,
2165 ExitControlFlow::Continue { variable } => Some(("continue", variable)),
2166 ExitControlFlow::Break { variable } => Some(("break", variable)),
2167 };
2168 if let Some((control_flow, variable)) = op {
2169 writeln!(self.out, "{level}if ({variable}) {{")?;
2170 writeln!(self.out, "{indent_level_1}{control_flow};")?;
2171 writeln!(self.out, "{level}}}")?;
2172 }
2173
2174 Ok(())
2175 }
2176
2177 fn write_index(
2178 &mut self,
2179 module: &Module,
2180 index: Index,
2181 func_ctx: &back::FunctionCtx<'_>,
2182 ) -> BackendResult {
2183 match index {
2184 Index::Static(index) => {
2185 write!(self.out, "{index}")?;
2186 }
2187 Index::Expression(index) => {
2188 self.write_expr(module, index, func_ctx)?;
2189 }
2190 }
2191 Ok(())
2192 }
2193
2194 fn write_stmt(
2199 &mut self,
2200 module: &Module,
2201 stmt: &crate::Statement,
2202 func_ctx: &back::FunctionCtx<'_>,
2203 level: back::Level,
2204 ) -> BackendResult {
2205 use crate::Statement;
2206
2207 match *stmt {
2208 Statement::Emit(ref range) => {
2209 for handle in range.clone() {
2210 let ptr_class = func_ctx.resolve_type(handle, &module.types).pointer_space();
2211 let expr_name = if ptr_class.is_some() {
2212 None
2216 } else if let Some(name) = func_ctx.named_expressions.get(&handle) {
2217 Some(self.namer.call(name))
2222 } else if self.need_bake_expressions.contains(&handle) {
2223 Some(Baked(handle).to_string())
2224 } else {
2225 None
2226 };
2227
2228 if let Some(name) = expr_name {
2229 write!(self.out, "{level}")?;
2230 self.write_named_expr(module, handle, name, handle, func_ctx)?;
2231 }
2232 }
2233 }
2234 Statement::Block(ref block) => {
2236 write!(self.out, "{level}")?;
2237 writeln!(self.out, "{{")?;
2238 for sta in block.iter() {
2239 self.write_stmt(module, sta, func_ctx, level.next())?
2241 }
2242 writeln!(self.out, "{level}}}")?
2243 }
2244 Statement::If {
2246 condition,
2247 ref accept,
2248 ref reject,
2249 } => {
2250 write!(self.out, "{level}")?;
2251 write!(self.out, "if (")?;
2252 self.write_expr(module, condition, func_ctx)?;
2253 writeln!(self.out, ") {{")?;
2254
2255 let l2 = level.next();
2256 for sta in accept {
2257 self.write_stmt(module, sta, func_ctx, l2)?;
2259 }
2260
2261 if !reject.is_empty() {
2264 writeln!(self.out, "{level}}} else {{")?;
2265
2266 for sta in reject {
2267 self.write_stmt(module, sta, func_ctx, l2)?;
2269 }
2270 }
2271
2272 writeln!(self.out, "{level}}}")?
2273 }
2274 Statement::Kill => writeln!(self.out, "{level}discard;")?,
2276 Statement::Return { value: None } => {
2277 writeln!(self.out, "{level}return;")?;
2278 }
2279 Statement::Return { value: Some(expr) } => {
2280 let base_ty_res = &func_ctx.info[expr].ty;
2281 let mut resolved = base_ty_res.inner_with(&module.types);
2282 if let TypeInner::Pointer { base, space: _ } = *resolved {
2283 resolved = &module.types[base].inner;
2284 }
2285
2286 if let TypeInner::Struct { .. } = *resolved {
2287 let ty = base_ty_res.handle().unwrap();
2289 let struct_name = &self.names[&NameKey::Type(ty)];
2290 let variable_name = self.namer.call(&struct_name.to_lowercase());
2291 write!(self.out, "{level}const {struct_name} {variable_name} = ",)?;
2292 self.write_expr(module, expr, func_ctx)?;
2293 writeln!(self.out, ";")?;
2294
2295 let ep_output = match func_ctx.ty {
2297 back::FunctionType::Function(_) => None,
2298 back::FunctionType::EntryPoint(index) => self
2299 .entry_point_io
2300 .get(&(index as usize))
2301 .unwrap()
2302 .output
2303 .as_ref(),
2304 };
2305 let final_name = match ep_output {
2306 Some(ep_output) => {
2307 let final_name = self.namer.call(&variable_name);
2308 write!(
2309 self.out,
2310 "{}const {} {} = {{ ",
2311 level, ep_output.ty_name, final_name,
2312 )?;
2313 for (index, m) in ep_output.members.iter().enumerate() {
2314 if index != 0 {
2315 write!(self.out, ", ")?;
2316 }
2317 let member_name = &self.names[&NameKey::StructMember(ty, m.index)];
2318 write!(self.out, "{variable_name}.{member_name}")?;
2319 }
2320 writeln!(self.out, " }};")?;
2321 final_name
2322 }
2323 None => variable_name,
2324 };
2325 writeln!(self.out, "{level}return {final_name};")?;
2326 } else {
2327 write!(self.out, "{level}return ")?;
2328 self.write_expr(module, expr, func_ctx)?;
2329 writeln!(self.out, ";")?
2330 }
2331 }
2332 Statement::Store { pointer, value } => {
2333 let ty_inner = func_ctx.resolve_type(pointer, &module.types);
2334 if ty_inner.is_atomic_pointer(&module.types) {
2335 let pointer_space = ty_inner.pointer_space().unwrap();
2336 let dummy = self.namer.call("dummy");
2337 write!(self.out, "{level}{{ ")?;
2338 if let TypeInner::Pointer { base, .. } = *ty_inner {
2339 self.write_value_type(module, &module.types[base].inner)?;
2340 }
2341 write!(self.out, " {dummy} = 0; ")?;
2342 match pointer_space {
2343 crate::AddressSpace::WorkGroup => {
2344 write!(self.out, "InterlockedExchange(")?;
2345 self.write_expr(module, pointer, func_ctx)?;
2346 }
2347 crate::AddressSpace::Storage { .. } => {
2348 let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2349 let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
2350 write!(self.out, "{var_name}.InterlockedExchange(")?;
2351 let chain = mem::take(&mut self.temp_access_chain);
2352 self.write_storage_address(module, &chain, func_ctx)?;
2353 self.temp_access_chain = chain;
2354 }
2355 _ => unreachable!(),
2356 }
2357 write!(self.out, ", ")?;
2358 self.write_expr(module, value, func_ctx)?;
2359 writeln!(self.out, ", {dummy}); }}")?;
2360 } else if let Some(crate::AddressSpace::Storage { .. }) = ty_inner.pointer_space() {
2361 let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2362 self.write_storage_store(
2363 module,
2364 var_handle,
2365 StoreValue::Expression(value),
2366 func_ctx,
2367 level,
2368 None,
2369 )?;
2370 } else {
2371 enum MatrixAccess {
2377 Direct {
2378 base: Handle<crate::Expression>,
2379 index: u32,
2380 },
2381 Struct {
2382 columns: crate::VectorSize,
2383 width: u8,
2384 base: Handle<crate::Expression>,
2385 },
2386 }
2387
2388 let get_members = |expr: Handle<crate::Expression>| {
2389 let resolved = func_ctx.resolve_type(expr, &module.types);
2390 match *resolved {
2391 TypeInner::Pointer { base, .. } => match module.types[base].inner {
2392 TypeInner::Struct { ref members, .. } => Some(members),
2393 _ => None,
2394 },
2395 _ => None,
2396 }
2397 };
2398
2399 write!(self.out, "{level}")?;
2400
2401 let matrix_access_on_lhs =
2402 find_matrix_in_access_chain(module, pointer, func_ctx).and_then(
2403 |(matrix_expr, vector, scalar)| match (
2404 func_ctx.resolve_type(matrix_expr, &module.types),
2405 &func_ctx.expressions[matrix_expr],
2406 ) {
2407 (
2408 &TypeInner::Pointer { base: ty, .. },
2409 &crate::Expression::AccessIndex { base, index },
2410 ) if matches!(
2411 module.types[ty].inner,
2412 TypeInner::Matrix {
2413 rows: crate::VectorSize::Bi,
2414 ..
2415 }
2416 ) && get_members(base)
2417 .map(|members| members[index as usize].binding.is_none())
2418 == Some(true) =>
2419 {
2420 Some((MatrixAccess::Direct { base, index }, vector, scalar))
2421 }
2422 _ => {
2423 if let Some(MatrixType {
2424 columns,
2425 rows: crate::VectorSize::Bi,
2426 width,
2427 }) = get_inner_matrix_of_struct_array_member(
2428 module,
2429 matrix_expr,
2430 func_ctx,
2431 true,
2432 ) {
2433 Some((
2434 MatrixAccess::Struct {
2435 columns,
2436 width,
2437 base: matrix_expr,
2438 },
2439 vector,
2440 scalar,
2441 ))
2442 } else {
2443 None
2444 }
2445 }
2446 },
2447 );
2448
2449 match matrix_access_on_lhs {
2450 Some((MatrixAccess::Direct { index, base }, vector, scalar)) => {
2451 let base_ty_res = &func_ctx.info[base].ty;
2452 let resolved = base_ty_res.inner_with(&module.types);
2453 let ty = match *resolved {
2454 TypeInner::Pointer { base, .. } => base,
2455 _ => base_ty_res.handle().unwrap(),
2456 };
2457
2458 if let Some(Index::Static(vec_index)) = vector {
2459 self.write_expr(module, base, func_ctx)?;
2460 write!(
2461 self.out,
2462 ".{}_{}",
2463 &self.names[&NameKey::StructMember(ty, index)],
2464 vec_index
2465 )?;
2466
2467 if let Some(scalar_index) = scalar {
2468 write!(self.out, "[")?;
2469 self.write_index(module, scalar_index, func_ctx)?;
2470 write!(self.out, "]")?;
2471 }
2472
2473 write!(self.out, " = ")?;
2474 self.write_expr(module, value, func_ctx)?;
2475 writeln!(self.out, ";")?;
2476 } else {
2477 let access = WrappedStructMatrixAccess { ty, index };
2478 match (&vector, &scalar) {
2479 (&Some(_), &Some(_)) => {
2480 self.write_wrapped_struct_matrix_set_scalar_function_name(
2481 access,
2482 )?;
2483 }
2484 (&Some(_), &None) => {
2485 self.write_wrapped_struct_matrix_set_vec_function_name(
2486 access,
2487 )?;
2488 }
2489 (&None, _) => {
2490 self.write_wrapped_struct_matrix_set_function_name(access)?;
2491 }
2492 }
2493
2494 write!(self.out, "(")?;
2495 self.write_expr(module, base, func_ctx)?;
2496 write!(self.out, ", ")?;
2497 self.write_expr(module, value, func_ctx)?;
2498
2499 if let Some(Index::Expression(vec_index)) = vector {
2500 write!(self.out, ", ")?;
2501 self.write_expr(module, vec_index, func_ctx)?;
2502
2503 if let Some(scalar_index) = scalar {
2504 write!(self.out, ", ")?;
2505 self.write_index(module, scalar_index, func_ctx)?;
2506 }
2507 }
2508 writeln!(self.out, ");")?;
2509 }
2510 }
2511 Some((
2512 MatrixAccess::Struct {
2513 columns,
2514 width,
2515 base,
2516 },
2517 Some(Index::Expression(vec_index)),
2518 scalar,
2519 )) => {
2520 if scalar.is_some() {
2524 write!(
2525 self.out,
2526 "__set_el_of_mat{}x2_f{}",
2527 columns as u8,
2528 width * 8
2529 )?;
2530 } else {
2531 write!(
2532 self.out,
2533 "__set_col_of_mat{}x2_f{}",
2534 columns as u8,
2535 width * 8
2536 )?;
2537 }
2538 write!(self.out, "(")?;
2539 self.write_expr(module, base, func_ctx)?;
2540 write!(self.out, ", ")?;
2541 self.write_expr(module, vec_index, func_ctx)?;
2542
2543 if let Some(scalar_index) = scalar {
2544 write!(self.out, ", ")?;
2545 self.write_index(module, scalar_index, func_ctx)?;
2546 }
2547
2548 write!(self.out, ", ")?;
2549 self.write_expr(module, value, func_ctx)?;
2550
2551 writeln!(self.out, ");")?;
2552 }
2553 Some((MatrixAccess::Struct { .. }, Some(Index::Static(_)), _))
2554 | Some((MatrixAccess::Struct { .. }, None, _))
2555 | None => {
2556 self.write_expr(module, pointer, func_ctx)?;
2557 write!(self.out, " = ")?;
2558
2559 if let Some(MatrixType {
2564 columns,
2565 rows: crate::VectorSize::Bi,
2566 width,
2567 }) = get_inner_matrix_of_struct_array_member(
2568 module, pointer, func_ctx, false,
2569 ) {
2570 let mut resolved = func_ctx.resolve_type(pointer, &module.types);
2571 if let TypeInner::Pointer { base, .. } = *resolved {
2572 resolved = &module.types[base].inner;
2573 }
2574
2575 write!(self.out, "(__mat{}x2_f{}", columns as u8, width * 8)?;
2576 if let TypeInner::Array { base, size, .. } = *resolved {
2577 self.write_array_size(module, base, size)?;
2578 }
2579 write!(self.out, ")")?;
2580 }
2581
2582 self.write_expr(module, value, func_ctx)?;
2583 writeln!(self.out, ";")?
2584 }
2585 }
2586 }
2587 }
2588 Statement::Loop {
2589 ref body,
2590 ref continuing,
2591 break_if,
2592 } => {
2593 let force_loop_bound_statements = self.gen_force_bounded_loop_statements(level);
2594 let gate_name = (!continuing.is_empty() || break_if.is_some())
2595 .then(|| self.namer.call("loop_init"));
2596
2597 if let Some((ref decl, _)) = force_loop_bound_statements {
2598 writeln!(self.out, "{decl}")?;
2599 }
2600 if let Some(ref gate_name) = gate_name {
2601 writeln!(self.out, "{level}bool {gate_name} = true;")?;
2602 }
2603
2604 self.continue_ctx.enter_loop();
2605 writeln!(self.out, "{level}while(true) {{")?;
2606 if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
2607 writeln!(self.out, "{break_and_inc}")?;
2608 }
2609 let l2 = level.next();
2610 if let Some(gate_name) = gate_name {
2611 writeln!(self.out, "{l2}if (!{gate_name}) {{")?;
2612 let l3 = l2.next();
2613 for sta in continuing.iter() {
2614 self.write_stmt(module, sta, func_ctx, l3)?;
2615 }
2616 if let Some(condition) = break_if {
2617 write!(self.out, "{l3}if (")?;
2618 self.write_expr(module, condition, func_ctx)?;
2619 writeln!(self.out, ") {{")?;
2620 writeln!(self.out, "{}break;", l3.next())?;
2621 writeln!(self.out, "{l3}}}")?;
2622 }
2623 writeln!(self.out, "{l2}}}")?;
2624 writeln!(self.out, "{l2}{gate_name} = false;")?;
2625 }
2626
2627 for sta in body.iter() {
2628 self.write_stmt(module, sta, func_ctx, l2)?;
2629 }
2630
2631 writeln!(self.out, "{level}}}")?;
2632 self.continue_ctx.exit_loop();
2633 }
2634 Statement::Break => writeln!(self.out, "{level}break;")?,
2635 Statement::Continue => {
2636 if let Some(variable) = self.continue_ctx.continue_encountered() {
2637 writeln!(self.out, "{level}{variable} = true;")?;
2638 writeln!(self.out, "{level}break;")?
2639 } else {
2640 writeln!(self.out, "{level}continue;")?
2641 }
2642 }
2643 Statement::ControlBarrier(barrier) => {
2644 self.write_control_barrier(barrier, level)?;
2645 }
2646 Statement::MemoryBarrier(barrier) => {
2647 self.write_memory_barrier(barrier, level)?;
2648 }
2649 Statement::ImageStore {
2650 image,
2651 coordinate,
2652 array_index,
2653 value,
2654 } => {
2655 write!(self.out, "{level}")?;
2656 self.write_expr(module, image, func_ctx)?;
2657
2658 write!(self.out, "[")?;
2659 if let Some(index) = array_index {
2660 write!(self.out, "int3(")?;
2662 self.write_expr(module, coordinate, func_ctx)?;
2663 write!(self.out, ", ")?;
2664 self.write_expr(module, index, func_ctx)?;
2665 write!(self.out, ")")?;
2666 } else {
2667 self.write_expr(module, coordinate, func_ctx)?;
2668 }
2669 write!(self.out, "]")?;
2670
2671 write!(self.out, " = ")?;
2672 self.write_expr(module, value, func_ctx)?;
2673 writeln!(self.out, ";")?;
2674 }
2675 Statement::Call {
2676 function,
2677 ref arguments,
2678 result,
2679 } => {
2680 write!(self.out, "{level}")?;
2681
2682 if let Some(expr) = result {
2683 write!(self.out, "const ")?;
2684 let name = Baked(expr).to_string();
2685 let expr_ty = &func_ctx.info[expr].ty;
2686 let ty_inner = match *expr_ty {
2687 proc::TypeResolution::Handle(handle) => {
2688 self.write_type(module, handle)?;
2689 &module.types[handle].inner
2690 }
2691 proc::TypeResolution::Value(ref value) => {
2692 self.write_value_type(module, value)?;
2693 value
2694 }
2695 };
2696 write!(self.out, " {name}")?;
2697 if let TypeInner::Array { base, size, .. } = *ty_inner {
2698 self.write_array_size(module, base, size)?;
2699 }
2700 write!(self.out, " = ")?;
2701 self.named_expressions.insert(expr, name);
2702 }
2703 let func_name = &self.names[&NameKey::Function(function)];
2704 write!(self.out, "{func_name}(")?;
2705 let mut any_args_written = false;
2706 let mut separator = || {
2707 if any_args_written {
2708 ", "
2709 } else {
2710 any_args_written = true;
2711 ""
2712 }
2713 };
2714 for argument in arguments {
2715 write!(self.out, "{}", separator())?;
2716 self.write_expr(module, *argument, func_ctx)?;
2717 }
2718 if let Some(&var) = self.function_task_payload_var.get(&function) {
2719 let name = &self.names[&NameKey::GlobalVariable(var)];
2720 write!(self.out, "{}{name}", separator())?;
2722 }
2723 writeln!(self.out, ");")?;
2724 }
2725 Statement::Atomic {
2726 pointer,
2727 ref fun,
2728 value,
2729 result,
2730 } => {
2731 write!(self.out, "{level}")?;
2732 let res_var_info = if let Some(res_handle) = result {
2733 let name = Baked(res_handle).to_string();
2734 match func_ctx.info[res_handle].ty {
2735 proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2736 proc::TypeResolution::Value(ref value) => {
2737 self.write_value_type(module, value)?
2738 }
2739 };
2740 write!(self.out, " {name}; ")?;
2741 self.named_expressions.insert(res_handle, name.clone());
2742 Some((res_handle, name))
2743 } else {
2744 None
2745 };
2746 let pointer_space = func_ctx
2747 .resolve_type(pointer, &module.types)
2748 .pointer_space()
2749 .unwrap();
2750 let fun_str = fun.to_hlsl_suffix();
2751 let compare_expr = match *fun {
2752 crate::AtomicFunction::Exchange { compare: Some(cmp) } => Some(cmp),
2753 _ => None,
2754 };
2755 match pointer_space {
2756 crate::AddressSpace::WorkGroup => {
2757 write!(self.out, "Interlocked{fun_str}(")?;
2758 self.write_expr(module, pointer, func_ctx)?;
2759 self.emit_hlsl_atomic_tail(
2760 module,
2761 func_ctx,
2762 fun,
2763 compare_expr,
2764 value,
2765 &res_var_info,
2766 )?;
2767 }
2768 crate::AddressSpace::Storage { .. } => {
2769 let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2770 let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
2771 let width = match func_ctx.resolve_type(value, &module.types) {
2772 &TypeInner::Scalar(Scalar { width: 8, .. }) => "64",
2773 _ => "",
2774 };
2775 write!(self.out, "{var_name}.Interlocked{fun_str}{width}(")?;
2776 let chain = mem::take(&mut self.temp_access_chain);
2777 self.write_storage_address(module, &chain, func_ctx)?;
2778 self.temp_access_chain = chain;
2779 self.emit_hlsl_atomic_tail(
2780 module,
2781 func_ctx,
2782 fun,
2783 compare_expr,
2784 value,
2785 &res_var_info,
2786 )?;
2787 }
2788 ref other => {
2789 return Err(Error::Custom(format!(
2790 "invalid address space {other:?} for atomic statement"
2791 )))
2792 }
2793 }
2794 if let Some(cmp) = compare_expr {
2795 if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
2796 write!(
2797 self.out,
2798 "{level}{res_name}.exchanged = ({res_name}.old_value == "
2799 )?;
2800 self.write_expr(module, cmp, func_ctx)?;
2801 writeln!(self.out, ");")?;
2802 }
2803 }
2804 }
2805 Statement::ImageAtomic {
2806 image,
2807 coordinate,
2808 array_index,
2809 fun,
2810 value,
2811 } => {
2812 write!(self.out, "{level}")?;
2813
2814 let fun_str = fun.to_hlsl_suffix();
2815 write!(self.out, "Interlocked{fun_str}(")?;
2816 self.write_expr(module, image, func_ctx)?;
2817 write!(self.out, "[")?;
2818 self.write_texture_coordinates(
2819 "int",
2820 coordinate,
2821 array_index,
2822 None,
2823 module,
2824 func_ctx,
2825 )?;
2826 write!(self.out, "],")?;
2827
2828 self.write_expr(module, value, func_ctx)?;
2829 writeln!(self.out, ");")?;
2830 }
2831 Statement::WorkGroupUniformLoad { pointer, result } => {
2832 self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
2833 write!(self.out, "{level}")?;
2834 let name = Baked(result).to_string();
2835 self.write_named_expr(module, pointer, name, result, func_ctx)?;
2836
2837 self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
2838 }
2839 Statement::Switch {
2840 selector,
2841 ref cases,
2842 } => {
2843 self.write_switch(module, func_ctx, level, selector, cases)?;
2844 }
2845 Statement::RayQuery { query, ref fun } => {
2846 let crate::Expression::LocalVariable(query_var) = func_ctx.expressions[query]
2858 else {
2859 unreachable!()
2860 };
2861
2862 let tracker_expr_name = format!(
2863 "{RAY_QUERY_TRACKER_VARIABLE_PREFIX}{}",
2864 self.names[&func_ctx.name_key(query_var)]
2865 );
2866
2867 match *fun {
2868 RayQueryFunction::Initialize {
2869 acceleration_structure,
2870 descriptor,
2871 } => {
2872 self.write_initialize_function(
2873 module,
2874 level,
2875 query,
2876 acceleration_structure,
2877 descriptor,
2878 &tracker_expr_name,
2879 func_ctx,
2880 )?;
2881 }
2882 RayQueryFunction::Proceed { result } => {
2883 self.write_proceed(
2884 module,
2885 level,
2886 query,
2887 result,
2888 &tracker_expr_name,
2889 func_ctx,
2890 )?;
2891 }
2892 RayQueryFunction::GenerateIntersection { hit_t } => {
2893 self.write_generate_intersection(
2894 module,
2895 level,
2896 query,
2897 hit_t,
2898 &tracker_expr_name,
2899 func_ctx,
2900 )?;
2901 }
2902 RayQueryFunction::ConfirmIntersection => {
2903 self.write_confirm_intersection(
2904 module,
2905 level,
2906 query,
2907 &tracker_expr_name,
2908 func_ctx,
2909 )?;
2910 }
2911 RayQueryFunction::Terminate => {
2912 self.write_terminate(module, level, query, &tracker_expr_name, func_ctx)?;
2913 }
2914 }
2915 }
2916 Statement::SubgroupBallot { result, predicate } => {
2917 write!(self.out, "{level}")?;
2918 let name = Baked(result).to_string();
2919 write!(self.out, "const uint4 {name} = ")?;
2920 self.named_expressions.insert(result, name);
2921
2922 write!(self.out, "WaveActiveBallot(")?;
2923 match predicate {
2924 Some(predicate) => self.write_expr(module, predicate, func_ctx)?,
2925 None => write!(self.out, "true")?,
2926 }
2927 writeln!(self.out, ");")?;
2928 }
2929 Statement::SubgroupCollectiveOperation {
2930 op,
2931 collective_op,
2932 argument,
2933 result,
2934 } => {
2935 write!(self.out, "{level}")?;
2936 write!(self.out, "const ")?;
2937 let name = Baked(result).to_string();
2938 match func_ctx.info[result].ty {
2939 proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2940 proc::TypeResolution::Value(ref value) => {
2941 self.write_value_type(module, value)?
2942 }
2943 };
2944 write!(self.out, " {name} = ")?;
2945 self.named_expressions.insert(result, name);
2946
2947 match (collective_op, op) {
2948 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
2949 write!(self.out, "WaveActiveAllTrue(")?
2950 }
2951 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
2952 write!(self.out, "WaveActiveAnyTrue(")?
2953 }
2954 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
2955 write!(self.out, "WaveActiveSum(")?
2956 }
2957 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
2958 write!(self.out, "WaveActiveProduct(")?
2959 }
2960 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
2961 write!(self.out, "WaveActiveMax(")?
2962 }
2963 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
2964 write!(self.out, "WaveActiveMin(")?
2965 }
2966 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
2967 write!(self.out, "WaveActiveBitAnd(")?
2968 }
2969 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
2970 write!(self.out, "WaveActiveBitOr(")?
2971 }
2972 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
2973 write!(self.out, "WaveActiveBitXor(")?
2974 }
2975 (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
2976 write!(self.out, "WavePrefixSum(")?
2977 }
2978 (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
2979 write!(self.out, "WavePrefixProduct(")?
2980 }
2981 (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
2982 self.write_expr(module, argument, func_ctx)?;
2983 write!(self.out, " + WavePrefixSum(")?;
2984 }
2985 (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
2986 self.write_expr(module, argument, func_ctx)?;
2987 write!(self.out, " * WavePrefixProduct(")?;
2988 }
2989 _ => unimplemented!(),
2990 }
2991 self.write_expr(module, argument, func_ctx)?;
2992 writeln!(self.out, ");")?;
2993 }
2994 Statement::SubgroupGather {
2995 mode,
2996 argument,
2997 result,
2998 } => {
2999 write!(self.out, "{level}")?;
3000 write!(self.out, "const ")?;
3001 let name = Baked(result).to_string();
3002 match func_ctx.info[result].ty {
3003 proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
3004 proc::TypeResolution::Value(ref value) => {
3005 self.write_value_type(module, value)?
3006 }
3007 };
3008 write!(self.out, " {name} = ")?;
3009 self.named_expressions.insert(result, name);
3010 match mode {
3011 crate::GatherMode::BroadcastFirst => {
3012 write!(self.out, "WaveReadLaneFirst(")?;
3013 self.write_expr(module, argument, func_ctx)?;
3014 }
3015 crate::GatherMode::QuadBroadcast(index) => {
3016 write!(self.out, "QuadReadLaneAt(")?;
3017 self.write_expr(module, argument, func_ctx)?;
3018 write!(self.out, ", ")?;
3019 self.write_expr(module, index, func_ctx)?;
3020 }
3021 crate::GatherMode::QuadSwap(direction) => {
3022 match direction {
3023 crate::Direction::X => {
3024 write!(self.out, "QuadReadAcrossX(")?;
3025 }
3026 crate::Direction::Y => {
3027 write!(self.out, "QuadReadAcrossY(")?;
3028 }
3029 crate::Direction::Diagonal => {
3030 write!(self.out, "QuadReadAcrossDiagonal(")?;
3031 }
3032 }
3033 self.write_expr(module, argument, func_ctx)?;
3034 }
3035 _ => {
3036 write!(self.out, "WaveReadLaneAt(")?;
3037 self.write_expr(module, argument, func_ctx)?;
3038 write!(self.out, ", ")?;
3039 match mode {
3040 crate::GatherMode::BroadcastFirst => unreachable!(),
3041 crate::GatherMode::Broadcast(index)
3042 | crate::GatherMode::Shuffle(index) => {
3043 self.write_expr(module, index, func_ctx)?;
3044 }
3045 crate::GatherMode::ShuffleDown(index) => {
3046 write!(self.out, "WaveGetLaneIndex() + ")?;
3047 self.write_expr(module, index, func_ctx)?;
3048 }
3049 crate::GatherMode::ShuffleUp(index) => {
3050 write!(self.out, "WaveGetLaneIndex() - ")?;
3051 self.write_expr(module, index, func_ctx)?;
3052 }
3053 crate::GatherMode::ShuffleXor(index) => {
3054 write!(self.out, "WaveGetLaneIndex() ^ ")?;
3055 self.write_expr(module, index, func_ctx)?;
3056 }
3057 crate::GatherMode::QuadBroadcast(_) => unreachable!(),
3058 crate::GatherMode::QuadSwap(_) => unreachable!(),
3059 }
3060 }
3061 }
3062 writeln!(self.out, ");")?;
3063 }
3064 Statement::CooperativeStore { .. } => unimplemented!(),
3065 Statement::RayPipelineFunction(_) => unreachable!(),
3066 }
3067
3068 Ok(())
3069 }
3070
3071 #[allow(clippy::too_many_arguments)]
3072 fn write_math_expression(
3073 &mut self,
3074 module: &Module,
3075 fun: crate::MathFunction,
3076 arg: Handle<crate::Expression>,
3077 arg1: Option<Handle<crate::Expression>>,
3078 arg2: Option<Handle<crate::Expression>>,
3079 arg3: Option<Handle<crate::Expression>>,
3080 func_ctx: &back::FunctionCtx<'_>,
3081 ) -> BackendResult {
3082 use crate::MathFunction as Mf;
3083
3084 enum Function {
3085 Asincosh { is_sin: bool },
3086 Atanh,
3087 Pack2x16float,
3088 Pack2x16snorm,
3089 Pack2x16unorm,
3090 Pack4x8snorm,
3091 Pack4x8unorm,
3092 Pack4xI8,
3093 Pack4xU8,
3094 Pack4xI8Clamp,
3095 Pack4xU8Clamp,
3096 Unpack2x16float,
3097 Unpack2x16snorm,
3098 Unpack2x16unorm,
3099 Unpack4x8snorm,
3100 Unpack4x8unorm,
3101 Unpack4xI8,
3102 Unpack4xU8,
3103 Dot4I8Packed,
3104 Dot4U8Packed,
3105 QuantizeToF16,
3106 Regular(&'static str),
3107 MissingIntOverload(&'static str),
3108 MissingIntReturnType(&'static str),
3109 CountTrailingZeros,
3110 CountLeadingZeros,
3111 }
3112
3113 let fun = match fun {
3114 Mf::Abs => match func_ctx.resolve_type(arg, &module.types).scalar() {
3116 Some(Scalar::I32) => Function::Regular(ABS_FUNCTION),
3117 _ => Function::Regular("abs"),
3118 },
3119 Mf::Min => Function::Regular("min"),
3120 Mf::Max => Function::Regular("max"),
3121 Mf::Clamp => Function::Regular("clamp"),
3122 Mf::Saturate => Function::Regular("saturate"),
3123 Mf::Cos => Function::Regular("cos"),
3125 Mf::Cosh => Function::Regular("cosh"),
3126 Mf::Sin => Function::Regular("sin"),
3127 Mf::Sinh => Function::Regular("sinh"),
3128 Mf::Tan => Function::Regular("tan"),
3129 Mf::Tanh => Function::Regular("tanh"),
3130 Mf::Acos => Function::Regular("acos"),
3131 Mf::Asin => Function::Regular("asin"),
3132 Mf::Atan => Function::Regular("atan"),
3133 Mf::Atan2 => Function::Regular("atan2"),
3134 Mf::Asinh => Function::Asincosh { is_sin: true },
3135 Mf::Acosh => Function::Asincosh { is_sin: false },
3136 Mf::Atanh => Function::Atanh,
3137 Mf::Radians => Function::Regular("radians"),
3138 Mf::Degrees => Function::Regular("degrees"),
3139 Mf::Ceil => Function::Regular("ceil"),
3141 Mf::Floor => Function::Regular("floor"),
3142 Mf::Round => Function::Regular("round"),
3143 Mf::Fract => Function::Regular("frac"),
3144 Mf::Trunc => Function::Regular("trunc"),
3145 Mf::Modf => Function::Regular(MODF_FUNCTION),
3146 Mf::Frexp => Function::Regular(FREXP_FUNCTION),
3147 Mf::Ldexp => Function::Regular("ldexp"),
3148 Mf::Exp => Function::Regular("exp"),
3150 Mf::Exp2 => Function::Regular("exp2"),
3151 Mf::Log => Function::Regular("log"),
3152 Mf::Log2 => Function::Regular("log2"),
3153 Mf::Pow => Function::Regular("pow"),
3154 Mf::Dot => Function::Regular("dot"),
3156 Mf::Dot4I8Packed => Function::Dot4I8Packed,
3157 Mf::Dot4U8Packed => Function::Dot4U8Packed,
3158 Mf::Cross => Function::Regular("cross"),
3160 Mf::Distance => Function::Regular("distance"),
3161 Mf::Length => Function::Regular("length"),
3162 Mf::Normalize => Function::Regular("normalize"),
3163 Mf::FaceForward => Function::Regular("faceforward"),
3164 Mf::Reflect => Function::Regular("reflect"),
3165 Mf::Refract => Function::Regular("refract"),
3166 Mf::Sign => Function::Regular("sign"),
3168 Mf::Fma => Function::Regular("mad"),
3169 Mf::Mix => Function::Regular("lerp"),
3170 Mf::Step => Function::Regular("step"),
3171 Mf::SmoothStep => Function::Regular("smoothstep"),
3172 Mf::Sqrt => Function::Regular("sqrt"),
3173 Mf::InverseSqrt => Function::Regular("rsqrt"),
3174 Mf::Transpose => Function::Regular("transpose"),
3176 Mf::Determinant => Function::Regular("determinant"),
3177 Mf::QuantizeToF16 => Function::QuantizeToF16,
3178 Mf::CountTrailingZeros => Function::CountTrailingZeros,
3180 Mf::CountLeadingZeros => Function::CountLeadingZeros,
3181 Mf::CountOneBits => Function::MissingIntOverload("countbits"),
3182 Mf::ReverseBits => Function::MissingIntOverload("reversebits"),
3183 Mf::FirstTrailingBit => Function::MissingIntReturnType("firstbitlow"),
3184 Mf::FirstLeadingBit => Function::MissingIntReturnType("firstbithigh"),
3185 Mf::ExtractBits => Function::Regular(EXTRACT_BITS_FUNCTION),
3186 Mf::InsertBits => Function::Regular(INSERT_BITS_FUNCTION),
3187 Mf::Pack2x16float => Function::Pack2x16float,
3189 Mf::Pack2x16snorm => Function::Pack2x16snorm,
3190 Mf::Pack2x16unorm => Function::Pack2x16unorm,
3191 Mf::Pack4x8snorm => Function::Pack4x8snorm,
3192 Mf::Pack4x8unorm => Function::Pack4x8unorm,
3193 Mf::Pack4xI8 => Function::Pack4xI8,
3194 Mf::Pack4xU8 => Function::Pack4xU8,
3195 Mf::Pack4xI8Clamp => Function::Pack4xI8Clamp,
3196 Mf::Pack4xU8Clamp => Function::Pack4xU8Clamp,
3197 Mf::Unpack2x16float => Function::Unpack2x16float,
3199 Mf::Unpack2x16snorm => Function::Unpack2x16snorm,
3200 Mf::Unpack2x16unorm => Function::Unpack2x16unorm,
3201 Mf::Unpack4x8snorm => Function::Unpack4x8snorm,
3202 Mf::Unpack4x8unorm => Function::Unpack4x8unorm,
3203 Mf::Unpack4xI8 => Function::Unpack4xI8,
3204 Mf::Unpack4xU8 => Function::Unpack4xU8,
3205 _ => return Err(Error::Unimplemented(format!("write_expr_math {fun:?}"))),
3206 };
3207
3208 match fun {
3209 Function::Asincosh { is_sin } => {
3210 write!(self.out, "log(")?;
3211 self.write_expr(module, arg, func_ctx)?;
3212 write!(self.out, " + sqrt(")?;
3213 self.write_expr(module, arg, func_ctx)?;
3214 write!(self.out, " * ")?;
3215 self.write_expr(module, arg, func_ctx)?;
3216 match is_sin {
3217 true => write!(self.out, " + 1.0))")?,
3218 false => write!(self.out, " - 1.0))")?,
3219 }
3220 }
3221 Function::Atanh => {
3222 write!(self.out, "0.5 * log((1.0 + ")?;
3223 self.write_expr(module, arg, func_ctx)?;
3224 write!(self.out, ") / (1.0 - ")?;
3225 self.write_expr(module, arg, func_ctx)?;
3226 write!(self.out, "))")?;
3227 }
3228 Function::Pack2x16float => {
3229 write!(self.out, "(f32tof16(")?;
3230 self.write_expr(module, arg, func_ctx)?;
3231 write!(self.out, "[0]) | f32tof16(")?;
3232 self.write_expr(module, arg, func_ctx)?;
3233 write!(self.out, "[1]) << 16)")?;
3234 }
3235 Function::Pack2x16snorm => {
3236 let scale = 32767;
3237
3238 write!(self.out, "uint((int(round(clamp(")?;
3239 self.write_expr(module, arg, func_ctx)?;
3240 write!(
3241 self.out,
3242 "[0], -1.0, 1.0) * {scale}.0)) & 0xFFFF) | ((int(round(clamp("
3243 )?;
3244 self.write_expr(module, arg, func_ctx)?;
3245 write!(self.out, "[1], -1.0, 1.0) * {scale}.0)) & 0xFFFF) << 16))",)?;
3246 }
3247 Function::Pack2x16unorm => {
3248 let scale = 65535;
3249
3250 write!(self.out, "(uint(round(clamp(")?;
3251 self.write_expr(module, arg, func_ctx)?;
3252 write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
3253 self.write_expr(module, arg, func_ctx)?;
3254 write!(self.out, "[1], 0.0, 1.0) * {scale}.0)) << 16)")?;
3255 }
3256 Function::Pack4x8snorm => {
3257 let scale = 127;
3258
3259 write!(self.out, "uint((int(round(clamp(")?;
3260 self.write_expr(module, arg, func_ctx)?;
3261 write!(
3262 self.out,
3263 "[0], -1.0, 1.0) * {scale}.0)) & 0xFF) | ((int(round(clamp("
3264 )?;
3265 self.write_expr(module, arg, func_ctx)?;
3266 write!(
3267 self.out,
3268 "[1], -1.0, 1.0) * {scale}.0)) & 0xFF) << 8) | ((int(round(clamp("
3269 )?;
3270 self.write_expr(module, arg, func_ctx)?;
3271 write!(
3272 self.out,
3273 "[2], -1.0, 1.0) * {scale}.0)) & 0xFF) << 16) | ((int(round(clamp("
3274 )?;
3275 self.write_expr(module, arg, func_ctx)?;
3276 write!(self.out, "[3], -1.0, 1.0) * {scale}.0)) & 0xFF) << 24))",)?;
3277 }
3278 Function::Pack4x8unorm => {
3279 let scale = 255;
3280
3281 write!(self.out, "(uint(round(clamp(")?;
3282 self.write_expr(module, arg, func_ctx)?;
3283 write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
3284 self.write_expr(module, arg, func_ctx)?;
3285 write!(
3286 self.out,
3287 "[1], 0.0, 1.0) * {scale}.0)) << 8 | uint(round(clamp("
3288 )?;
3289 self.write_expr(module, arg, func_ctx)?;
3290 write!(
3291 self.out,
3292 "[2], 0.0, 1.0) * {scale}.0)) << 16 | uint(round(clamp("
3293 )?;
3294 self.write_expr(module, arg, func_ctx)?;
3295 write!(self.out, "[3], 0.0, 1.0) * {scale}.0)) << 24)")?;
3296 }
3297 fun @ (Function::Pack4xI8
3298 | Function::Pack4xU8
3299 | Function::Pack4xI8Clamp
3300 | Function::Pack4xU8Clamp) => {
3301 let was_signed = matches!(fun, Function::Pack4xI8 | Function::Pack4xI8Clamp);
3302 let clamp_bounds = match fun {
3303 Function::Pack4xI8Clamp => Some(("-128", "127")),
3304 Function::Pack4xU8Clamp => Some(("0", "255")),
3305 _ => None,
3306 };
3307 if was_signed {
3308 write!(self.out, "uint(")?;
3309 }
3310 let write_arg = |this: &mut Self| -> BackendResult {
3311 if let Some((min, max)) = clamp_bounds {
3312 write!(this.out, "clamp(")?;
3313 this.write_expr(module, arg, func_ctx)?;
3314 write!(this.out, ", {min}, {max})")?;
3315 } else {
3316 this.write_expr(module, arg, func_ctx)?;
3317 }
3318 Ok(())
3319 };
3320 write!(self.out, "(")?;
3321 write_arg(self)?;
3322 write!(self.out, "[0] & 0xFF) | ((")?;
3323 write_arg(self)?;
3324 write!(self.out, "[1] & 0xFF) << 8) | ((")?;
3325 write_arg(self)?;
3326 write!(self.out, "[2] & 0xFF) << 16) | ((")?;
3327 write_arg(self)?;
3328 write!(self.out, "[3] & 0xFF) << 24)")?;
3329 if was_signed {
3330 write!(self.out, ")")?;
3331 }
3332 }
3333
3334 Function::Unpack2x16float => {
3335 write!(self.out, "float2(f16tof32(")?;
3336 self.write_expr(module, arg, func_ctx)?;
3337 write!(self.out, "), f16tof32((")?;
3338 self.write_expr(module, arg, func_ctx)?;
3339 write!(self.out, ") >> 16))")?;
3340 }
3341 Function::Unpack2x16snorm => {
3342 let scale = 32767;
3343
3344 write!(self.out, "(float2(int2(")?;
3345 self.write_expr(module, arg, func_ctx)?;
3346 write!(self.out, " << 16, ")?;
3347 self.write_expr(module, arg, func_ctx)?;
3348 write!(self.out, ") >> 16) / {scale}.0)")?;
3349 }
3350 Function::Unpack2x16unorm => {
3351 let scale = 65535;
3352
3353 write!(self.out, "(float2(")?;
3354 self.write_expr(module, arg, func_ctx)?;
3355 write!(self.out, " & 0xFFFF, ")?;
3356 self.write_expr(module, arg, func_ctx)?;
3357 write!(self.out, " >> 16) / {scale}.0)")?;
3358 }
3359 Function::Unpack4x8snorm => {
3360 let scale = 127;
3361
3362 write!(self.out, "(float4(int4(")?;
3363 self.write_expr(module, arg, func_ctx)?;
3364 write!(self.out, " << 24, ")?;
3365 self.write_expr(module, arg, func_ctx)?;
3366 write!(self.out, " << 16, ")?;
3367 self.write_expr(module, arg, func_ctx)?;
3368 write!(self.out, " << 8, ")?;
3369 self.write_expr(module, arg, func_ctx)?;
3370 write!(self.out, ") >> 24) / {scale}.0)")?;
3371 }
3372 Function::Unpack4x8unorm => {
3373 let scale = 255;
3374
3375 write!(self.out, "(float4(")?;
3376 self.write_expr(module, arg, func_ctx)?;
3377 write!(self.out, " & 0xFF, ")?;
3378 self.write_expr(module, arg, func_ctx)?;
3379 write!(self.out, " >> 8 & 0xFF, ")?;
3380 self.write_expr(module, arg, func_ctx)?;
3381 write!(self.out, " >> 16 & 0xFF, ")?;
3382 self.write_expr(module, arg, func_ctx)?;
3383 write!(self.out, " >> 24) / {scale}.0)")?;
3384 }
3385 fun @ (Function::Unpack4xI8 | Function::Unpack4xU8) => {
3386 write!(self.out, "(")?;
3387 if matches!(fun, Function::Unpack4xU8) {
3388 write!(self.out, "u")?;
3389 }
3390 write!(self.out, "int4(")?;
3391 self.write_expr(module, arg, func_ctx)?;
3392 write!(self.out, ", ")?;
3393 self.write_expr(module, arg, func_ctx)?;
3394 write!(self.out, " >> 8, ")?;
3395 self.write_expr(module, arg, func_ctx)?;
3396 write!(self.out, " >> 16, ")?;
3397 self.write_expr(module, arg, func_ctx)?;
3398 write!(self.out, " >> 24) << 24 >> 24)")?;
3399 }
3400 fun @ (Function::Dot4I8Packed | Function::Dot4U8Packed) => {
3401 let arg1 = arg1.unwrap();
3402
3403 if self.options.shader_model >= ShaderModel::V6_4 {
3404 let function_name = match fun {
3406 Function::Dot4I8Packed => "dot4add_i8packed",
3407 Function::Dot4U8Packed => "dot4add_u8packed",
3408 _ => unreachable!(),
3409 };
3410 write!(self.out, "{function_name}(")?;
3411 self.write_expr(module, arg, func_ctx)?;
3412 write!(self.out, ", ")?;
3413 self.write_expr(module, arg1, func_ctx)?;
3414 write!(self.out, ", 0)")?;
3415 } else {
3416 write!(self.out, "dot(")?;
3418
3419 if matches!(fun, Function::Dot4U8Packed) {
3420 write!(self.out, "u")?;
3421 }
3422 write!(self.out, "int4(")?;
3423 self.write_expr(module, arg, func_ctx)?;
3424 write!(self.out, ", ")?;
3425 self.write_expr(module, arg, func_ctx)?;
3426 write!(self.out, " >> 8, ")?;
3427 self.write_expr(module, arg, func_ctx)?;
3428 write!(self.out, " >> 16, ")?;
3429 self.write_expr(module, arg, func_ctx)?;
3430 write!(self.out, " >> 24) << 24 >> 24, ")?;
3431
3432 if matches!(fun, Function::Dot4U8Packed) {
3433 write!(self.out, "u")?;
3434 }
3435 write!(self.out, "int4(")?;
3436 self.write_expr(module, arg1, func_ctx)?;
3437 write!(self.out, ", ")?;
3438 self.write_expr(module, arg1, func_ctx)?;
3439 write!(self.out, " >> 8, ")?;
3440 self.write_expr(module, arg1, func_ctx)?;
3441 write!(self.out, " >> 16, ")?;
3442 self.write_expr(module, arg1, func_ctx)?;
3443 write!(self.out, " >> 24) << 24 >> 24)")?;
3444 }
3445 }
3446 Function::QuantizeToF16 => {
3447 write!(self.out, "f16tof32(f32tof16(")?;
3448 self.write_expr(module, arg, func_ctx)?;
3449 write!(self.out, "))")?;
3450 }
3451 Function::Regular(fun_name) => {
3452 write!(self.out, "{fun_name}(")?;
3453 self.write_expr(module, arg, func_ctx)?;
3454 if let Some(arg) = arg1 {
3455 write!(self.out, ", ")?;
3456 self.write_expr(module, arg, func_ctx)?;
3457 }
3458 if let Some(arg) = arg2 {
3459 write!(self.out, ", ")?;
3460 self.write_expr(module, arg, func_ctx)?;
3461 }
3462 if let Some(arg) = arg3 {
3463 write!(self.out, ", ")?;
3464 self.write_expr(module, arg, func_ctx)?;
3465 }
3466 write!(self.out, ")")?
3467 }
3468 Function::MissingIntOverload(fun_name) => {
3471 let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
3472 if let Some(Scalar::I32) = scalar_kind {
3473 write!(self.out, "asint({fun_name}(asuint(")?;
3474 self.write_expr(module, arg, func_ctx)?;
3475 write!(self.out, ")))")?;
3476 } else {
3477 write!(self.out, "{fun_name}(")?;
3478 self.write_expr(module, arg, func_ctx)?;
3479 write!(self.out, ")")?;
3480 }
3481 }
3482 Function::MissingIntReturnType(fun_name) => {
3485 let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
3486 if let Some(Scalar::I32) = scalar_kind {
3487 write!(self.out, "asint({fun_name}(")?;
3488 self.write_expr(module, arg, func_ctx)?;
3489 write!(self.out, "))")?;
3490 } else {
3491 write!(self.out, "{fun_name}(")?;
3492 self.write_expr(module, arg, func_ctx)?;
3493 write!(self.out, ")")?;
3494 }
3495 }
3496 Function::CountTrailingZeros => {
3497 match *func_ctx.resolve_type(arg, &module.types) {
3498 TypeInner::Vector { size, scalar } => {
3499 let s = match size {
3500 crate::VectorSize::Bi => ".xx",
3501 crate::VectorSize::Tri => ".xxx",
3502 crate::VectorSize::Quad => ".xxxx",
3503 };
3504
3505 let scalar_width_bits = scalar.width * 8;
3506
3507 if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
3508 write!(self.out, "min(({scalar_width_bits}u){s}, firstbitlow(")?;
3509 self.write_expr(module, arg, func_ctx)?;
3510 write!(self.out, "))")?;
3511 } else {
3512 write!(
3514 self.out,
3515 "asint(min(({scalar_width_bits}u){s}, firstbitlow("
3516 )?;
3517 self.write_expr(module, arg, func_ctx)?;
3518 write!(self.out, ")))")?;
3519 }
3520 }
3521 TypeInner::Scalar(scalar) => {
3522 let scalar_width_bits = scalar.width * 8;
3523
3524 if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
3525 write!(self.out, "min({scalar_width_bits}u, firstbitlow(")?;
3526 self.write_expr(module, arg, func_ctx)?;
3527 write!(self.out, "))")?;
3528 } else {
3529 write!(self.out, "asint(min({scalar_width_bits}u, firstbitlow(")?;
3531 self.write_expr(module, arg, func_ctx)?;
3532 write!(self.out, ")))")?;
3533 }
3534 }
3535 _ => unreachable!(),
3536 }
3537
3538 return Ok(());
3539 }
3540 Function::CountLeadingZeros => {
3541 match *func_ctx.resolve_type(arg, &module.types) {
3542 TypeInner::Vector { size, scalar } => {
3543 let s = match size {
3544 crate::VectorSize::Bi => ".xx",
3545 crate::VectorSize::Tri => ".xxx",
3546 crate::VectorSize::Quad => ".xxxx",
3547 };
3548
3549 let constant = scalar.width * 8 - 1;
3551
3552 if scalar.kind == ScalarKind::Uint {
3553 write!(self.out, "(({constant}u){s} - firstbithigh(")?;
3554 self.write_expr(module, arg, func_ctx)?;
3555 write!(self.out, "))")?;
3556 } else {
3557 let conversion_func = match scalar.width {
3558 4 => "asint",
3559 _ => "",
3560 };
3561 write!(self.out, "(")?;
3562 self.write_expr(module, arg, func_ctx)?;
3563 write!(
3564 self.out,
3565 " < (0){s} ? (0){s} : ({constant}){s} - {conversion_func}(firstbithigh("
3566 )?;
3567 self.write_expr(module, arg, func_ctx)?;
3568 write!(self.out, ")))")?;
3569 }
3570 }
3571 TypeInner::Scalar(scalar) => {
3572 let constant = scalar.width * 8 - 1;
3574
3575 if let ScalarKind::Uint = scalar.kind {
3576 write!(self.out, "({constant}u - firstbithigh(")?;
3577 self.write_expr(module, arg, func_ctx)?;
3578 write!(self.out, "))")?;
3579 } else {
3580 let conversion_func = match scalar.width {
3581 4 => "asint",
3582 _ => "",
3583 };
3584 write!(self.out, "(")?;
3585 self.write_expr(module, arg, func_ctx)?;
3586 write!(
3587 self.out,
3588 " < 0 ? 0 : {constant} - {conversion_func}(firstbithigh("
3589 )?;
3590 self.write_expr(module, arg, func_ctx)?;
3591 write!(self.out, ")))")?;
3592 }
3593 }
3594 _ => unreachable!(),
3595 }
3596 }
3597 }
3598
3599 Ok(())
3600 }
3601
3602 fn write_const_expression(
3603 &mut self,
3604 module: &Module,
3605 expr: Handle<crate::Expression>,
3606 arena: &crate::Arena<crate::Expression>,
3607 ) -> BackendResult {
3608 self.write_possibly_const_expression(module, expr, arena, |writer, expr| {
3609 writer.write_const_expression(module, expr, arena)
3610 })
3611 }
3612
3613 pub(super) fn write_literal(&mut self, literal: crate::Literal) -> BackendResult {
3614 match literal {
3615 crate::Literal::F64(value) => write!(self.out, "{value:?}L")?,
3616 crate::Literal::F32(value) => write!(self.out, "{value:?}")?,
3617 crate::Literal::F16(value) => write!(self.out, "{value:?}h")?,
3618 crate::Literal::U16(value) => write!(self.out, "uint16_t({value})")?,
3619 crate::Literal::I16(value) => write!(self.out, "int16_t({value})")?,
3620 crate::Literal::U32(value) => write!(self.out, "{value}u")?,
3621 crate::Literal::I32(value) if value == i32::MIN => {
3627 write!(self.out, "int({} - 1)", value + 1)?
3628 }
3629 crate::Literal::I32(value) => write!(self.out, "int({value})")?,
3633 crate::Literal::U64(value) => write!(self.out, "{value}uL")?,
3634 crate::Literal::I64(value) if value == i64::MIN => {
3636 write!(self.out, "({}L - 1L)", value + 1)?;
3637 }
3638 crate::Literal::I64(value) => write!(self.out, "{value}L")?,
3639 crate::Literal::Bool(value) => write!(self.out, "{value}")?,
3640 crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
3641 return Err(Error::Custom(
3642 "Abstract types should not appear in IR presented to backends".into(),
3643 ));
3644 }
3645 }
3646 Ok(())
3647 }
3648
3649 fn write_possibly_const_expression<E>(
3650 &mut self,
3651 module: &Module,
3652 expr: Handle<crate::Expression>,
3653 expressions: &crate::Arena<crate::Expression>,
3654 write_expression: E,
3655 ) -> BackendResult
3656 where
3657 E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult,
3658 {
3659 use crate::Expression;
3660
3661 match expressions[expr] {
3662 Expression::Literal(literal) => {
3663 self.write_literal(literal)?;
3664 }
3665 Expression::Constant(handle) => {
3666 let constant = &module.constants[handle];
3667 if constant.name.is_some() {
3668 write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
3669 } else {
3670 self.write_const_expression(module, constant.init, &module.global_expressions)?;
3671 }
3672 }
3673 Expression::ZeroValue(ty) => {
3674 self.write_wrapped_zero_value_function_name(module, WrappedZeroValue { ty })?;
3675 write!(self.out, "()")?;
3676 }
3677 Expression::Compose { ty, ref components } => {
3678 match module.types[ty].inner {
3679 TypeInner::Struct { .. } | TypeInner::Array { .. } => {
3680 self.write_wrapped_constructor_function_name(
3681 module,
3682 WrappedConstructor { ty },
3683 )?;
3684 }
3685 _ => {
3686 self.write_type(module, ty)?;
3687 }
3688 };
3689 write!(self.out, "(")?;
3690 for (index, component) in components.iter().enumerate() {
3691 if index != 0 {
3692 write!(self.out, ", ")?;
3693 }
3694 write_expression(self, *component)?;
3695 }
3696 write!(self.out, ")")?;
3697 }
3698 Expression::Splat { size, value } => {
3699 let number_of_components = match size {
3703 crate::VectorSize::Bi => "xx",
3704 crate::VectorSize::Tri => "xxx",
3705 crate::VectorSize::Quad => "xxxx",
3706 };
3707 write!(self.out, "(")?;
3708 write_expression(self, value)?;
3709 write!(self.out, ").{number_of_components}")?
3710 }
3711 _ => {
3712 return Err(Error::Override);
3713 }
3714 }
3715
3716 Ok(())
3717 }
3718
3719 pub(super) fn write_expr(
3724 &mut self,
3725 module: &Module,
3726 expr: Handle<crate::Expression>,
3727 func_ctx: &back::FunctionCtx<'_>,
3728 ) -> BackendResult {
3729 use crate::Expression;
3730
3731 let ff_input = if self.options.special_constants_binding.is_some() {
3733 func_ctx.is_fixed_function_input(expr, module)
3734 } else {
3735 None
3736 };
3737 let closing_bracket = match ff_input {
3738 Some(crate::BuiltIn::VertexIndex) => {
3739 write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX} + ")?;
3740 ")"
3741 }
3742 Some(crate::BuiltIn::InstanceIndex) => {
3743 write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE} + ",)?;
3744 ")"
3745 }
3746 Some(crate::BuiltIn::NumWorkGroups) => {
3747 write!(
3751 self.out,
3752 "uint3({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX}, {SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE}, {SPECIAL_CBUF_VAR}.{SPECIAL_OTHER})",
3753 )?;
3754 return Ok(());
3755 }
3756 _ => "",
3757 };
3758
3759 if let Some(name) = self.named_expressions.get(&expr) {
3760 write!(self.out, "{name}{closing_bracket}")?;
3761 return Ok(());
3762 }
3763
3764 let expression = &func_ctx.expressions[expr];
3765
3766 match *expression {
3767 Expression::Literal(_)
3768 | Expression::Constant(_)
3769 | Expression::ZeroValue(_)
3770 | Expression::Compose { .. }
3771 | Expression::Splat { .. } => {
3772 self.write_possibly_const_expression(
3773 module,
3774 expr,
3775 func_ctx.expressions,
3776 |writer, expr| writer.write_expr(module, expr, func_ctx),
3777 )?;
3778 }
3779 Expression::Override(_) => return Err(Error::Override),
3780 Expression::Binary {
3787 op:
3788 op @ crate::BinaryOperator::Add
3789 | op @ crate::BinaryOperator::Subtract
3790 | op @ crate::BinaryOperator::Multiply,
3791 left,
3792 right,
3793 } if matches!(
3794 func_ctx.resolve_type(expr, &module.types).scalar(),
3795 Some(Scalar::I32)
3796 ) =>
3797 {
3798 write!(self.out, "asint(asuint(",)?;
3799 self.write_expr(module, left, func_ctx)?;
3800 write!(self.out, ") {} asuint(", back::binary_operation_str(op))?;
3801 self.write_expr(module, right, func_ctx)?;
3802 write!(self.out, "))")?;
3803 }
3804 Expression::Binary {
3807 op: crate::BinaryOperator::Multiply,
3808 left,
3809 right,
3810 } if func_ctx.resolve_type(left, &module.types).is_matrix()
3811 || func_ctx.resolve_type(right, &module.types).is_matrix() =>
3812 {
3813 write!(self.out, "mul(")?;
3815 self.write_expr(module, right, func_ctx)?;
3816 write!(self.out, ", ")?;
3817 self.write_expr(module, left, func_ctx)?;
3818 write!(self.out, ")")?;
3819 }
3820
3821 Expression::Binary {
3833 op: crate::BinaryOperator::Divide,
3834 left,
3835 right,
3836 } if matches!(
3837 func_ctx.resolve_type(expr, &module.types).scalar_kind(),
3838 Some(ScalarKind::Sint | ScalarKind::Uint)
3839 ) =>
3840 {
3841 write!(self.out, "{DIV_FUNCTION}(")?;
3842 self.write_expr(module, left, func_ctx)?;
3843 write!(self.out, ", ")?;
3844 self.write_expr(module, right, func_ctx)?;
3845 write!(self.out, ")")?;
3846 }
3847
3848 Expression::Binary {
3849 op: crate::BinaryOperator::Modulo,
3850 left,
3851 right,
3852 } if matches!(
3853 func_ctx.resolve_type(expr, &module.types).scalar_kind(),
3854 Some(ScalarKind::Sint | ScalarKind::Uint | ScalarKind::Float)
3855 ) =>
3856 {
3857 write!(self.out, "{MOD_FUNCTION}(")?;
3858 self.write_expr(module, left, func_ctx)?;
3859 write!(self.out, ", ")?;
3860 self.write_expr(module, right, func_ctx)?;
3861 write!(self.out, ")")?;
3862 }
3863
3864 Expression::Binary { op, left, right } => {
3865 write!(self.out, "(")?;
3866 self.write_expr(module, left, func_ctx)?;
3867 write!(self.out, " {} ", back::binary_operation_str(op))?;
3868 self.write_expr(module, right, func_ctx)?;
3869 write!(self.out, ")")?;
3870 }
3871 Expression::Access { base, index } => {
3872 if let Some(crate::AddressSpace::Storage { .. }) =
3873 func_ctx.resolve_type(expr, &module.types).pointer_space()
3874 {
3875 } else {
3877 if let Some(MatrixType {
3884 columns,
3885 rows: crate::VectorSize::Bi,
3886 width,
3887 }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
3888 .or_else(|| {
3889 get_inner_matrix_of_global_uniform(module, base, func_ctx, true)
3890 })
3891 {
3892 write!(
3893 self.out,
3894 "__get_col_of_mat{}x2_f{}(",
3895 columns as u8,
3896 width * 8
3897 )?;
3898 self.write_expr(module, base, func_ctx)?;
3899 write!(self.out, ", ")?;
3900 self.write_expr(module, index, func_ctx)?;
3901 write!(self.out, ")")?;
3902 return Ok(());
3903 }
3904
3905 let resolved = func_ctx.resolve_type(base, &module.types);
3906
3907 let (indexing_binding_array, non_uniform_qualifier) = match *resolved {
3908 TypeInner::BindingArray { .. } => {
3909 let uniformity = &func_ctx.info[index].uniformity;
3910
3911 (true, uniformity.non_uniform_result.is_some())
3912 }
3913 _ => (false, false),
3914 };
3915
3916 self.write_expr(module, base, func_ctx)?;
3917
3918 let array_sampler_info = self.sampler_binding_array_info_from_expression(
3919 module, func_ctx, base, resolved,
3920 );
3921
3922 if let Some(ref info) = array_sampler_info {
3923 write!(self.out, "{}[", info.sampler_heap_name)?;
3924 } else {
3925 write!(self.out, "[")?;
3926 }
3927
3928 let needs_bound_check = self.options.restrict_indexing
3929 && !indexing_binding_array
3930 && match resolved.pointer_space() {
3931 Some(
3932 crate::AddressSpace::Function
3933 | crate::AddressSpace::Private
3934 | crate::AddressSpace::WorkGroup
3935 | crate::AddressSpace::Immediate
3936 | crate::AddressSpace::TaskPayload
3937 | crate::AddressSpace::RayPayload
3938 | crate::AddressSpace::IncomingRayPayload,
3939 )
3940 | None => true,
3941 Some(crate::AddressSpace::Uniform) => {
3942 let var_handle = self.fill_access_chain(module, base, func_ctx)?;
3944 let bind_target = self
3945 .options
3946 .resolve_resource_binding(
3947 module.global_variables[var_handle]
3948 .binding
3949 .as_ref()
3950 .unwrap(),
3951 )
3952 .unwrap();
3953 bind_target.restrict_indexing
3954 }
3955 Some(
3956 crate::AddressSpace::Handle | crate::AddressSpace::Storage { .. },
3957 ) => unreachable!(),
3958 };
3959 let restriction_needed = if needs_bound_check {
3961 index::access_needs_check(
3962 base,
3963 index::GuardedIndex::Expression(index),
3964 module,
3965 func_ctx.expressions,
3966 func_ctx.info,
3967 )
3968 } else {
3969 None
3970 };
3971 if let Some(limit) = restriction_needed {
3972 write!(self.out, "min(uint(")?;
3973 self.write_expr(module, index, func_ctx)?;
3974 write!(self.out, "), ")?;
3975 match limit {
3976 index::IndexableLength::Known(limit) => {
3977 write!(self.out, "{}u", limit - 1)?;
3978 }
3979 index::IndexableLength::Dynamic => unreachable!(),
3980 }
3981 write!(self.out, ")")?;
3982 } else {
3983 if non_uniform_qualifier {
3984 write!(self.out, "NonUniformResourceIndex(")?;
3985 }
3986 if let Some(ref info) = array_sampler_info {
3987 write!(
3988 self.out,
3989 "{}[{} + ",
3990 info.sampler_index_buffer_name, info.binding_array_base_index_name,
3991 )?;
3992 }
3993 self.write_expr(module, index, func_ctx)?;
3994 if array_sampler_info.is_some() {
3995 write!(self.out, "]")?;
3996 }
3997 if non_uniform_qualifier {
3998 write!(self.out, ")")?;
3999 }
4000 }
4001
4002 write!(self.out, "]")?;
4003 }
4004 }
4005 Expression::AccessIndex { base, index } => {
4006 if let Some(crate::AddressSpace::Storage { .. }) =
4007 func_ctx.resolve_type(expr, &module.types).pointer_space()
4008 {
4009 } else {
4011 if let Some(MatrixType {
4015 rows: crate::VectorSize::Bi,
4016 ..
4017 }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
4018 .or_else(|| {
4019 get_inner_matrix_of_global_uniform(module, base, func_ctx, true)
4020 })
4021 {
4022 self.write_expr(module, base, func_ctx)?;
4023 write!(self.out, "._{index}")?;
4024 return Ok(());
4025 }
4026
4027 let base_ty_res = &func_ctx.info[base].ty;
4028 let mut resolved = base_ty_res.inner_with(&module.types);
4029 let base_ty_handle = match *resolved {
4030 TypeInner::Pointer { base, .. } => {
4031 resolved = &module.types[base].inner;
4032 Some(base)
4033 }
4034 _ => base_ty_res.handle(),
4035 };
4036
4037 if let TypeInner::Struct { ref members, .. } = *resolved {
4043 let member = &members[index as usize];
4044
4045 match module.types[member.ty].inner {
4046 TypeInner::Matrix {
4047 rows: crate::VectorSize::Bi,
4048 ..
4049 } if member.binding.is_none() => {
4050 let ty = base_ty_handle.unwrap();
4051 self.write_wrapped_struct_matrix_get_function_name(
4052 WrappedStructMatrixAccess { ty, index },
4053 )?;
4054 write!(self.out, "(")?;
4055 self.write_expr(module, base, func_ctx)?;
4056 write!(self.out, ")")?;
4057 return Ok(());
4058 }
4059 _ => {}
4060 }
4061 }
4062
4063 let array_sampler_info = self.sampler_binding_array_info_from_expression(
4064 module, func_ctx, base, resolved,
4065 );
4066
4067 if let Some(ref info) = array_sampler_info {
4068 write!(
4069 self.out,
4070 "{}[{}",
4071 info.sampler_heap_name, info.sampler_index_buffer_name
4072 )?;
4073 }
4074
4075 self.write_expr(module, base, func_ctx)?;
4076
4077 match *resolved {
4078 TypeInner::Vector { .. } | TypeInner::ValuePointer { .. } => {
4084 write!(self.out, ".{}", back::COMPONENTS[index as usize])?
4086 }
4087 TypeInner::Matrix { .. }
4088 | TypeInner::Array { .. }
4089 | TypeInner::BindingArray { .. } => {
4090 if let Some(ref info) = array_sampler_info {
4091 write!(
4092 self.out,
4093 "[{} + {index}]",
4094 info.binding_array_base_index_name
4095 )?;
4096 } else {
4097 write!(self.out, "[{index}]")?;
4098 }
4099 }
4100 TypeInner::Struct { .. } => {
4101 let ty = base_ty_handle.unwrap();
4104
4105 write!(
4106 self.out,
4107 ".{}",
4108 &self.names[&NameKey::StructMember(ty, index)]
4109 )?
4110 }
4111 ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))),
4112 }
4113
4114 if array_sampler_info.is_some() {
4115 write!(self.out, "]")?;
4116 }
4117 }
4118 }
4119 Expression::FunctionArgument(pos) => {
4120 let ty = func_ctx.resolve_type(expr, &module.types);
4121
4122 if let TypeInner::Image {
4128 class: crate::ImageClass::External,
4129 ..
4130 } = *ty
4131 {
4132 let plane_names = [0, 1, 2].map(|i| {
4133 &self.names[&func_ctx
4134 .external_texture_argument_key(pos, ExternalTextureNameKey::Plane(i))]
4135 });
4136 let params_name = &self.names[&func_ctx
4137 .external_texture_argument_key(pos, ExternalTextureNameKey::Params)];
4138 write!(
4139 self.out,
4140 "{}, {}, {}, {}",
4141 plane_names[0], plane_names[1], plane_names[2], params_name
4142 )?;
4143 } else {
4144 let key = func_ctx.argument_key(pos);
4145 let name = &self.names[&key];
4146 write!(self.out, "{name}")?;
4147 }
4148 }
4149 Expression::ImageSample {
4150 coordinate,
4151 image,
4152 sampler,
4153 clamp_to_edge: true,
4154 gather: None,
4155 array_index: None,
4156 offset: None,
4157 level: crate::SampleLevel::Zero,
4158 depth_ref: None,
4159 } => {
4160 write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?;
4161 self.write_expr(module, image, func_ctx)?;
4162 write!(self.out, ", ")?;
4163 self.write_expr(module, sampler, func_ctx)?;
4164 write!(self.out, ", ")?;
4165 self.write_expr(module, coordinate, func_ctx)?;
4166 write!(self.out, ")")?;
4167 }
4168 Expression::ImageSample {
4169 image,
4170 sampler,
4171 gather,
4172 coordinate,
4173 array_index,
4174 offset,
4175 level,
4176 depth_ref,
4177 clamp_to_edge,
4178 } => {
4179 if clamp_to_edge {
4180 return Err(Error::Custom(
4181 "ImageSample::clamp_to_edge should have been validated out".to_string(),
4182 ));
4183 }
4184
4185 use crate::SampleLevel as Sl;
4186 const COMPONENTS: [&str; 4] = ["", "Green", "Blue", "Alpha"];
4187
4188 let (base_str, component_str) = match gather {
4189 Some(component) => ("Gather", COMPONENTS[component as usize]),
4190 None => ("Sample", ""),
4191 };
4192 let cmp_str = match depth_ref {
4193 Some(_) => "Cmp",
4194 None => "",
4195 };
4196 let level_str = match level {
4197 Sl::Zero if gather.is_none() => "LevelZero",
4198 Sl::Auto | Sl::Zero => "",
4199 Sl::Exact(_) => "Level",
4200 Sl::Bias(_) => "Bias",
4201 Sl::Gradient { .. } => "Grad",
4202 };
4203
4204 self.write_expr(module, image, func_ctx)?;
4205 write!(self.out, ".{base_str}{cmp_str}{component_str}{level_str}(")?;
4206 self.write_expr(module, sampler, func_ctx)?;
4207 write!(self.out, ", ")?;
4208 self.write_texture_coordinates(
4209 "float",
4210 coordinate,
4211 array_index,
4212 None,
4213 module,
4214 func_ctx,
4215 )?;
4216
4217 if let Some(depth_ref) = depth_ref {
4218 write!(self.out, ", ")?;
4219 self.write_expr(module, depth_ref, func_ctx)?;
4220 }
4221
4222 match level {
4223 Sl::Auto | Sl::Zero => {}
4224 Sl::Exact(expr) => {
4225 write!(self.out, ", ")?;
4226 self.write_expr(module, expr, func_ctx)?;
4227 }
4228 Sl::Bias(expr) => {
4229 write!(self.out, ", ")?;
4230 self.write_expr(module, expr, func_ctx)?;
4231 }
4232 Sl::Gradient { x, y } => {
4233 write!(self.out, ", ")?;
4234 self.write_expr(module, x, func_ctx)?;
4235 write!(self.out, ", ")?;
4236 self.write_expr(module, y, func_ctx)?;
4237 }
4238 }
4239
4240 if let Some(offset) = offset {
4241 write!(self.out, ", ")?;
4242 write!(self.out, "int2(")?; self.write_const_expression(module, offset, func_ctx.expressions)?;
4244 write!(self.out, ")")?;
4245 }
4246
4247 write!(self.out, ")")?;
4248 }
4249 Expression::ImageQuery { image, query } => {
4250 if let TypeInner::Image {
4252 dim,
4253 arrayed,
4254 class,
4255 } = *func_ctx.resolve_type(image, &module.types)
4256 {
4257 let wrapped_image_query = WrappedImageQuery {
4258 dim,
4259 arrayed,
4260 class,
4261 query: query.into(),
4262 };
4263
4264 self.write_wrapped_image_query_function_name(wrapped_image_query)?;
4265 write!(self.out, "(")?;
4266 self.write_expr(module, image, func_ctx)?;
4268 if let crate::ImageQuery::Size { level: Some(level) } = query {
4269 write!(self.out, ", ")?;
4270 self.write_expr(module, level, func_ctx)?;
4271 }
4272 write!(self.out, ")")?;
4273 }
4274 }
4275 Expression::ImageLoad {
4276 image,
4277 coordinate,
4278 array_index,
4279 sample,
4280 level,
4281 } => self.write_image_load(
4282 &module,
4283 expr,
4284 func_ctx,
4285 image,
4286 coordinate,
4287 array_index,
4288 sample,
4289 level,
4290 )?,
4291 Expression::GlobalVariable(handle) => {
4292 let global_variable = &module.global_variables[handle];
4293 let ty = &module.types[global_variable.ty].inner;
4294
4295 let is_binding_array_of_samplers = match *ty {
4300 TypeInner::BindingArray { base, .. } => {
4301 let base_ty = &module.types[base].inner;
4302 matches!(*base_ty, TypeInner::Sampler { .. })
4303 }
4304 _ => false,
4305 };
4306
4307 let is_storage_space =
4308 matches!(global_variable.space, crate::AddressSpace::Storage { .. });
4309
4310 if let TypeInner::Image {
4318 class: crate::ImageClass::External,
4319 ..
4320 } = *ty
4321 {
4322 let plane_names = [0, 1, 2].map(|i| {
4323 &self.names[&NameKey::ExternalTextureGlobalVariable(
4324 handle,
4325 ExternalTextureNameKey::Plane(i),
4326 )]
4327 });
4328 let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
4329 handle,
4330 ExternalTextureNameKey::Params,
4331 )];
4332 write!(
4333 self.out,
4334 "{}, {}, {}, {}",
4335 plane_names[0], plane_names[1], plane_names[2], params_name
4336 )?;
4337 } else if !is_binding_array_of_samplers && !is_storage_space {
4338 let name = &self.names[&NameKey::GlobalVariable(handle)];
4339 write!(self.out, "{name}")?;
4340 }
4341 }
4342 Expression::LocalVariable(handle) => {
4343 write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])?
4344 }
4345 Expression::Load { pointer } => {
4346 match func_ctx
4347 .resolve_type(pointer, &module.types)
4348 .pointer_space()
4349 {
4350 Some(crate::AddressSpace::Storage { .. }) => {
4351 let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
4352 let result_ty = func_ctx.info[expr].ty.clone();
4353 self.write_storage_load(module, var_handle, result_ty, func_ctx)?;
4354 }
4355 _ => {
4356 let mut close_paren = false;
4357
4358 if let Some(MatrixType {
4363 rows: crate::VectorSize::Bi,
4364 ..
4365 }) = get_inner_matrix_of_struct_array_member(
4366 module, pointer, func_ctx, false,
4367 )
4368 .or_else(|| {
4369 get_inner_matrix_of_global_uniform(module, pointer, func_ctx, false)
4370 }) {
4371 let mut resolved = func_ctx.resolve_type(pointer, &module.types);
4372 let ptr_tr = resolved.pointer_base_type();
4373 if let Some(ptr_ty) =
4374 ptr_tr.as_ref().map(|tr| tr.inner_with(&module.types))
4375 {
4376 resolved = ptr_ty;
4377 }
4378
4379 write!(self.out, "((")?;
4380 if let TypeInner::Array { base, size, .. } = *resolved {
4381 self.write_type(module, base)?;
4382 self.write_array_size(module, base, size)?;
4383 } else {
4384 self.write_value_type(module, resolved)?;
4385 }
4386 write!(self.out, ")")?;
4387 close_paren = true;
4388 }
4389
4390 self.write_expr(module, pointer, func_ctx)?;
4391
4392 if close_paren {
4393 write!(self.out, ")")?;
4394 }
4395 }
4396 }
4397 }
4398 Expression::Unary { op, expr } => {
4399 let op_str = match op {
4401 crate::UnaryOperator::Negate => {
4402 match func_ctx.resolve_type(expr, &module.types).scalar() {
4403 Some(Scalar::I32) => NEG_FUNCTION,
4404 _ => "-",
4405 }
4406 }
4407 crate::UnaryOperator::LogicalNot => "!",
4408 crate::UnaryOperator::BitwiseNot => "~",
4409 };
4410 write!(self.out, "{op_str}(")?;
4411 self.write_expr(module, expr, func_ctx)?;
4412 write!(self.out, ")")?;
4413 }
4414 Expression::As {
4415 expr,
4416 kind,
4417 convert,
4418 } => {
4419 let inner = func_ctx.resolve_type(expr, &module.types);
4420 if inner.scalar_kind() == Some(ScalarKind::Float)
4421 && (kind == ScalarKind::Sint || kind == ScalarKind::Uint)
4422 && convert.is_some()
4423 && matches!(convert, Some(4) | Some(8))
4424 {
4425 let fun_name = match (kind, convert) {
4429 (ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
4430 (ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
4431 (ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
4432 (ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
4433 _ => unreachable!(),
4434 };
4435 write!(self.out, "{fun_name}(")?;
4436 self.write_expr(module, expr, func_ctx)?;
4437 write!(self.out, ")")?;
4438 } else {
4439 let close_paren = match convert {
4440 Some(dst_width) => {
4441 let scalar = Scalar {
4442 kind,
4443 width: dst_width,
4444 };
4445 match *inner {
4446 TypeInner::Vector { size, .. } => {
4447 write!(
4448 self.out,
4449 "{}{}(",
4450 scalar.to_hlsl_str()?,
4451 common::vector_size_str(size)
4452 )?;
4453 }
4454 TypeInner::Scalar(_) => {
4455 write!(self.out, "{}(", scalar.to_hlsl_str()?,)?;
4456 }
4457 TypeInner::Matrix { columns, rows, .. } => {
4458 write!(
4459 self.out,
4460 "{}{}x{}(",
4461 scalar.to_hlsl_str()?,
4462 common::vector_size_str(columns),
4463 common::vector_size_str(rows)
4464 )?;
4465 }
4466 _ => {
4467 return Err(Error::Unimplemented(format!(
4468 "write_expr expression::as {inner:?}"
4469 )));
4470 }
4471 };
4472 true
4473 }
4474 None => {
4475 if inner.scalar_width() == Some(8) {
4476 false
4477 } else if inner.scalar_width() == Some(2) {
4478 let dst_scalar = Scalar { kind, width: 2 };
4481 match *inner {
4482 TypeInner::Vector { size, .. } => {
4483 write!(
4484 self.out,
4485 "{}{}(",
4486 dst_scalar.to_hlsl_str()?,
4487 common::vector_size_str(size)
4488 )?;
4489 }
4490 _ => {
4491 write!(self.out, "{}(", dst_scalar.to_hlsl_str()?)?;
4492 }
4493 };
4494 true
4495 } else {
4496 write!(self.out, "{}(", kind.to_hlsl_cast(),)?;
4497 true
4498 }
4499 }
4500 };
4501 self.write_expr(module, expr, func_ctx)?;
4502 if close_paren {
4503 write!(self.out, ")")?;
4504 }
4505 }
4506 }
4507 Expression::Math {
4508 fun,
4509 arg,
4510 arg1,
4511 arg2,
4512 arg3,
4513 } => {
4514 return self.write_math_expression(module, fun, arg, arg1, arg2, arg3, func_ctx);
4515 }
4516 Expression::Swizzle {
4517 size,
4518 vector,
4519 pattern,
4520 } => {
4521 self.write_expr(module, vector, func_ctx)?;
4522 write!(self.out, ".")?;
4523 for &sc in pattern[..size as usize].iter() {
4524 self.out.write_char(back::COMPONENTS[sc as usize])?;
4525 }
4526 }
4527 Expression::ArrayLength(expr) => {
4528 let var_handle = match func_ctx.expressions[expr] {
4529 Expression::AccessIndex { base, index: _ } => {
4530 match func_ctx.expressions[base] {
4531 Expression::GlobalVariable(handle) => handle,
4532 _ => unreachable!(),
4533 }
4534 }
4535 Expression::GlobalVariable(handle) => handle,
4536 _ => unreachable!(),
4537 };
4538
4539 let var = &module.global_variables[var_handle];
4540 let (offset, stride) = match module.types[var.ty].inner {
4541 TypeInner::Array { stride, .. } => (0, stride),
4542 TypeInner::Struct { ref members, .. } => {
4543 let last = members.last().unwrap();
4544 let stride = match module.types[last.ty].inner {
4545 TypeInner::Array { stride, .. } => stride,
4546 _ => unreachable!(),
4547 };
4548 (last.offset, stride)
4549 }
4550 _ => unreachable!(),
4551 };
4552
4553 let storage_access = match var.space {
4554 crate::AddressSpace::Storage { access } => access,
4555 _ => crate::StorageAccess::default(),
4556 };
4557 let wrapped_array_length = WrappedArrayLength {
4558 writable: storage_access.contains(crate::StorageAccess::STORE),
4559 };
4560
4561 write!(self.out, "((")?;
4562 self.write_wrapped_array_length_function_name(wrapped_array_length)?;
4563 let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
4564 write!(self.out, "({var_name}) - {offset}) / {stride})")?
4565 }
4566 Expression::Derivative { axis, ctrl, expr } => {
4567 use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
4568 if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
4569 let tail = match ctrl {
4570 Ctrl::Coarse => "coarse",
4571 Ctrl::Fine => "fine",
4572 Ctrl::None => unreachable!(),
4573 };
4574 write!(self.out, "abs(ddx_{tail}(")?;
4575 self.write_expr(module, expr, func_ctx)?;
4576 write!(self.out, ")) + abs(ddy_{tail}(")?;
4577 self.write_expr(module, expr, func_ctx)?;
4578 write!(self.out, "))")?
4579 } else {
4580 let fun_str = match (axis, ctrl) {
4581 (Axis::X, Ctrl::Coarse) => "ddx_coarse",
4582 (Axis::X, Ctrl::Fine) => "ddx_fine",
4583 (Axis::X, Ctrl::None) => "ddx",
4584 (Axis::Y, Ctrl::Coarse) => "ddy_coarse",
4585 (Axis::Y, Ctrl::Fine) => "ddy_fine",
4586 (Axis::Y, Ctrl::None) => "ddy",
4587 (Axis::Width, Ctrl::Coarse | Ctrl::Fine) => unreachable!(),
4588 (Axis::Width, Ctrl::None) => "fwidth",
4589 };
4590 write!(self.out, "{fun_str}(")?;
4591 self.write_expr(module, expr, func_ctx)?;
4592 write!(self.out, ")")?
4593 }
4594 }
4595 Expression::Relational { fun, argument } => {
4596 use crate::RelationalFunction as Rf;
4597
4598 let fun_str = match fun {
4599 Rf::All => "all",
4600 Rf::Any => "any",
4601 Rf::IsNan => "isnan",
4602 Rf::IsInf => "isinf",
4603 };
4604 write!(self.out, "{fun_str}(")?;
4605 self.write_expr(module, argument, func_ctx)?;
4606 write!(self.out, ")")?
4607 }
4608 Expression::Select {
4609 condition,
4610 accept,
4611 reject,
4612 } => {
4613 write!(self.out, "(")?;
4614 self.write_expr(module, condition, func_ctx)?;
4615 write!(self.out, " ? ")?;
4616 self.write_expr(module, accept, func_ctx)?;
4617 write!(self.out, " : ")?;
4618 self.write_expr(module, reject, func_ctx)?;
4619 write!(self.out, ")")?
4620 }
4621 Expression::RayQueryGetIntersection { query, committed } => {
4622 let Expression::LocalVariable(query_var) = func_ctx.expressions[query] else {
4624 unreachable!()
4625 };
4626
4627 let tracker_expr_name = format!(
4628 "{RAY_QUERY_TRACKER_VARIABLE_PREFIX}{}",
4629 self.names[&func_ctx.name_key(query_var)]
4630 );
4631
4632 if committed {
4633 write!(self.out, "GetCommittedIntersection(")?;
4634 self.write_expr(module, query, func_ctx)?;
4635 write!(self.out, ", {tracker_expr_name})")?;
4636 } else {
4637 write!(self.out, "GetCandidateIntersection(")?;
4638 self.write_expr(module, query, func_ctx)?;
4639 write!(self.out, ", {tracker_expr_name})")?;
4640 }
4641 }
4642 Expression::RayQueryVertexPositions { .. }
4644 | Expression::CooperativeLoad { .. }
4645 | Expression::CooperativeMultiplyAdd { .. } => {
4646 unreachable!()
4647 }
4648 Expression::CallResult(_)
4650 | Expression::AtomicResult { .. }
4651 | Expression::WorkGroupUniformLoadResult { .. }
4652 | Expression::RayQueryProceedResult
4653 | Expression::SubgroupBallotResult
4654 | Expression::SubgroupOperationResult { .. } => {}
4655 }
4656
4657 if !closing_bracket.is_empty() {
4658 write!(self.out, "{closing_bracket}")?;
4659 }
4660 Ok(())
4661 }
4662
4663 #[allow(clippy::too_many_arguments)]
4664 fn write_image_load(
4665 &mut self,
4666 module: &&Module,
4667 expr: Handle<crate::Expression>,
4668 func_ctx: &back::FunctionCtx,
4669 image: Handle<crate::Expression>,
4670 coordinate: Handle<crate::Expression>,
4671 array_index: Option<Handle<crate::Expression>>,
4672 sample: Option<Handle<crate::Expression>>,
4673 level: Option<Handle<crate::Expression>>,
4674 ) -> Result<(), Error> {
4675 let mut wrapping_type = None;
4676 match *func_ctx.resolve_type(image, &module.types) {
4677 TypeInner::Image {
4678 class: crate::ImageClass::External,
4679 ..
4680 } => {
4681 write!(self.out, "{IMAGE_LOAD_EXTERNAL_FUNCTION}(")?;
4682 self.write_expr(module, image, func_ctx)?;
4683 write!(self.out, ", ")?;
4684 self.write_expr(module, coordinate, func_ctx)?;
4685 write!(self.out, ")")?;
4686 return Ok(());
4687 }
4688 TypeInner::Image {
4689 class: crate::ImageClass::Storage { format, .. },
4690 ..
4691 } => {
4692 if format.single_component() {
4693 wrapping_type = Some(Scalar::from(format));
4694 }
4695 }
4696 _ => {}
4697 }
4698 if let Some(scalar) = wrapping_type {
4699 write!(
4700 self.out,
4701 "{}{}(",
4702 help::IMAGE_STORAGE_LOAD_SCALAR_WRAPPER,
4703 scalar.to_hlsl_str()?
4704 )?;
4705 }
4706 self.write_expr(module, image, func_ctx)?;
4708 write!(self.out, ".Load(")?;
4709
4710 self.write_texture_coordinates("int", coordinate, array_index, level, module, func_ctx)?;
4711
4712 if let Some(sample) = sample {
4713 write!(self.out, ", ")?;
4714 self.write_expr(module, sample, func_ctx)?;
4715 }
4716
4717 write!(self.out, ")")?;
4719
4720 if wrapping_type.is_some() {
4721 write!(self.out, ")")?;
4722 }
4723
4724 if let TypeInner::Scalar(_) = *func_ctx.resolve_type(expr, &module.types) {
4726 write!(self.out, ".x")?;
4727 }
4728 Ok(())
4729 }
4730
4731 fn sampler_binding_array_info_from_expression(
4734 &mut self,
4735 module: &Module,
4736 func_ctx: &back::FunctionCtx<'_>,
4737 base: Handle<crate::Expression>,
4738 resolved: &TypeInner,
4739 ) -> Option<BindingArraySamplerInfo> {
4740 if let TypeInner::BindingArray {
4741 base: base_ty_handle,
4742 ..
4743 } = *resolved
4744 {
4745 let base_ty = &module.types[base_ty_handle].inner;
4746 if let TypeInner::Sampler { comparison, .. } = *base_ty {
4747 let base = &func_ctx.expressions[base];
4748
4749 if let crate::Expression::GlobalVariable(handle) = *base {
4750 let variable = &module.global_variables[handle];
4751
4752 let sampler_heap_name = match comparison {
4753 true => COMPARISON_SAMPLER_HEAP_VAR,
4754 false => SAMPLER_HEAP_VAR,
4755 };
4756
4757 return Some(BindingArraySamplerInfo {
4758 sampler_heap_name,
4759 sampler_index_buffer_name: self
4760 .wrapped
4761 .sampler_index_buffers
4762 .get(&super::SamplerIndexBufferKey {
4763 group: variable.binding.unwrap().group,
4764 })
4765 .unwrap()
4766 .clone(),
4767 binding_array_base_index_name: self.names[&NameKey::GlobalVariable(handle)]
4768 .clone(),
4769 });
4770 }
4771 }
4772 }
4773
4774 None
4775 }
4776
4777 fn write_named_expr(
4778 &mut self,
4779 module: &Module,
4780 handle: Handle<crate::Expression>,
4781 name: String,
4782 expr: Handle<crate::Expression>,
4785 func_ctx: &back::FunctionCtx,
4786 ) -> BackendResult {
4787 if let crate::Expression::Load { pointer } = func_ctx.expressions[expr] {
4788 let ty_inner = func_ctx.resolve_type(pointer, &module.types);
4789 if ty_inner.is_atomic_pointer(&module.types) {
4790 let pointer_space = ty_inner.pointer_space().unwrap();
4791 self.write_value_type(module, func_ctx.info[handle].ty.inner_with(&module.types))?;
4792 write!(self.out, " {name}; ")?;
4793 match pointer_space {
4794 crate::AddressSpace::WorkGroup => {
4795 write!(self.out, "InterlockedOr(")?;
4796 self.write_expr(module, pointer, func_ctx)?;
4797 }
4798 crate::AddressSpace::Storage { .. } => {
4799 let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
4800 let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
4801 write!(self.out, "{var_name}.InterlockedOr(")?;
4802 let chain = mem::take(&mut self.temp_access_chain);
4803 self.write_storage_address(module, &chain, func_ctx)?;
4804 self.temp_access_chain = chain;
4805 }
4806 _ => unreachable!(),
4807 }
4808 writeln!(self.out, ", 0, {name});")?;
4809 self.named_expressions.insert(expr, name);
4810 return Ok(());
4811 }
4812 }
4813 match func_ctx.info[expr].ty {
4814 proc::TypeResolution::Handle(ty_handle) => match module.types[ty_handle].inner {
4815 TypeInner::Struct { .. } => {
4816 let ty_name = &self.names[&NameKey::Type(ty_handle)];
4817 write!(self.out, "{ty_name}")?;
4818 }
4819 _ => {
4820 self.write_type(module, ty_handle)?;
4821 }
4822 },
4823 proc::TypeResolution::Value(ref inner) => {
4824 self.write_value_type(module, inner)?;
4825 }
4826 }
4827
4828 let resolved = func_ctx.resolve_type(expr, &module.types);
4829
4830 write!(self.out, " {name}")?;
4831 if let TypeInner::Array { base, size, .. } = *resolved {
4833 self.write_array_size(module, base, size)?;
4834 }
4835 write!(self.out, " = ")?;
4836 self.write_expr(module, handle, func_ctx)?;
4837 writeln!(self.out, ";")?;
4838 self.named_expressions.insert(expr, name);
4839
4840 Ok(())
4841 }
4842
4843 pub(super) fn write_default_init(
4845 &mut self,
4846 module: &Module,
4847 ty: Handle<crate::Type>,
4848 ) -> BackendResult {
4849 write!(self.out, "(")?;
4850 self.write_type(module, ty)?;
4851 if let TypeInner::Array { base, size, .. } = module.types[ty].inner {
4852 self.write_array_size(module, base, size)?;
4853 }
4854 write!(self.out, ")0")?;
4855 Ok(())
4856 }
4857
4858 pub(super) fn write_control_barrier(
4859 &mut self,
4860 barrier: crate::Barrier,
4861 level: back::Level,
4862 ) -> BackendResult {
4863 if barrier.contains(crate::Barrier::STORAGE) {
4864 writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
4865 }
4866 if barrier.contains(crate::Barrier::WORK_GROUP) {
4867 writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?;
4868 }
4869 if barrier.contains(crate::Barrier::SUB_GROUP) {
4870 }
4872 if barrier.contains(crate::Barrier::TEXTURE) {
4873 writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
4874 }
4875 Ok(())
4876 }
4877
4878 fn write_memory_barrier(
4879 &mut self,
4880 barrier: crate::Barrier,
4881 level: back::Level,
4882 ) -> BackendResult {
4883 if barrier.contains(crate::Barrier::STORAGE) {
4884 writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
4885 }
4886 if barrier.contains(crate::Barrier::WORK_GROUP) {
4887 writeln!(self.out, "{level}GroupMemoryBarrier();")?;
4888 }
4889 if barrier.contains(crate::Barrier::SUB_GROUP) {
4890 }
4892 if barrier.contains(crate::Barrier::TEXTURE) {
4893 writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
4894 }
4895 Ok(())
4896 }
4897
4898 fn emit_hlsl_atomic_tail(
4900 &mut self,
4901 module: &Module,
4902 func_ctx: &back::FunctionCtx<'_>,
4903 fun: &crate::AtomicFunction,
4904 compare_expr: Option<Handle<crate::Expression>>,
4905 value: Handle<crate::Expression>,
4906 res_var_info: &Option<(Handle<crate::Expression>, String)>,
4907 ) -> BackendResult {
4908 if let Some(cmp) = compare_expr {
4909 write!(self.out, ", ")?;
4910 self.write_expr(module, cmp, func_ctx)?;
4911 }
4912 write!(self.out, ", ")?;
4913 if let crate::AtomicFunction::Subtract = *fun {
4914 write!(self.out, "-")?;
4916 }
4917 self.write_expr(module, value, func_ctx)?;
4918 if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
4919 write!(self.out, ", ")?;
4920 if compare_expr.is_some() {
4921 write!(self.out, "{res_name}.old_value")?;
4922 } else {
4923 write!(self.out, "{res_name}")?;
4924 }
4925 }
4926 writeln!(self.out, ");")?;
4927 Ok(())
4928 }
4929}
4930
4931pub(super) struct MatrixType {
4932 pub(super) columns: crate::VectorSize,
4933 pub(super) rows: crate::VectorSize,
4934 pub(super) width: crate::Bytes,
4935}
4936
4937pub(super) fn get_inner_matrix_data(
4938 module: &Module,
4939 handle: Handle<crate::Type>,
4940) -> Option<MatrixType> {
4941 match module.types[handle].inner {
4942 TypeInner::Matrix {
4943 columns,
4944 rows,
4945 scalar,
4946 } => Some(MatrixType {
4947 columns,
4948 rows,
4949 width: scalar.width,
4950 }),
4951 TypeInner::Array { base, .. } => get_inner_matrix_data(module, base),
4952 _ => None,
4953 }
4954}
4955
4956fn find_matrix_in_access_chain(
4960 module: &Module,
4961 base: Handle<crate::Expression>,
4962 func_ctx: &back::FunctionCtx<'_>,
4963) -> Option<(Handle<crate::Expression>, Option<Index>, Option<Index>)> {
4964 let mut current_base = base;
4965 let mut vector = None;
4966 let mut scalar = None;
4967 loop {
4968 let resolved_tr = func_ctx
4969 .resolve_type(current_base, &module.types)
4970 .pointer_base_type();
4971 let resolved = resolved_tr.as_ref()?.inner_with(&module.types);
4972
4973 match *resolved {
4974 TypeInner::Matrix { .. } => return Some((current_base, vector, scalar)),
4975 TypeInner::Scalar(_) | TypeInner::Vector { .. } => {}
4976 _ => return None,
4977 }
4978
4979 let index;
4980 (current_base, index) = match func_ctx.expressions[current_base] {
4981 crate::Expression::Access { base, index } => (base, Index::Expression(index)),
4982 crate::Expression::AccessIndex { base, index } => (base, Index::Static(index)),
4983 _ => return None,
4984 };
4985
4986 match *resolved {
4987 TypeInner::Scalar(_) => scalar = Some(index),
4988 TypeInner::Vector { .. } => vector = Some(index),
4989 _ => unreachable!(),
4990 }
4991 }
4992}
4993
4994pub(super) fn get_inner_matrix_of_struct_array_member(
4999 module: &Module,
5000 base: Handle<crate::Expression>,
5001 func_ctx: &back::FunctionCtx<'_>,
5002 direct: bool,
5003) -> Option<MatrixType> {
5004 let mut mat_data = None;
5005 let mut array_base = None;
5006
5007 let mut current_base = base;
5008 loop {
5009 let mut resolved = func_ctx.resolve_type(current_base, &module.types);
5010 if let TypeInner::Pointer { base, .. } = *resolved {
5011 resolved = &module.types[base].inner;
5012 };
5013
5014 match *resolved {
5015 TypeInner::Matrix {
5016 columns,
5017 rows,
5018 scalar,
5019 } => {
5020 mat_data = Some(MatrixType {
5021 columns,
5022 rows,
5023 width: scalar.width,
5024 })
5025 }
5026 TypeInner::Array { base, .. } => {
5027 array_base = Some(base);
5028 }
5029 TypeInner::Struct { .. } => {
5030 if let Some(array_base) = array_base {
5031 if direct {
5032 return mat_data;
5033 } else {
5034 return get_inner_matrix_data(module, array_base);
5035 }
5036 }
5037
5038 break;
5039 }
5040 _ => break,
5041 }
5042
5043 current_base = match func_ctx.expressions[current_base] {
5044 crate::Expression::Access { base, .. } => base,
5045 crate::Expression::AccessIndex { base, .. } => base,
5046 _ => break,
5047 };
5048 }
5049 None
5050}
5051
5052fn get_inner_matrix_of_global_uniform(
5057 module: &Module,
5058 base: Handle<crate::Expression>,
5059 func_ctx: &back::FunctionCtx<'_>,
5060 direct: bool,
5061) -> Option<MatrixType> {
5062 let mut mat_data = None;
5063 let mut array_base = None;
5064
5065 let mut current_base = base;
5066 loop {
5067 let mut resolved = func_ctx.resolve_type(current_base, &module.types);
5068 if let TypeInner::Pointer { base, .. } = *resolved {
5069 resolved = &module.types[base].inner;
5070 };
5071
5072 match *resolved {
5073 TypeInner::Matrix {
5074 columns,
5075 rows,
5076 scalar,
5077 } => {
5078 mat_data = Some(MatrixType {
5079 columns,
5080 rows,
5081 width: scalar.width,
5082 })
5083 }
5084 TypeInner::Array { base, .. } => {
5085 if !direct {
5086 array_base = Some(base);
5087 }
5088 }
5089 _ => break,
5090 }
5091
5092 current_base = match func_ctx.expressions[current_base] {
5093 crate::Expression::Access { base, .. } => base,
5094 crate::Expression::AccessIndex { base, .. } => base,
5095 crate::Expression::GlobalVariable(handle)
5096 if module.global_variables[handle].space == crate::AddressSpace::Uniform =>
5097 {
5098 return mat_data.or_else(|| {
5099 array_base.and_then(|array_base| get_inner_matrix_data(module, array_base))
5100 })
5101 }
5102 _ => break,
5103 };
5104 }
5105 None
5106}