1use alloc::{
2 format,
3 string::{String, ToString},
4 vec::Vec,
5};
6use core::{
7 fmt::{self, Write as _},
8 mem,
9};
10
11use super::{
12 help,
13 help::{
14 WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess,
15 WrappedZeroValue,
16 },
17 storage::StoreValue,
18 BackendResult, Error, FragmentEntryPoint, Options, PipelineOptions, ShaderModel,
19};
20use crate::{
21 back::{self, get_entry_points, Baked},
22 common,
23 proc::{self, index, ExternalTextureNameKey, NameKey},
24 valid, Handle, Module, RayQueryFunction, Scalar, ScalarKind, ShaderStage, TypeInner,
25};
26
27const LOCATION_SEMANTIC: &str = "LOC";
28const SPECIAL_CBUF_TYPE: &str = "NagaConstants";
29const SPECIAL_CBUF_VAR: &str = "_NagaConstants";
30const SPECIAL_FIRST_VERTEX: &str = "first_vertex";
31const SPECIAL_FIRST_INSTANCE: &str = "first_instance";
32const SPECIAL_OTHER: &str = "other";
33
34pub(crate) const MODF_FUNCTION: &str = "naga_modf";
35pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
36pub(crate) const EXTRACT_BITS_FUNCTION: &str = "naga_extractBits";
37pub(crate) const INSERT_BITS_FUNCTION: &str = "naga_insertBits";
38pub(crate) const SAMPLER_HEAP_VAR: &str = "nagaSamplerHeap";
39pub(crate) const COMPARISON_SAMPLER_HEAP_VAR: &str = "nagaComparisonSamplerHeap";
40pub(crate) const SAMPLE_EXTERNAL_TEXTURE_FUNCTION: &str = "nagaSampleExternalTexture";
41pub(crate) const ABS_FUNCTION: &str = "naga_abs";
42pub(crate) const DIV_FUNCTION: &str = "naga_div";
43pub(crate) const MOD_FUNCTION: &str = "naga_mod";
44pub(crate) const NEG_FUNCTION: &str = "naga_neg";
45pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
46pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
47pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
48pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
49pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
50 "nagaTextureSampleBaseClampToEdge";
51pub(crate) const IMAGE_LOAD_EXTERNAL_FUNCTION: &str = "nagaTextureLoadExternal";
52pub(crate) const RAY_QUERY_TRACKER_VARIABLE_PREFIX: &str = "naga_query_init_tracker_for_";
53pub(crate) const INTERNAL_PREFIX: &str = "naga_";
55
56enum Index {
57 Expression(Handle<crate::Expression>),
58 Static(u32),
59}
60
61pub(super) struct EpStructMember {
62 pub(super) name: String,
63 pub(super) ty: Handle<crate::Type>,
64 pub(super) binding: Option<crate::Binding>,
67 pub(super) index: u32,
68}
69
70pub(super) struct EntryPointBinding {
73 pub(super) arg_name: String,
76 pub(super) ty_name: String,
78 pub(super) members: Vec<EpStructMember>,
80 pub(super) local_invocation_index_name: Option<String>,
81}
82
83pub(super) struct EntryPointInterface {
84 pub(crate) input: Option<EntryPointBinding>,
89 pub(crate) output: Option<EntryPointBinding>,
93 pub(crate) mesh_vertices: Option<EntryPointBinding>,
94 pub(crate) mesh_primitives: Option<EntryPointBinding>,
95 pub(crate) mesh_indices: Option<EntryPointBinding>,
96}
97
98#[derive(Clone, Eq, PartialEq, PartialOrd, Ord)]
99enum InterfaceKey {
100 Location(u32),
101 BuiltIn(crate::BuiltIn),
102 Other,
103}
104
105impl InterfaceKey {
106 const fn new(binding: Option<&crate::Binding>) -> Self {
107 match binding {
108 Some(&crate::Binding::Location { location, .. }) => Self::Location(location),
109 Some(&crate::Binding::BuiltIn(built_in)) => Self::BuiltIn(built_in),
110 None => Self::Other,
111 }
112 }
113}
114
115#[derive(Copy, Clone, PartialEq)]
116pub(super) enum Io {
117 Input,
118 Output,
119 MeshVertices,
120 MeshPrimitives,
121}
122
123pub(super) struct NestedEntryPointArgs {
125 pub user_args: Vec<String>,
127 pub task_payload: Option<String>,
128 pub local_invocation_index: String,
129}
130
131const fn is_subgroup_builtin_binding(binding: &Option<crate::Binding>) -> bool {
132 let &Some(crate::Binding::BuiltIn(builtin)) = binding else {
133 return false;
134 };
135 matches!(
136 builtin,
137 crate::BuiltIn::SubgroupSize
138 | crate::BuiltIn::SubgroupInvocationId
139 | crate::BuiltIn::NumSubgroups
140 | crate::BuiltIn::SubgroupId
141 )
142}
143
144struct BindingArraySamplerInfo {
146 sampler_heap_name: &'static str,
148 sampler_index_buffer_name: String,
150 binding_array_base_index_name: String,
152}
153
154impl<'a, W: fmt::Write> super::Writer<'a, W> {
155 pub fn new(out: W, options: &'a Options, pipeline_options: &'a PipelineOptions) -> Self {
156 Self {
157 out,
158 names: crate::FastHashMap::default(),
159 namer: proc::Namer::default(),
160 options,
161 pipeline_options,
162 entry_point_io: crate::FastHashMap::default(),
163 named_expressions: crate::NamedExpressions::default(),
164 wrapped: super::Wrapped::default(),
165 written_committed_intersection: false,
166 written_candidate_intersection: false,
167 continue_ctx: back::continue_forward::ContinueCtx::default(),
168 temp_access_chain: Vec::new(),
169 need_bake_expressions: Default::default(),
170 function_task_payload_var: Default::default(),
171 }
172 }
173
174 fn reset(&mut self, module: &Module) {
175 self.names.clear();
176 self.namer.reset(
177 module,
178 &super::keywords::RESERVED_SET,
179 proc::KeywordSet::empty(),
180 &super::keywords::RESERVED_CASE_INSENSITIVE_SET,
181 super::keywords::RESERVED_PREFIXES,
182 &mut self.names,
183 );
184 self.entry_point_io.clear();
185 self.named_expressions.clear();
186 self.wrapped.clear();
187 self.written_committed_intersection = false;
188 self.written_candidate_intersection = false;
189 self.continue_ctx.clear();
190 self.need_bake_expressions.clear();
191 self.function_task_payload_var.clear();
192 }
193
194 fn gen_force_bounded_loop_statements(
202 &mut self,
203 level: back::Level,
204 ) -> Option<(String, String)> {
205 if !self.options.force_loop_bounding {
206 return None;
207 }
208
209 let loop_bound_name = self.namer.call("loop_bound");
210 let max = u32::MAX;
211 let decl = format!("{level}uint2 {loop_bound_name} = uint2({max}u, {max}u);");
214 let level = level.next();
215 let break_and_inc = format!(
216 "{level}if (all({loop_bound_name} == uint2(0u, 0u))) {{ break; }}
217{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
218 );
219
220 Some((decl, break_and_inc))
221 }
222
223 fn update_expressions_to_bake(
228 &mut self,
229 module: &Module,
230 func: &crate::Function,
231 info: &valid::FunctionInfo,
232 ) {
233 use crate::Expression;
234 self.need_bake_expressions.clear();
235 for (exp_handle, expr) in func.expressions.iter() {
236 let expr_info = &info[exp_handle];
237 let min_ref_count = func.expressions[exp_handle].bake_ref_count();
238 if min_ref_count <= expr_info.ref_count {
239 self.need_bake_expressions.insert(exp_handle);
240 }
241 if let Expression::Load { pointer } = *expr {
242 if info[pointer]
243 .ty
244 .inner_with(&module.types)
245 .is_atomic_pointer(&module.types)
246 {
247 self.need_bake_expressions.insert(exp_handle);
248 }
249 }
250
251 if let Expression::Math { fun, arg, arg1, .. } = *expr {
252 match fun {
253 crate::MathFunction::Asinh
254 | crate::MathFunction::Acosh
255 | crate::MathFunction::Atanh
256 | crate::MathFunction::Unpack2x16float
257 | crate::MathFunction::Unpack2x16snorm
258 | crate::MathFunction::Unpack2x16unorm
259 | crate::MathFunction::Unpack4x8snorm
260 | crate::MathFunction::Unpack4x8unorm
261 | crate::MathFunction::Unpack4xI8
262 | crate::MathFunction::Unpack4xU8
263 | crate::MathFunction::Pack2x16float
264 | crate::MathFunction::Pack2x16snorm
265 | crate::MathFunction::Pack2x16unorm
266 | crate::MathFunction::Pack4x8snorm
267 | crate::MathFunction::Pack4x8unorm
268 | crate::MathFunction::Pack4xI8
269 | crate::MathFunction::Pack4xU8
270 | crate::MathFunction::Pack4xI8Clamp
271 | crate::MathFunction::Pack4xU8Clamp => {
272 self.need_bake_expressions.insert(arg);
273 }
274 crate::MathFunction::CountLeadingZeros => {
275 let inner = info[exp_handle].ty.inner_with(&module.types);
276 if let Some(ScalarKind::Sint) = inner.scalar_kind() {
277 self.need_bake_expressions.insert(arg);
278 }
279 }
280 crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
281 self.need_bake_expressions.insert(arg);
282 self.need_bake_expressions.insert(arg1.unwrap());
283 }
284 _ => {}
285 }
286 }
287
288 if let Expression::Derivative { axis, ctrl, expr } = *expr {
289 use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
290 if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
291 self.need_bake_expressions.insert(expr);
292 }
293 }
294
295 if let Expression::GlobalVariable(_) = *expr {
296 let inner = info[exp_handle].ty.inner_with(&module.types);
297
298 if let TypeInner::Sampler { .. } = *inner {
299 self.need_bake_expressions.insert(exp_handle);
300 }
301 }
302 }
303 for statement in func.body.iter() {
304 match *statement {
305 crate::Statement::SubgroupCollectiveOperation {
306 op: _,
307 collective_op: crate::CollectiveOperation::InclusiveScan,
308 argument,
309 result: _,
310 } => {
311 self.need_bake_expressions.insert(argument);
312 }
313 crate::Statement::Atomic {
314 fun: crate::AtomicFunction::Exchange { compare: Some(cmp) },
315 ..
316 } => {
317 self.need_bake_expressions.insert(cmp);
318 }
319 _ => {}
320 }
321 }
322 }
323
324 pub fn write(
325 &mut self,
326 module: &Module,
327 module_info: &valid::ModuleInfo,
328 fragment_entry_point: Option<&FragmentEntryPoint<'_>>,
329 ) -> Result<super::ReflectionInfo, Error> {
330 self.reset(module);
331
332 if module.uses_mesh_shaders() && self.options.shader_model < ShaderModel::V6_5 {
333 return Err(Error::ShaderModelTooLow(
334 "mesh shaders".to_string(),
335 ShaderModel::V6_5,
336 ));
337 }
338
339 if let Some(ref bt) = self.options.special_constants_binding {
341 writeln!(self.out, "struct {SPECIAL_CBUF_TYPE} {{")?;
342 writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_VERTEX)?;
343 writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_INSTANCE)?;
344 writeln!(self.out, "{}uint {};", back::INDENT, SPECIAL_OTHER)?;
345 writeln!(self.out, "}};")?;
346 write!(
347 self.out,
348 "ConstantBuffer<{}> {}: register(b{}",
349 SPECIAL_CBUF_TYPE, SPECIAL_CBUF_VAR, bt.register
350 )?;
351 if bt.space != 0 {
352 write!(self.out, ", space{}", bt.space)?;
353 }
354 writeln!(self.out, ");")?;
355
356 writeln!(self.out)?;
358 }
359
360 for (group, bt) in self.options.dynamic_storage_buffer_offsets_targets.iter() {
361 writeln!(self.out, "struct __dynamic_buffer_offsetsTy{group} {{")?;
362 for i in 0..bt.size {
363 writeln!(self.out, "{}uint _{};", back::INDENT, i)?;
364 }
365 writeln!(self.out, "}};")?;
366 writeln!(
367 self.out,
368 "ConstantBuffer<__dynamic_buffer_offsetsTy{}> __dynamic_buffer_offsets{}: register(b{}, space{});",
369 group, group, bt.register, bt.space
370 )?;
371
372 writeln!(self.out)?;
374 }
375
376 let ep_results = module
378 .entry_points
379 .iter()
380 .map(|ep| (ep.stage, ep.function.result.clone()))
381 .collect::<Vec<(ShaderStage, Option<crate::FunctionResult>)>>();
382
383 self.write_all_mat_cx2_typedefs_and_functions(module)?;
384
385 for (handle, ty) in module.types.iter() {
387 if let TypeInner::Struct { ref members, span } = ty.inner {
388 if module.types[members.last().unwrap().ty]
389 .inner
390 .is_dynamically_sized(&module.types)
391 {
392 continue;
395 }
396
397 let ep_result = ep_results.iter().find(|e| {
398 if let Some(ref result) = e.1 {
399 result.ty == handle
400 } else {
401 false
402 }
403 });
404
405 self.write_struct(
406 module,
407 handle,
408 members,
409 span,
410 ep_result.map(|r| (r.0, Io::Output)),
411 )?;
412 writeln!(self.out)?;
413 }
414 }
415
416 self.write_special_functions(module)?;
417
418 self.write_wrapped_expression_functions(module, &module.global_expressions, None)?;
419 self.write_wrapped_zero_value_functions(module, &module.global_expressions)?;
420
421 let mut constants = module
423 .constants
424 .iter()
425 .filter(|&(_, c)| c.name.is_some())
426 .peekable();
427 while let Some((handle, _)) = constants.next() {
428 self.write_global_constant(module, handle)?;
429 if constants.peek().is_none() {
431 writeln!(self.out)?;
432 }
433 }
434
435 for (global, _) in module.global_variables.iter() {
437 self.write_global(module, global)?;
438 }
439
440 if !module.global_variables.is_empty() {
441 writeln!(self.out)?;
443 }
444
445 let ep_range = get_entry_points(module, self.pipeline_options.entry_point.as_ref())
446 .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
447
448 for index in ep_range.clone() {
450 let ep = &module.entry_points[index];
451 let ep_name = self.names[&NameKey::EntryPoint(index as u16)].clone();
452 let ep_io = self.write_ep_interface(module, ep, &ep_name, fragment_entry_point)?;
453 self.entry_point_io.insert(index, ep_io);
454 }
455
456 for (handle, function) in module.functions.iter() {
458 let info = &module_info[handle];
459
460 if !self.options.fake_missing_bindings {
462 if let Some((var_handle, _)) =
463 module
464 .global_variables
465 .iter()
466 .find(|&(var_handle, var)| match var.binding {
467 Some(ref binding) if !info[var_handle].is_empty() => {
468 self.options.resolve_resource_binding(binding).is_err()
469 && self
470 .options
471 .resolve_external_texture_resource_binding(binding)
472 .is_err()
473 }
474 _ => false,
475 })
476 {
477 log::debug!(
478 "Skipping function {:?} (name {:?}) because global {:?} is inaccessible",
479 handle,
480 function.name,
481 var_handle
482 );
483 continue;
484 }
485 }
486
487 let ctx = back::FunctionCtx {
488 ty: back::FunctionType::Function(handle),
489 info,
490 expressions: &function.expressions,
491 named_expressions: &function.named_expressions,
492 };
493 let name = self.names[&NameKey::Function(handle)].clone();
494
495 self.write_wrapped_functions(module, &ctx)?;
496
497 self.write_function(module, name.as_str(), function, &ctx, info, String::new())?;
498
499 writeln!(self.out)?;
500 }
501
502 let mut translated_ep_names = Vec::with_capacity(ep_range.len());
503
504 for index in ep_range {
506 let ep = &module.entry_points[index];
507 let info = module_info.get_entry_point(index);
508
509 if !self.options.fake_missing_bindings {
510 let mut ep_error = None;
511 for (var_handle, var) in module.global_variables.iter() {
512 match var.binding {
513 Some(ref binding) if !info[var_handle].is_empty() => {
514 if let Err(err) = self.options.resolve_resource_binding(binding) {
515 if self
516 .options
517 .resolve_external_texture_resource_binding(binding)
518 .is_err()
519 {
520 ep_error = Some(err);
521 break;
522 }
523 }
524 }
525 _ => {}
526 }
527 }
528 if let Some(err) = ep_error {
529 translated_ep_names.push(Err(err));
530 continue;
531 }
532 }
533
534 let ctx = back::FunctionCtx {
535 ty: back::FunctionType::EntryPoint(index as u16),
536 info,
537 expressions: &ep.function.expressions,
538 named_expressions: &ep.function.named_expressions,
539 };
540
541 self.write_wrapped_functions(module, &ctx)?;
542
543 let mut attribute_string = String::new();
546 if ep.stage.compute_like() {
547 let num_threads = ep.workgroup_size;
549 writeln!(
550 attribute_string,
551 "[numthreads({}, {}, {})]",
552 num_threads[0], num_threads[1], num_threads[2]
553 )?;
554 }
555 if let Some(ref info) = ep.mesh_info {
556 let topology_str = match info.topology {
557 crate::MeshOutputTopology::Points => unreachable!(),
558 crate::MeshOutputTopology::Lines => "line",
559 crate::MeshOutputTopology::Triangles => "triangle",
560 };
561 writeln!(attribute_string, "[outputtopology(\"{topology_str}\")]")?;
562 }
563
564 let name = self.names[&NameKey::EntryPoint(index as u16)].clone();
565 self.write_function(module, &name, &ep.function, &ctx, info, attribute_string)?;
566
567 if index < module.entry_points.len() - 1 {
568 writeln!(self.out)?;
569 }
570
571 translated_ep_names.push(Ok(name));
572 }
573
574 Ok(super::ReflectionInfo {
575 entry_point_names: translated_ep_names,
576 })
577 }
578
579 fn write_modifier(&mut self, binding: &crate::Binding) -> BackendResult {
580 match *binding {
581 crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }) => {
582 write!(self.out, "precise ")?;
583 }
584 crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { perspective: false }) => {
585 write!(self.out, "noperspective ")?;
586 }
587 crate::Binding::Location {
588 interpolation,
589 sampling,
590 ..
591 } => {
592 if let Some(interpolation) = interpolation {
593 if let Some(string) = interpolation.to_hlsl_str() {
594 write!(self.out, "{string} ")?
595 }
596 }
597
598 if let Some(sampling) = sampling {
599 if let Some(string) = sampling.to_hlsl_str() {
600 write!(self.out, "{string} ")?
601 }
602 }
603 }
604 crate::Binding::BuiltIn(_) => {}
605 }
606
607 Ok(())
608 }
609
610 pub(super) fn write_semantic(
613 &mut self,
614 binding: &Option<crate::Binding>,
615 stage: Option<(ShaderStage, Io)>,
616 ) -> BackendResult {
617 let is_per_primitive = match *binding {
618 Some(crate::Binding::BuiltIn(builtin)) if !is_subgroup_builtin_binding(binding) => {
619 if builtin == crate::BuiltIn::ViewIndex
620 && self.options.shader_model < ShaderModel::V6_1
621 {
622 return Err(Error::ShaderModelTooLow(
623 "used @builtin(view_index) or SV_ViewID".to_string(),
624 ShaderModel::V6_1,
625 ));
626 }
627 if let Some(builtin_str) = builtin.to_hlsl_str()? {
628 write!(self.out, " : {builtin_str}")?;
629 }
630 false
631 }
632 Some(crate::Binding::Location {
633 blend_src: Some(1),
634 per_primitive,
635 ..
636 }) => {
637 write!(self.out, " : SV_Target1")?;
638 per_primitive
639 }
640 Some(crate::Binding::Location {
641 location,
642 per_primitive,
643 ..
644 }) => {
645 if stage == Some((ShaderStage::Fragment, Io::Output)) {
646 write!(self.out, " : SV_Target{location}")?;
647 } else {
648 write!(self.out, " : {LOCATION_SEMANTIC}{location}")?;
649 }
650 per_primitive
651 }
652 _ => false,
653 };
654 if is_per_primitive {
655 write!(self.out, " : primitive")?;
656 }
657
658 Ok(())
659 }
660
661 pub(super) fn write_interface_struct(
662 &mut self,
663 module: &Module,
664 shader_stage: (ShaderStage, Io),
665 struct_name: String,
666 var_name: Option<&str>,
667 mut members: Vec<EpStructMember>,
668 ) -> Result<EntryPointBinding, Error> {
669 let struct_name = self.namer.call(&struct_name);
670 members.sort_by_key(|m| InterfaceKey::new(m.binding.as_ref()));
674
675 write!(self.out, "struct {struct_name}")?;
676 writeln!(self.out, " {{")?;
677 let mut local_invocation_index_name = None;
678 let mut subgroup_id_used = false;
679 for m in members.iter() {
680 debug_assert!(m.binding.is_some());
683
684 match m.binding {
685 Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => {
686 subgroup_id_used = true;
687 }
688 Some(crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationIndex)) => {
689 local_invocation_index_name = Some(m.name.clone());
690 }
691 _ => (),
692 }
693
694 if is_subgroup_builtin_binding(&m.binding) {
695 continue;
696 }
697 write!(self.out, "{}", back::INDENT)?;
698 if let Some(ref binding) = m.binding {
699 self.write_modifier(binding)?;
700 }
701 self.write_type(module, m.ty)?;
702 write!(self.out, " {}", &m.name)?;
703 self.write_semantic(&m.binding, Some(shader_stage))?;
704 writeln!(self.out, ";")?;
705 }
706 if subgroup_id_used && local_invocation_index_name.is_none() {
707 let name = self.namer.call("local_invocation_index");
708 writeln!(self.out, "{}uint {name} : SV_GroupIndex;", back::INDENT)?;
709 local_invocation_index_name = Some(name);
710 }
711 writeln!(self.out, "}};")?;
712 writeln!(self.out)?;
713
714 match shader_stage.1 {
716 Io::Input => {
717 members.sort_by_key(|m| m.index);
719 }
720 Io::Output | Io::MeshVertices | Io::MeshPrimitives => {
721 }
723 }
724
725 Ok(EntryPointBinding {
726 arg_name: self
727 .namer
728 .call(var_name.unwrap_or(struct_name.to_lowercase().as_str())),
729 ty_name: struct_name,
730 members,
731 local_invocation_index_name,
732 })
733 }
734
735 fn write_ep_input_struct(
739 &mut self,
740 module: &Module,
741 func: &crate::Function,
742 stage: ShaderStage,
743 entry_point_name: &str,
744 ) -> Result<EntryPointBinding, Error> {
745 let struct_name = format!("{stage:?}Input_{entry_point_name}");
746
747 let mut fake_members = Vec::new();
748 for arg in func.arguments.iter() {
749 match module.types[arg.ty].inner {
754 TypeInner::Struct { ref members, .. } => {
755 for member in members.iter() {
756 let name = self.namer.call_or(&member.name, "member");
757 let index = fake_members.len() as u32;
758 fake_members.push(EpStructMember {
759 name,
760 ty: member.ty,
761 binding: member.binding.clone(),
762 index,
763 });
764 }
765 }
766 _ => {
767 let member_name = self.namer.call_or(&arg.name, "member");
768 let index = fake_members.len() as u32;
769 fake_members.push(EpStructMember {
770 name: member_name,
771 ty: arg.ty,
772 binding: arg.binding.clone(),
773 index,
774 });
775 }
776 }
777 }
778
779 self.write_interface_struct(module, (stage, Io::Input), struct_name, None, fake_members)
780 }
781
782 fn write_ep_output_struct(
786 &mut self,
787 module: &Module,
788 result: &crate::FunctionResult,
789 stage: ShaderStage,
790 entry_point_name: &str,
791 frag_ep: Option<&FragmentEntryPoint<'_>>,
792 ) -> Result<EntryPointBinding, Error> {
793 let struct_name = format!("{stage:?}Output_{entry_point_name}");
794
795 let empty = [];
796 let members = match module.types[result.ty].inner {
797 TypeInner::Struct { ref members, .. } => members,
798 ref other => {
799 log::error!("Unexpected {other:?} output type without a binding");
800 &empty[..]
801 }
802 };
803
804 let fs_input_locs = if let (Some(frag_ep), ShaderStage::Vertex) = (frag_ep, stage) {
809 let mut fs_input_locs = Vec::new();
810 for arg in frag_ep.func.arguments.iter() {
811 let mut push_if_location = |binding: &Option<crate::Binding>| match *binding {
812 Some(crate::Binding::Location { location, .. }) => fs_input_locs.push(location),
813 Some(crate::Binding::BuiltIn(_)) | None => {}
814 };
815
816 match frag_ep.module.types[arg.ty].inner {
819 TypeInner::Struct { ref members, .. } => {
820 for member in members.iter() {
821 push_if_location(&member.binding);
822 }
823 }
824 _ => push_if_location(&arg.binding),
825 }
826 }
827 fs_input_locs.sort();
828 Some(fs_input_locs)
829 } else {
830 None
831 };
832
833 let mut fake_members = Vec::new();
834 for (index, member) in members.iter().enumerate() {
835 if let Some(ref fs_input_locs) = fs_input_locs {
836 match member.binding {
837 Some(crate::Binding::Location { location, .. }) => {
838 if fs_input_locs.binary_search(&location).is_err() {
839 continue;
840 }
841 }
842 Some(crate::Binding::BuiltIn(_)) | None => {}
843 }
844 }
845
846 let member_name = self.namer.call_or(&member.name, "member");
847 fake_members.push(EpStructMember {
848 name: member_name,
849 ty: member.ty,
850 binding: member.binding.clone(),
851 index: index as u32,
852 });
853 }
854
855 self.write_interface_struct(module, (stage, Io::Output), struct_name, None, fake_members)
856 }
857
858 fn write_ep_interface(
862 &mut self,
863 module: &Module,
864 ep: &crate::EntryPoint,
865 ep_name: &str,
866 frag_ep: Option<&FragmentEntryPoint<'_>>,
867 ) -> Result<EntryPointInterface, Error> {
868 let func = &ep.function;
869 let stage = ep.stage;
870 Ok(EntryPointInterface {
871 input: if !func.arguments.is_empty()
872 && (stage == ShaderStage::Fragment
873 || func
874 .arguments
875 .iter()
876 .any(|arg| is_subgroup_builtin_binding(&arg.binding)))
877 {
878 Some(self.write_ep_input_struct(module, func, stage, ep_name)?)
879 } else {
880 None
881 },
882 output: match func.result {
883 Some(ref fr) if fr.binding.is_none() && stage == ShaderStage::Vertex => {
884 Some(self.write_ep_output_struct(module, fr, stage, ep_name, frag_ep)?)
885 }
886 _ => None,
887 },
888 mesh_vertices: if let Some(ref info) = ep.mesh_info {
889 Some(self.write_ep_mesh_output_struct(module, ep_name, false, info)?)
890 } else {
891 None
892 },
893 mesh_primitives: if let Some(ref info) = ep.mesh_info {
894 Some(self.write_ep_mesh_output_struct(module, ep_name, true, info)?)
895 } else {
896 None
897 },
898 mesh_indices: if let Some(ref info) = ep.mesh_info {
899 Some(self.write_ep_mesh_output_indices(info.topology)?)
900 } else {
901 None
902 },
903 })
904 }
905
906 fn write_ep_argument_initialization(
907 &mut self,
908 ep: &crate::EntryPoint,
909 ep_input: &EntryPointBinding,
910 fake_member: &EpStructMember,
911 ) -> BackendResult {
912 match fake_member.binding {
913 Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupSize)) => {
914 write!(self.out, "WaveGetLaneCount()")?
915 }
916 Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupInvocationId)) => {
917 write!(self.out, "WaveGetLaneIndex()")?
918 }
919 Some(crate::Binding::BuiltIn(crate::BuiltIn::NumSubgroups)) => write!(
920 self.out,
921 "({}u + WaveGetLaneCount() - 1u) / WaveGetLaneCount()",
922 ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2]
923 )?,
924 Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => {
925 write!(
926 self.out,
927 "{}.{} / WaveGetLaneCount()",
928 ep_input.arg_name,
929 ep_input.local_invocation_index_name.as_ref().unwrap()
931 )?;
932 }
933 Some(crate::Binding::Location {
934 interpolation: Some(crate::Interpolation::PerVertex),
935 ..
936 }) => {
937 if self.options.shader_model < ShaderModel::V6_1 {
938 return Err(Error::ShaderModelTooLow(
939 "per_vertex fragment inputs".to_string(),
940 ShaderModel::V6_1,
941 ));
942 }
943 write!(
944 self.out,
945 "{{ GetAttributeAtVertex({0}.{1}, 0), GetAttributeAtVertex({0}.{1}, 1), GetAttributeAtVertex({0}.{1}, 2) }}",
946 ep_input.arg_name,
947 fake_member.name,
948 )?;
949 }
950 _ => {
951 write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?;
952 }
953 }
954 Ok(())
955 }
956
957 fn write_ep_arguments_initialization(
959 &mut self,
960 module: &Module,
961 func: &crate::Function,
962 ep_index: u16,
963 ) -> BackendResult {
964 let ep = &module.entry_points[ep_index as usize];
965 let ep_input = match self
966 .entry_point_io
967 .get_mut(&(ep_index as usize))
968 .unwrap()
969 .input
970 .take()
971 {
972 Some(ep_input) => ep_input,
973 None => return Ok(()),
974 };
975 let mut fake_iter = ep_input.members.iter();
976 for (arg_index, arg) in func.arguments.iter().enumerate() {
977 write!(self.out, "{}", back::INDENT)?;
978 self.write_type(module, arg.ty)?;
979 let arg_name = &self.names[&NameKey::EntryPointArgument(ep_index, arg_index as u32)];
980 write!(self.out, " {arg_name}")?;
981 match module.types[arg.ty].inner {
982 TypeInner::Array { base, size, .. } => {
983 self.write_array_size(module, base, size)?;
984 write!(self.out, " = ")?;
985 self.write_ep_argument_initialization(
986 ep,
987 &ep_input,
988 fake_iter.next().unwrap(),
989 )?;
990 writeln!(self.out, ";")?;
991 }
992 TypeInner::Struct { ref members, .. } => {
993 write!(self.out, " = {{ ")?;
994 for index in 0..members.len() {
995 if index != 0 {
996 write!(self.out, ", ")?;
997 }
998 self.write_ep_argument_initialization(
999 ep,
1000 &ep_input,
1001 fake_iter.next().unwrap(),
1002 )?;
1003 }
1004 writeln!(self.out, " }};")?;
1005 }
1006 _ => {
1007 write!(self.out, " = ")?;
1008 self.write_ep_argument_initialization(
1009 ep,
1010 &ep_input,
1011 fake_iter.next().unwrap(),
1012 )?;
1013 writeln!(self.out, ";")?;
1014 }
1015 }
1016 }
1017 assert!(fake_iter.next().is_none());
1018 Ok(())
1019 }
1020
1021 fn write_global(
1025 &mut self,
1026 module: &Module,
1027 handle: Handle<crate::GlobalVariable>,
1028 ) -> BackendResult {
1029 let global = &module.global_variables[handle];
1030 let inner = &module.types[global.ty].inner;
1031
1032 let handle_ty = match *inner {
1033 TypeInner::BindingArray { ref base, .. } => &module.types[*base].inner,
1034 _ => inner,
1035 };
1036
1037 let is_external_texture = matches!(
1041 *handle_ty,
1042 TypeInner::Image {
1043 class: crate::ImageClass::External,
1044 ..
1045 }
1046 );
1047 if is_external_texture {
1048 return self.write_global_external_texture(module, handle, global);
1049 }
1050
1051 if let Some(ref binding) = global.binding {
1052 if let Err(err) = self.options.resolve_resource_binding(binding) {
1053 log::debug!(
1054 "Skipping global {:?} (name {:?}) for being inaccessible: {}",
1055 handle,
1056 global.name,
1057 err,
1058 );
1059 return Ok(());
1060 }
1061 }
1062
1063 let is_sampler = matches!(*handle_ty, TypeInner::Sampler { .. });
1065
1066 if is_sampler {
1067 return self.write_global_sampler(module, handle, global);
1068 }
1069
1070 let register_ty = match global.space {
1072 crate::AddressSpace::Function => unreachable!("Function address space"),
1073 crate::AddressSpace::Private => {
1074 write!(self.out, "static ")?;
1075 self.write_type(module, global.ty)?;
1076 ""
1077 }
1078 crate::AddressSpace::WorkGroup | crate::AddressSpace::TaskPayload => {
1079 write!(self.out, "groupshared ")?;
1080 self.write_type(module, global.ty)?;
1081 ""
1082 }
1083 crate::AddressSpace::Uniform => {
1084 write!(self.out, "cbuffer")?;
1087 "b"
1088 }
1089 crate::AddressSpace::Storage { access } => {
1090 if global
1091 .memory_decorations
1092 .contains(crate::MemoryDecorations::COHERENT)
1093 {
1094 write!(self.out, "globallycoherent ")?;
1095 }
1096 let (prefix, register) = if access.contains(crate::StorageAccess::STORE) {
1097 ("RW", "u")
1098 } else {
1099 ("", "t")
1100 };
1101 write!(self.out, "{prefix}ByteAddressBuffer")?;
1102 register
1103 }
1104 crate::AddressSpace::Handle => {
1105 let register = match *handle_ty {
1106 TypeInner::Image {
1108 class: crate::ImageClass::Storage { .. },
1109 ..
1110 } => "u",
1111 _ => "t",
1112 };
1113 self.write_type(module, global.ty)?;
1114 register
1115 }
1116 crate::AddressSpace::Immediate => {
1117 write!(self.out, "ConstantBuffer<")?;
1119 "b"
1120 }
1121 crate::AddressSpace::RayPayload | crate::AddressSpace::IncomingRayPayload => {
1122 unimplemented!()
1123 }
1124 };
1125
1126 if global.space == crate::AddressSpace::Immediate {
1129 self.write_global_type(module, global.ty)?;
1130
1131 if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1133 self.write_array_size(module, base, size)?;
1134 }
1135
1136 write!(self.out, ">")?;
1138 }
1139
1140 let name = &self.names[&NameKey::GlobalVariable(handle)];
1141 write!(self.out, " {name}")?;
1142
1143 if global.space == crate::AddressSpace::Immediate {
1146 match module.types[global.ty].inner {
1147 TypeInner::Struct { .. } => {}
1148 _ => {
1149 return Err(Error::Unimplemented(format!(
1150 "push-constant '{name}' has non-struct type; tracked by: https://github.com/gfx-rs/wgpu/issues/5683"
1151 )));
1152 }
1153 }
1154
1155 let target = self
1156 .options
1157 .immediates_target
1158 .as_ref()
1159 .expect("No bind target was defined for the immediates block");
1160 write!(self.out, ": register(b{}", target.register)?;
1161 if target.space != 0 {
1162 write!(self.out, ", space{}", target.space)?;
1163 }
1164 write!(self.out, ")")?;
1165 }
1166
1167 if let Some(ref binding) = global.binding {
1168 let bt = self.options.resolve_resource_binding(binding).unwrap();
1170
1171 if let TypeInner::BindingArray { base, size, .. } = module.types[global.ty].inner {
1173 if let Some(overridden_size) = bt.binding_array_size {
1174 write!(self.out, "[{overridden_size}]")?;
1175 } else {
1176 self.write_array_size(module, base, size)?;
1177 }
1178 }
1179
1180 write!(self.out, " : register({}{}", register_ty, bt.register)?;
1181 if bt.space != 0 {
1182 write!(self.out, ", space{}", bt.space)?;
1183 }
1184 write!(self.out, ")")?;
1185 } else {
1186 if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1188 self.write_array_size(module, base, size)?;
1189 }
1190 if global.space == crate::AddressSpace::Private {
1191 write!(self.out, " = ")?;
1192 if let Some(init) = global.init {
1193 self.write_const_expression(module, init, &module.global_expressions)?;
1194 } else {
1195 self.write_default_init(module, global.ty)?;
1196 }
1197 }
1198 }
1199
1200 if global.space == crate::AddressSpace::Uniform {
1201 write!(self.out, " {{ ")?;
1202
1203 self.write_global_type(module, global.ty)?;
1204
1205 write!(
1206 self.out,
1207 " {}",
1208 &self.names[&NameKey::GlobalVariable(handle)]
1209 )?;
1210
1211 if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1213 self.write_array_size(module, base, size)?;
1214 }
1215
1216 writeln!(self.out, "; }}")?;
1217 } else {
1218 writeln!(self.out, ";")?;
1219 }
1220
1221 Ok(())
1222 }
1223
1224 fn write_global_sampler(
1225 &mut self,
1226 module: &Module,
1227 handle: Handle<crate::GlobalVariable>,
1228 global: &crate::GlobalVariable,
1229 ) -> BackendResult {
1230 let binding = *global.binding.as_ref().unwrap();
1231
1232 let key = super::SamplerIndexBufferKey {
1233 group: binding.group,
1234 };
1235 self.write_wrapped_sampler_buffer(key)?;
1236
1237 let bt = self.options.resolve_resource_binding(&binding).unwrap();
1239
1240 match module.types[global.ty].inner {
1241 TypeInner::Sampler { comparison } => {
1242 write!(self.out, "static const ")?;
1249 self.write_type(module, global.ty)?;
1250
1251 let heap_var = if comparison {
1252 COMPARISON_SAMPLER_HEAP_VAR
1253 } else {
1254 SAMPLER_HEAP_VAR
1255 };
1256
1257 let index_buffer_name = &self.wrapped.sampler_index_buffers[&key];
1258 let name = &self.names[&NameKey::GlobalVariable(handle)];
1259 writeln!(
1260 self.out,
1261 " {name} = {heap_var}[{index_buffer_name}[{register}]];",
1262 register = bt.register
1263 )?;
1264 }
1265 TypeInner::BindingArray { .. } => {
1266 let name = &self.names[&NameKey::GlobalVariable(handle)];
1272 writeln!(
1273 self.out,
1274 "static const uint {name} = {register};",
1275 register = bt.register
1276 )?;
1277 }
1278 _ => unreachable!(),
1279 };
1280
1281 Ok(())
1282 }
1283
1284 fn write_global_external_texture(
1288 &mut self,
1289 module: &Module,
1290 handle: Handle<crate::GlobalVariable>,
1291 global: &crate::GlobalVariable,
1292 ) -> BackendResult {
1293 let res_binding = global
1294 .binding
1295 .as_ref()
1296 .expect("External texture global variables must have a resource binding");
1297 let ext_tex_bindings = match self
1298 .options
1299 .resolve_external_texture_resource_binding(res_binding)
1300 {
1301 Ok(bindings) => bindings,
1302 Err(err) => {
1303 log::debug!(
1304 "Skipping global {:?} (name {:?}) for being inaccessible: {}",
1305 handle,
1306 global.name,
1307 err,
1308 );
1309 return Ok(());
1310 }
1311 };
1312
1313 let mut write_plane = |bt: &super::BindTarget, name| -> BackendResult {
1314 write!(
1315 self.out,
1316 "Texture2D<float4> {}: register(t{}",
1317 name, bt.register
1318 )?;
1319 if bt.space != 0 {
1320 write!(self.out, ", space{}", bt.space)?;
1321 }
1322 writeln!(self.out, ");")?;
1323 Ok(())
1324 };
1325 for (i, bt) in ext_tex_bindings.planes.iter().enumerate() {
1326 let plane_name = &self.names
1327 [&NameKey::ExternalTextureGlobalVariable(handle, ExternalTextureNameKey::Plane(i))];
1328 write_plane(bt, plane_name)?;
1329 }
1330
1331 let params_name = &self.names
1332 [&NameKey::ExternalTextureGlobalVariable(handle, ExternalTextureNameKey::Params)];
1333 let params_ty_name =
1334 &self.names[&NameKey::Type(module.special_types.external_texture_params.unwrap())];
1335 write!(
1336 self.out,
1337 "cbuffer {}: register(b{}",
1338 params_name, ext_tex_bindings.params.register
1339 )?;
1340 if ext_tex_bindings.params.space != 0 {
1341 write!(self.out, ", space{}", ext_tex_bindings.params.space)?;
1342 }
1343 writeln!(self.out, ") {{ {params_ty_name} {params_name}; }};")?;
1344
1345 Ok(())
1346 }
1347
1348 fn write_global_constant(
1353 &mut self,
1354 module: &Module,
1355 handle: Handle<crate::Constant>,
1356 ) -> BackendResult {
1357 write!(self.out, "static const ")?;
1358 let constant = &module.constants[handle];
1359 self.write_type(module, constant.ty)?;
1360 let name = &self.names[&NameKey::Constant(handle)];
1361 write!(self.out, " {name}")?;
1362 if let TypeInner::Array { base, size, .. } = module.types[constant.ty].inner {
1364 self.write_array_size(module, base, size)?;
1365 }
1366 write!(self.out, " = ")?;
1367 self.write_const_expression(module, constant.init, &module.global_expressions)?;
1368 writeln!(self.out, ";")?;
1369 Ok(())
1370 }
1371
1372 pub(super) fn write_array_size(
1373 &mut self,
1374 module: &Module,
1375 base: Handle<crate::Type>,
1376 size: crate::ArraySize,
1377 ) -> BackendResult {
1378 write!(self.out, "[")?;
1379
1380 match size.resolve(module.to_ctx())? {
1381 proc::IndexableLength::Known(size) => {
1382 write!(self.out, "{size}")?;
1383 }
1384 proc::IndexableLength::Dynamic => unreachable!(),
1385 }
1386
1387 write!(self.out, "]")?;
1388
1389 if let TypeInner::Array {
1390 base: next_base,
1391 size: next_size,
1392 ..
1393 } = module.types[base].inner
1394 {
1395 self.write_array_size(module, next_base, next_size)?;
1396 }
1397
1398 Ok(())
1399 }
1400
1401 fn write_struct(
1406 &mut self,
1407 module: &Module,
1408 handle: Handle<crate::Type>,
1409 members: &[crate::StructMember],
1410 span: u32,
1411 shader_stage: Option<(ShaderStage, Io)>,
1412 ) -> BackendResult {
1413 let struct_name = &self.names[&NameKey::Type(handle)];
1415 writeln!(self.out, "struct {struct_name} {{")?;
1416
1417 let mut last_offset = 0;
1418 for (index, member) in members.iter().enumerate() {
1419 if member.binding.is_none() && member.offset > last_offset {
1420 let padding = (member.offset - last_offset) / 4;
1424 for i in 0..padding {
1425 writeln!(self.out, "{}int _pad{}_{};", back::INDENT, index, i)?;
1426 }
1427 }
1428 let ty_inner = &module.types[member.ty].inner;
1429 last_offset = member.offset + ty_inner.size_hlsl(module.to_ctx())?;
1430
1431 write!(self.out, "{}", back::INDENT)?;
1433
1434 match module.types[member.ty].inner {
1435 TypeInner::Array { base, size, .. } => {
1436 self.write_global_type(module, member.ty)?;
1439
1440 write!(
1442 self.out,
1443 " {}",
1444 &self.names[&NameKey::StructMember(handle, index as u32)]
1445 )?;
1446 self.write_array_size(module, base, size)?;
1448 }
1449 TypeInner::Matrix {
1452 rows,
1453 columns,
1454 scalar,
1455 } if member.binding.is_none() && rows == crate::VectorSize::Bi => {
1456 let vec_ty = TypeInner::Vector { size: rows, scalar };
1457 let field_name_key = NameKey::StructMember(handle, index as u32);
1458
1459 for i in 0..columns as u8 {
1460 if i != 0 {
1461 write!(self.out, "; ")?;
1462 }
1463 self.write_value_type(module, &vec_ty)?;
1464 write!(self.out, " {}_{}", &self.names[&field_name_key], i)?;
1465 }
1466 }
1467 _ => {
1468 if let Some(ref binding) = member.binding {
1470 self.write_modifier(binding)?;
1471 }
1472
1473 if let TypeInner::Matrix { .. } = module.types[member.ty].inner {
1477 write!(self.out, "row_major ")?;
1478 }
1479
1480 self.write_type(module, member.ty)?;
1482 write!(
1483 self.out,
1484 " {}",
1485 &self.names[&NameKey::StructMember(handle, index as u32)]
1486 )?;
1487 }
1488 }
1489
1490 self.write_semantic(&member.binding, shader_stage)?;
1491 writeln!(self.out, ";")?;
1492 }
1493
1494 if members.last().unwrap().binding.is_none() && span > last_offset {
1496 let padding = (span - last_offset) / 4;
1497 for i in 0..padding {
1498 writeln!(self.out, "{}int _end_pad_{};", back::INDENT, i)?;
1499 }
1500 }
1501
1502 writeln!(self.out, "}};")?;
1503 Ok(())
1504 }
1505
1506 pub(super) fn write_global_type(
1511 &mut self,
1512 module: &Module,
1513 ty: Handle<crate::Type>,
1514 ) -> BackendResult {
1515 let matrix_data = get_inner_matrix_data(module, ty);
1516
1517 if let Some(MatrixType {
1520 columns,
1521 rows: crate::VectorSize::Bi,
1522 width: 4,
1523 }) = matrix_data
1524 {
1525 write!(self.out, "__mat{}x2", columns as u8)?;
1526 } else {
1527 if matrix_data.is_some() {
1531 write!(self.out, "row_major ")?;
1532 }
1533
1534 self.write_type(module, ty)?;
1535 }
1536
1537 Ok(())
1538 }
1539
1540 pub(super) fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
1545 let inner = &module.types[ty].inner;
1546 match *inner {
1547 TypeInner::Struct { .. } => write!(self.out, "{}", self.names[&NameKey::Type(ty)])?,
1548 TypeInner::Array { base, .. } | TypeInner::BindingArray { base, .. } => {
1550 self.write_type(module, base)?
1551 }
1552 ref other => self.write_value_type(module, other)?,
1553 }
1554
1555 Ok(())
1556 }
1557
1558 pub(super) fn write_value_type(&mut self, module: &Module, inner: &TypeInner) -> BackendResult {
1563 match *inner {
1564 TypeInner::Scalar(scalar) | TypeInner::Atomic(scalar) => {
1565 write!(self.out, "{}", scalar.to_hlsl_str()?)?;
1566 }
1567 TypeInner::Vector { size, scalar } => {
1568 write!(
1569 self.out,
1570 "{}{}",
1571 scalar.to_hlsl_str()?,
1572 common::vector_size_str(size)
1573 )?;
1574 }
1575 TypeInner::Matrix {
1576 columns,
1577 rows,
1578 scalar,
1579 } => {
1580 write!(
1585 self.out,
1586 "{}{}x{}",
1587 scalar.to_hlsl_str()?,
1588 common::vector_size_str(columns),
1589 common::vector_size_str(rows),
1590 )?;
1591 }
1592 TypeInner::Image {
1593 dim,
1594 arrayed,
1595 class,
1596 } => {
1597 self.write_image_type(dim, arrayed, class)?;
1598 }
1599 TypeInner::Sampler { comparison } => {
1600 let sampler = if comparison {
1601 "SamplerComparisonState"
1602 } else {
1603 "SamplerState"
1604 };
1605 write!(self.out, "{sampler}")?;
1606 }
1607 TypeInner::Array { base, size, .. } | TypeInner::BindingArray { base, size } => {
1611 self.write_array_size(module, base, size)?;
1612 }
1613 TypeInner::AccelerationStructure { .. } => {
1614 write!(self.out, "RaytracingAccelerationStructure")?;
1615 }
1616 TypeInner::RayQuery { .. } => {
1617 write!(self.out, "RayQuery<RAY_FLAG_NONE>")?;
1619 }
1620 _ => return Err(Error::Unimplemented(format!("write_value_type {inner:?}"))),
1621 }
1622
1623 Ok(())
1624 }
1625
1626 fn write_function(
1630 &mut self,
1631 module: &Module,
1632 name: &str,
1633 func: &crate::Function,
1634 func_ctx: &back::FunctionCtx<'_>,
1635 info: &valid::FunctionInfo,
1636 header: String,
1637 ) -> BackendResult {
1638 self.update_expressions_to_bake(module, func, info);
1641 let ep = match func_ctx.ty {
1642 back::FunctionType::EntryPoint(idx) => Some(&module.entry_points[idx as usize]),
1643 back::FunctionType::Function(_) => None,
1644 };
1645
1646 let nested = matches!(
1647 ep,
1648 Some(crate::EntryPoint {
1649 stage: ShaderStage::Task | ShaderStage::Mesh,
1650 ..
1651 })
1652 );
1653 if !nested {
1654 write!(self.out, "{header}")?;
1655 }
1656
1657 if let Some(ref result) = func.result {
1658 let array_return_type = match module.types[result.ty].inner {
1660 TypeInner::Array { base, size, .. } => {
1661 let array_return_type = self.namer.call(&format!("ret_{name}"));
1662 write!(self.out, "typedef ")?;
1663 self.write_type(module, result.ty)?;
1664 write!(self.out, " {array_return_type}")?;
1665 self.write_array_size(module, base, size)?;
1666 writeln!(self.out, ";")?;
1667 Some(array_return_type)
1668 }
1669 _ => None,
1670 };
1671
1672 if let Some(
1674 ref binding @ crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }),
1675 ) = result.binding
1676 {
1677 self.write_modifier(binding)?;
1678 }
1679
1680 match func_ctx.ty {
1682 back::FunctionType::Function(_) => {
1683 if let Some(array_return_type) = array_return_type {
1684 write!(self.out, "{array_return_type}")?;
1685 } else {
1686 self.write_type(module, result.ty)?;
1687 }
1688 }
1689 back::FunctionType::EntryPoint(index) => {
1690 if let Some(ref ep_output) =
1691 self.entry_point_io.get(&(index as usize)).unwrap().output
1692 {
1693 write!(self.out, "{}", ep_output.ty_name)?;
1694 } else {
1695 self.write_type(module, result.ty)?;
1696 }
1697 }
1698 }
1699 } else {
1700 write!(self.out, "void")?;
1701 }
1702
1703 let nested_name = if nested {
1704 self.namer.call(&format!("_{name}"))
1705 } else {
1706 name.to_string()
1707 };
1708
1709 write!(self.out, " {nested_name}(")?;
1711
1712 let need_workgroup_variables_initialization =
1713 self.need_workgroup_variables_initialization(func_ctx, module);
1714
1715 let mut any_args_written = false;
1716 let mut separator = || {
1717 if any_args_written {
1718 ", "
1719 } else {
1720 any_args_written = true;
1721 ""
1722 }
1723 };
1724
1725 let needs_local_invocation_index_name = need_workgroup_variables_initialization || nested;
1726 let mut local_invocation_index_name = None;
1727 let mut nested_wgsl_args: Vec<String> = Vec::new();
1730 let mut nested_task_payload_name: Option<String> = None;
1731 match func_ctx.ty {
1733 back::FunctionType::Function(handle) => {
1734 for (index, arg) in func.arguments.iter().enumerate() {
1735 write!(self.out, "{}", separator())?;
1736 self.write_function_argument(module, handle, arg, index)?;
1737 }
1738 for (var_handle, var) in module.global_variables.iter() {
1740 let uses = info[var_handle];
1741 if uses.contains(valid::GlobalUse::READ)
1742 && !uses.contains(valid::GlobalUse::WRITE)
1743 && var.space == crate::AddressSpace::TaskPayload
1744 {
1745 self.function_task_payload_var.insert(handle, var_handle);
1746 write!(self.out, "{}in ", separator())?;
1747
1748 self.write_type(module, var.ty)?;
1749 let name = &self.names[&NameKey::GlobalVariable(var_handle)];
1750 write!(self.out, " {name}")?;
1751 break;
1752 }
1753 }
1754 }
1755 back::FunctionType::EntryPoint(ep_index) => {
1756 let ep = &module.entry_points[ep_index as usize];
1757 if let Some(ref ep_input) =
1758 self.entry_point_io.get(&(ep_index as usize)).unwrap().input
1759 {
1760 write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name)?;
1761 separator();
1762 nested_wgsl_args.push(ep_input.arg_name.clone());
1763 } else {
1764 let stage = ep.stage;
1765 for (index, arg) in func.arguments.iter().enumerate() {
1766 write!(self.out, "{}", separator())?;
1767 self.write_type(module, arg.ty)?;
1768
1769 let argument_name =
1770 &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)];
1771
1772 if arg.binding
1773 == Some(crate::Binding::BuiltIn(
1774 crate::BuiltIn::LocalInvocationIndex,
1775 ))
1776 {
1777 local_invocation_index_name = Some(argument_name.clone());
1778 }
1779
1780 nested_wgsl_args.push(argument_name.clone());
1781 write!(self.out, " {argument_name}")?;
1782 if let TypeInner::Array { base, size, .. } = module.types[arg.ty].inner {
1783 self.write_array_size(module, base, size)?;
1784 }
1785
1786 self.write_semantic(&arg.binding, Some((stage, Io::Input)))?;
1787 }
1788 }
1789 if ep.stage == ShaderStage::Mesh {
1790 if let Some(var_handle) = ep.task_payload {
1791 let var = &module.global_variables[var_handle];
1792 write!(self.out, "{}in ", separator())?;
1793 self.write_type(module, var.ty)?;
1794 let arg_name = &self.names[&NameKey::GlobalVariable(var_handle)];
1795 write!(self.out, " {arg_name}")?;
1796 nested_task_payload_name = Some(arg_name.clone());
1797 if let TypeInner::Array { base, size, .. } = module.types[var.ty].inner {
1798 self.write_array_size(module, base, size)?;
1799 }
1800 }
1801 }
1802 if needs_local_invocation_index_name && local_invocation_index_name.is_none() {
1803 let name = self.namer.call("local_invocation_index");
1804 write!(self.out, "{}uint {name}", separator())?;
1805 write!(self.out, " : SV_GroupIndex")?;
1806 local_invocation_index_name = Some(name);
1807 }
1808 }
1809 }
1810 write!(self.out, ")")?;
1812
1813 if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1815 let stage = module.entry_points[index as usize].stage;
1816 if let Some(crate::FunctionResult { ref binding, .. }) = func.result {
1817 self.write_semantic(binding, Some((stage, Io::Output)))?;
1818 }
1819 }
1820
1821 writeln!(self.out)?;
1823 writeln!(self.out, "{{")?;
1824
1825 if need_workgroup_variables_initialization && !nested {
1826 let back::FunctionType::EntryPoint(index) = func_ctx.ty else {
1827 unreachable!();
1828 };
1829 writeln!(
1830 self.out,
1831 "{}if ({} == 0) {{",
1832 back::INDENT,
1833 local_invocation_index_name.as_ref().unwrap(),
1836 )?;
1837 self.write_workgroup_variables_initialization(
1838 func_ctx,
1839 module,
1840 module.entry_points[index as usize].stage,
1841 )?;
1842
1843 writeln!(self.out, "{}}}", back::INDENT)?;
1844 self.write_control_barrier(crate::Barrier::WORK_GROUP, back::Level(1))?;
1845 }
1846
1847 if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1848 self.write_ep_arguments_initialization(module, func, index)?;
1849 }
1850
1851 for (handle, local) in func.local_variables.iter() {
1853 write!(self.out, "{}", back::INDENT)?;
1855
1856 self.write_type(module, local.ty)?;
1859 write!(self.out, " {}", self.names[&func_ctx.name_key(handle)])?;
1860 if let TypeInner::Array { base, size, .. } = module.types[local.ty].inner {
1862 self.write_array_size(module, base, size)?;
1863 }
1864
1865 let is_ray_query = match module.types[local.ty].inner {
1866 TypeInner::RayQuery { .. } => true,
1868 _ => {
1869 write!(self.out, " = ")?;
1870 if let Some(init) = local.init {
1872 self.write_expr(module, init, func_ctx)?;
1873 } else {
1874 self.write_default_init(module, local.ty)?;
1876 }
1877 false
1878 }
1879 };
1880 writeln!(self.out, ";")?;
1882 if is_ray_query {
1884 write!(self.out, "{}", back::INDENT)?;
1885 self.write_value_type(module, &TypeInner::Scalar(Scalar::U32))?;
1886 writeln!(
1887 self.out,
1888 " {RAY_QUERY_TRACKER_VARIABLE_PREFIX}{} = 0;",
1889 self.names[&func_ctx.name_key(handle)]
1890 )?;
1891 }
1892 }
1893
1894 if !func.local_variables.is_empty() {
1895 writeln!(self.out)?;
1896 }
1897
1898 for sta in func.body.iter() {
1900 self.write_stmt(module, sta, func_ctx, back::Level(1))?;
1902 }
1903
1904 writeln!(self.out, "}}")?;
1905
1906 if nested {
1907 self.write_nested_function_outer(
1908 module,
1909 func_ctx,
1910 &header,
1911 name,
1912 need_workgroup_variables_initialization,
1913 &nested_name,
1914 ep.unwrap(),
1915 NestedEntryPointArgs {
1916 user_args: nested_wgsl_args,
1917 task_payload: nested_task_payload_name,
1918 local_invocation_index: local_invocation_index_name.unwrap(),
1920 },
1921 )?;
1922 }
1923
1924 self.named_expressions.clear();
1925
1926 Ok(())
1927 }
1928
1929 fn write_function_argument(
1930 &mut self,
1931 module: &Module,
1932 handle: Handle<crate::Function>,
1933 arg: &crate::FunctionArgument,
1934 index: usize,
1935 ) -> BackendResult {
1936 if let TypeInner::Image {
1939 class: crate::ImageClass::External,
1940 ..
1941 } = module.types[arg.ty].inner
1942 {
1943 return self.write_function_external_texture_argument(module, handle, index);
1944 }
1945
1946 let arg_ty = match module.types[arg.ty].inner {
1948 TypeInner::Pointer { base, .. } => {
1950 write!(self.out, "inout ")?;
1952 base
1953 }
1954 _ => arg.ty,
1955 };
1956 self.write_type(module, arg_ty)?;
1957
1958 let argument_name = &self.names[&NameKey::FunctionArgument(handle, index as u32)];
1959
1960 write!(self.out, " {argument_name}")?;
1962 if let TypeInner::Array { base, size, .. } = module.types[arg_ty].inner {
1963 self.write_array_size(module, base, size)?;
1964 }
1965
1966 Ok(())
1967 }
1968
1969 fn write_function_external_texture_argument(
1970 &mut self,
1971 module: &Module,
1972 handle: Handle<crate::Function>,
1973 index: usize,
1974 ) -> BackendResult {
1975 let plane_names = [0, 1, 2].map(|i| {
1976 &self.names[&NameKey::ExternalTextureFunctionArgument(
1977 handle,
1978 index as u32,
1979 ExternalTextureNameKey::Plane(i),
1980 )]
1981 });
1982 let params_name = &self.names[&NameKey::ExternalTextureFunctionArgument(
1983 handle,
1984 index as u32,
1985 ExternalTextureNameKey::Params,
1986 )];
1987 let params_ty_name =
1988 &self.names[&NameKey::Type(module.special_types.external_texture_params.unwrap())];
1989 write!(
1990 self.out,
1991 "Texture2D<float4> {}, Texture2D<float4> {}, Texture2D<float4> {}, {params_ty_name} {params_name}",
1992 plane_names[0], plane_names[1], plane_names[2],
1993 )?;
1994 Ok(())
1995 }
1996
1997 fn need_workgroup_variables_initialization(
1998 &mut self,
1999 func_ctx: &back::FunctionCtx,
2000 module: &Module,
2001 ) -> bool {
2002 self.options.zero_initialize_workgroup_memory
2003 && func_ctx.ty.is_compute_like_entry_point(module)
2004 && module.global_variables.iter().any(|(handle, var)| {
2005 !func_ctx.info[handle].is_empty() && var.space.is_workgroup_like()
2006 })
2007 }
2008
2009 pub(super) fn write_workgroup_variables_initialization(
2010 &mut self,
2011 func_ctx: &back::FunctionCtx,
2012 module: &Module,
2013 stage: ShaderStage,
2014 ) -> BackendResult {
2015 let vars = module.global_variables.iter().filter(|&(handle, var)| {
2016 let task_needs_zero =
2018 (var.space == crate::AddressSpace::TaskPayload) && stage == ShaderStage::Task;
2019 !func_ctx.info[handle].is_empty()
2020 && (var.space == crate::AddressSpace::WorkGroup || task_needs_zero)
2021 });
2022
2023 for (handle, var) in vars {
2024 let name = &self.names[&NameKey::GlobalVariable(handle)];
2025 write!(self.out, "{}{} = ", back::Level(2), name)?;
2026 self.write_default_init(module, var.ty)?;
2027 writeln!(self.out, ";")?;
2028 }
2029 Ok(())
2030 }
2031
2032 fn write_switch(
2034 &mut self,
2035 module: &Module,
2036 func_ctx: &back::FunctionCtx<'_>,
2037 level: back::Level,
2038 selector: Handle<crate::Expression>,
2039 cases: &[crate::SwitchCase],
2040 ) -> BackendResult {
2041 let indent_level_1 = level.next();
2043 let indent_level_2 = indent_level_1.next();
2044
2045 if let Some(variable) = self.continue_ctx.enter_switch(&mut self.namer) {
2047 writeln!(self.out, "{level}bool {variable} = false;",)?;
2048 };
2049
2050 let one_body = cases
2055 .iter()
2056 .rev()
2057 .skip(1)
2058 .all(|case| case.fall_through && case.body.is_empty());
2059 if one_body {
2060 writeln!(self.out, "{level}do {{")?;
2062 if let Some(case) = cases.last() {
2066 for sta in case.body.iter() {
2067 self.write_stmt(module, sta, func_ctx, indent_level_1)?;
2068 }
2069 }
2070 writeln!(self.out, "{level}}} while(false);")?;
2072 } else {
2073 write!(self.out, "{level}")?;
2075 write!(self.out, "switch(")?;
2076 self.write_expr(module, selector, func_ctx)?;
2077 writeln!(self.out, ") {{")?;
2078
2079 for (i, case) in cases.iter().enumerate() {
2080 match case.value {
2081 crate::SwitchValue::I32(value) => {
2082 write!(self.out, "{indent_level_1}case {value}:")?
2083 }
2084 crate::SwitchValue::U32(value) => {
2085 write!(self.out, "{indent_level_1}case {value}u:")?
2086 }
2087 crate::SwitchValue::Default => write!(self.out, "{indent_level_1}default:")?,
2088 }
2089
2090 let write_block_braces = !(case.fall_through && case.body.is_empty());
2097 if write_block_braces {
2098 writeln!(self.out, " {{")?;
2099 } else {
2100 writeln!(self.out)?;
2101 }
2102
2103 if case.fall_through && !case.body.is_empty() {
2121 let curr_len = i + 1;
2122 let end_case_idx = curr_len
2123 + cases
2124 .iter()
2125 .skip(curr_len)
2126 .position(|case| !case.fall_through)
2127 .unwrap();
2128 let indent_level_3 = indent_level_2.next();
2129 for case in &cases[i..=end_case_idx] {
2130 writeln!(self.out, "{indent_level_2}{{")?;
2131 let prev_len = self.named_expressions.len();
2132 for sta in case.body.iter() {
2133 self.write_stmt(module, sta, func_ctx, indent_level_3)?;
2134 }
2135 self.named_expressions.truncate(prev_len);
2137 writeln!(self.out, "{indent_level_2}}}")?;
2138 }
2139
2140 let last_case = &cases[end_case_idx];
2141 if last_case.body.last().is_none_or(|s| !s.is_terminator()) {
2142 writeln!(self.out, "{indent_level_2}break;")?;
2143 }
2144 } else {
2145 for sta in case.body.iter() {
2146 self.write_stmt(module, sta, func_ctx, indent_level_2)?;
2147 }
2148 if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator()) {
2149 writeln!(self.out, "{indent_level_2}break;")?;
2150 }
2151 }
2152
2153 if write_block_braces {
2154 writeln!(self.out, "{indent_level_1}}}")?;
2155 }
2156 }
2157
2158 writeln!(self.out, "{level}}}")?;
2159 }
2160
2161 use back::continue_forward::ExitControlFlow;
2163 let op = match self.continue_ctx.exit_switch() {
2164 ExitControlFlow::None => None,
2165 ExitControlFlow::Continue { variable } => Some(("continue", variable)),
2166 ExitControlFlow::Break { variable } => Some(("break", variable)),
2167 };
2168 if let Some((control_flow, variable)) = op {
2169 writeln!(self.out, "{level}if ({variable}) {{")?;
2170 writeln!(self.out, "{indent_level_1}{control_flow};")?;
2171 writeln!(self.out, "{level}}}")?;
2172 }
2173
2174 Ok(())
2175 }
2176
2177 fn write_index(
2178 &mut self,
2179 module: &Module,
2180 index: Index,
2181 func_ctx: &back::FunctionCtx<'_>,
2182 ) -> BackendResult {
2183 match index {
2184 Index::Static(index) => {
2185 write!(self.out, "{index}")?;
2186 }
2187 Index::Expression(index) => {
2188 self.write_expr(module, index, func_ctx)?;
2189 }
2190 }
2191 Ok(())
2192 }
2193
2194 fn write_stmt(
2199 &mut self,
2200 module: &Module,
2201 stmt: &crate::Statement,
2202 func_ctx: &back::FunctionCtx<'_>,
2203 level: back::Level,
2204 ) -> BackendResult {
2205 use crate::Statement;
2206
2207 match *stmt {
2208 Statement::Emit(ref range) => {
2209 for handle in range.clone() {
2210 let ptr_class = func_ctx.resolve_type(handle, &module.types).pointer_space();
2211 let expr_name = if ptr_class.is_some() {
2212 None
2216 } else if let Some(name) = func_ctx.named_expressions.get(&handle) {
2217 Some(self.namer.call(name))
2222 } else if self.need_bake_expressions.contains(&handle) {
2223 Some(Baked(handle).to_string())
2224 } else {
2225 None
2226 };
2227
2228 if let Some(name) = expr_name {
2229 write!(self.out, "{level}")?;
2230 self.write_named_expr(module, handle, name, handle, func_ctx)?;
2231 }
2232 }
2233 }
2234 Statement::Block(ref block) => {
2236 write!(self.out, "{level}")?;
2237 writeln!(self.out, "{{")?;
2238 for sta in block.iter() {
2239 self.write_stmt(module, sta, func_ctx, level.next())?
2241 }
2242 writeln!(self.out, "{level}}}")?
2243 }
2244 Statement::If {
2246 condition,
2247 ref accept,
2248 ref reject,
2249 } => {
2250 write!(self.out, "{level}")?;
2251 write!(self.out, "if (")?;
2252 self.write_expr(module, condition, func_ctx)?;
2253 writeln!(self.out, ") {{")?;
2254
2255 let l2 = level.next();
2256 for sta in accept {
2257 self.write_stmt(module, sta, func_ctx, l2)?;
2259 }
2260
2261 if !reject.is_empty() {
2264 writeln!(self.out, "{level}}} else {{")?;
2265
2266 for sta in reject {
2267 self.write_stmt(module, sta, func_ctx, l2)?;
2269 }
2270 }
2271
2272 writeln!(self.out, "{level}}}")?
2273 }
2274 Statement::Kill => writeln!(self.out, "{level}discard;")?,
2276 Statement::Return { value: None } => {
2277 writeln!(self.out, "{level}return;")?;
2278 }
2279 Statement::Return { value: Some(expr) } => {
2280 let base_ty_res = &func_ctx.info[expr].ty;
2281 let mut resolved = base_ty_res.inner_with(&module.types);
2282 if let TypeInner::Pointer { base, space: _ } = *resolved {
2283 resolved = &module.types[base].inner;
2284 }
2285
2286 if let TypeInner::Struct { .. } = *resolved {
2287 let ty = base_ty_res.handle().unwrap();
2289 let struct_name = &self.names[&NameKey::Type(ty)];
2290 let variable_name = self.namer.call(&struct_name.to_lowercase());
2291 write!(self.out, "{level}const {struct_name} {variable_name} = ",)?;
2292 self.write_expr(module, expr, func_ctx)?;
2293 writeln!(self.out, ";")?;
2294
2295 let ep_output = match func_ctx.ty {
2297 back::FunctionType::Function(_) => None,
2298 back::FunctionType::EntryPoint(index) => self
2299 .entry_point_io
2300 .get(&(index as usize))
2301 .unwrap()
2302 .output
2303 .as_ref(),
2304 };
2305 let final_name = match ep_output {
2306 Some(ep_output) => {
2307 let final_name = self.namer.call(&variable_name);
2308 write!(
2309 self.out,
2310 "{}const {} {} = {{ ",
2311 level, ep_output.ty_name, final_name,
2312 )?;
2313 for (index, m) in ep_output.members.iter().enumerate() {
2314 if index != 0 {
2315 write!(self.out, ", ")?;
2316 }
2317 let member_name = &self.names[&NameKey::StructMember(ty, m.index)];
2318 write!(self.out, "{variable_name}.{member_name}")?;
2319 }
2320 writeln!(self.out, " }};")?;
2321 final_name
2322 }
2323 None => variable_name,
2324 };
2325 writeln!(self.out, "{level}return {final_name};")?;
2326 } else {
2327 write!(self.out, "{level}return ")?;
2328 self.write_expr(module, expr, func_ctx)?;
2329 writeln!(self.out, ";")?
2330 }
2331 }
2332 Statement::Store { pointer, value } => {
2333 let ty_inner = func_ctx.resolve_type(pointer, &module.types);
2334 if ty_inner.is_atomic_pointer(&module.types) {
2335 let pointer_space = ty_inner.pointer_space().unwrap();
2336 let dummy = self.namer.call("dummy");
2337 write!(self.out, "{level}{{ ")?;
2338 if let TypeInner::Pointer { base, .. } = *ty_inner {
2339 self.write_value_type(module, &module.types[base].inner)?;
2340 }
2341 write!(self.out, " {dummy} = 0; ")?;
2342 match pointer_space {
2343 crate::AddressSpace::WorkGroup => {
2344 write!(self.out, "InterlockedExchange(")?;
2345 self.write_expr(module, pointer, func_ctx)?;
2346 }
2347 crate::AddressSpace::Storage { .. } => {
2348 let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2349 let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
2350 write!(self.out, "{var_name}.InterlockedExchange(")?;
2351 let chain = mem::take(&mut self.temp_access_chain);
2352 self.write_storage_address(module, &chain, func_ctx)?;
2353 self.temp_access_chain = chain;
2354 }
2355 _ => unreachable!(),
2356 }
2357 write!(self.out, ", ")?;
2358 self.write_expr(module, value, func_ctx)?;
2359 writeln!(self.out, ", {dummy}); }}")?;
2360 } else if let Some(crate::AddressSpace::Storage { .. }) = ty_inner.pointer_space() {
2361 let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2362 self.write_storage_store(
2363 module,
2364 var_handle,
2365 StoreValue::Expression(value),
2366 func_ctx,
2367 level,
2368 None,
2369 )?;
2370 } else {
2371 enum MatrixAccess {
2377 Direct {
2378 base: Handle<crate::Expression>,
2379 index: u32,
2380 },
2381 Struct {
2382 columns: crate::VectorSize,
2383 base: Handle<crate::Expression>,
2384 },
2385 }
2386
2387 let get_members = |expr: Handle<crate::Expression>| {
2388 let resolved = func_ctx.resolve_type(expr, &module.types);
2389 match *resolved {
2390 TypeInner::Pointer { base, .. } => match module.types[base].inner {
2391 TypeInner::Struct { ref members, .. } => Some(members),
2392 _ => None,
2393 },
2394 _ => None,
2395 }
2396 };
2397
2398 write!(self.out, "{level}")?;
2399
2400 let matrix_access_on_lhs =
2401 find_matrix_in_access_chain(module, pointer, func_ctx).and_then(
2402 |(matrix_expr, vector, scalar)| match (
2403 func_ctx.resolve_type(matrix_expr, &module.types),
2404 &func_ctx.expressions[matrix_expr],
2405 ) {
2406 (
2407 &TypeInner::Pointer { base: ty, .. },
2408 &crate::Expression::AccessIndex { base, index },
2409 ) if matches!(
2410 module.types[ty].inner,
2411 TypeInner::Matrix {
2412 rows: crate::VectorSize::Bi,
2413 ..
2414 }
2415 ) && get_members(base)
2416 .map(|members| members[index as usize].binding.is_none())
2417 == Some(true) =>
2418 {
2419 Some((MatrixAccess::Direct { base, index }, vector, scalar))
2420 }
2421 _ => {
2422 if let Some(MatrixType {
2423 columns,
2424 rows: crate::VectorSize::Bi,
2425 width: 4,
2426 }) = get_inner_matrix_of_struct_array_member(
2427 module,
2428 matrix_expr,
2429 func_ctx,
2430 true,
2431 ) {
2432 Some((
2433 MatrixAccess::Struct {
2434 columns,
2435 base: matrix_expr,
2436 },
2437 vector,
2438 scalar,
2439 ))
2440 } else {
2441 None
2442 }
2443 }
2444 },
2445 );
2446
2447 match matrix_access_on_lhs {
2448 Some((MatrixAccess::Direct { index, base }, vector, scalar)) => {
2449 let base_ty_res = &func_ctx.info[base].ty;
2450 let resolved = base_ty_res.inner_with(&module.types);
2451 let ty = match *resolved {
2452 TypeInner::Pointer { base, .. } => base,
2453 _ => base_ty_res.handle().unwrap(),
2454 };
2455
2456 if let Some(Index::Static(vec_index)) = vector {
2457 self.write_expr(module, base, func_ctx)?;
2458 write!(
2459 self.out,
2460 ".{}_{}",
2461 &self.names[&NameKey::StructMember(ty, index)],
2462 vec_index
2463 )?;
2464
2465 if let Some(scalar_index) = scalar {
2466 write!(self.out, "[")?;
2467 self.write_index(module, scalar_index, func_ctx)?;
2468 write!(self.out, "]")?;
2469 }
2470
2471 write!(self.out, " = ")?;
2472 self.write_expr(module, value, func_ctx)?;
2473 writeln!(self.out, ";")?;
2474 } else {
2475 let access = WrappedStructMatrixAccess { ty, index };
2476 match (&vector, &scalar) {
2477 (&Some(_), &Some(_)) => {
2478 self.write_wrapped_struct_matrix_set_scalar_function_name(
2479 access,
2480 )?;
2481 }
2482 (&Some(_), &None) => {
2483 self.write_wrapped_struct_matrix_set_vec_function_name(
2484 access,
2485 )?;
2486 }
2487 (&None, _) => {
2488 self.write_wrapped_struct_matrix_set_function_name(access)?;
2489 }
2490 }
2491
2492 write!(self.out, "(")?;
2493 self.write_expr(module, base, func_ctx)?;
2494 write!(self.out, ", ")?;
2495 self.write_expr(module, value, func_ctx)?;
2496
2497 if let Some(Index::Expression(vec_index)) = vector {
2498 write!(self.out, ", ")?;
2499 self.write_expr(module, vec_index, func_ctx)?;
2500
2501 if let Some(scalar_index) = scalar {
2502 write!(self.out, ", ")?;
2503 self.write_index(module, scalar_index, func_ctx)?;
2504 }
2505 }
2506 writeln!(self.out, ");")?;
2507 }
2508 }
2509 Some((
2510 MatrixAccess::Struct { columns, base },
2511 Some(Index::Expression(vec_index)),
2512 scalar,
2513 )) => {
2514 if scalar.is_some() {
2518 write!(self.out, "__set_el_of_mat{}x2", columns as u8)?;
2519 } else {
2520 write!(self.out, "__set_col_of_mat{}x2", columns as u8)?;
2521 }
2522 write!(self.out, "(")?;
2523 self.write_expr(module, base, func_ctx)?;
2524 write!(self.out, ", ")?;
2525 self.write_expr(module, vec_index, func_ctx)?;
2526
2527 if let Some(scalar_index) = scalar {
2528 write!(self.out, ", ")?;
2529 self.write_index(module, scalar_index, func_ctx)?;
2530 }
2531
2532 write!(self.out, ", ")?;
2533 self.write_expr(module, value, func_ctx)?;
2534
2535 writeln!(self.out, ");")?;
2536 }
2537 Some((MatrixAccess::Struct { .. }, Some(Index::Static(_)), _))
2538 | Some((MatrixAccess::Struct { .. }, None, _))
2539 | None => {
2540 self.write_expr(module, pointer, func_ctx)?;
2541 write!(self.out, " = ")?;
2542
2543 if let Some(MatrixType {
2548 columns,
2549 rows: crate::VectorSize::Bi,
2550 width: 4,
2551 }) = get_inner_matrix_of_struct_array_member(
2552 module, pointer, func_ctx, false,
2553 ) {
2554 let mut resolved = func_ctx.resolve_type(pointer, &module.types);
2555 if let TypeInner::Pointer { base, .. } = *resolved {
2556 resolved = &module.types[base].inner;
2557 }
2558
2559 write!(self.out, "(__mat{}x2", columns as u8)?;
2560 if let TypeInner::Array { base, size, .. } = *resolved {
2561 self.write_array_size(module, base, size)?;
2562 }
2563 write!(self.out, ")")?;
2564 }
2565
2566 self.write_expr(module, value, func_ctx)?;
2567 writeln!(self.out, ";")?
2568 }
2569 }
2570 }
2571 }
2572 Statement::Loop {
2573 ref body,
2574 ref continuing,
2575 break_if,
2576 } => {
2577 let force_loop_bound_statements = self.gen_force_bounded_loop_statements(level);
2578 let gate_name = (!continuing.is_empty() || break_if.is_some())
2579 .then(|| self.namer.call("loop_init"));
2580
2581 if let Some((ref decl, _)) = force_loop_bound_statements {
2582 writeln!(self.out, "{decl}")?;
2583 }
2584 if let Some(ref gate_name) = gate_name {
2585 writeln!(self.out, "{level}bool {gate_name} = true;")?;
2586 }
2587
2588 self.continue_ctx.enter_loop();
2589 writeln!(self.out, "{level}while(true) {{")?;
2590 if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
2591 writeln!(self.out, "{break_and_inc}")?;
2592 }
2593 let l2 = level.next();
2594 if let Some(gate_name) = gate_name {
2595 writeln!(self.out, "{l2}if (!{gate_name}) {{")?;
2596 let l3 = l2.next();
2597 for sta in continuing.iter() {
2598 self.write_stmt(module, sta, func_ctx, l3)?;
2599 }
2600 if let Some(condition) = break_if {
2601 write!(self.out, "{l3}if (")?;
2602 self.write_expr(module, condition, func_ctx)?;
2603 writeln!(self.out, ") {{")?;
2604 writeln!(self.out, "{}break;", l3.next())?;
2605 writeln!(self.out, "{l3}}}")?;
2606 }
2607 writeln!(self.out, "{l2}}}")?;
2608 writeln!(self.out, "{l2}{gate_name} = false;")?;
2609 }
2610
2611 for sta in body.iter() {
2612 self.write_stmt(module, sta, func_ctx, l2)?;
2613 }
2614
2615 writeln!(self.out, "{level}}}")?;
2616 self.continue_ctx.exit_loop();
2617 }
2618 Statement::Break => writeln!(self.out, "{level}break;")?,
2619 Statement::Continue => {
2620 if let Some(variable) = self.continue_ctx.continue_encountered() {
2621 writeln!(self.out, "{level}{variable} = true;")?;
2622 writeln!(self.out, "{level}break;")?
2623 } else {
2624 writeln!(self.out, "{level}continue;")?
2625 }
2626 }
2627 Statement::ControlBarrier(barrier) => {
2628 self.write_control_barrier(barrier, level)?;
2629 }
2630 Statement::MemoryBarrier(barrier) => {
2631 self.write_memory_barrier(barrier, level)?;
2632 }
2633 Statement::ImageStore {
2634 image,
2635 coordinate,
2636 array_index,
2637 value,
2638 } => {
2639 write!(self.out, "{level}")?;
2640 self.write_expr(module, image, func_ctx)?;
2641
2642 write!(self.out, "[")?;
2643 if let Some(index) = array_index {
2644 write!(self.out, "int3(")?;
2646 self.write_expr(module, coordinate, func_ctx)?;
2647 write!(self.out, ", ")?;
2648 self.write_expr(module, index, func_ctx)?;
2649 write!(self.out, ")")?;
2650 } else {
2651 self.write_expr(module, coordinate, func_ctx)?;
2652 }
2653 write!(self.out, "]")?;
2654
2655 write!(self.out, " = ")?;
2656 self.write_expr(module, value, func_ctx)?;
2657 writeln!(self.out, ";")?;
2658 }
2659 Statement::Call {
2660 function,
2661 ref arguments,
2662 result,
2663 } => {
2664 write!(self.out, "{level}")?;
2665
2666 if let Some(expr) = result {
2667 write!(self.out, "const ")?;
2668 let name = Baked(expr).to_string();
2669 let expr_ty = &func_ctx.info[expr].ty;
2670 let ty_inner = match *expr_ty {
2671 proc::TypeResolution::Handle(handle) => {
2672 self.write_type(module, handle)?;
2673 &module.types[handle].inner
2674 }
2675 proc::TypeResolution::Value(ref value) => {
2676 self.write_value_type(module, value)?;
2677 value
2678 }
2679 };
2680 write!(self.out, " {name}")?;
2681 if let TypeInner::Array { base, size, .. } = *ty_inner {
2682 self.write_array_size(module, base, size)?;
2683 }
2684 write!(self.out, " = ")?;
2685 self.named_expressions.insert(expr, name);
2686 }
2687 let func_name = &self.names[&NameKey::Function(function)];
2688 write!(self.out, "{func_name}(")?;
2689 let mut any_args_written = false;
2690 let mut separator = || {
2691 if any_args_written {
2692 ", "
2693 } else {
2694 any_args_written = true;
2695 ""
2696 }
2697 };
2698 for argument in arguments {
2699 write!(self.out, "{}", separator())?;
2700 self.write_expr(module, *argument, func_ctx)?;
2701 }
2702 if let Some(&var) = self.function_task_payload_var.get(&function) {
2703 let name = &self.names[&NameKey::GlobalVariable(var)];
2704 write!(self.out, "{}{name}", separator())?;
2706 }
2707 writeln!(self.out, ");")?;
2708 }
2709 Statement::Atomic {
2710 pointer,
2711 ref fun,
2712 value,
2713 result,
2714 } => {
2715 write!(self.out, "{level}")?;
2716 let res_var_info = if let Some(res_handle) = result {
2717 let name = Baked(res_handle).to_string();
2718 match func_ctx.info[res_handle].ty {
2719 proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2720 proc::TypeResolution::Value(ref value) => {
2721 self.write_value_type(module, value)?
2722 }
2723 };
2724 write!(self.out, " {name}; ")?;
2725 self.named_expressions.insert(res_handle, name.clone());
2726 Some((res_handle, name))
2727 } else {
2728 None
2729 };
2730 let pointer_space = func_ctx
2731 .resolve_type(pointer, &module.types)
2732 .pointer_space()
2733 .unwrap();
2734 let fun_str = fun.to_hlsl_suffix();
2735 let compare_expr = match *fun {
2736 crate::AtomicFunction::Exchange { compare: Some(cmp) } => Some(cmp),
2737 _ => None,
2738 };
2739 match pointer_space {
2740 crate::AddressSpace::WorkGroup => {
2741 write!(self.out, "Interlocked{fun_str}(")?;
2742 self.write_expr(module, pointer, func_ctx)?;
2743 self.emit_hlsl_atomic_tail(
2744 module,
2745 func_ctx,
2746 fun,
2747 compare_expr,
2748 value,
2749 &res_var_info,
2750 )?;
2751 }
2752 crate::AddressSpace::Storage { .. } => {
2753 let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2754 let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
2755 let width = match func_ctx.resolve_type(value, &module.types) {
2756 &TypeInner::Scalar(Scalar { width: 8, .. }) => "64",
2757 _ => "",
2758 };
2759 write!(self.out, "{var_name}.Interlocked{fun_str}{width}(")?;
2760 let chain = mem::take(&mut self.temp_access_chain);
2761 self.write_storage_address(module, &chain, func_ctx)?;
2762 self.temp_access_chain = chain;
2763 self.emit_hlsl_atomic_tail(
2764 module,
2765 func_ctx,
2766 fun,
2767 compare_expr,
2768 value,
2769 &res_var_info,
2770 )?;
2771 }
2772 ref other => {
2773 return Err(Error::Custom(format!(
2774 "invalid address space {other:?} for atomic statement"
2775 )))
2776 }
2777 }
2778 if let Some(cmp) = compare_expr {
2779 if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
2780 write!(
2781 self.out,
2782 "{level}{res_name}.exchanged = ({res_name}.old_value == "
2783 )?;
2784 self.write_expr(module, cmp, func_ctx)?;
2785 writeln!(self.out, ");")?;
2786 }
2787 }
2788 }
2789 Statement::ImageAtomic {
2790 image,
2791 coordinate,
2792 array_index,
2793 fun,
2794 value,
2795 } => {
2796 write!(self.out, "{level}")?;
2797
2798 let fun_str = fun.to_hlsl_suffix();
2799 write!(self.out, "Interlocked{fun_str}(")?;
2800 self.write_expr(module, image, func_ctx)?;
2801 write!(self.out, "[")?;
2802 self.write_texture_coordinates(
2803 "int",
2804 coordinate,
2805 array_index,
2806 None,
2807 module,
2808 func_ctx,
2809 )?;
2810 write!(self.out, "],")?;
2811
2812 self.write_expr(module, value, func_ctx)?;
2813 writeln!(self.out, ");")?;
2814 }
2815 Statement::WorkGroupUniformLoad { pointer, result } => {
2816 self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
2817 write!(self.out, "{level}")?;
2818 let name = Baked(result).to_string();
2819 self.write_named_expr(module, pointer, name, result, func_ctx)?;
2820
2821 self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
2822 }
2823 Statement::Switch {
2824 selector,
2825 ref cases,
2826 } => {
2827 self.write_switch(module, func_ctx, level, selector, cases)?;
2828 }
2829 Statement::RayQuery { query, ref fun } => {
2830 let crate::Expression::LocalVariable(query_var) = func_ctx.expressions[query]
2842 else {
2843 unreachable!()
2844 };
2845
2846 let tracker_expr_name = format!(
2847 "{RAY_QUERY_TRACKER_VARIABLE_PREFIX}{}",
2848 self.names[&func_ctx.name_key(query_var)]
2849 );
2850
2851 match *fun {
2852 RayQueryFunction::Initialize {
2853 acceleration_structure,
2854 descriptor,
2855 } => {
2856 self.write_initialize_function(
2857 module,
2858 level,
2859 query,
2860 acceleration_structure,
2861 descriptor,
2862 &tracker_expr_name,
2863 func_ctx,
2864 )?;
2865 }
2866 RayQueryFunction::Proceed { result } => {
2867 self.write_proceed(
2868 module,
2869 level,
2870 query,
2871 result,
2872 &tracker_expr_name,
2873 func_ctx,
2874 )?;
2875 }
2876 RayQueryFunction::GenerateIntersection { hit_t } => {
2877 self.write_generate_intersection(
2878 module,
2879 level,
2880 query,
2881 hit_t,
2882 &tracker_expr_name,
2883 func_ctx,
2884 )?;
2885 }
2886 RayQueryFunction::ConfirmIntersection => {
2887 self.write_confirm_intersection(
2888 module,
2889 level,
2890 query,
2891 &tracker_expr_name,
2892 func_ctx,
2893 )?;
2894 }
2895 RayQueryFunction::Terminate => {
2896 self.write_terminate(module, level, query, &tracker_expr_name, func_ctx)?;
2897 }
2898 }
2899 }
2900 Statement::SubgroupBallot { result, predicate } => {
2901 write!(self.out, "{level}")?;
2902 let name = Baked(result).to_string();
2903 write!(self.out, "const uint4 {name} = ")?;
2904 self.named_expressions.insert(result, name);
2905
2906 write!(self.out, "WaveActiveBallot(")?;
2907 match predicate {
2908 Some(predicate) => self.write_expr(module, predicate, func_ctx)?,
2909 None => write!(self.out, "true")?,
2910 }
2911 writeln!(self.out, ");")?;
2912 }
2913 Statement::SubgroupCollectiveOperation {
2914 op,
2915 collective_op,
2916 argument,
2917 result,
2918 } => {
2919 write!(self.out, "{level}")?;
2920 write!(self.out, "const ")?;
2921 let name = Baked(result).to_string();
2922 match func_ctx.info[result].ty {
2923 proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2924 proc::TypeResolution::Value(ref value) => {
2925 self.write_value_type(module, value)?
2926 }
2927 };
2928 write!(self.out, " {name} = ")?;
2929 self.named_expressions.insert(result, name);
2930
2931 match (collective_op, op) {
2932 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
2933 write!(self.out, "WaveActiveAllTrue(")?
2934 }
2935 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
2936 write!(self.out, "WaveActiveAnyTrue(")?
2937 }
2938 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
2939 write!(self.out, "WaveActiveSum(")?
2940 }
2941 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
2942 write!(self.out, "WaveActiveProduct(")?
2943 }
2944 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
2945 write!(self.out, "WaveActiveMax(")?
2946 }
2947 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
2948 write!(self.out, "WaveActiveMin(")?
2949 }
2950 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
2951 write!(self.out, "WaveActiveBitAnd(")?
2952 }
2953 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
2954 write!(self.out, "WaveActiveBitOr(")?
2955 }
2956 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
2957 write!(self.out, "WaveActiveBitXor(")?
2958 }
2959 (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
2960 write!(self.out, "WavePrefixSum(")?
2961 }
2962 (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
2963 write!(self.out, "WavePrefixProduct(")?
2964 }
2965 (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
2966 self.write_expr(module, argument, func_ctx)?;
2967 write!(self.out, " + WavePrefixSum(")?;
2968 }
2969 (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
2970 self.write_expr(module, argument, func_ctx)?;
2971 write!(self.out, " * WavePrefixProduct(")?;
2972 }
2973 _ => unimplemented!(),
2974 }
2975 self.write_expr(module, argument, func_ctx)?;
2976 writeln!(self.out, ");")?;
2977 }
2978 Statement::SubgroupGather {
2979 mode,
2980 argument,
2981 result,
2982 } => {
2983 write!(self.out, "{level}")?;
2984 write!(self.out, "const ")?;
2985 let name = Baked(result).to_string();
2986 match func_ctx.info[result].ty {
2987 proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2988 proc::TypeResolution::Value(ref value) => {
2989 self.write_value_type(module, value)?
2990 }
2991 };
2992 write!(self.out, " {name} = ")?;
2993 self.named_expressions.insert(result, name);
2994 match mode {
2995 crate::GatherMode::BroadcastFirst => {
2996 write!(self.out, "WaveReadLaneFirst(")?;
2997 self.write_expr(module, argument, func_ctx)?;
2998 }
2999 crate::GatherMode::QuadBroadcast(index) => {
3000 write!(self.out, "QuadReadLaneAt(")?;
3001 self.write_expr(module, argument, func_ctx)?;
3002 write!(self.out, ", ")?;
3003 self.write_expr(module, index, func_ctx)?;
3004 }
3005 crate::GatherMode::QuadSwap(direction) => {
3006 match direction {
3007 crate::Direction::X => {
3008 write!(self.out, "QuadReadAcrossX(")?;
3009 }
3010 crate::Direction::Y => {
3011 write!(self.out, "QuadReadAcrossY(")?;
3012 }
3013 crate::Direction::Diagonal => {
3014 write!(self.out, "QuadReadAcrossDiagonal(")?;
3015 }
3016 }
3017 self.write_expr(module, argument, func_ctx)?;
3018 }
3019 _ => {
3020 write!(self.out, "WaveReadLaneAt(")?;
3021 self.write_expr(module, argument, func_ctx)?;
3022 write!(self.out, ", ")?;
3023 match mode {
3024 crate::GatherMode::BroadcastFirst => unreachable!(),
3025 crate::GatherMode::Broadcast(index)
3026 | crate::GatherMode::Shuffle(index) => {
3027 self.write_expr(module, index, func_ctx)?;
3028 }
3029 crate::GatherMode::ShuffleDown(index) => {
3030 write!(self.out, "WaveGetLaneIndex() + ")?;
3031 self.write_expr(module, index, func_ctx)?;
3032 }
3033 crate::GatherMode::ShuffleUp(index) => {
3034 write!(self.out, "WaveGetLaneIndex() - ")?;
3035 self.write_expr(module, index, func_ctx)?;
3036 }
3037 crate::GatherMode::ShuffleXor(index) => {
3038 write!(self.out, "WaveGetLaneIndex() ^ ")?;
3039 self.write_expr(module, index, func_ctx)?;
3040 }
3041 crate::GatherMode::QuadBroadcast(_) => unreachable!(),
3042 crate::GatherMode::QuadSwap(_) => unreachable!(),
3043 }
3044 }
3045 }
3046 writeln!(self.out, ");")?;
3047 }
3048 Statement::CooperativeStore { .. } => unimplemented!(),
3049 Statement::RayPipelineFunction(_) => unreachable!(),
3050 }
3051
3052 Ok(())
3053 }
3054
3055 fn write_const_expression(
3056 &mut self,
3057 module: &Module,
3058 expr: Handle<crate::Expression>,
3059 arena: &crate::Arena<crate::Expression>,
3060 ) -> BackendResult {
3061 self.write_possibly_const_expression(module, expr, arena, |writer, expr| {
3062 writer.write_const_expression(module, expr, arena)
3063 })
3064 }
3065
3066 pub(super) fn write_literal(&mut self, literal: crate::Literal) -> BackendResult {
3067 match literal {
3068 crate::Literal::F64(value) => write!(self.out, "{value:?}L")?,
3069 crate::Literal::F32(value) => write!(self.out, "{value:?}")?,
3070 crate::Literal::F16(value) => write!(self.out, "{value:?}h")?,
3071 crate::Literal::U32(value) => write!(self.out, "{value}u")?,
3072 crate::Literal::I32(value) if value == i32::MIN => {
3078 write!(self.out, "int({} - 1)", value + 1)?
3079 }
3080 crate::Literal::I32(value) => write!(self.out, "int({value})")?,
3084 crate::Literal::U64(value) => write!(self.out, "{value}uL")?,
3085 crate::Literal::I64(value) if value == i64::MIN => {
3087 write!(self.out, "({}L - 1L)", value + 1)?;
3088 }
3089 crate::Literal::I64(value) => write!(self.out, "{value}L")?,
3090 crate::Literal::Bool(value) => write!(self.out, "{value}")?,
3091 crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
3092 return Err(Error::Custom(
3093 "Abstract types should not appear in IR presented to backends".into(),
3094 ));
3095 }
3096 }
3097 Ok(())
3098 }
3099
3100 fn write_possibly_const_expression<E>(
3101 &mut self,
3102 module: &Module,
3103 expr: Handle<crate::Expression>,
3104 expressions: &crate::Arena<crate::Expression>,
3105 write_expression: E,
3106 ) -> BackendResult
3107 where
3108 E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult,
3109 {
3110 use crate::Expression;
3111
3112 match expressions[expr] {
3113 Expression::Literal(literal) => {
3114 self.write_literal(literal)?;
3115 }
3116 Expression::Constant(handle) => {
3117 let constant = &module.constants[handle];
3118 if constant.name.is_some() {
3119 write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
3120 } else {
3121 self.write_const_expression(module, constant.init, &module.global_expressions)?;
3122 }
3123 }
3124 Expression::ZeroValue(ty) => {
3125 self.write_wrapped_zero_value_function_name(module, WrappedZeroValue { ty })?;
3126 write!(self.out, "()")?;
3127 }
3128 Expression::Compose { ty, ref components } => {
3129 match module.types[ty].inner {
3130 TypeInner::Struct { .. } | TypeInner::Array { .. } => {
3131 self.write_wrapped_constructor_function_name(
3132 module,
3133 WrappedConstructor { ty },
3134 )?;
3135 }
3136 _ => {
3137 self.write_type(module, ty)?;
3138 }
3139 };
3140 write!(self.out, "(")?;
3141 for (index, component) in components.iter().enumerate() {
3142 if index != 0 {
3143 write!(self.out, ", ")?;
3144 }
3145 write_expression(self, *component)?;
3146 }
3147 write!(self.out, ")")?;
3148 }
3149 Expression::Splat { size, value } => {
3150 let number_of_components = match size {
3154 crate::VectorSize::Bi => "xx",
3155 crate::VectorSize::Tri => "xxx",
3156 crate::VectorSize::Quad => "xxxx",
3157 };
3158 write!(self.out, "(")?;
3159 write_expression(self, value)?;
3160 write!(self.out, ").{number_of_components}")?
3161 }
3162 _ => {
3163 return Err(Error::Override);
3164 }
3165 }
3166
3167 Ok(())
3168 }
3169
3170 pub(super) fn write_expr(
3175 &mut self,
3176 module: &Module,
3177 expr: Handle<crate::Expression>,
3178 func_ctx: &back::FunctionCtx<'_>,
3179 ) -> BackendResult {
3180 use crate::Expression;
3181
3182 let ff_input = if self.options.special_constants_binding.is_some() {
3184 func_ctx.is_fixed_function_input(expr, module)
3185 } else {
3186 None
3187 };
3188 let closing_bracket = match ff_input {
3189 Some(crate::BuiltIn::VertexIndex) => {
3190 write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX} + ")?;
3191 ")"
3192 }
3193 Some(crate::BuiltIn::InstanceIndex) => {
3194 write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE} + ",)?;
3195 ")"
3196 }
3197 Some(crate::BuiltIn::NumWorkGroups) => {
3198 write!(
3202 self.out,
3203 "uint3({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX}, {SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE}, {SPECIAL_CBUF_VAR}.{SPECIAL_OTHER})",
3204 )?;
3205 return Ok(());
3206 }
3207 _ => "",
3208 };
3209
3210 if let Some(name) = self.named_expressions.get(&expr) {
3211 write!(self.out, "{name}{closing_bracket}")?;
3212 return Ok(());
3213 }
3214
3215 let expression = &func_ctx.expressions[expr];
3216
3217 match *expression {
3218 Expression::Literal(_)
3219 | Expression::Constant(_)
3220 | Expression::ZeroValue(_)
3221 | Expression::Compose { .. }
3222 | Expression::Splat { .. } => {
3223 self.write_possibly_const_expression(
3224 module,
3225 expr,
3226 func_ctx.expressions,
3227 |writer, expr| writer.write_expr(module, expr, func_ctx),
3228 )?;
3229 }
3230 Expression::Override(_) => return Err(Error::Override),
3231 Expression::Binary {
3238 op:
3239 op @ crate::BinaryOperator::Add
3240 | op @ crate::BinaryOperator::Subtract
3241 | op @ crate::BinaryOperator::Multiply,
3242 left,
3243 right,
3244 } if matches!(
3245 func_ctx.resolve_type(expr, &module.types).scalar(),
3246 Some(Scalar::I32)
3247 ) =>
3248 {
3249 write!(self.out, "asint(asuint(",)?;
3250 self.write_expr(module, left, func_ctx)?;
3251 write!(self.out, ") {} asuint(", back::binary_operation_str(op))?;
3252 self.write_expr(module, right, func_ctx)?;
3253 write!(self.out, "))")?;
3254 }
3255 Expression::Binary {
3258 op: crate::BinaryOperator::Multiply,
3259 left,
3260 right,
3261 } if func_ctx.resolve_type(left, &module.types).is_matrix()
3262 || func_ctx.resolve_type(right, &module.types).is_matrix() =>
3263 {
3264 write!(self.out, "mul(")?;
3266 self.write_expr(module, right, func_ctx)?;
3267 write!(self.out, ", ")?;
3268 self.write_expr(module, left, func_ctx)?;
3269 write!(self.out, ")")?;
3270 }
3271
3272 Expression::Binary {
3284 op: crate::BinaryOperator::Divide,
3285 left,
3286 right,
3287 } if matches!(
3288 func_ctx.resolve_type(expr, &module.types).scalar_kind(),
3289 Some(ScalarKind::Sint | ScalarKind::Uint)
3290 ) =>
3291 {
3292 write!(self.out, "{DIV_FUNCTION}(")?;
3293 self.write_expr(module, left, func_ctx)?;
3294 write!(self.out, ", ")?;
3295 self.write_expr(module, right, func_ctx)?;
3296 write!(self.out, ")")?;
3297 }
3298
3299 Expression::Binary {
3300 op: crate::BinaryOperator::Modulo,
3301 left,
3302 right,
3303 } if matches!(
3304 func_ctx.resolve_type(expr, &module.types).scalar_kind(),
3305 Some(ScalarKind::Sint | ScalarKind::Uint | ScalarKind::Float)
3306 ) =>
3307 {
3308 write!(self.out, "{MOD_FUNCTION}(")?;
3309 self.write_expr(module, left, func_ctx)?;
3310 write!(self.out, ", ")?;
3311 self.write_expr(module, right, func_ctx)?;
3312 write!(self.out, ")")?;
3313 }
3314
3315 Expression::Binary { op, left, right } => {
3316 write!(self.out, "(")?;
3317 self.write_expr(module, left, func_ctx)?;
3318 write!(self.out, " {} ", back::binary_operation_str(op))?;
3319 self.write_expr(module, right, func_ctx)?;
3320 write!(self.out, ")")?;
3321 }
3322 Expression::Access { base, index } => {
3323 if let Some(crate::AddressSpace::Storage { .. }) =
3324 func_ctx.resolve_type(expr, &module.types).pointer_space()
3325 {
3326 } else {
3328 if let Some(MatrixType {
3335 columns,
3336 rows: crate::VectorSize::Bi,
3337 width: 4,
3338 }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
3339 .or_else(|| get_global_uniform_matrix(module, base, func_ctx))
3340 {
3341 write!(self.out, "__get_col_of_mat{}x2(", columns as u8)?;
3342 self.write_expr(module, base, func_ctx)?;
3343 write!(self.out, ", ")?;
3344 self.write_expr(module, index, func_ctx)?;
3345 write!(self.out, ")")?;
3346 return Ok(());
3347 }
3348
3349 let resolved = func_ctx.resolve_type(base, &module.types);
3350
3351 let (indexing_binding_array, non_uniform_qualifier) = match *resolved {
3352 TypeInner::BindingArray { .. } => {
3353 let uniformity = &func_ctx.info[index].uniformity;
3354
3355 (true, uniformity.non_uniform_result.is_some())
3356 }
3357 _ => (false, false),
3358 };
3359
3360 self.write_expr(module, base, func_ctx)?;
3361
3362 let array_sampler_info = self.sampler_binding_array_info_from_expression(
3363 module, func_ctx, base, resolved,
3364 );
3365
3366 if let Some(ref info) = array_sampler_info {
3367 write!(self.out, "{}[", info.sampler_heap_name)?;
3368 } else {
3369 write!(self.out, "[")?;
3370 }
3371
3372 let needs_bound_check = self.options.restrict_indexing
3373 && !indexing_binding_array
3374 && match resolved.pointer_space() {
3375 Some(
3376 crate::AddressSpace::Function
3377 | crate::AddressSpace::Private
3378 | crate::AddressSpace::WorkGroup
3379 | crate::AddressSpace::Immediate
3380 | crate::AddressSpace::TaskPayload
3381 | crate::AddressSpace::RayPayload
3382 | crate::AddressSpace::IncomingRayPayload,
3383 )
3384 | None => true,
3385 Some(crate::AddressSpace::Uniform) => {
3386 let var_handle = self.fill_access_chain(module, base, func_ctx)?;
3388 let bind_target = self
3389 .options
3390 .resolve_resource_binding(
3391 module.global_variables[var_handle]
3392 .binding
3393 .as_ref()
3394 .unwrap(),
3395 )
3396 .unwrap();
3397 bind_target.restrict_indexing
3398 }
3399 Some(
3400 crate::AddressSpace::Handle | crate::AddressSpace::Storage { .. },
3401 ) => unreachable!(),
3402 };
3403 let restriction_needed = if needs_bound_check {
3405 index::access_needs_check(
3406 base,
3407 index::GuardedIndex::Expression(index),
3408 module,
3409 func_ctx.expressions,
3410 func_ctx.info,
3411 )
3412 } else {
3413 None
3414 };
3415 if let Some(limit) = restriction_needed {
3416 write!(self.out, "min(uint(")?;
3417 self.write_expr(module, index, func_ctx)?;
3418 write!(self.out, "), ")?;
3419 match limit {
3420 index::IndexableLength::Known(limit) => {
3421 write!(self.out, "{}u", limit - 1)?;
3422 }
3423 index::IndexableLength::Dynamic => unreachable!(),
3424 }
3425 write!(self.out, ")")?;
3426 } else {
3427 if non_uniform_qualifier {
3428 write!(self.out, "NonUniformResourceIndex(")?;
3429 }
3430 if let Some(ref info) = array_sampler_info {
3431 write!(
3432 self.out,
3433 "{}[{} + ",
3434 info.sampler_index_buffer_name, info.binding_array_base_index_name,
3435 )?;
3436 }
3437 self.write_expr(module, index, func_ctx)?;
3438 if array_sampler_info.is_some() {
3439 write!(self.out, "]")?;
3440 }
3441 if non_uniform_qualifier {
3442 write!(self.out, ")")?;
3443 }
3444 }
3445
3446 write!(self.out, "]")?;
3447 }
3448 }
3449 Expression::AccessIndex { base, index } => {
3450 if let Some(crate::AddressSpace::Storage { .. }) =
3451 func_ctx.resolve_type(expr, &module.types).pointer_space()
3452 {
3453 } else {
3455 if let Some(MatrixType {
3459 rows: crate::VectorSize::Bi,
3460 width: 4,
3461 ..
3462 }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
3463 .or_else(|| get_global_uniform_matrix(module, base, func_ctx))
3464 {
3465 self.write_expr(module, base, func_ctx)?;
3466 write!(self.out, "._{index}")?;
3467 return Ok(());
3468 }
3469
3470 let base_ty_res = &func_ctx.info[base].ty;
3471 let mut resolved = base_ty_res.inner_with(&module.types);
3472 let base_ty_handle = match *resolved {
3473 TypeInner::Pointer { base, .. } => {
3474 resolved = &module.types[base].inner;
3475 Some(base)
3476 }
3477 _ => base_ty_res.handle(),
3478 };
3479
3480 if let TypeInner::Struct { ref members, .. } = *resolved {
3486 let member = &members[index as usize];
3487
3488 match module.types[member.ty].inner {
3489 TypeInner::Matrix {
3490 rows: crate::VectorSize::Bi,
3491 ..
3492 } if member.binding.is_none() => {
3493 let ty = base_ty_handle.unwrap();
3494 self.write_wrapped_struct_matrix_get_function_name(
3495 WrappedStructMatrixAccess { ty, index },
3496 )?;
3497 write!(self.out, "(")?;
3498 self.write_expr(module, base, func_ctx)?;
3499 write!(self.out, ")")?;
3500 return Ok(());
3501 }
3502 _ => {}
3503 }
3504 }
3505
3506 let array_sampler_info = self.sampler_binding_array_info_from_expression(
3507 module, func_ctx, base, resolved,
3508 );
3509
3510 if let Some(ref info) = array_sampler_info {
3511 write!(
3512 self.out,
3513 "{}[{}",
3514 info.sampler_heap_name, info.sampler_index_buffer_name
3515 )?;
3516 }
3517
3518 self.write_expr(module, base, func_ctx)?;
3519
3520 match *resolved {
3521 TypeInner::Vector { .. } | TypeInner::ValuePointer { .. } => {
3527 write!(self.out, ".{}", back::COMPONENTS[index as usize])?
3529 }
3530 TypeInner::Matrix { .. }
3531 | TypeInner::Array { .. }
3532 | TypeInner::BindingArray { .. } => {
3533 if let Some(ref info) = array_sampler_info {
3534 write!(
3535 self.out,
3536 "[{} + {index}]",
3537 info.binding_array_base_index_name
3538 )?;
3539 } else {
3540 write!(self.out, "[{index}]")?;
3541 }
3542 }
3543 TypeInner::Struct { .. } => {
3544 let ty = base_ty_handle.unwrap();
3547
3548 write!(
3549 self.out,
3550 ".{}",
3551 &self.names[&NameKey::StructMember(ty, index)]
3552 )?
3553 }
3554 ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))),
3555 }
3556
3557 if array_sampler_info.is_some() {
3558 write!(self.out, "]")?;
3559 }
3560 }
3561 }
3562 Expression::FunctionArgument(pos) => {
3563 let ty = func_ctx.resolve_type(expr, &module.types);
3564
3565 if let TypeInner::Image {
3571 class: crate::ImageClass::External,
3572 ..
3573 } = *ty
3574 {
3575 let plane_names = [0, 1, 2].map(|i| {
3576 &self.names[&func_ctx
3577 .external_texture_argument_key(pos, ExternalTextureNameKey::Plane(i))]
3578 });
3579 let params_name = &self.names[&func_ctx
3580 .external_texture_argument_key(pos, ExternalTextureNameKey::Params)];
3581 write!(
3582 self.out,
3583 "{}, {}, {}, {}",
3584 plane_names[0], plane_names[1], plane_names[2], params_name
3585 )?;
3586 } else {
3587 let key = func_ctx.argument_key(pos);
3588 let name = &self.names[&key];
3589 write!(self.out, "{name}")?;
3590 }
3591 }
3592 Expression::ImageSample {
3593 coordinate,
3594 image,
3595 sampler,
3596 clamp_to_edge: true,
3597 gather: None,
3598 array_index: None,
3599 offset: None,
3600 level: crate::SampleLevel::Zero,
3601 depth_ref: None,
3602 } => {
3603 write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?;
3604 self.write_expr(module, image, func_ctx)?;
3605 write!(self.out, ", ")?;
3606 self.write_expr(module, sampler, func_ctx)?;
3607 write!(self.out, ", ")?;
3608 self.write_expr(module, coordinate, func_ctx)?;
3609 write!(self.out, ")")?;
3610 }
3611 Expression::ImageSample {
3612 image,
3613 sampler,
3614 gather,
3615 coordinate,
3616 array_index,
3617 offset,
3618 level,
3619 depth_ref,
3620 clamp_to_edge,
3621 } => {
3622 if clamp_to_edge {
3623 return Err(Error::Custom(
3624 "ImageSample::clamp_to_edge should have been validated out".to_string(),
3625 ));
3626 }
3627
3628 use crate::SampleLevel as Sl;
3629 const COMPONENTS: [&str; 4] = ["", "Green", "Blue", "Alpha"];
3630
3631 let (base_str, component_str) = match gather {
3632 Some(component) => ("Gather", COMPONENTS[component as usize]),
3633 None => ("Sample", ""),
3634 };
3635 let cmp_str = match depth_ref {
3636 Some(_) => "Cmp",
3637 None => "",
3638 };
3639 let level_str = match level {
3640 Sl::Zero if gather.is_none() => "LevelZero",
3641 Sl::Auto | Sl::Zero => "",
3642 Sl::Exact(_) => "Level",
3643 Sl::Bias(_) => "Bias",
3644 Sl::Gradient { .. } => "Grad",
3645 };
3646
3647 self.write_expr(module, image, func_ctx)?;
3648 write!(self.out, ".{base_str}{cmp_str}{component_str}{level_str}(")?;
3649 self.write_expr(module, sampler, func_ctx)?;
3650 write!(self.out, ", ")?;
3651 self.write_texture_coordinates(
3652 "float",
3653 coordinate,
3654 array_index,
3655 None,
3656 module,
3657 func_ctx,
3658 )?;
3659
3660 if let Some(depth_ref) = depth_ref {
3661 write!(self.out, ", ")?;
3662 self.write_expr(module, depth_ref, func_ctx)?;
3663 }
3664
3665 match level {
3666 Sl::Auto | Sl::Zero => {}
3667 Sl::Exact(expr) => {
3668 write!(self.out, ", ")?;
3669 self.write_expr(module, expr, func_ctx)?;
3670 }
3671 Sl::Bias(expr) => {
3672 write!(self.out, ", ")?;
3673 self.write_expr(module, expr, func_ctx)?;
3674 }
3675 Sl::Gradient { x, y } => {
3676 write!(self.out, ", ")?;
3677 self.write_expr(module, x, func_ctx)?;
3678 write!(self.out, ", ")?;
3679 self.write_expr(module, y, func_ctx)?;
3680 }
3681 }
3682
3683 if let Some(offset) = offset {
3684 write!(self.out, ", ")?;
3685 write!(self.out, "int2(")?; self.write_const_expression(module, offset, func_ctx.expressions)?;
3687 write!(self.out, ")")?;
3688 }
3689
3690 write!(self.out, ")")?;
3691 }
3692 Expression::ImageQuery { image, query } => {
3693 if let TypeInner::Image {
3695 dim,
3696 arrayed,
3697 class,
3698 } = *func_ctx.resolve_type(image, &module.types)
3699 {
3700 let wrapped_image_query = WrappedImageQuery {
3701 dim,
3702 arrayed,
3703 class,
3704 query: query.into(),
3705 };
3706
3707 self.write_wrapped_image_query_function_name(wrapped_image_query)?;
3708 write!(self.out, "(")?;
3709 self.write_expr(module, image, func_ctx)?;
3711 if let crate::ImageQuery::Size { level: Some(level) } = query {
3712 write!(self.out, ", ")?;
3713 self.write_expr(module, level, func_ctx)?;
3714 }
3715 write!(self.out, ")")?;
3716 }
3717 }
3718 Expression::ImageLoad {
3719 image,
3720 coordinate,
3721 array_index,
3722 sample,
3723 level,
3724 } => self.write_image_load(
3725 &module,
3726 expr,
3727 func_ctx,
3728 image,
3729 coordinate,
3730 array_index,
3731 sample,
3732 level,
3733 )?,
3734 Expression::GlobalVariable(handle) => {
3735 let global_variable = &module.global_variables[handle];
3736 let ty = &module.types[global_variable.ty].inner;
3737
3738 let is_binding_array_of_samplers = match *ty {
3743 TypeInner::BindingArray { base, .. } => {
3744 let base_ty = &module.types[base].inner;
3745 matches!(*base_ty, TypeInner::Sampler { .. })
3746 }
3747 _ => false,
3748 };
3749
3750 let is_storage_space =
3751 matches!(global_variable.space, crate::AddressSpace::Storage { .. });
3752
3753 if let TypeInner::Image {
3761 class: crate::ImageClass::External,
3762 ..
3763 } = *ty
3764 {
3765 let plane_names = [0, 1, 2].map(|i| {
3766 &self.names[&NameKey::ExternalTextureGlobalVariable(
3767 handle,
3768 ExternalTextureNameKey::Plane(i),
3769 )]
3770 });
3771 let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
3772 handle,
3773 ExternalTextureNameKey::Params,
3774 )];
3775 write!(
3776 self.out,
3777 "{}, {}, {}, {}",
3778 plane_names[0], plane_names[1], plane_names[2], params_name
3779 )?;
3780 } else if !is_binding_array_of_samplers && !is_storage_space {
3781 let name = &self.names[&NameKey::GlobalVariable(handle)];
3782 write!(self.out, "{name}")?;
3783 }
3784 }
3785 Expression::LocalVariable(handle) => {
3786 write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])?
3787 }
3788 Expression::Load { pointer } => {
3789 match func_ctx
3790 .resolve_type(pointer, &module.types)
3791 .pointer_space()
3792 {
3793 Some(crate::AddressSpace::Storage { .. }) => {
3794 let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
3795 let result_ty = func_ctx.info[expr].ty.clone();
3796 self.write_storage_load(module, var_handle, result_ty, func_ctx)?;
3797 }
3798 _ => {
3799 let mut close_paren = false;
3800
3801 if let Some(MatrixType {
3806 rows: crate::VectorSize::Bi,
3807 width: 4,
3808 ..
3809 }) = get_inner_matrix_of_struct_array_member(
3810 module, pointer, func_ctx, false,
3811 )
3812 .or_else(|| get_inner_matrix_of_global_uniform(module, pointer, func_ctx))
3813 {
3814 let mut resolved = func_ctx.resolve_type(pointer, &module.types);
3815 let ptr_tr = resolved.pointer_base_type();
3816 if let Some(ptr_ty) =
3817 ptr_tr.as_ref().map(|tr| tr.inner_with(&module.types))
3818 {
3819 resolved = ptr_ty;
3820 }
3821
3822 write!(self.out, "((")?;
3823 if let TypeInner::Array { base, size, .. } = *resolved {
3824 self.write_type(module, base)?;
3825 self.write_array_size(module, base, size)?;
3826 } else {
3827 self.write_value_type(module, resolved)?;
3828 }
3829 write!(self.out, ")")?;
3830 close_paren = true;
3831 }
3832
3833 self.write_expr(module, pointer, func_ctx)?;
3834
3835 if close_paren {
3836 write!(self.out, ")")?;
3837 }
3838 }
3839 }
3840 }
3841 Expression::Unary { op, expr } => {
3842 let op_str = match op {
3844 crate::UnaryOperator::Negate => {
3845 match func_ctx.resolve_type(expr, &module.types).scalar() {
3846 Some(Scalar::I32) => NEG_FUNCTION,
3847 _ => "-",
3848 }
3849 }
3850 crate::UnaryOperator::LogicalNot => "!",
3851 crate::UnaryOperator::BitwiseNot => "~",
3852 };
3853 write!(self.out, "{op_str}(")?;
3854 self.write_expr(module, expr, func_ctx)?;
3855 write!(self.out, ")")?;
3856 }
3857 Expression::As {
3858 expr,
3859 kind,
3860 convert,
3861 } => {
3862 let inner = func_ctx.resolve_type(expr, &module.types);
3863 if inner.scalar_kind() == Some(ScalarKind::Float)
3864 && (kind == ScalarKind::Sint || kind == ScalarKind::Uint)
3865 && convert.is_some()
3866 {
3867 let fun_name = match (kind, convert) {
3871 (ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
3872 (ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
3873 (ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
3874 (ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
3875 _ => unreachable!(),
3876 };
3877 write!(self.out, "{fun_name}(")?;
3878 self.write_expr(module, expr, func_ctx)?;
3879 write!(self.out, ")")?;
3880 } else {
3881 let close_paren = match convert {
3882 Some(dst_width) => {
3883 let scalar = Scalar {
3884 kind,
3885 width: dst_width,
3886 };
3887 match *inner {
3888 TypeInner::Vector { size, .. } => {
3889 write!(
3890 self.out,
3891 "{}{}(",
3892 scalar.to_hlsl_str()?,
3893 common::vector_size_str(size)
3894 )?;
3895 }
3896 TypeInner::Scalar(_) => {
3897 write!(self.out, "{}(", scalar.to_hlsl_str()?,)?;
3898 }
3899 TypeInner::Matrix { columns, rows, .. } => {
3900 write!(
3901 self.out,
3902 "{}{}x{}(",
3903 scalar.to_hlsl_str()?,
3904 common::vector_size_str(columns),
3905 common::vector_size_str(rows)
3906 )?;
3907 }
3908 _ => {
3909 return Err(Error::Unimplemented(format!(
3910 "write_expr expression::as {inner:?}"
3911 )));
3912 }
3913 };
3914 true
3915 }
3916 None => {
3917 if inner.scalar_width() == Some(8) {
3918 false
3919 } else {
3920 write!(self.out, "{}(", kind.to_hlsl_cast(),)?;
3921 true
3922 }
3923 }
3924 };
3925 self.write_expr(module, expr, func_ctx)?;
3926 if close_paren {
3927 write!(self.out, ")")?;
3928 }
3929 }
3930 }
3931 Expression::Math {
3932 fun,
3933 arg,
3934 arg1,
3935 arg2,
3936 arg3,
3937 } => {
3938 use crate::MathFunction as Mf;
3939
3940 enum Function {
3941 Asincosh { is_sin: bool },
3942 Atanh,
3943 Pack2x16float,
3944 Pack2x16snorm,
3945 Pack2x16unorm,
3946 Pack4x8snorm,
3947 Pack4x8unorm,
3948 Pack4xI8,
3949 Pack4xU8,
3950 Pack4xI8Clamp,
3951 Pack4xU8Clamp,
3952 Unpack2x16float,
3953 Unpack2x16snorm,
3954 Unpack2x16unorm,
3955 Unpack4x8snorm,
3956 Unpack4x8unorm,
3957 Unpack4xI8,
3958 Unpack4xU8,
3959 Dot4I8Packed,
3960 Dot4U8Packed,
3961 QuantizeToF16,
3962 Regular(&'static str),
3963 MissingIntOverload(&'static str),
3964 MissingIntReturnType(&'static str),
3965 CountTrailingZeros,
3966 CountLeadingZeros,
3967 }
3968
3969 let fun = match fun {
3970 Mf::Abs => match func_ctx.resolve_type(arg, &module.types).scalar() {
3972 Some(Scalar::I32) => Function::Regular(ABS_FUNCTION),
3973 _ => Function::Regular("abs"),
3974 },
3975 Mf::Min => Function::Regular("min"),
3976 Mf::Max => Function::Regular("max"),
3977 Mf::Clamp => Function::Regular("clamp"),
3978 Mf::Saturate => Function::Regular("saturate"),
3979 Mf::Cos => Function::Regular("cos"),
3981 Mf::Cosh => Function::Regular("cosh"),
3982 Mf::Sin => Function::Regular("sin"),
3983 Mf::Sinh => Function::Regular("sinh"),
3984 Mf::Tan => Function::Regular("tan"),
3985 Mf::Tanh => Function::Regular("tanh"),
3986 Mf::Acos => Function::Regular("acos"),
3987 Mf::Asin => Function::Regular("asin"),
3988 Mf::Atan => Function::Regular("atan"),
3989 Mf::Atan2 => Function::Regular("atan2"),
3990 Mf::Asinh => Function::Asincosh { is_sin: true },
3991 Mf::Acosh => Function::Asincosh { is_sin: false },
3992 Mf::Atanh => Function::Atanh,
3993 Mf::Radians => Function::Regular("radians"),
3994 Mf::Degrees => Function::Regular("degrees"),
3995 Mf::Ceil => Function::Regular("ceil"),
3997 Mf::Floor => Function::Regular("floor"),
3998 Mf::Round => Function::Regular("round"),
3999 Mf::Fract => Function::Regular("frac"),
4000 Mf::Trunc => Function::Regular("trunc"),
4001 Mf::Modf => Function::Regular(MODF_FUNCTION),
4002 Mf::Frexp => Function::Regular(FREXP_FUNCTION),
4003 Mf::Ldexp => Function::Regular("ldexp"),
4004 Mf::Exp => Function::Regular("exp"),
4006 Mf::Exp2 => Function::Regular("exp2"),
4007 Mf::Log => Function::Regular("log"),
4008 Mf::Log2 => Function::Regular("log2"),
4009 Mf::Pow => Function::Regular("pow"),
4010 Mf::Dot => Function::Regular("dot"),
4012 Mf::Dot4I8Packed => Function::Dot4I8Packed,
4013 Mf::Dot4U8Packed => Function::Dot4U8Packed,
4014 Mf::Cross => Function::Regular("cross"),
4016 Mf::Distance => Function::Regular("distance"),
4017 Mf::Length => Function::Regular("length"),
4018 Mf::Normalize => Function::Regular("normalize"),
4019 Mf::FaceForward => Function::Regular("faceforward"),
4020 Mf::Reflect => Function::Regular("reflect"),
4021 Mf::Refract => Function::Regular("refract"),
4022 Mf::Sign => Function::Regular("sign"),
4024 Mf::Fma => Function::Regular("mad"),
4025 Mf::Mix => Function::Regular("lerp"),
4026 Mf::Step => Function::Regular("step"),
4027 Mf::SmoothStep => Function::Regular("smoothstep"),
4028 Mf::Sqrt => Function::Regular("sqrt"),
4029 Mf::InverseSqrt => Function::Regular("rsqrt"),
4030 Mf::Transpose => Function::Regular("transpose"),
4032 Mf::Determinant => Function::Regular("determinant"),
4033 Mf::QuantizeToF16 => Function::QuantizeToF16,
4034 Mf::CountTrailingZeros => Function::CountTrailingZeros,
4036 Mf::CountLeadingZeros => Function::CountLeadingZeros,
4037 Mf::CountOneBits => Function::MissingIntOverload("countbits"),
4038 Mf::ReverseBits => Function::MissingIntOverload("reversebits"),
4039 Mf::FirstTrailingBit => Function::MissingIntReturnType("firstbitlow"),
4040 Mf::FirstLeadingBit => Function::MissingIntReturnType("firstbithigh"),
4041 Mf::ExtractBits => Function::Regular(EXTRACT_BITS_FUNCTION),
4042 Mf::InsertBits => Function::Regular(INSERT_BITS_FUNCTION),
4043 Mf::Pack2x16float => Function::Pack2x16float,
4045 Mf::Pack2x16snorm => Function::Pack2x16snorm,
4046 Mf::Pack2x16unorm => Function::Pack2x16unorm,
4047 Mf::Pack4x8snorm => Function::Pack4x8snorm,
4048 Mf::Pack4x8unorm => Function::Pack4x8unorm,
4049 Mf::Pack4xI8 => Function::Pack4xI8,
4050 Mf::Pack4xU8 => Function::Pack4xU8,
4051 Mf::Pack4xI8Clamp => Function::Pack4xI8Clamp,
4052 Mf::Pack4xU8Clamp => Function::Pack4xU8Clamp,
4053 Mf::Unpack2x16float => Function::Unpack2x16float,
4055 Mf::Unpack2x16snorm => Function::Unpack2x16snorm,
4056 Mf::Unpack2x16unorm => Function::Unpack2x16unorm,
4057 Mf::Unpack4x8snorm => Function::Unpack4x8snorm,
4058 Mf::Unpack4x8unorm => Function::Unpack4x8unorm,
4059 Mf::Unpack4xI8 => Function::Unpack4xI8,
4060 Mf::Unpack4xU8 => Function::Unpack4xU8,
4061 _ => return Err(Error::Unimplemented(format!("write_expr_math {fun:?}"))),
4062 };
4063
4064 match fun {
4065 Function::Asincosh { is_sin } => {
4066 write!(self.out, "log(")?;
4067 self.write_expr(module, arg, func_ctx)?;
4068 write!(self.out, " + sqrt(")?;
4069 self.write_expr(module, arg, func_ctx)?;
4070 write!(self.out, " * ")?;
4071 self.write_expr(module, arg, func_ctx)?;
4072 match is_sin {
4073 true => write!(self.out, " + 1.0))")?,
4074 false => write!(self.out, " - 1.0))")?,
4075 }
4076 }
4077 Function::Atanh => {
4078 write!(self.out, "0.5 * log((1.0 + ")?;
4079 self.write_expr(module, arg, func_ctx)?;
4080 write!(self.out, ") / (1.0 - ")?;
4081 self.write_expr(module, arg, func_ctx)?;
4082 write!(self.out, "))")?;
4083 }
4084 Function::Pack2x16float => {
4085 write!(self.out, "(f32tof16(")?;
4086 self.write_expr(module, arg, func_ctx)?;
4087 write!(self.out, "[0]) | f32tof16(")?;
4088 self.write_expr(module, arg, func_ctx)?;
4089 write!(self.out, "[1]) << 16)")?;
4090 }
4091 Function::Pack2x16snorm => {
4092 let scale = 32767;
4093
4094 write!(self.out, "uint((int(round(clamp(")?;
4095 self.write_expr(module, arg, func_ctx)?;
4096 write!(
4097 self.out,
4098 "[0], -1.0, 1.0) * {scale}.0)) & 0xFFFF) | ((int(round(clamp("
4099 )?;
4100 self.write_expr(module, arg, func_ctx)?;
4101 write!(self.out, "[1], -1.0, 1.0) * {scale}.0)) & 0xFFFF) << 16))",)?;
4102 }
4103 Function::Pack2x16unorm => {
4104 let scale = 65535;
4105
4106 write!(self.out, "(uint(round(clamp(")?;
4107 self.write_expr(module, arg, func_ctx)?;
4108 write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
4109 self.write_expr(module, arg, func_ctx)?;
4110 write!(self.out, "[1], 0.0, 1.0) * {scale}.0)) << 16)")?;
4111 }
4112 Function::Pack4x8snorm => {
4113 let scale = 127;
4114
4115 write!(self.out, "uint((int(round(clamp(")?;
4116 self.write_expr(module, arg, func_ctx)?;
4117 write!(
4118 self.out,
4119 "[0], -1.0, 1.0) * {scale}.0)) & 0xFF) | ((int(round(clamp("
4120 )?;
4121 self.write_expr(module, arg, func_ctx)?;
4122 write!(
4123 self.out,
4124 "[1], -1.0, 1.0) * {scale}.0)) & 0xFF) << 8) | ((int(round(clamp("
4125 )?;
4126 self.write_expr(module, arg, func_ctx)?;
4127 write!(
4128 self.out,
4129 "[2], -1.0, 1.0) * {scale}.0)) & 0xFF) << 16) | ((int(round(clamp("
4130 )?;
4131 self.write_expr(module, arg, func_ctx)?;
4132 write!(self.out, "[3], -1.0, 1.0) * {scale}.0)) & 0xFF) << 24))",)?;
4133 }
4134 Function::Pack4x8unorm => {
4135 let scale = 255;
4136
4137 write!(self.out, "(uint(round(clamp(")?;
4138 self.write_expr(module, arg, func_ctx)?;
4139 write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
4140 self.write_expr(module, arg, func_ctx)?;
4141 write!(
4142 self.out,
4143 "[1], 0.0, 1.0) * {scale}.0)) << 8 | uint(round(clamp("
4144 )?;
4145 self.write_expr(module, arg, func_ctx)?;
4146 write!(
4147 self.out,
4148 "[2], 0.0, 1.0) * {scale}.0)) << 16 | uint(round(clamp("
4149 )?;
4150 self.write_expr(module, arg, func_ctx)?;
4151 write!(self.out, "[3], 0.0, 1.0) * {scale}.0)) << 24)")?;
4152 }
4153 fun @ (Function::Pack4xI8
4154 | Function::Pack4xU8
4155 | Function::Pack4xI8Clamp
4156 | Function::Pack4xU8Clamp) => {
4157 let was_signed =
4158 matches!(fun, Function::Pack4xI8 | Function::Pack4xI8Clamp);
4159 let clamp_bounds = match fun {
4160 Function::Pack4xI8Clamp => Some(("-128", "127")),
4161 Function::Pack4xU8Clamp => Some(("0", "255")),
4162 _ => None,
4163 };
4164 if was_signed {
4165 write!(self.out, "uint(")?;
4166 }
4167 let write_arg = |this: &mut Self| -> BackendResult {
4168 if let Some((min, max)) = clamp_bounds {
4169 write!(this.out, "clamp(")?;
4170 this.write_expr(module, arg, func_ctx)?;
4171 write!(this.out, ", {min}, {max})")?;
4172 } else {
4173 this.write_expr(module, arg, func_ctx)?;
4174 }
4175 Ok(())
4176 };
4177 write!(self.out, "(")?;
4178 write_arg(self)?;
4179 write!(self.out, "[0] & 0xFF) | ((")?;
4180 write_arg(self)?;
4181 write!(self.out, "[1] & 0xFF) << 8) | ((")?;
4182 write_arg(self)?;
4183 write!(self.out, "[2] & 0xFF) << 16) | ((")?;
4184 write_arg(self)?;
4185 write!(self.out, "[3] & 0xFF) << 24)")?;
4186 if was_signed {
4187 write!(self.out, ")")?;
4188 }
4189 }
4190
4191 Function::Unpack2x16float => {
4192 write!(self.out, "float2(f16tof32(")?;
4193 self.write_expr(module, arg, func_ctx)?;
4194 write!(self.out, "), f16tof32((")?;
4195 self.write_expr(module, arg, func_ctx)?;
4196 write!(self.out, ") >> 16))")?;
4197 }
4198 Function::Unpack2x16snorm => {
4199 let scale = 32767;
4200
4201 write!(self.out, "(float2(int2(")?;
4202 self.write_expr(module, arg, func_ctx)?;
4203 write!(self.out, " << 16, ")?;
4204 self.write_expr(module, arg, func_ctx)?;
4205 write!(self.out, ") >> 16) / {scale}.0)")?;
4206 }
4207 Function::Unpack2x16unorm => {
4208 let scale = 65535;
4209
4210 write!(self.out, "(float2(")?;
4211 self.write_expr(module, arg, func_ctx)?;
4212 write!(self.out, " & 0xFFFF, ")?;
4213 self.write_expr(module, arg, func_ctx)?;
4214 write!(self.out, " >> 16) / {scale}.0)")?;
4215 }
4216 Function::Unpack4x8snorm => {
4217 let scale = 127;
4218
4219 write!(self.out, "(float4(int4(")?;
4220 self.write_expr(module, arg, func_ctx)?;
4221 write!(self.out, " << 24, ")?;
4222 self.write_expr(module, arg, func_ctx)?;
4223 write!(self.out, " << 16, ")?;
4224 self.write_expr(module, arg, func_ctx)?;
4225 write!(self.out, " << 8, ")?;
4226 self.write_expr(module, arg, func_ctx)?;
4227 write!(self.out, ") >> 24) / {scale}.0)")?;
4228 }
4229 Function::Unpack4x8unorm => {
4230 let scale = 255;
4231
4232 write!(self.out, "(float4(")?;
4233 self.write_expr(module, arg, func_ctx)?;
4234 write!(self.out, " & 0xFF, ")?;
4235 self.write_expr(module, arg, func_ctx)?;
4236 write!(self.out, " >> 8 & 0xFF, ")?;
4237 self.write_expr(module, arg, func_ctx)?;
4238 write!(self.out, " >> 16 & 0xFF, ")?;
4239 self.write_expr(module, arg, func_ctx)?;
4240 write!(self.out, " >> 24) / {scale}.0)")?;
4241 }
4242 fun @ (Function::Unpack4xI8 | Function::Unpack4xU8) => {
4243 write!(self.out, "(")?;
4244 if matches!(fun, Function::Unpack4xU8) {
4245 write!(self.out, "u")?;
4246 }
4247 write!(self.out, "int4(")?;
4248 self.write_expr(module, arg, func_ctx)?;
4249 write!(self.out, ", ")?;
4250 self.write_expr(module, arg, func_ctx)?;
4251 write!(self.out, " >> 8, ")?;
4252 self.write_expr(module, arg, func_ctx)?;
4253 write!(self.out, " >> 16, ")?;
4254 self.write_expr(module, arg, func_ctx)?;
4255 write!(self.out, " >> 24) << 24 >> 24)")?;
4256 }
4257 fun @ (Function::Dot4I8Packed | Function::Dot4U8Packed) => {
4258 let arg1 = arg1.unwrap();
4259
4260 if self.options.shader_model >= ShaderModel::V6_4 {
4261 let function_name = match fun {
4263 Function::Dot4I8Packed => "dot4add_i8packed",
4264 Function::Dot4U8Packed => "dot4add_u8packed",
4265 _ => unreachable!(),
4266 };
4267 write!(self.out, "{function_name}(")?;
4268 self.write_expr(module, arg, func_ctx)?;
4269 write!(self.out, ", ")?;
4270 self.write_expr(module, arg1, func_ctx)?;
4271 write!(self.out, ", 0)")?;
4272 } else {
4273 write!(self.out, "dot(")?;
4275
4276 if matches!(fun, Function::Dot4U8Packed) {
4277 write!(self.out, "u")?;
4278 }
4279 write!(self.out, "int4(")?;
4280 self.write_expr(module, arg, func_ctx)?;
4281 write!(self.out, ", ")?;
4282 self.write_expr(module, arg, func_ctx)?;
4283 write!(self.out, " >> 8, ")?;
4284 self.write_expr(module, arg, func_ctx)?;
4285 write!(self.out, " >> 16, ")?;
4286 self.write_expr(module, arg, func_ctx)?;
4287 write!(self.out, " >> 24) << 24 >> 24, ")?;
4288
4289 if matches!(fun, Function::Dot4U8Packed) {
4290 write!(self.out, "u")?;
4291 }
4292 write!(self.out, "int4(")?;
4293 self.write_expr(module, arg1, func_ctx)?;
4294 write!(self.out, ", ")?;
4295 self.write_expr(module, arg1, func_ctx)?;
4296 write!(self.out, " >> 8, ")?;
4297 self.write_expr(module, arg1, func_ctx)?;
4298 write!(self.out, " >> 16, ")?;
4299 self.write_expr(module, arg1, func_ctx)?;
4300 write!(self.out, " >> 24) << 24 >> 24)")?;
4301 }
4302 }
4303 Function::QuantizeToF16 => {
4304 write!(self.out, "f16tof32(f32tof16(")?;
4305 self.write_expr(module, arg, func_ctx)?;
4306 write!(self.out, "))")?;
4307 }
4308 Function::Regular(fun_name) => {
4309 write!(self.out, "{fun_name}(")?;
4310 self.write_expr(module, arg, func_ctx)?;
4311 if let Some(arg) = arg1 {
4312 write!(self.out, ", ")?;
4313 self.write_expr(module, arg, func_ctx)?;
4314 }
4315 if let Some(arg) = arg2 {
4316 write!(self.out, ", ")?;
4317 self.write_expr(module, arg, func_ctx)?;
4318 }
4319 if let Some(arg) = arg3 {
4320 write!(self.out, ", ")?;
4321 self.write_expr(module, arg, func_ctx)?;
4322 }
4323 write!(self.out, ")")?
4324 }
4325 Function::MissingIntOverload(fun_name) => {
4328 let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
4329 if let Some(Scalar::I32) = scalar_kind {
4330 write!(self.out, "asint({fun_name}(asuint(")?;
4331 self.write_expr(module, arg, func_ctx)?;
4332 write!(self.out, ")))")?;
4333 } else {
4334 write!(self.out, "{fun_name}(")?;
4335 self.write_expr(module, arg, func_ctx)?;
4336 write!(self.out, ")")?;
4337 }
4338 }
4339 Function::MissingIntReturnType(fun_name) => {
4342 let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
4343 if let Some(Scalar::I32) = scalar_kind {
4344 write!(self.out, "asint({fun_name}(")?;
4345 self.write_expr(module, arg, func_ctx)?;
4346 write!(self.out, "))")?;
4347 } else {
4348 write!(self.out, "{fun_name}(")?;
4349 self.write_expr(module, arg, func_ctx)?;
4350 write!(self.out, ")")?;
4351 }
4352 }
4353 Function::CountTrailingZeros => {
4354 match *func_ctx.resolve_type(arg, &module.types) {
4355 TypeInner::Vector { size, scalar } => {
4356 let s = match size {
4357 crate::VectorSize::Bi => ".xx",
4358 crate::VectorSize::Tri => ".xxx",
4359 crate::VectorSize::Quad => ".xxxx",
4360 };
4361
4362 let scalar_width_bits = scalar.width * 8;
4363
4364 if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
4365 write!(
4366 self.out,
4367 "min(({scalar_width_bits}u){s}, firstbitlow("
4368 )?;
4369 self.write_expr(module, arg, func_ctx)?;
4370 write!(self.out, "))")?;
4371 } else {
4372 write!(
4374 self.out,
4375 "asint(min(({scalar_width_bits}u){s}, firstbitlow("
4376 )?;
4377 self.write_expr(module, arg, func_ctx)?;
4378 write!(self.out, ")))")?;
4379 }
4380 }
4381 TypeInner::Scalar(scalar) => {
4382 let scalar_width_bits = scalar.width * 8;
4383
4384 if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
4385 write!(self.out, "min({scalar_width_bits}u, firstbitlow(")?;
4386 self.write_expr(module, arg, func_ctx)?;
4387 write!(self.out, "))")?;
4388 } else {
4389 write!(
4391 self.out,
4392 "asint(min({scalar_width_bits}u, firstbitlow("
4393 )?;
4394 self.write_expr(module, arg, func_ctx)?;
4395 write!(self.out, ")))")?;
4396 }
4397 }
4398 _ => unreachable!(),
4399 }
4400
4401 return Ok(());
4402 }
4403 Function::CountLeadingZeros => {
4404 match *func_ctx.resolve_type(arg, &module.types) {
4405 TypeInner::Vector { size, scalar } => {
4406 let s = match size {
4407 crate::VectorSize::Bi => ".xx",
4408 crate::VectorSize::Tri => ".xxx",
4409 crate::VectorSize::Quad => ".xxxx",
4410 };
4411
4412 let constant = scalar.width * 8 - 1;
4414
4415 if scalar.kind == ScalarKind::Uint {
4416 write!(self.out, "(({constant}u){s} - firstbithigh(")?;
4417 self.write_expr(module, arg, func_ctx)?;
4418 write!(self.out, "))")?;
4419 } else {
4420 let conversion_func = match scalar.width {
4421 4 => "asint",
4422 _ => "",
4423 };
4424 write!(self.out, "(")?;
4425 self.write_expr(module, arg, func_ctx)?;
4426 write!(
4427 self.out,
4428 " < (0){s} ? (0){s} : ({constant}){s} - {conversion_func}(firstbithigh("
4429 )?;
4430 self.write_expr(module, arg, func_ctx)?;
4431 write!(self.out, ")))")?;
4432 }
4433 }
4434 TypeInner::Scalar(scalar) => {
4435 let constant = scalar.width * 8 - 1;
4437
4438 if let ScalarKind::Uint = scalar.kind {
4439 write!(self.out, "({constant}u - firstbithigh(")?;
4440 self.write_expr(module, arg, func_ctx)?;
4441 write!(self.out, "))")?;
4442 } else {
4443 let conversion_func = match scalar.width {
4444 4 => "asint",
4445 _ => "",
4446 };
4447 write!(self.out, "(")?;
4448 self.write_expr(module, arg, func_ctx)?;
4449 write!(
4450 self.out,
4451 " < 0 ? 0 : {constant} - {conversion_func}(firstbithigh("
4452 )?;
4453 self.write_expr(module, arg, func_ctx)?;
4454 write!(self.out, ")))")?;
4455 }
4456 }
4457 _ => unreachable!(),
4458 }
4459
4460 return Ok(());
4461 }
4462 }
4463 }
4464 Expression::Swizzle {
4465 size,
4466 vector,
4467 pattern,
4468 } => {
4469 self.write_expr(module, vector, func_ctx)?;
4470 write!(self.out, ".")?;
4471 for &sc in pattern[..size as usize].iter() {
4472 self.out.write_char(back::COMPONENTS[sc as usize])?;
4473 }
4474 }
4475 Expression::ArrayLength(expr) => {
4476 let var_handle = match func_ctx.expressions[expr] {
4477 Expression::AccessIndex { base, index: _ } => {
4478 match func_ctx.expressions[base] {
4479 Expression::GlobalVariable(handle) => handle,
4480 _ => unreachable!(),
4481 }
4482 }
4483 Expression::GlobalVariable(handle) => handle,
4484 _ => unreachable!(),
4485 };
4486
4487 let var = &module.global_variables[var_handle];
4488 let (offset, stride) = match module.types[var.ty].inner {
4489 TypeInner::Array { stride, .. } => (0, stride),
4490 TypeInner::Struct { ref members, .. } => {
4491 let last = members.last().unwrap();
4492 let stride = match module.types[last.ty].inner {
4493 TypeInner::Array { stride, .. } => stride,
4494 _ => unreachable!(),
4495 };
4496 (last.offset, stride)
4497 }
4498 _ => unreachable!(),
4499 };
4500
4501 let storage_access = match var.space {
4502 crate::AddressSpace::Storage { access } => access,
4503 _ => crate::StorageAccess::default(),
4504 };
4505 let wrapped_array_length = WrappedArrayLength {
4506 writable: storage_access.contains(crate::StorageAccess::STORE),
4507 };
4508
4509 write!(self.out, "((")?;
4510 self.write_wrapped_array_length_function_name(wrapped_array_length)?;
4511 let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
4512 write!(self.out, "({var_name}) - {offset}) / {stride})")?
4513 }
4514 Expression::Derivative { axis, ctrl, expr } => {
4515 use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
4516 if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
4517 let tail = match ctrl {
4518 Ctrl::Coarse => "coarse",
4519 Ctrl::Fine => "fine",
4520 Ctrl::None => unreachable!(),
4521 };
4522 write!(self.out, "abs(ddx_{tail}(")?;
4523 self.write_expr(module, expr, func_ctx)?;
4524 write!(self.out, ")) + abs(ddy_{tail}(")?;
4525 self.write_expr(module, expr, func_ctx)?;
4526 write!(self.out, "))")?
4527 } else {
4528 let fun_str = match (axis, ctrl) {
4529 (Axis::X, Ctrl::Coarse) => "ddx_coarse",
4530 (Axis::X, Ctrl::Fine) => "ddx_fine",
4531 (Axis::X, Ctrl::None) => "ddx",
4532 (Axis::Y, Ctrl::Coarse) => "ddy_coarse",
4533 (Axis::Y, Ctrl::Fine) => "ddy_fine",
4534 (Axis::Y, Ctrl::None) => "ddy",
4535 (Axis::Width, Ctrl::Coarse | Ctrl::Fine) => unreachable!(),
4536 (Axis::Width, Ctrl::None) => "fwidth",
4537 };
4538 write!(self.out, "{fun_str}(")?;
4539 self.write_expr(module, expr, func_ctx)?;
4540 write!(self.out, ")")?
4541 }
4542 }
4543 Expression::Relational { fun, argument } => {
4544 use crate::RelationalFunction as Rf;
4545
4546 let fun_str = match fun {
4547 Rf::All => "all",
4548 Rf::Any => "any",
4549 Rf::IsNan => "isnan",
4550 Rf::IsInf => "isinf",
4551 };
4552 write!(self.out, "{fun_str}(")?;
4553 self.write_expr(module, argument, func_ctx)?;
4554 write!(self.out, ")")?
4555 }
4556 Expression::Select {
4557 condition,
4558 accept,
4559 reject,
4560 } => {
4561 write!(self.out, "(")?;
4562 self.write_expr(module, condition, func_ctx)?;
4563 write!(self.out, " ? ")?;
4564 self.write_expr(module, accept, func_ctx)?;
4565 write!(self.out, " : ")?;
4566 self.write_expr(module, reject, func_ctx)?;
4567 write!(self.out, ")")?
4568 }
4569 Expression::RayQueryGetIntersection { query, committed } => {
4570 let Expression::LocalVariable(query_var) = func_ctx.expressions[query] else {
4572 unreachable!()
4573 };
4574
4575 let tracker_expr_name = format!(
4576 "{RAY_QUERY_TRACKER_VARIABLE_PREFIX}{}",
4577 self.names[&func_ctx.name_key(query_var)]
4578 );
4579
4580 if committed {
4581 write!(self.out, "GetCommittedIntersection(")?;
4582 self.write_expr(module, query, func_ctx)?;
4583 write!(self.out, ", {tracker_expr_name})")?;
4584 } else {
4585 write!(self.out, "GetCandidateIntersection(")?;
4586 self.write_expr(module, query, func_ctx)?;
4587 write!(self.out, ", {tracker_expr_name})")?;
4588 }
4589 }
4590 Expression::RayQueryVertexPositions { .. }
4592 | Expression::CooperativeLoad { .. }
4593 | Expression::CooperativeMultiplyAdd { .. } => {
4594 unreachable!()
4595 }
4596 Expression::CallResult(_)
4598 | Expression::AtomicResult { .. }
4599 | Expression::WorkGroupUniformLoadResult { .. }
4600 | Expression::RayQueryProceedResult
4601 | Expression::SubgroupBallotResult
4602 | Expression::SubgroupOperationResult { .. } => {}
4603 }
4604
4605 if !closing_bracket.is_empty() {
4606 write!(self.out, "{closing_bracket}")?;
4607 }
4608 Ok(())
4609 }
4610
4611 #[allow(clippy::too_many_arguments)]
4612 fn write_image_load(
4613 &mut self,
4614 module: &&Module,
4615 expr: Handle<crate::Expression>,
4616 func_ctx: &back::FunctionCtx,
4617 image: Handle<crate::Expression>,
4618 coordinate: Handle<crate::Expression>,
4619 array_index: Option<Handle<crate::Expression>>,
4620 sample: Option<Handle<crate::Expression>>,
4621 level: Option<Handle<crate::Expression>>,
4622 ) -> Result<(), Error> {
4623 let mut wrapping_type = None;
4624 match *func_ctx.resolve_type(image, &module.types) {
4625 TypeInner::Image {
4626 class: crate::ImageClass::External,
4627 ..
4628 } => {
4629 write!(self.out, "{IMAGE_LOAD_EXTERNAL_FUNCTION}(")?;
4630 self.write_expr(module, image, func_ctx)?;
4631 write!(self.out, ", ")?;
4632 self.write_expr(module, coordinate, func_ctx)?;
4633 write!(self.out, ")")?;
4634 return Ok(());
4635 }
4636 TypeInner::Image {
4637 class: crate::ImageClass::Storage { format, .. },
4638 ..
4639 } => {
4640 if format.single_component() {
4641 wrapping_type = Some(Scalar::from(format));
4642 }
4643 }
4644 _ => {}
4645 }
4646 if let Some(scalar) = wrapping_type {
4647 write!(
4648 self.out,
4649 "{}{}(",
4650 help::IMAGE_STORAGE_LOAD_SCALAR_WRAPPER,
4651 scalar.to_hlsl_str()?
4652 )?;
4653 }
4654 self.write_expr(module, image, func_ctx)?;
4656 write!(self.out, ".Load(")?;
4657
4658 self.write_texture_coordinates("int", coordinate, array_index, level, module, func_ctx)?;
4659
4660 if let Some(sample) = sample {
4661 write!(self.out, ", ")?;
4662 self.write_expr(module, sample, func_ctx)?;
4663 }
4664
4665 write!(self.out, ")")?;
4667
4668 if wrapping_type.is_some() {
4669 write!(self.out, ")")?;
4670 }
4671
4672 if let TypeInner::Scalar(_) = *func_ctx.resolve_type(expr, &module.types) {
4674 write!(self.out, ".x")?;
4675 }
4676 Ok(())
4677 }
4678
4679 fn sampler_binding_array_info_from_expression(
4682 &mut self,
4683 module: &Module,
4684 func_ctx: &back::FunctionCtx<'_>,
4685 base: Handle<crate::Expression>,
4686 resolved: &TypeInner,
4687 ) -> Option<BindingArraySamplerInfo> {
4688 if let TypeInner::BindingArray {
4689 base: base_ty_handle,
4690 ..
4691 } = *resolved
4692 {
4693 let base_ty = &module.types[base_ty_handle].inner;
4694 if let TypeInner::Sampler { comparison, .. } = *base_ty {
4695 let base = &func_ctx.expressions[base];
4696
4697 if let crate::Expression::GlobalVariable(handle) = *base {
4698 let variable = &module.global_variables[handle];
4699
4700 let sampler_heap_name = match comparison {
4701 true => COMPARISON_SAMPLER_HEAP_VAR,
4702 false => SAMPLER_HEAP_VAR,
4703 };
4704
4705 return Some(BindingArraySamplerInfo {
4706 sampler_heap_name,
4707 sampler_index_buffer_name: self
4708 .wrapped
4709 .sampler_index_buffers
4710 .get(&super::SamplerIndexBufferKey {
4711 group: variable.binding.unwrap().group,
4712 })
4713 .unwrap()
4714 .clone(),
4715 binding_array_base_index_name: self.names[&NameKey::GlobalVariable(handle)]
4716 .clone(),
4717 });
4718 }
4719 }
4720 }
4721
4722 None
4723 }
4724
4725 fn write_named_expr(
4726 &mut self,
4727 module: &Module,
4728 handle: Handle<crate::Expression>,
4729 name: String,
4730 expr: Handle<crate::Expression>,
4733 func_ctx: &back::FunctionCtx,
4734 ) -> BackendResult {
4735 if let crate::Expression::Load { pointer } = func_ctx.expressions[expr] {
4736 let ty_inner = func_ctx.resolve_type(pointer, &module.types);
4737 if ty_inner.is_atomic_pointer(&module.types) {
4738 let pointer_space = ty_inner.pointer_space().unwrap();
4739 self.write_value_type(module, func_ctx.info[handle].ty.inner_with(&module.types))?;
4740 write!(self.out, " {name}; ")?;
4741 match pointer_space {
4742 crate::AddressSpace::WorkGroup => {
4743 write!(self.out, "InterlockedOr(")?;
4744 self.write_expr(module, pointer, func_ctx)?;
4745 }
4746 crate::AddressSpace::Storage { .. } => {
4747 let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
4748 let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
4749 write!(self.out, "{var_name}.InterlockedOr(")?;
4750 let chain = mem::take(&mut self.temp_access_chain);
4751 self.write_storage_address(module, &chain, func_ctx)?;
4752 self.temp_access_chain = chain;
4753 }
4754 _ => unreachable!(),
4755 }
4756 writeln!(self.out, ", 0, {name});")?;
4757 self.named_expressions.insert(expr, name);
4758 return Ok(());
4759 }
4760 }
4761 match func_ctx.info[expr].ty {
4762 proc::TypeResolution::Handle(ty_handle) => match module.types[ty_handle].inner {
4763 TypeInner::Struct { .. } => {
4764 let ty_name = &self.names[&NameKey::Type(ty_handle)];
4765 write!(self.out, "{ty_name}")?;
4766 }
4767 _ => {
4768 self.write_type(module, ty_handle)?;
4769 }
4770 },
4771 proc::TypeResolution::Value(ref inner) => {
4772 self.write_value_type(module, inner)?;
4773 }
4774 }
4775
4776 let resolved = func_ctx.resolve_type(expr, &module.types);
4777
4778 write!(self.out, " {name}")?;
4779 if let TypeInner::Array { base, size, .. } = *resolved {
4781 self.write_array_size(module, base, size)?;
4782 }
4783 write!(self.out, " = ")?;
4784 self.write_expr(module, handle, func_ctx)?;
4785 writeln!(self.out, ";")?;
4786 self.named_expressions.insert(expr, name);
4787
4788 Ok(())
4789 }
4790
4791 pub(super) fn write_default_init(
4793 &mut self,
4794 module: &Module,
4795 ty: Handle<crate::Type>,
4796 ) -> BackendResult {
4797 write!(self.out, "(")?;
4798 self.write_type(module, ty)?;
4799 if let TypeInner::Array { base, size, .. } = module.types[ty].inner {
4800 self.write_array_size(module, base, size)?;
4801 }
4802 write!(self.out, ")0")?;
4803 Ok(())
4804 }
4805
4806 pub(super) fn write_control_barrier(
4807 &mut self,
4808 barrier: crate::Barrier,
4809 level: back::Level,
4810 ) -> BackendResult {
4811 if barrier.contains(crate::Barrier::STORAGE) {
4812 writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
4813 }
4814 if barrier.contains(crate::Barrier::WORK_GROUP) {
4815 writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?;
4816 }
4817 if barrier.contains(crate::Barrier::SUB_GROUP) {
4818 }
4820 if barrier.contains(crate::Barrier::TEXTURE) {
4821 writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
4822 }
4823 Ok(())
4824 }
4825
4826 fn write_memory_barrier(
4827 &mut self,
4828 barrier: crate::Barrier,
4829 level: back::Level,
4830 ) -> BackendResult {
4831 if barrier.contains(crate::Barrier::STORAGE) {
4832 writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
4833 }
4834 if barrier.contains(crate::Barrier::WORK_GROUP) {
4835 writeln!(self.out, "{level}GroupMemoryBarrier();")?;
4836 }
4837 if barrier.contains(crate::Barrier::SUB_GROUP) {
4838 }
4840 if barrier.contains(crate::Barrier::TEXTURE) {
4841 writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
4842 }
4843 Ok(())
4844 }
4845
4846 fn emit_hlsl_atomic_tail(
4848 &mut self,
4849 module: &Module,
4850 func_ctx: &back::FunctionCtx<'_>,
4851 fun: &crate::AtomicFunction,
4852 compare_expr: Option<Handle<crate::Expression>>,
4853 value: Handle<crate::Expression>,
4854 res_var_info: &Option<(Handle<crate::Expression>, String)>,
4855 ) -> BackendResult {
4856 if let Some(cmp) = compare_expr {
4857 write!(self.out, ", ")?;
4858 self.write_expr(module, cmp, func_ctx)?;
4859 }
4860 write!(self.out, ", ")?;
4861 if let crate::AtomicFunction::Subtract = *fun {
4862 write!(self.out, "-")?;
4864 }
4865 self.write_expr(module, value, func_ctx)?;
4866 if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
4867 write!(self.out, ", ")?;
4868 if compare_expr.is_some() {
4869 write!(self.out, "{res_name}.old_value")?;
4870 } else {
4871 write!(self.out, "{res_name}")?;
4872 }
4873 }
4874 writeln!(self.out, ");")?;
4875 Ok(())
4876 }
4877}
4878
4879pub(super) struct MatrixType {
4880 pub(super) columns: crate::VectorSize,
4881 pub(super) rows: crate::VectorSize,
4882 pub(super) width: crate::Bytes,
4883}
4884
4885pub(super) fn get_inner_matrix_data(
4886 module: &Module,
4887 handle: Handle<crate::Type>,
4888) -> Option<MatrixType> {
4889 match module.types[handle].inner {
4890 TypeInner::Matrix {
4891 columns,
4892 rows,
4893 scalar,
4894 } => Some(MatrixType {
4895 columns,
4896 rows,
4897 width: scalar.width,
4898 }),
4899 TypeInner::Array { base, .. } => get_inner_matrix_data(module, base),
4900 _ => None,
4901 }
4902}
4903
4904fn find_matrix_in_access_chain(
4908 module: &Module,
4909 base: Handle<crate::Expression>,
4910 func_ctx: &back::FunctionCtx<'_>,
4911) -> Option<(Handle<crate::Expression>, Option<Index>, Option<Index>)> {
4912 let mut current_base = base;
4913 let mut vector = None;
4914 let mut scalar = None;
4915 loop {
4916 let resolved_tr = func_ctx
4917 .resolve_type(current_base, &module.types)
4918 .pointer_base_type();
4919 let resolved = resolved_tr.as_ref()?.inner_with(&module.types);
4920
4921 match *resolved {
4922 TypeInner::Matrix { .. } => return Some((current_base, vector, scalar)),
4923 TypeInner::Scalar(_) | TypeInner::Vector { .. } => {}
4924 _ => return None,
4925 }
4926
4927 let index;
4928 (current_base, index) = match func_ctx.expressions[current_base] {
4929 crate::Expression::Access { base, index } => (base, Index::Expression(index)),
4930 crate::Expression::AccessIndex { base, index } => (base, Index::Static(index)),
4931 _ => return None,
4932 };
4933
4934 match *resolved {
4935 TypeInner::Scalar(_) => scalar = Some(index),
4936 TypeInner::Vector { .. } => vector = Some(index),
4937 _ => unreachable!(),
4938 }
4939 }
4940}
4941
4942pub(super) fn get_inner_matrix_of_struct_array_member(
4947 module: &Module,
4948 base: Handle<crate::Expression>,
4949 func_ctx: &back::FunctionCtx<'_>,
4950 direct: bool,
4951) -> Option<MatrixType> {
4952 let mut mat_data = None;
4953 let mut array_base = None;
4954
4955 let mut current_base = base;
4956 loop {
4957 let mut resolved = func_ctx.resolve_type(current_base, &module.types);
4958 if let TypeInner::Pointer { base, .. } = *resolved {
4959 resolved = &module.types[base].inner;
4960 };
4961
4962 match *resolved {
4963 TypeInner::Matrix {
4964 columns,
4965 rows,
4966 scalar,
4967 } => {
4968 mat_data = Some(MatrixType {
4969 columns,
4970 rows,
4971 width: scalar.width,
4972 })
4973 }
4974 TypeInner::Array { base, .. } => {
4975 array_base = Some(base);
4976 }
4977 TypeInner::Struct { .. } => {
4978 if let Some(array_base) = array_base {
4979 if direct {
4980 return mat_data;
4981 } else {
4982 return get_inner_matrix_data(module, array_base);
4983 }
4984 }
4985
4986 break;
4987 }
4988 _ => break,
4989 }
4990
4991 current_base = match func_ctx.expressions[current_base] {
4992 crate::Expression::Access { base, .. } => base,
4993 crate::Expression::AccessIndex { base, .. } => base,
4994 _ => break,
4995 };
4996 }
4997 None
4998}
4999
5000fn get_global_uniform_matrix(
5003 module: &Module,
5004 base: Handle<crate::Expression>,
5005 func_ctx: &back::FunctionCtx<'_>,
5006) -> Option<MatrixType> {
5007 let base_tr = func_ctx
5008 .resolve_type(base, &module.types)
5009 .pointer_base_type();
5010 let base_ty = base_tr.as_ref().map(|tr| tr.inner_with(&module.types));
5011 match (&func_ctx.expressions[base], base_ty) {
5012 (
5013 &crate::Expression::GlobalVariable(handle),
5014 Some(&TypeInner::Matrix {
5015 columns,
5016 rows,
5017 scalar,
5018 }),
5019 ) if module.global_variables[handle].space == crate::AddressSpace::Uniform => {
5020 Some(MatrixType {
5021 columns,
5022 rows,
5023 width: scalar.width,
5024 })
5025 }
5026 _ => None,
5027 }
5028}
5029
5030fn get_inner_matrix_of_global_uniform(
5035 module: &Module,
5036 base: Handle<crate::Expression>,
5037 func_ctx: &back::FunctionCtx<'_>,
5038) -> Option<MatrixType> {
5039 let mut mat_data = None;
5040 let mut array_base = None;
5041
5042 let mut current_base = base;
5043 loop {
5044 let mut resolved = func_ctx.resolve_type(current_base, &module.types);
5045 if let TypeInner::Pointer { base, .. } = *resolved {
5046 resolved = &module.types[base].inner;
5047 };
5048
5049 match *resolved {
5050 TypeInner::Matrix {
5051 columns,
5052 rows,
5053 scalar,
5054 } => {
5055 mat_data = Some(MatrixType {
5056 columns,
5057 rows,
5058 width: scalar.width,
5059 })
5060 }
5061 TypeInner::Array { base, .. } => {
5062 array_base = Some(base);
5063 }
5064 _ => break,
5065 }
5066
5067 current_base = match func_ctx.expressions[current_base] {
5068 crate::Expression::Access { base, .. } => base,
5069 crate::Expression::AccessIndex { base, .. } => base,
5070 crate::Expression::GlobalVariable(handle)
5071 if module.global_variables[handle].space == crate::AddressSpace::Uniform =>
5072 {
5073 return mat_data.or_else(|| {
5074 array_base.and_then(|array_base| get_inner_matrix_data(module, array_base))
5075 })
5076 }
5077 _ => break,
5078 };
5079 }
5080 None
5081}