1use alloc::{
2 format,
3 string::{String, ToString},
4 vec::Vec,
5};
6use core::{fmt, mem};
7
8use super::{
9 help,
10 help::{
11 WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess,
12 WrappedZeroValue,
13 },
14 storage::StoreValue,
15 BackendResult, Error, FragmentEntryPoint, Options, PipelineOptions, ShaderModel,
16};
17use crate::{
18 back::{self, get_entry_points, Baked},
19 common,
20 proc::{self, index, ExternalTextureNameKey, NameKey},
21 valid, Handle, Module, RayQueryFunction, Scalar, ScalarKind, ShaderStage, TypeInner,
22};
23
24const LOCATION_SEMANTIC: &str = "LOC";
25const SPECIAL_CBUF_TYPE: &str = "NagaConstants";
26const SPECIAL_CBUF_VAR: &str = "_NagaConstants";
27const SPECIAL_FIRST_VERTEX: &str = "first_vertex";
28const SPECIAL_FIRST_INSTANCE: &str = "first_instance";
29const SPECIAL_OTHER: &str = "other";
30
31pub(crate) const MODF_FUNCTION: &str = "naga_modf";
32pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
33pub(crate) const EXTRACT_BITS_FUNCTION: &str = "naga_extractBits";
34pub(crate) const INSERT_BITS_FUNCTION: &str = "naga_insertBits";
35pub(crate) const SAMPLER_HEAP_VAR: &str = "nagaSamplerHeap";
36pub(crate) const COMPARISON_SAMPLER_HEAP_VAR: &str = "nagaComparisonSamplerHeap";
37pub(crate) const SAMPLE_EXTERNAL_TEXTURE_FUNCTION: &str = "nagaSampleExternalTexture";
38pub(crate) const ABS_FUNCTION: &str = "naga_abs";
39pub(crate) const DIV_FUNCTION: &str = "naga_div";
40pub(crate) const MOD_FUNCTION: &str = "naga_mod";
41pub(crate) const NEG_FUNCTION: &str = "naga_neg";
42pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
43pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
44pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
45pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
46pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
47 "nagaTextureSampleBaseClampToEdge";
48pub(crate) const IMAGE_LOAD_EXTERNAL_FUNCTION: &str = "nagaTextureLoadExternal";
49
50enum Index {
51 Expression(Handle<crate::Expression>),
52 Static(u32),
53}
54
55struct EpStructMember {
56 name: String,
57 ty: Handle<crate::Type>,
58 binding: Option<crate::Binding>,
61 index: u32,
62}
63
64struct EntryPointBinding {
67 arg_name: String,
70 ty_name: String,
72 members: Vec<EpStructMember>,
74}
75
76pub(super) struct EntryPointInterface {
77 input: Option<EntryPointBinding>,
82 output: Option<EntryPointBinding>,
86}
87
88#[derive(Clone, Eq, PartialEq, PartialOrd, Ord)]
89enum InterfaceKey {
90 Location(u32),
91 BuiltIn(crate::BuiltIn),
92 Other,
93}
94
95impl InterfaceKey {
96 const fn new(binding: Option<&crate::Binding>) -> Self {
97 match binding {
98 Some(&crate::Binding::Location { location, .. }) => Self::Location(location),
99 Some(&crate::Binding::BuiltIn(built_in)) => Self::BuiltIn(built_in),
100 None => Self::Other,
101 }
102 }
103}
104
105#[derive(Copy, Clone, PartialEq)]
106enum Io {
107 Input,
108 Output,
109}
110
111const fn is_subgroup_builtin_binding(binding: &Option<crate::Binding>) -> bool {
112 let &Some(crate::Binding::BuiltIn(builtin)) = binding else {
113 return false;
114 };
115 matches!(
116 builtin,
117 crate::BuiltIn::SubgroupSize
118 | crate::BuiltIn::SubgroupInvocationId
119 | crate::BuiltIn::NumSubgroups
120 | crate::BuiltIn::SubgroupId
121 )
122}
123
124struct BindingArraySamplerInfo {
126 sampler_heap_name: &'static str,
128 sampler_index_buffer_name: String,
130 binding_array_base_index_name: String,
132}
133
134impl<'a, W: fmt::Write> super::Writer<'a, W> {
135 pub fn new(out: W, options: &'a Options, pipeline_options: &'a PipelineOptions) -> Self {
136 Self {
137 out,
138 names: crate::FastHashMap::default(),
139 namer: proc::Namer::default(),
140 options,
141 pipeline_options,
142 entry_point_io: crate::FastHashMap::default(),
143 named_expressions: crate::NamedExpressions::default(),
144 wrapped: super::Wrapped::default(),
145 written_committed_intersection: false,
146 written_candidate_intersection: false,
147 continue_ctx: back::continue_forward::ContinueCtx::default(),
148 temp_access_chain: Vec::new(),
149 need_bake_expressions: Default::default(),
150 }
151 }
152
153 fn reset(&mut self, module: &Module) {
154 self.names.clear();
155 self.namer.reset(
156 module,
157 &super::keywords::RESERVED_SET,
158 &super::keywords::RESERVED_CASE_INSENSITIVE_SET,
159 super::keywords::RESERVED_PREFIXES,
160 &mut self.names,
161 );
162 self.entry_point_io.clear();
163 self.named_expressions.clear();
164 self.wrapped.clear();
165 self.written_committed_intersection = false;
166 self.written_candidate_intersection = false;
167 self.continue_ctx.clear();
168 self.need_bake_expressions.clear();
169 }
170
171 fn gen_force_bounded_loop_statements(
179 &mut self,
180 level: back::Level,
181 ) -> Option<(String, String)> {
182 if !self.options.force_loop_bounding {
183 return None;
184 }
185
186 let loop_bound_name = self.namer.call("loop_bound");
187 let max = u32::MAX;
188 let decl = format!("{level}uint2 {loop_bound_name} = uint2({max}u, {max}u);");
191 let level = level.next();
192 let break_and_inc = format!(
193 "{level}if (all({loop_bound_name} == uint2(0u, 0u))) {{ break; }}
194{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
195 );
196
197 Some((decl, break_and_inc))
198 }
199
200 fn update_expressions_to_bake(
205 &mut self,
206 module: &Module,
207 func: &crate::Function,
208 info: &valid::FunctionInfo,
209 ) {
210 use crate::Expression;
211 self.need_bake_expressions.clear();
212 for (exp_handle, expr) in func.expressions.iter() {
213 let expr_info = &info[exp_handle];
214 let min_ref_count = func.expressions[exp_handle].bake_ref_count();
215 if min_ref_count <= expr_info.ref_count {
216 self.need_bake_expressions.insert(exp_handle);
217 }
218
219 if let Expression::Math { fun, arg, arg1, .. } = *expr {
220 match fun {
221 crate::MathFunction::Asinh
222 | crate::MathFunction::Acosh
223 | crate::MathFunction::Atanh
224 | crate::MathFunction::Unpack2x16float
225 | crate::MathFunction::Unpack2x16snorm
226 | crate::MathFunction::Unpack2x16unorm
227 | crate::MathFunction::Unpack4x8snorm
228 | crate::MathFunction::Unpack4x8unorm
229 | crate::MathFunction::Unpack4xI8
230 | crate::MathFunction::Unpack4xU8
231 | crate::MathFunction::Pack2x16float
232 | crate::MathFunction::Pack2x16snorm
233 | crate::MathFunction::Pack2x16unorm
234 | crate::MathFunction::Pack4x8snorm
235 | crate::MathFunction::Pack4x8unorm
236 | crate::MathFunction::Pack4xI8
237 | crate::MathFunction::Pack4xU8
238 | crate::MathFunction::Pack4xI8Clamp
239 | crate::MathFunction::Pack4xU8Clamp => {
240 self.need_bake_expressions.insert(arg);
241 }
242 crate::MathFunction::CountLeadingZeros => {
243 let inner = info[exp_handle].ty.inner_with(&module.types);
244 if let Some(ScalarKind::Sint) = inner.scalar_kind() {
245 self.need_bake_expressions.insert(arg);
246 }
247 }
248 crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
249 self.need_bake_expressions.insert(arg);
250 self.need_bake_expressions.insert(arg1.unwrap());
251 }
252 _ => {}
253 }
254 }
255
256 if let Expression::Derivative { axis, ctrl, expr } = *expr {
257 use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
258 if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
259 self.need_bake_expressions.insert(expr);
260 }
261 }
262
263 if let Expression::GlobalVariable(_) = *expr {
264 let inner = info[exp_handle].ty.inner_with(&module.types);
265
266 if let TypeInner::Sampler { .. } = *inner {
267 self.need_bake_expressions.insert(exp_handle);
268 }
269 }
270 }
271 for statement in func.body.iter() {
272 match *statement {
273 crate::Statement::SubgroupCollectiveOperation {
274 op: _,
275 collective_op: crate::CollectiveOperation::InclusiveScan,
276 argument,
277 result: _,
278 } => {
279 self.need_bake_expressions.insert(argument);
280 }
281 crate::Statement::Atomic {
282 fun: crate::AtomicFunction::Exchange { compare: Some(cmp) },
283 ..
284 } => {
285 self.need_bake_expressions.insert(cmp);
286 }
287 _ => {}
288 }
289 }
290 }
291
292 pub fn write(
293 &mut self,
294 module: &Module,
295 module_info: &valid::ModuleInfo,
296 fragment_entry_point: Option<&FragmentEntryPoint<'_>>,
297 ) -> Result<super::ReflectionInfo, Error> {
298 self.reset(module);
299
300 if let Some(ref bt) = self.options.special_constants_binding {
302 writeln!(self.out, "struct {SPECIAL_CBUF_TYPE} {{")?;
303 writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_VERTEX)?;
304 writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_INSTANCE)?;
305 writeln!(self.out, "{}uint {};", back::INDENT, SPECIAL_OTHER)?;
306 writeln!(self.out, "}};")?;
307 write!(
308 self.out,
309 "ConstantBuffer<{}> {}: register(b{}",
310 SPECIAL_CBUF_TYPE, SPECIAL_CBUF_VAR, bt.register
311 )?;
312 if bt.space != 0 {
313 write!(self.out, ", space{}", bt.space)?;
314 }
315 writeln!(self.out, ");")?;
316
317 writeln!(self.out)?;
319 }
320
321 for (group, bt) in self.options.dynamic_storage_buffer_offsets_targets.iter() {
322 writeln!(self.out, "struct __dynamic_buffer_offsetsTy{group} {{")?;
323 for i in 0..bt.size {
324 writeln!(self.out, "{}uint _{};", back::INDENT, i)?;
325 }
326 writeln!(self.out, "}};")?;
327 writeln!(
328 self.out,
329 "ConstantBuffer<__dynamic_buffer_offsetsTy{}> __dynamic_buffer_offsets{}: register(b{}, space{});",
330 group, group, bt.register, bt.space
331 )?;
332
333 writeln!(self.out)?;
335 }
336
337 let ep_results = module
339 .entry_points
340 .iter()
341 .map(|ep| (ep.stage, ep.function.result.clone()))
342 .collect::<Vec<(ShaderStage, Option<crate::FunctionResult>)>>();
343
344 self.write_all_mat_cx2_typedefs_and_functions(module)?;
345
346 for (handle, ty) in module.types.iter() {
348 if let TypeInner::Struct { ref members, span } = ty.inner {
349 if module.types[members.last().unwrap().ty]
350 .inner
351 .is_dynamically_sized(&module.types)
352 {
353 continue;
356 }
357
358 let ep_result = ep_results.iter().find(|e| {
359 if let Some(ref result) = e.1 {
360 result.ty == handle
361 } else {
362 false
363 }
364 });
365
366 self.write_struct(
367 module,
368 handle,
369 members,
370 span,
371 ep_result.map(|r| (r.0, Io::Output)),
372 )?;
373 writeln!(self.out)?;
374 }
375 }
376
377 self.write_special_functions(module)?;
378
379 self.write_wrapped_expression_functions(module, &module.global_expressions, None)?;
380 self.write_wrapped_zero_value_functions(module, &module.global_expressions)?;
381
382 let mut constants = module
384 .constants
385 .iter()
386 .filter(|&(_, c)| c.name.is_some())
387 .peekable();
388 while let Some((handle, _)) = constants.next() {
389 self.write_global_constant(module, handle)?;
390 if constants.peek().is_none() {
392 writeln!(self.out)?;
393 }
394 }
395
396 for (global, _) in module.global_variables.iter() {
398 self.write_global(module, global)?;
399 }
400
401 if !module.global_variables.is_empty() {
402 writeln!(self.out)?;
404 }
405
406 let ep_range = get_entry_points(module, self.pipeline_options.entry_point.as_ref())
407 .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
408
409 for index in ep_range.clone() {
411 let ep = &module.entry_points[index];
412 let ep_name = self.names[&NameKey::EntryPoint(index as u16)].clone();
413 let ep_io = self.write_ep_interface(
414 module,
415 &ep.function,
416 ep.stage,
417 &ep_name,
418 fragment_entry_point,
419 )?;
420 self.entry_point_io.insert(index, ep_io);
421 }
422
423 for (handle, function) in module.functions.iter() {
425 let info = &module_info[handle];
426
427 if !self.options.fake_missing_bindings {
429 if let Some((var_handle, _)) =
430 module
431 .global_variables
432 .iter()
433 .find(|&(var_handle, var)| match var.binding {
434 Some(ref binding) if !info[var_handle].is_empty() => {
435 self.options.resolve_resource_binding(binding).is_err()
436 && self
437 .options
438 .resolve_external_texture_resource_binding(binding)
439 .is_err()
440 }
441 _ => false,
442 })
443 {
444 log::info!(
445 "Skipping function {:?} (name {:?}) because global {:?} is inaccessible",
446 handle,
447 function.name,
448 var_handle
449 );
450 continue;
451 }
452 }
453
454 let ctx = back::FunctionCtx {
455 ty: back::FunctionType::Function(handle),
456 info,
457 expressions: &function.expressions,
458 named_expressions: &function.named_expressions,
459 };
460 let name = self.names[&NameKey::Function(handle)].clone();
461
462 self.write_wrapped_functions(module, &ctx)?;
463
464 self.write_function(module, name.as_str(), function, &ctx, info)?;
465
466 writeln!(self.out)?;
467 }
468
469 let mut translated_ep_names = Vec::with_capacity(ep_range.len());
470
471 for index in ep_range {
473 let ep = &module.entry_points[index];
474 let info = module_info.get_entry_point(index);
475
476 if !self.options.fake_missing_bindings {
477 let mut ep_error = None;
478 for (var_handle, var) in module.global_variables.iter() {
479 match var.binding {
480 Some(ref binding) if !info[var_handle].is_empty() => {
481 if let Err(err) = self.options.resolve_resource_binding(binding) {
482 if self
483 .options
484 .resolve_external_texture_resource_binding(binding)
485 .is_err()
486 {
487 ep_error = Some(err);
488 break;
489 }
490 }
491 }
492 _ => {}
493 }
494 }
495 if let Some(err) = ep_error {
496 translated_ep_names.push(Err(err));
497 continue;
498 }
499 }
500
501 let ctx = back::FunctionCtx {
502 ty: back::FunctionType::EntryPoint(index as u16),
503 info,
504 expressions: &ep.function.expressions,
505 named_expressions: &ep.function.named_expressions,
506 };
507
508 self.write_wrapped_functions(module, &ctx)?;
509
510 if ep.stage == ShaderStage::Compute {
511 let num_threads = ep.workgroup_size;
513 writeln!(
514 self.out,
515 "[numthreads({}, {}, {})]",
516 num_threads[0], num_threads[1], num_threads[2]
517 )?;
518 }
519
520 let name = self.names[&NameKey::EntryPoint(index as u16)].clone();
521 self.write_function(module, &name, &ep.function, &ctx, info)?;
522
523 if index < module.entry_points.len() - 1 {
524 writeln!(self.out)?;
525 }
526
527 translated_ep_names.push(Ok(name));
528 }
529
530 Ok(super::ReflectionInfo {
531 entry_point_names: translated_ep_names,
532 })
533 }
534
535 fn write_modifier(&mut self, binding: &crate::Binding) -> BackendResult {
536 match *binding {
537 crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }) => {
538 write!(self.out, "precise ")?;
539 }
540 crate::Binding::Location {
541 interpolation,
542 sampling,
543 ..
544 } => {
545 if let Some(interpolation) = interpolation {
546 if let Some(string) = interpolation.to_hlsl_str() {
547 write!(self.out, "{string} ")?
548 }
549 }
550
551 if let Some(sampling) = sampling {
552 if let Some(string) = sampling.to_hlsl_str() {
553 write!(self.out, "{string} ")?
554 }
555 }
556 }
557 crate::Binding::BuiltIn(_) => {}
558 }
559
560 Ok(())
561 }
562
563 fn write_semantic(
566 &mut self,
567 binding: &Option<crate::Binding>,
568 stage: Option<(ShaderStage, Io)>,
569 ) -> BackendResult {
570 match *binding {
571 Some(crate::Binding::BuiltIn(builtin)) if !is_subgroup_builtin_binding(binding) => {
572 let builtin_str = builtin.to_hlsl_str()?;
573 write!(self.out, " : {builtin_str}")?;
574 }
575 Some(crate::Binding::Location {
576 blend_src: Some(1), ..
577 }) => {
578 write!(self.out, " : SV_Target1")?;
579 }
580 Some(crate::Binding::Location { location, .. }) => {
581 if stage == Some((ShaderStage::Fragment, Io::Output)) {
582 write!(self.out, " : SV_Target{location}")?;
583 } else {
584 write!(self.out, " : {LOCATION_SEMANTIC}{location}")?;
585 }
586 }
587 _ => {}
588 }
589
590 Ok(())
591 }
592
593 fn write_interface_struct(
594 &mut self,
595 module: &Module,
596 shader_stage: (ShaderStage, Io),
597 struct_name: String,
598 mut members: Vec<EpStructMember>,
599 ) -> Result<EntryPointBinding, Error> {
600 members.sort_by_key(|m| InterfaceKey::new(m.binding.as_ref()));
604
605 write!(self.out, "struct {struct_name}")?;
606 writeln!(self.out, " {{")?;
607 for m in members.iter() {
608 debug_assert!(m.binding.is_some());
611
612 if is_subgroup_builtin_binding(&m.binding) {
613 continue;
614 }
615 write!(self.out, "{}", back::INDENT)?;
616 if let Some(ref binding) = m.binding {
617 self.write_modifier(binding)?;
618 }
619 self.write_type(module, m.ty)?;
620 write!(self.out, " {}", &m.name)?;
621 self.write_semantic(&m.binding, Some(shader_stage))?;
622 writeln!(self.out, ";")?;
623 }
624 if members.iter().any(|arg| {
625 matches!(
626 arg.binding,
627 Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId))
628 )
629 }) {
630 writeln!(
631 self.out,
632 "{}uint __local_invocation_index : SV_GroupIndex;",
633 back::INDENT
634 )?;
635 }
636 writeln!(self.out, "}};")?;
637 writeln!(self.out)?;
638
639 match shader_stage.1 {
641 Io::Input => {
642 members.sort_by_key(|m| m.index);
644 }
645 Io::Output => {
646 }
648 }
649
650 Ok(EntryPointBinding {
651 arg_name: self.namer.call(struct_name.to_lowercase().as_str()),
652 ty_name: struct_name,
653 members,
654 })
655 }
656
657 fn write_ep_input_struct(
661 &mut self,
662 module: &Module,
663 func: &crate::Function,
664 stage: ShaderStage,
665 entry_point_name: &str,
666 ) -> Result<EntryPointBinding, Error> {
667 let struct_name = format!("{stage:?}Input_{entry_point_name}");
668
669 let mut fake_members = Vec::new();
670 for arg in func.arguments.iter() {
671 match module.types[arg.ty].inner {
676 TypeInner::Struct { ref members, .. } => {
677 for member in members.iter() {
678 let name = self.namer.call_or(&member.name, "member");
679 let index = fake_members.len() as u32;
680 fake_members.push(EpStructMember {
681 name,
682 ty: member.ty,
683 binding: member.binding.clone(),
684 index,
685 });
686 }
687 }
688 _ => {
689 let member_name = self.namer.call_or(&arg.name, "member");
690 let index = fake_members.len() as u32;
691 fake_members.push(EpStructMember {
692 name: member_name,
693 ty: arg.ty,
694 binding: arg.binding.clone(),
695 index,
696 });
697 }
698 }
699 }
700
701 self.write_interface_struct(module, (stage, Io::Input), struct_name, fake_members)
702 }
703
704 fn write_ep_output_struct(
708 &mut self,
709 module: &Module,
710 result: &crate::FunctionResult,
711 stage: ShaderStage,
712 entry_point_name: &str,
713 frag_ep: Option<&FragmentEntryPoint<'_>>,
714 ) -> Result<EntryPointBinding, Error> {
715 let struct_name = format!("{stage:?}Output_{entry_point_name}");
716
717 let empty = [];
718 let members = match module.types[result.ty].inner {
719 TypeInner::Struct { ref members, .. } => members,
720 ref other => {
721 log::error!("Unexpected {other:?} output type without a binding");
722 &empty[..]
723 }
724 };
725
726 let fs_input_locs = if let (Some(frag_ep), ShaderStage::Vertex) = (frag_ep, stage) {
731 let mut fs_input_locs = Vec::new();
732 for arg in frag_ep.func.arguments.iter() {
733 let mut push_if_location = |binding: &Option<crate::Binding>| match *binding {
734 Some(crate::Binding::Location { location, .. }) => fs_input_locs.push(location),
735 Some(crate::Binding::BuiltIn(_)) | None => {}
736 };
737
738 match frag_ep.module.types[arg.ty].inner {
741 TypeInner::Struct { ref members, .. } => {
742 for member in members.iter() {
743 push_if_location(&member.binding);
744 }
745 }
746 _ => push_if_location(&arg.binding),
747 }
748 }
749 fs_input_locs.sort();
750 Some(fs_input_locs)
751 } else {
752 None
753 };
754
755 let mut fake_members = Vec::new();
756 for (index, member) in members.iter().enumerate() {
757 if let Some(ref fs_input_locs) = fs_input_locs {
758 match member.binding {
759 Some(crate::Binding::Location { location, .. }) => {
760 if fs_input_locs.binary_search(&location).is_err() {
761 continue;
762 }
763 }
764 Some(crate::Binding::BuiltIn(_)) | None => {}
765 }
766 }
767
768 let member_name = self.namer.call_or(&member.name, "member");
769 fake_members.push(EpStructMember {
770 name: member_name,
771 ty: member.ty,
772 binding: member.binding.clone(),
773 index: index as u32,
774 });
775 }
776
777 self.write_interface_struct(module, (stage, Io::Output), struct_name, fake_members)
778 }
779
780 fn write_ep_interface(
784 &mut self,
785 module: &Module,
786 func: &crate::Function,
787 stage: ShaderStage,
788 ep_name: &str,
789 frag_ep: Option<&FragmentEntryPoint<'_>>,
790 ) -> Result<EntryPointInterface, Error> {
791 Ok(EntryPointInterface {
792 input: if !func.arguments.is_empty()
793 && (stage == ShaderStage::Fragment
794 || func
795 .arguments
796 .iter()
797 .any(|arg| is_subgroup_builtin_binding(&arg.binding)))
798 {
799 Some(self.write_ep_input_struct(module, func, stage, ep_name)?)
800 } else {
801 None
802 },
803 output: match func.result {
804 Some(ref fr) if fr.binding.is_none() && stage == ShaderStage::Vertex => {
805 Some(self.write_ep_output_struct(module, fr, stage, ep_name, frag_ep)?)
806 }
807 _ => None,
808 },
809 })
810 }
811
812 fn write_ep_argument_initialization(
813 &mut self,
814 ep: &crate::EntryPoint,
815 ep_input: &EntryPointBinding,
816 fake_member: &EpStructMember,
817 ) -> BackendResult {
818 match fake_member.binding {
819 Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupSize)) => {
820 write!(self.out, "WaveGetLaneCount()")?
821 }
822 Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupInvocationId)) => {
823 write!(self.out, "WaveGetLaneIndex()")?
824 }
825 Some(crate::Binding::BuiltIn(crate::BuiltIn::NumSubgroups)) => write!(
826 self.out,
827 "({}u + WaveGetLaneCount() - 1u) / WaveGetLaneCount()",
828 ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2]
829 )?,
830 Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => {
831 write!(
832 self.out,
833 "{}.__local_invocation_index / WaveGetLaneCount()",
834 ep_input.arg_name
835 )?;
836 }
837 _ => {
838 write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?;
839 }
840 }
841 Ok(())
842 }
843
844 fn write_ep_arguments_initialization(
846 &mut self,
847 module: &Module,
848 func: &crate::Function,
849 ep_index: u16,
850 ) -> BackendResult {
851 let ep = &module.entry_points[ep_index as usize];
852 let ep_input = match self
853 .entry_point_io
854 .get_mut(&(ep_index as usize))
855 .unwrap()
856 .input
857 .take()
858 {
859 Some(ep_input) => ep_input,
860 None => return Ok(()),
861 };
862 let mut fake_iter = ep_input.members.iter();
863 for (arg_index, arg) in func.arguments.iter().enumerate() {
864 write!(self.out, "{}", back::INDENT)?;
865 self.write_type(module, arg.ty)?;
866 let arg_name = &self.names[&NameKey::EntryPointArgument(ep_index, arg_index as u32)];
867 write!(self.out, " {arg_name}")?;
868 match module.types[arg.ty].inner {
869 TypeInner::Array { base, size, .. } => {
870 self.write_array_size(module, base, size)?;
871 write!(self.out, " = ")?;
872 self.write_ep_argument_initialization(
873 ep,
874 &ep_input,
875 fake_iter.next().unwrap(),
876 )?;
877 writeln!(self.out, ";")?;
878 }
879 TypeInner::Struct { ref members, .. } => {
880 write!(self.out, " = {{ ")?;
881 for index in 0..members.len() {
882 if index != 0 {
883 write!(self.out, ", ")?;
884 }
885 self.write_ep_argument_initialization(
886 ep,
887 &ep_input,
888 fake_iter.next().unwrap(),
889 )?;
890 }
891 writeln!(self.out, " }};")?;
892 }
893 _ => {
894 write!(self.out, " = ")?;
895 self.write_ep_argument_initialization(
896 ep,
897 &ep_input,
898 fake_iter.next().unwrap(),
899 )?;
900 writeln!(self.out, ";")?;
901 }
902 }
903 }
904 assert!(fake_iter.next().is_none());
905 Ok(())
906 }
907
908 fn write_global(
912 &mut self,
913 module: &Module,
914 handle: Handle<crate::GlobalVariable>,
915 ) -> BackendResult {
916 let global = &module.global_variables[handle];
917 let inner = &module.types[global.ty].inner;
918
919 let handle_ty = match *inner {
920 TypeInner::BindingArray { ref base, .. } => &module.types[*base].inner,
921 _ => inner,
922 };
923
924 let is_external_texture = matches!(
928 *handle_ty,
929 TypeInner::Image {
930 class: crate::ImageClass::External,
931 ..
932 }
933 );
934 if is_external_texture {
935 return self.write_global_external_texture(module, handle, global);
936 }
937
938 if let Some(ref binding) = global.binding {
939 if let Err(err) = self.options.resolve_resource_binding(binding) {
940 log::info!(
941 "Skipping global {:?} (name {:?}) for being inaccessible: {}",
942 handle,
943 global.name,
944 err,
945 );
946 return Ok(());
947 }
948 }
949
950 let is_sampler = matches!(*handle_ty, TypeInner::Sampler { .. });
952
953 if is_sampler {
954 return self.write_global_sampler(module, handle, global);
955 }
956
957 let register_ty = match global.space {
959 crate::AddressSpace::Function => unreachable!("Function address space"),
960 crate::AddressSpace::Private => {
961 write!(self.out, "static ")?;
962 self.write_type(module, global.ty)?;
963 ""
964 }
965 crate::AddressSpace::WorkGroup => {
966 write!(self.out, "groupshared ")?;
967 self.write_type(module, global.ty)?;
968 ""
969 }
970 crate::AddressSpace::Uniform => {
971 write!(self.out, "cbuffer")?;
974 "b"
975 }
976 crate::AddressSpace::Storage { access } => {
977 let (prefix, register) = if access.contains(crate::StorageAccess::STORE) {
978 ("RW", "u")
979 } else {
980 ("", "t")
981 };
982 write!(self.out, "{prefix}ByteAddressBuffer")?;
983 register
984 }
985 crate::AddressSpace::Handle => {
986 let register = match *handle_ty {
987 TypeInner::Image {
989 class: crate::ImageClass::Storage { .. },
990 ..
991 } => "u",
992 _ => "t",
993 };
994 self.write_type(module, global.ty)?;
995 register
996 }
997 crate::AddressSpace::PushConstant => {
998 write!(self.out, "ConstantBuffer<")?;
1000 "b"
1001 }
1002 };
1003
1004 if global.space == crate::AddressSpace::PushConstant {
1007 self.write_global_type(module, global.ty)?;
1008
1009 if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1011 self.write_array_size(module, base, size)?;
1012 }
1013
1014 write!(self.out, ">")?;
1016 }
1017
1018 let name = &self.names[&NameKey::GlobalVariable(handle)];
1019 write!(self.out, " {name}")?;
1020
1021 if global.space == crate::AddressSpace::PushConstant {
1024 match module.types[global.ty].inner {
1025 TypeInner::Struct { .. } => {}
1026 _ => {
1027 return Err(Error::Unimplemented(format!(
1028 "push-constant '{name}' has non-struct type; tracked by: https://github.com/gfx-rs/wgpu/issues/5683"
1029 )));
1030 }
1031 }
1032
1033 let target = self
1034 .options
1035 .push_constants_target
1036 .as_ref()
1037 .expect("No bind target was defined for the push constants block");
1038 write!(self.out, ": register(b{}", target.register)?;
1039 if target.space != 0 {
1040 write!(self.out, ", space{}", target.space)?;
1041 }
1042 write!(self.out, ")")?;
1043 }
1044
1045 if let Some(ref binding) = global.binding {
1046 let bt = self.options.resolve_resource_binding(binding).unwrap();
1048
1049 if let TypeInner::BindingArray { base, size, .. } = module.types[global.ty].inner {
1051 if let Some(overridden_size) = bt.binding_array_size {
1052 write!(self.out, "[{overridden_size}]")?;
1053 } else {
1054 self.write_array_size(module, base, size)?;
1055 }
1056 }
1057
1058 write!(self.out, " : register({}{}", register_ty, bt.register)?;
1059 if bt.space != 0 {
1060 write!(self.out, ", space{}", bt.space)?;
1061 }
1062 write!(self.out, ")")?;
1063 } else {
1064 if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1066 self.write_array_size(module, base, size)?;
1067 }
1068 if global.space == crate::AddressSpace::Private {
1069 write!(self.out, " = ")?;
1070 if let Some(init) = global.init {
1071 self.write_const_expression(module, init, &module.global_expressions)?;
1072 } else {
1073 self.write_default_init(module, global.ty)?;
1074 }
1075 }
1076 }
1077
1078 if global.space == crate::AddressSpace::Uniform {
1079 write!(self.out, " {{ ")?;
1080
1081 self.write_global_type(module, global.ty)?;
1082
1083 write!(
1084 self.out,
1085 " {}",
1086 &self.names[&NameKey::GlobalVariable(handle)]
1087 )?;
1088
1089 if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1091 self.write_array_size(module, base, size)?;
1092 }
1093
1094 writeln!(self.out, "; }}")?;
1095 } else {
1096 writeln!(self.out, ";")?;
1097 }
1098
1099 Ok(())
1100 }
1101
1102 fn write_global_sampler(
1103 &mut self,
1104 module: &Module,
1105 handle: Handle<crate::GlobalVariable>,
1106 global: &crate::GlobalVariable,
1107 ) -> BackendResult {
1108 let binding = *global.binding.as_ref().unwrap();
1109
1110 let key = super::SamplerIndexBufferKey {
1111 group: binding.group,
1112 };
1113 self.write_wrapped_sampler_buffer(key)?;
1114
1115 let bt = self.options.resolve_resource_binding(&binding).unwrap();
1117
1118 match module.types[global.ty].inner {
1119 TypeInner::Sampler { comparison } => {
1120 write!(self.out, "static const ")?;
1127 self.write_type(module, global.ty)?;
1128
1129 let heap_var = if comparison {
1130 COMPARISON_SAMPLER_HEAP_VAR
1131 } else {
1132 SAMPLER_HEAP_VAR
1133 };
1134
1135 let index_buffer_name = &self.wrapped.sampler_index_buffers[&key];
1136 let name = &self.names[&NameKey::GlobalVariable(handle)];
1137 writeln!(
1138 self.out,
1139 " {name} = {heap_var}[{index_buffer_name}[{register}]];",
1140 register = bt.register
1141 )?;
1142 }
1143 TypeInner::BindingArray { .. } => {
1144 let name = &self.names[&NameKey::GlobalVariable(handle)];
1150 writeln!(
1151 self.out,
1152 "static const uint {name} = {register};",
1153 register = bt.register
1154 )?;
1155 }
1156 _ => unreachable!(),
1157 };
1158
1159 Ok(())
1160 }
1161
1162 fn write_global_external_texture(
1166 &mut self,
1167 module: &Module,
1168 handle: Handle<crate::GlobalVariable>,
1169 global: &crate::GlobalVariable,
1170 ) -> BackendResult {
1171 let res_binding = global
1172 .binding
1173 .as_ref()
1174 .expect("External texture global variables must have a resource binding");
1175 let ext_tex_bindings = match self
1176 .options
1177 .resolve_external_texture_resource_binding(res_binding)
1178 {
1179 Ok(bindings) => bindings,
1180 Err(err) => {
1181 log::info!(
1182 "Skipping global {:?} (name {:?}) for being inaccessible: {}",
1183 handle,
1184 global.name,
1185 err,
1186 );
1187 return Ok(());
1188 }
1189 };
1190
1191 let mut write_plane = |bt: &super::BindTarget, name| -> BackendResult {
1192 write!(
1193 self.out,
1194 "Texture2D<float4> {}: register(t{}",
1195 name, bt.register
1196 )?;
1197 if bt.space != 0 {
1198 write!(self.out, ", space{}", bt.space)?;
1199 }
1200 writeln!(self.out, ");")?;
1201 Ok(())
1202 };
1203 for (i, bt) in ext_tex_bindings.planes.iter().enumerate() {
1204 let plane_name = &self.names
1205 [&NameKey::ExternalTextureGlobalVariable(handle, ExternalTextureNameKey::Plane(i))];
1206 write_plane(bt, plane_name)?;
1207 }
1208
1209 let params_name = &self.names
1210 [&NameKey::ExternalTextureGlobalVariable(handle, ExternalTextureNameKey::Params)];
1211 let params_ty_name =
1212 &self.names[&NameKey::Type(module.special_types.external_texture_params.unwrap())];
1213 write!(
1214 self.out,
1215 "cbuffer {}: register(b{}",
1216 params_name, ext_tex_bindings.params.register
1217 )?;
1218 if ext_tex_bindings.params.space != 0 {
1219 write!(self.out, ", space{}", ext_tex_bindings.params.space)?;
1220 }
1221 writeln!(self.out, ") {{ {params_ty_name} {params_name}; }};")?;
1222
1223 Ok(())
1224 }
1225
1226 fn write_global_constant(
1231 &mut self,
1232 module: &Module,
1233 handle: Handle<crate::Constant>,
1234 ) -> BackendResult {
1235 write!(self.out, "static const ")?;
1236 let constant = &module.constants[handle];
1237 self.write_type(module, constant.ty)?;
1238 let name = &self.names[&NameKey::Constant(handle)];
1239 write!(self.out, " {name}")?;
1240 if let TypeInner::Array { base, size, .. } = module.types[constant.ty].inner {
1242 self.write_array_size(module, base, size)?;
1243 }
1244 write!(self.out, " = ")?;
1245 self.write_const_expression(module, constant.init, &module.global_expressions)?;
1246 writeln!(self.out, ";")?;
1247 Ok(())
1248 }
1249
1250 pub(super) fn write_array_size(
1251 &mut self,
1252 module: &Module,
1253 base: Handle<crate::Type>,
1254 size: crate::ArraySize,
1255 ) -> BackendResult {
1256 write!(self.out, "[")?;
1257
1258 match size.resolve(module.to_ctx())? {
1259 proc::IndexableLength::Known(size) => {
1260 write!(self.out, "{size}")?;
1261 }
1262 proc::IndexableLength::Dynamic => unreachable!(),
1263 }
1264
1265 write!(self.out, "]")?;
1266
1267 if let TypeInner::Array {
1268 base: next_base,
1269 size: next_size,
1270 ..
1271 } = module.types[base].inner
1272 {
1273 self.write_array_size(module, next_base, next_size)?;
1274 }
1275
1276 Ok(())
1277 }
1278
1279 fn write_struct(
1284 &mut self,
1285 module: &Module,
1286 handle: Handle<crate::Type>,
1287 members: &[crate::StructMember],
1288 span: u32,
1289 shader_stage: Option<(ShaderStage, Io)>,
1290 ) -> BackendResult {
1291 let struct_name = &self.names[&NameKey::Type(handle)];
1293 writeln!(self.out, "struct {struct_name} {{")?;
1294
1295 let mut last_offset = 0;
1296 for (index, member) in members.iter().enumerate() {
1297 if member.binding.is_none() && member.offset > last_offset {
1298 let padding = (member.offset - last_offset) / 4;
1302 for i in 0..padding {
1303 writeln!(self.out, "{}int _pad{}_{};", back::INDENT, index, i)?;
1304 }
1305 }
1306 let ty_inner = &module.types[member.ty].inner;
1307 last_offset = member.offset + ty_inner.size_hlsl(module.to_ctx())?;
1308
1309 write!(self.out, "{}", back::INDENT)?;
1311
1312 match module.types[member.ty].inner {
1313 TypeInner::Array { base, size, .. } => {
1314 self.write_global_type(module, member.ty)?;
1317
1318 write!(
1320 self.out,
1321 " {}",
1322 &self.names[&NameKey::StructMember(handle, index as u32)]
1323 )?;
1324 self.write_array_size(module, base, size)?;
1326 }
1327 TypeInner::Matrix {
1330 rows,
1331 columns,
1332 scalar,
1333 } if member.binding.is_none() && rows == crate::VectorSize::Bi => {
1334 let vec_ty = TypeInner::Vector { size: rows, scalar };
1335 let field_name_key = NameKey::StructMember(handle, index as u32);
1336
1337 for i in 0..columns as u8 {
1338 if i != 0 {
1339 write!(self.out, "; ")?;
1340 }
1341 self.write_value_type(module, &vec_ty)?;
1342 write!(self.out, " {}_{}", &self.names[&field_name_key], i)?;
1343 }
1344 }
1345 _ => {
1346 if let Some(ref binding) = member.binding {
1348 self.write_modifier(binding)?;
1349 }
1350
1351 if let TypeInner::Matrix { .. } = module.types[member.ty].inner {
1355 write!(self.out, "row_major ")?;
1356 }
1357
1358 self.write_type(module, member.ty)?;
1360 write!(
1361 self.out,
1362 " {}",
1363 &self.names[&NameKey::StructMember(handle, index as u32)]
1364 )?;
1365 }
1366 }
1367
1368 self.write_semantic(&member.binding, shader_stage)?;
1369 writeln!(self.out, ";")?;
1370 }
1371
1372 if members.last().unwrap().binding.is_none() && span > last_offset {
1374 let padding = (span - last_offset) / 4;
1375 for i in 0..padding {
1376 writeln!(self.out, "{}int _end_pad_{};", back::INDENT, i)?;
1377 }
1378 }
1379
1380 writeln!(self.out, "}};")?;
1381 Ok(())
1382 }
1383
1384 pub(super) fn write_global_type(
1389 &mut self,
1390 module: &Module,
1391 ty: Handle<crate::Type>,
1392 ) -> BackendResult {
1393 let matrix_data = get_inner_matrix_data(module, ty);
1394
1395 if let Some(MatrixType {
1398 columns,
1399 rows: crate::VectorSize::Bi,
1400 width: 4,
1401 }) = matrix_data
1402 {
1403 write!(self.out, "__mat{}x2", columns as u8)?;
1404 } else {
1405 if matrix_data.is_some() {
1409 write!(self.out, "row_major ")?;
1410 }
1411
1412 self.write_type(module, ty)?;
1413 }
1414
1415 Ok(())
1416 }
1417
1418 pub(super) fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
1423 let inner = &module.types[ty].inner;
1424 match *inner {
1425 TypeInner::Struct { .. } => write!(self.out, "{}", self.names[&NameKey::Type(ty)])?,
1426 TypeInner::Array { base, .. } | TypeInner::BindingArray { base, .. } => {
1428 self.write_type(module, base)?
1429 }
1430 ref other => self.write_value_type(module, other)?,
1431 }
1432
1433 Ok(())
1434 }
1435
1436 pub(super) fn write_value_type(&mut self, module: &Module, inner: &TypeInner) -> BackendResult {
1441 match *inner {
1442 TypeInner::Scalar(scalar) | TypeInner::Atomic(scalar) => {
1443 write!(self.out, "{}", scalar.to_hlsl_str()?)?;
1444 }
1445 TypeInner::Vector { size, scalar } => {
1446 write!(
1447 self.out,
1448 "{}{}",
1449 scalar.to_hlsl_str()?,
1450 common::vector_size_str(size)
1451 )?;
1452 }
1453 TypeInner::Matrix {
1454 columns,
1455 rows,
1456 scalar,
1457 } => {
1458 write!(
1463 self.out,
1464 "{}{}x{}",
1465 scalar.to_hlsl_str()?,
1466 common::vector_size_str(columns),
1467 common::vector_size_str(rows),
1468 )?;
1469 }
1470 TypeInner::Image {
1471 dim,
1472 arrayed,
1473 class,
1474 } => {
1475 self.write_image_type(dim, arrayed, class)?;
1476 }
1477 TypeInner::Sampler { comparison } => {
1478 let sampler = if comparison {
1479 "SamplerComparisonState"
1480 } else {
1481 "SamplerState"
1482 };
1483 write!(self.out, "{sampler}")?;
1484 }
1485 TypeInner::Array { base, size, .. } | TypeInner::BindingArray { base, size } => {
1489 self.write_array_size(module, base, size)?;
1490 }
1491 TypeInner::AccelerationStructure { .. } => {
1492 write!(self.out, "RaytracingAccelerationStructure")?;
1493 }
1494 TypeInner::RayQuery { .. } => {
1495 write!(self.out, "RayQuery<RAY_FLAG_NONE>")?;
1497 }
1498 _ => return Err(Error::Unimplemented(format!("write_value_type {inner:?}"))),
1499 }
1500
1501 Ok(())
1502 }
1503
1504 fn write_function(
1508 &mut self,
1509 module: &Module,
1510 name: &str,
1511 func: &crate::Function,
1512 func_ctx: &back::FunctionCtx<'_>,
1513 info: &valid::FunctionInfo,
1514 ) -> BackendResult {
1515 self.update_expressions_to_bake(module, func, info);
1518
1519 if let Some(ref result) = func.result {
1520 let array_return_type = match module.types[result.ty].inner {
1522 TypeInner::Array { base, size, .. } => {
1523 let array_return_type = self.namer.call(&format!("ret_{name}"));
1524 write!(self.out, "typedef ")?;
1525 self.write_type(module, result.ty)?;
1526 write!(self.out, " {array_return_type}")?;
1527 self.write_array_size(module, base, size)?;
1528 writeln!(self.out, ";")?;
1529 Some(array_return_type)
1530 }
1531 _ => None,
1532 };
1533
1534 if let Some(
1536 ref binding @ crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }),
1537 ) = result.binding
1538 {
1539 self.write_modifier(binding)?;
1540 }
1541
1542 match func_ctx.ty {
1544 back::FunctionType::Function(_) => {
1545 if let Some(array_return_type) = array_return_type {
1546 write!(self.out, "{array_return_type}")?;
1547 } else {
1548 self.write_type(module, result.ty)?;
1549 }
1550 }
1551 back::FunctionType::EntryPoint(index) => {
1552 if let Some(ref ep_output) =
1553 self.entry_point_io.get(&(index as usize)).unwrap().output
1554 {
1555 write!(self.out, "{}", ep_output.ty_name)?;
1556 } else {
1557 self.write_type(module, result.ty)?;
1558 }
1559 }
1560 }
1561 } else {
1562 write!(self.out, "void")?;
1563 }
1564
1565 write!(self.out, " {name}(")?;
1567
1568 let need_workgroup_variables_initialization =
1569 self.need_workgroup_variables_initialization(func_ctx, module);
1570
1571 match func_ctx.ty {
1573 back::FunctionType::Function(handle) => {
1574 for (index, arg) in func.arguments.iter().enumerate() {
1575 if index != 0 {
1576 write!(self.out, ", ")?;
1577 }
1578
1579 self.write_function_argument(module, handle, arg, index)?;
1580 }
1581 }
1582 back::FunctionType::EntryPoint(ep_index) => {
1583 if let Some(ref ep_input) =
1584 self.entry_point_io.get(&(ep_index as usize)).unwrap().input
1585 {
1586 write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name)?;
1587 } else {
1588 let stage = module.entry_points[ep_index as usize].stage;
1589 for (index, arg) in func.arguments.iter().enumerate() {
1590 if index != 0 {
1591 write!(self.out, ", ")?;
1592 }
1593 self.write_type(module, arg.ty)?;
1594
1595 let argument_name =
1596 &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)];
1597
1598 write!(self.out, " {argument_name}")?;
1599 if let TypeInner::Array { base, size, .. } = module.types[arg.ty].inner {
1600 self.write_array_size(module, base, size)?;
1601 }
1602
1603 self.write_semantic(&arg.binding, Some((stage, Io::Input)))?;
1604 }
1605 }
1606 if need_workgroup_variables_initialization {
1607 if self
1608 .entry_point_io
1609 .get(&(ep_index as usize))
1610 .unwrap()
1611 .input
1612 .is_some()
1613 || !func.arguments.is_empty()
1614 {
1615 write!(self.out, ", ")?;
1616 }
1617 write!(self.out, "uint3 __local_invocation_id : SV_GroupThreadID")?;
1618 }
1619 }
1620 }
1621 write!(self.out, ")")?;
1623
1624 if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1626 let stage = module.entry_points[index as usize].stage;
1627 if let Some(crate::FunctionResult { ref binding, .. }) = func.result {
1628 self.write_semantic(binding, Some((stage, Io::Output)))?;
1629 }
1630 }
1631
1632 writeln!(self.out)?;
1634 writeln!(self.out, "{{")?;
1635
1636 if need_workgroup_variables_initialization {
1637 self.write_workgroup_variables_initialization(func_ctx, module)?;
1638 }
1639
1640 if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1641 self.write_ep_arguments_initialization(module, func, index)?;
1642 }
1643
1644 for (handle, local) in func.local_variables.iter() {
1646 write!(self.out, "{}", back::INDENT)?;
1648
1649 self.write_type(module, local.ty)?;
1652 write!(self.out, " {}", self.names[&func_ctx.name_key(handle)])?;
1653 if let TypeInner::Array { base, size, .. } = module.types[local.ty].inner {
1655 self.write_array_size(module, base, size)?;
1656 }
1657
1658 match module.types[local.ty].inner {
1659 TypeInner::RayQuery { .. } => {}
1661 _ => {
1662 write!(self.out, " = ")?;
1663 if let Some(init) = local.init {
1665 self.write_expr(module, init, func_ctx)?;
1666 } else {
1667 self.write_default_init(module, local.ty)?;
1669 }
1670 }
1671 }
1672 writeln!(self.out, ";")?
1674 }
1675
1676 if !func.local_variables.is_empty() {
1677 writeln!(self.out)?;
1678 }
1679
1680 for sta in func.body.iter() {
1682 self.write_stmt(module, sta, func_ctx, back::Level(1))?;
1684 }
1685
1686 writeln!(self.out, "}}")?;
1687
1688 self.named_expressions.clear();
1689
1690 Ok(())
1691 }
1692
1693 fn write_function_argument(
1694 &mut self,
1695 module: &Module,
1696 handle: Handle<crate::Function>,
1697 arg: &crate::FunctionArgument,
1698 index: usize,
1699 ) -> BackendResult {
1700 if let TypeInner::Image {
1703 class: crate::ImageClass::External,
1704 ..
1705 } = module.types[arg.ty].inner
1706 {
1707 return self.write_function_external_texture_argument(module, handle, index);
1708 }
1709
1710 let arg_ty = match module.types[arg.ty].inner {
1712 TypeInner::Pointer { base, .. } => {
1714 write!(self.out, "inout ")?;
1716 base
1717 }
1718 _ => arg.ty,
1719 };
1720 self.write_type(module, arg_ty)?;
1721
1722 let argument_name = &self.names[&NameKey::FunctionArgument(handle, index as u32)];
1723
1724 write!(self.out, " {argument_name}")?;
1726 if let TypeInner::Array { base, size, .. } = module.types[arg_ty].inner {
1727 self.write_array_size(module, base, size)?;
1728 }
1729
1730 Ok(())
1731 }
1732
1733 fn write_function_external_texture_argument(
1734 &mut self,
1735 module: &Module,
1736 handle: Handle<crate::Function>,
1737 index: usize,
1738 ) -> BackendResult {
1739 let plane_names = [0, 1, 2].map(|i| {
1740 &self.names[&NameKey::ExternalTextureFunctionArgument(
1741 handle,
1742 index as u32,
1743 ExternalTextureNameKey::Plane(i),
1744 )]
1745 });
1746 let params_name = &self.names[&NameKey::ExternalTextureFunctionArgument(
1747 handle,
1748 index as u32,
1749 ExternalTextureNameKey::Params,
1750 )];
1751 let params_ty_name =
1752 &self.names[&NameKey::Type(module.special_types.external_texture_params.unwrap())];
1753 write!(
1754 self.out,
1755 "Texture2D<float4> {}, Texture2D<float4> {}, Texture2D<float4> {}, {params_ty_name} {params_name}",
1756 plane_names[0], plane_names[1], plane_names[2],
1757 )?;
1758 Ok(())
1759 }
1760
1761 fn need_workgroup_variables_initialization(
1762 &mut self,
1763 func_ctx: &back::FunctionCtx,
1764 module: &Module,
1765 ) -> bool {
1766 self.options.zero_initialize_workgroup_memory
1767 && func_ctx.ty.is_compute_entry_point(module)
1768 && module.global_variables.iter().any(|(handle, var)| {
1769 !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
1770 })
1771 }
1772
1773 fn write_workgroup_variables_initialization(
1774 &mut self,
1775 func_ctx: &back::FunctionCtx,
1776 module: &Module,
1777 ) -> BackendResult {
1778 let level = back::Level(1);
1779
1780 writeln!(
1781 self.out,
1782 "{level}if (all(__local_invocation_id == uint3(0u, 0u, 0u))) {{"
1783 )?;
1784
1785 let vars = module.global_variables.iter().filter(|&(handle, var)| {
1786 !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
1787 });
1788
1789 for (handle, var) in vars {
1790 let name = &self.names[&NameKey::GlobalVariable(handle)];
1791 write!(self.out, "{}{} = ", level.next(), name)?;
1792 self.write_default_init(module, var.ty)?;
1793 writeln!(self.out, ";")?;
1794 }
1795
1796 writeln!(self.out, "{level}}}")?;
1797 self.write_control_barrier(crate::Barrier::WORK_GROUP, level)
1798 }
1799
1800 fn write_switch(
1802 &mut self,
1803 module: &Module,
1804 func_ctx: &back::FunctionCtx<'_>,
1805 level: back::Level,
1806 selector: Handle<crate::Expression>,
1807 cases: &[crate::SwitchCase],
1808 ) -> BackendResult {
1809 let indent_level_1 = level.next();
1811 let indent_level_2 = indent_level_1.next();
1812
1813 if let Some(variable) = self.continue_ctx.enter_switch(&mut self.namer) {
1815 writeln!(self.out, "{level}bool {variable} = false;",)?;
1816 };
1817
1818 let one_body = cases
1823 .iter()
1824 .rev()
1825 .skip(1)
1826 .all(|case| case.fall_through && case.body.is_empty());
1827 if one_body {
1828 writeln!(self.out, "{level}do {{")?;
1830 if let Some(case) = cases.last() {
1834 for sta in case.body.iter() {
1835 self.write_stmt(module, sta, func_ctx, indent_level_1)?;
1836 }
1837 }
1838 writeln!(self.out, "{level}}} while(false);")?;
1840 } else {
1841 write!(self.out, "{level}")?;
1843 write!(self.out, "switch(")?;
1844 self.write_expr(module, selector, func_ctx)?;
1845 writeln!(self.out, ") {{")?;
1846
1847 for (i, case) in cases.iter().enumerate() {
1848 match case.value {
1849 crate::SwitchValue::I32(value) => {
1850 write!(self.out, "{indent_level_1}case {value}:")?
1851 }
1852 crate::SwitchValue::U32(value) => {
1853 write!(self.out, "{indent_level_1}case {value}u:")?
1854 }
1855 crate::SwitchValue::Default => write!(self.out, "{indent_level_1}default:")?,
1856 }
1857
1858 let write_block_braces = !(case.fall_through && case.body.is_empty());
1865 if write_block_braces {
1866 writeln!(self.out, " {{")?;
1867 } else {
1868 writeln!(self.out)?;
1869 }
1870
1871 if case.fall_through && !case.body.is_empty() {
1889 let curr_len = i + 1;
1890 let end_case_idx = curr_len
1891 + cases
1892 .iter()
1893 .skip(curr_len)
1894 .position(|case| !case.fall_through)
1895 .unwrap();
1896 let indent_level_3 = indent_level_2.next();
1897 for case in &cases[i..=end_case_idx] {
1898 writeln!(self.out, "{indent_level_2}{{")?;
1899 let prev_len = self.named_expressions.len();
1900 for sta in case.body.iter() {
1901 self.write_stmt(module, sta, func_ctx, indent_level_3)?;
1902 }
1903 self.named_expressions.truncate(prev_len);
1905 writeln!(self.out, "{indent_level_2}}}")?;
1906 }
1907
1908 let last_case = &cases[end_case_idx];
1909 if last_case.body.last().is_none_or(|s| !s.is_terminator()) {
1910 writeln!(self.out, "{indent_level_2}break;")?;
1911 }
1912 } else {
1913 for sta in case.body.iter() {
1914 self.write_stmt(module, sta, func_ctx, indent_level_2)?;
1915 }
1916 if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator()) {
1917 writeln!(self.out, "{indent_level_2}break;")?;
1918 }
1919 }
1920
1921 if write_block_braces {
1922 writeln!(self.out, "{indent_level_1}}}")?;
1923 }
1924 }
1925
1926 writeln!(self.out, "{level}}}")?;
1927 }
1928
1929 use back::continue_forward::ExitControlFlow;
1931 let op = match self.continue_ctx.exit_switch() {
1932 ExitControlFlow::None => None,
1933 ExitControlFlow::Continue { variable } => Some(("continue", variable)),
1934 ExitControlFlow::Break { variable } => Some(("break", variable)),
1935 };
1936 if let Some((control_flow, variable)) = op {
1937 writeln!(self.out, "{level}if ({variable}) {{")?;
1938 writeln!(self.out, "{indent_level_1}{control_flow};")?;
1939 writeln!(self.out, "{level}}}")?;
1940 }
1941
1942 Ok(())
1943 }
1944
1945 fn write_index(
1946 &mut self,
1947 module: &Module,
1948 index: Index,
1949 func_ctx: &back::FunctionCtx<'_>,
1950 ) -> BackendResult {
1951 match index {
1952 Index::Static(index) => {
1953 write!(self.out, "{index}")?;
1954 }
1955 Index::Expression(index) => {
1956 self.write_expr(module, index, func_ctx)?;
1957 }
1958 }
1959 Ok(())
1960 }
1961
1962 fn write_stmt(
1967 &mut self,
1968 module: &Module,
1969 stmt: &crate::Statement,
1970 func_ctx: &back::FunctionCtx<'_>,
1971 level: back::Level,
1972 ) -> BackendResult {
1973 use crate::Statement;
1974
1975 match *stmt {
1976 Statement::Emit(ref range) => {
1977 for handle in range.clone() {
1978 let ptr_class = func_ctx.resolve_type(handle, &module.types).pointer_space();
1979 let expr_name = if ptr_class.is_some() {
1980 None
1984 } else if let Some(name) = func_ctx.named_expressions.get(&handle) {
1985 Some(self.namer.call(name))
1990 } else if self.need_bake_expressions.contains(&handle) {
1991 Some(Baked(handle).to_string())
1992 } else {
1993 None
1994 };
1995
1996 if let Some(name) = expr_name {
1997 write!(self.out, "{level}")?;
1998 self.write_named_expr(module, handle, name, handle, func_ctx)?;
1999 }
2000 }
2001 }
2002 Statement::Block(ref block) => {
2004 write!(self.out, "{level}")?;
2005 writeln!(self.out, "{{")?;
2006 for sta in block.iter() {
2007 self.write_stmt(module, sta, func_ctx, level.next())?
2009 }
2010 writeln!(self.out, "{level}}}")?
2011 }
2012 Statement::If {
2014 condition,
2015 ref accept,
2016 ref reject,
2017 } => {
2018 write!(self.out, "{level}")?;
2019 write!(self.out, "if (")?;
2020 self.write_expr(module, condition, func_ctx)?;
2021 writeln!(self.out, ") {{")?;
2022
2023 let l2 = level.next();
2024 for sta in accept {
2025 self.write_stmt(module, sta, func_ctx, l2)?;
2027 }
2028
2029 if !reject.is_empty() {
2032 writeln!(self.out, "{level}}} else {{")?;
2033
2034 for sta in reject {
2035 self.write_stmt(module, sta, func_ctx, l2)?;
2037 }
2038 }
2039
2040 writeln!(self.out, "{level}}}")?
2041 }
2042 Statement::Kill => writeln!(self.out, "{level}discard;")?,
2044 Statement::Return { value: None } => {
2045 writeln!(self.out, "{level}return;")?;
2046 }
2047 Statement::Return { value: Some(expr) } => {
2048 let base_ty_res = &func_ctx.info[expr].ty;
2049 let mut resolved = base_ty_res.inner_with(&module.types);
2050 if let TypeInner::Pointer { base, space: _ } = *resolved {
2051 resolved = &module.types[base].inner;
2052 }
2053
2054 if let TypeInner::Struct { .. } = *resolved {
2055 let ty = base_ty_res.handle().unwrap();
2057 let struct_name = &self.names[&NameKey::Type(ty)];
2058 let variable_name = self.namer.call(&struct_name.to_lowercase());
2059 write!(self.out, "{level}const {struct_name} {variable_name} = ",)?;
2060 self.write_expr(module, expr, func_ctx)?;
2061 writeln!(self.out, ";")?;
2062
2063 let ep_output = match func_ctx.ty {
2065 back::FunctionType::Function(_) => None,
2066 back::FunctionType::EntryPoint(index) => self
2067 .entry_point_io
2068 .get(&(index as usize))
2069 .unwrap()
2070 .output
2071 .as_ref(),
2072 };
2073 let final_name = match ep_output {
2074 Some(ep_output) => {
2075 let final_name = self.namer.call(&variable_name);
2076 write!(
2077 self.out,
2078 "{}const {} {} = {{ ",
2079 level, ep_output.ty_name, final_name,
2080 )?;
2081 for (index, m) in ep_output.members.iter().enumerate() {
2082 if index != 0 {
2083 write!(self.out, ", ")?;
2084 }
2085 let member_name = &self.names[&NameKey::StructMember(ty, m.index)];
2086 write!(self.out, "{variable_name}.{member_name}")?;
2087 }
2088 writeln!(self.out, " }};")?;
2089 final_name
2090 }
2091 None => variable_name,
2092 };
2093 writeln!(self.out, "{level}return {final_name};")?;
2094 } else {
2095 write!(self.out, "{level}return ")?;
2096 self.write_expr(module, expr, func_ctx)?;
2097 writeln!(self.out, ";")?
2098 }
2099 }
2100 Statement::Store { pointer, value } => {
2101 let ty_inner = func_ctx.resolve_type(pointer, &module.types);
2102 if let Some(crate::AddressSpace::Storage { .. }) = ty_inner.pointer_space() {
2103 let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2104 self.write_storage_store(
2105 module,
2106 var_handle,
2107 StoreValue::Expression(value),
2108 func_ctx,
2109 level,
2110 None,
2111 )?;
2112 } else {
2113 enum MatrixAccess {
2119 Direct {
2120 base: Handle<crate::Expression>,
2121 index: u32,
2122 },
2123 Struct {
2124 columns: crate::VectorSize,
2125 base: Handle<crate::Expression>,
2126 },
2127 }
2128
2129 let get_members = |expr: Handle<crate::Expression>| {
2130 let resolved = func_ctx.resolve_type(expr, &module.types);
2131 match *resolved {
2132 TypeInner::Pointer { base, .. } => match module.types[base].inner {
2133 TypeInner::Struct { ref members, .. } => Some(members),
2134 _ => None,
2135 },
2136 _ => None,
2137 }
2138 };
2139
2140 write!(self.out, "{level}")?;
2141
2142 let matrix_access_on_lhs =
2143 find_matrix_in_access_chain(module, pointer, func_ctx).and_then(
2144 |(matrix_expr, vector, scalar)| match (
2145 func_ctx.resolve_type(matrix_expr, &module.types),
2146 &func_ctx.expressions[matrix_expr],
2147 ) {
2148 (
2149 &TypeInner::Pointer { base: ty, .. },
2150 &crate::Expression::AccessIndex { base, index },
2151 ) if matches!(
2152 module.types[ty].inner,
2153 TypeInner::Matrix {
2154 rows: crate::VectorSize::Bi,
2155 ..
2156 }
2157 ) && get_members(base)
2158 .map(|members| members[index as usize].binding.is_none())
2159 == Some(true) =>
2160 {
2161 Some((MatrixAccess::Direct { base, index }, vector, scalar))
2162 }
2163 _ => {
2164 if let Some(MatrixType {
2165 columns,
2166 rows: crate::VectorSize::Bi,
2167 width: 4,
2168 }) = get_inner_matrix_of_struct_array_member(
2169 module,
2170 matrix_expr,
2171 func_ctx,
2172 true,
2173 ) {
2174 Some((
2175 MatrixAccess::Struct {
2176 columns,
2177 base: matrix_expr,
2178 },
2179 vector,
2180 scalar,
2181 ))
2182 } else {
2183 None
2184 }
2185 }
2186 },
2187 );
2188
2189 match matrix_access_on_lhs {
2190 Some((MatrixAccess::Direct { index, base }, vector, scalar)) => {
2191 let base_ty_res = &func_ctx.info[base].ty;
2192 let resolved = base_ty_res.inner_with(&module.types);
2193 let ty = match *resolved {
2194 TypeInner::Pointer { base, .. } => base,
2195 _ => base_ty_res.handle().unwrap(),
2196 };
2197
2198 if let Some(Index::Static(vec_index)) = vector {
2199 self.write_expr(module, base, func_ctx)?;
2200 write!(
2201 self.out,
2202 ".{}_{}",
2203 &self.names[&NameKey::StructMember(ty, index)],
2204 vec_index
2205 )?;
2206
2207 if let Some(scalar_index) = scalar {
2208 write!(self.out, "[")?;
2209 self.write_index(module, scalar_index, func_ctx)?;
2210 write!(self.out, "]")?;
2211 }
2212
2213 write!(self.out, " = ")?;
2214 self.write_expr(module, value, func_ctx)?;
2215 writeln!(self.out, ";")?;
2216 } else {
2217 let access = WrappedStructMatrixAccess { ty, index };
2218 match (&vector, &scalar) {
2219 (&Some(_), &Some(_)) => {
2220 self.write_wrapped_struct_matrix_set_scalar_function_name(
2221 access,
2222 )?;
2223 }
2224 (&Some(_), &None) => {
2225 self.write_wrapped_struct_matrix_set_vec_function_name(
2226 access,
2227 )?;
2228 }
2229 (&None, _) => {
2230 self.write_wrapped_struct_matrix_set_function_name(access)?;
2231 }
2232 }
2233
2234 write!(self.out, "(")?;
2235 self.write_expr(module, base, func_ctx)?;
2236 write!(self.out, ", ")?;
2237 self.write_expr(module, value, func_ctx)?;
2238
2239 if let Some(Index::Expression(vec_index)) = vector {
2240 write!(self.out, ", ")?;
2241 self.write_expr(module, vec_index, func_ctx)?;
2242
2243 if let Some(scalar_index) = scalar {
2244 write!(self.out, ", ")?;
2245 self.write_index(module, scalar_index, func_ctx)?;
2246 }
2247 }
2248 writeln!(self.out, ");")?;
2249 }
2250 }
2251 Some((
2252 MatrixAccess::Struct { columns, base },
2253 Some(Index::Expression(vec_index)),
2254 scalar,
2255 )) => {
2256 if scalar.is_some() {
2260 write!(self.out, "__set_el_of_mat{}x2", columns as u8)?;
2261 } else {
2262 write!(self.out, "__set_col_of_mat{}x2", columns as u8)?;
2263 }
2264 write!(self.out, "(")?;
2265 self.write_expr(module, base, func_ctx)?;
2266 write!(self.out, ", ")?;
2267 self.write_expr(module, vec_index, func_ctx)?;
2268
2269 if let Some(scalar_index) = scalar {
2270 write!(self.out, ", ")?;
2271 self.write_index(module, scalar_index, func_ctx)?;
2272 }
2273
2274 write!(self.out, ", ")?;
2275 self.write_expr(module, value, func_ctx)?;
2276
2277 writeln!(self.out, ");")?;
2278 }
2279 Some((MatrixAccess::Struct { .. }, Some(Index::Static(_)), _))
2280 | Some((MatrixAccess::Struct { .. }, None, _))
2281 | None => {
2282 self.write_expr(module, pointer, func_ctx)?;
2283 write!(self.out, " = ")?;
2284
2285 if let Some(MatrixType {
2290 columns,
2291 rows: crate::VectorSize::Bi,
2292 width: 4,
2293 }) = get_inner_matrix_of_struct_array_member(
2294 module, pointer, func_ctx, false,
2295 ) {
2296 let mut resolved = func_ctx.resolve_type(pointer, &module.types);
2297 if let TypeInner::Pointer { base, .. } = *resolved {
2298 resolved = &module.types[base].inner;
2299 }
2300
2301 write!(self.out, "(__mat{}x2", columns as u8)?;
2302 if let TypeInner::Array { base, size, .. } = *resolved {
2303 self.write_array_size(module, base, size)?;
2304 }
2305 write!(self.out, ")")?;
2306 }
2307
2308 self.write_expr(module, value, func_ctx)?;
2309 writeln!(self.out, ";")?
2310 }
2311 }
2312 }
2313 }
2314 Statement::Loop {
2315 ref body,
2316 ref continuing,
2317 break_if,
2318 } => {
2319 let force_loop_bound_statements = self.gen_force_bounded_loop_statements(level);
2320 let gate_name = (!continuing.is_empty() || break_if.is_some())
2321 .then(|| self.namer.call("loop_init"));
2322
2323 if let Some((ref decl, _)) = force_loop_bound_statements {
2324 writeln!(self.out, "{decl}")?;
2325 }
2326 if let Some(ref gate_name) = gate_name {
2327 writeln!(self.out, "{level}bool {gate_name} = true;")?;
2328 }
2329
2330 self.continue_ctx.enter_loop();
2331 writeln!(self.out, "{level}while(true) {{")?;
2332 if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
2333 writeln!(self.out, "{break_and_inc}")?;
2334 }
2335 let l2 = level.next();
2336 if let Some(gate_name) = gate_name {
2337 writeln!(self.out, "{l2}if (!{gate_name}) {{")?;
2338 let l3 = l2.next();
2339 for sta in continuing.iter() {
2340 self.write_stmt(module, sta, func_ctx, l3)?;
2341 }
2342 if let Some(condition) = break_if {
2343 write!(self.out, "{l3}if (")?;
2344 self.write_expr(module, condition, func_ctx)?;
2345 writeln!(self.out, ") {{")?;
2346 writeln!(self.out, "{}break;", l3.next())?;
2347 writeln!(self.out, "{l3}}}")?;
2348 }
2349 writeln!(self.out, "{l2}}}")?;
2350 writeln!(self.out, "{l2}{gate_name} = false;")?;
2351 }
2352
2353 for sta in body.iter() {
2354 self.write_stmt(module, sta, func_ctx, l2)?;
2355 }
2356
2357 writeln!(self.out, "{level}}}")?;
2358 self.continue_ctx.exit_loop();
2359 }
2360 Statement::Break => writeln!(self.out, "{level}break;")?,
2361 Statement::Continue => {
2362 if let Some(variable) = self.continue_ctx.continue_encountered() {
2363 writeln!(self.out, "{level}{variable} = true;")?;
2364 writeln!(self.out, "{level}break;")?
2365 } else {
2366 writeln!(self.out, "{level}continue;")?
2367 }
2368 }
2369 Statement::ControlBarrier(barrier) => {
2370 self.write_control_barrier(barrier, level)?;
2371 }
2372 Statement::MemoryBarrier(barrier) => {
2373 self.write_memory_barrier(barrier, level)?;
2374 }
2375 Statement::ImageStore {
2376 image,
2377 coordinate,
2378 array_index,
2379 value,
2380 } => {
2381 write!(self.out, "{level}")?;
2382 self.write_expr(module, image, func_ctx)?;
2383
2384 write!(self.out, "[")?;
2385 if let Some(index) = array_index {
2386 write!(self.out, "int3(")?;
2388 self.write_expr(module, coordinate, func_ctx)?;
2389 write!(self.out, ", ")?;
2390 self.write_expr(module, index, func_ctx)?;
2391 write!(self.out, ")")?;
2392 } else {
2393 self.write_expr(module, coordinate, func_ctx)?;
2394 }
2395 write!(self.out, "]")?;
2396
2397 write!(self.out, " = ")?;
2398 self.write_expr(module, value, func_ctx)?;
2399 writeln!(self.out, ";")?;
2400 }
2401 Statement::Call {
2402 function,
2403 ref arguments,
2404 result,
2405 } => {
2406 write!(self.out, "{level}")?;
2407 if let Some(expr) = result {
2408 write!(self.out, "const ")?;
2409 let name = Baked(expr).to_string();
2410 let expr_ty = &func_ctx.info[expr].ty;
2411 let ty_inner = match *expr_ty {
2412 proc::TypeResolution::Handle(handle) => {
2413 self.write_type(module, handle)?;
2414 &module.types[handle].inner
2415 }
2416 proc::TypeResolution::Value(ref value) => {
2417 self.write_value_type(module, value)?;
2418 value
2419 }
2420 };
2421 write!(self.out, " {name}")?;
2422 if let TypeInner::Array { base, size, .. } = *ty_inner {
2423 self.write_array_size(module, base, size)?;
2424 }
2425 write!(self.out, " = ")?;
2426 self.named_expressions.insert(expr, name);
2427 }
2428 let func_name = &self.names[&NameKey::Function(function)];
2429 write!(self.out, "{func_name}(")?;
2430 for (index, argument) in arguments.iter().enumerate() {
2431 if index != 0 {
2432 write!(self.out, ", ")?;
2433 }
2434 self.write_expr(module, *argument, func_ctx)?;
2435 }
2436 writeln!(self.out, ");")?
2437 }
2438 Statement::Atomic {
2439 pointer,
2440 ref fun,
2441 value,
2442 result,
2443 } => {
2444 write!(self.out, "{level}")?;
2445 let res_var_info = if let Some(res_handle) = result {
2446 let name = Baked(res_handle).to_string();
2447 match func_ctx.info[res_handle].ty {
2448 proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2449 proc::TypeResolution::Value(ref value) => {
2450 self.write_value_type(module, value)?
2451 }
2452 };
2453 write!(self.out, " {name}; ")?;
2454 self.named_expressions.insert(res_handle, name.clone());
2455 Some((res_handle, name))
2456 } else {
2457 None
2458 };
2459 let pointer_space = func_ctx
2460 .resolve_type(pointer, &module.types)
2461 .pointer_space()
2462 .unwrap();
2463 let fun_str = fun.to_hlsl_suffix();
2464 let compare_expr = match *fun {
2465 crate::AtomicFunction::Exchange { compare: Some(cmp) } => Some(cmp),
2466 _ => None,
2467 };
2468 match pointer_space {
2469 crate::AddressSpace::WorkGroup => {
2470 write!(self.out, "Interlocked{fun_str}(")?;
2471 self.write_expr(module, pointer, func_ctx)?;
2472 self.emit_hlsl_atomic_tail(
2473 module,
2474 func_ctx,
2475 fun,
2476 compare_expr,
2477 value,
2478 &res_var_info,
2479 )?;
2480 }
2481 crate::AddressSpace::Storage { .. } => {
2482 let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2483 let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
2484 let width = match func_ctx.resolve_type(value, &module.types) {
2485 &TypeInner::Scalar(Scalar { width: 8, .. }) => "64",
2486 _ => "",
2487 };
2488 write!(self.out, "{var_name}.Interlocked{fun_str}{width}(")?;
2489 let chain = mem::take(&mut self.temp_access_chain);
2490 self.write_storage_address(module, &chain, func_ctx)?;
2491 self.temp_access_chain = chain;
2492 self.emit_hlsl_atomic_tail(
2493 module,
2494 func_ctx,
2495 fun,
2496 compare_expr,
2497 value,
2498 &res_var_info,
2499 )?;
2500 }
2501 ref other => {
2502 return Err(Error::Custom(format!(
2503 "invalid address space {other:?} for atomic statement"
2504 )))
2505 }
2506 }
2507 if let Some(cmp) = compare_expr {
2508 if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
2509 write!(
2510 self.out,
2511 "{level}{res_name}.exchanged = ({res_name}.old_value == "
2512 )?;
2513 self.write_expr(module, cmp, func_ctx)?;
2514 writeln!(self.out, ");")?;
2515 }
2516 }
2517 }
2518 Statement::ImageAtomic {
2519 image,
2520 coordinate,
2521 array_index,
2522 fun,
2523 value,
2524 } => {
2525 write!(self.out, "{level}")?;
2526
2527 let fun_str = fun.to_hlsl_suffix();
2528 write!(self.out, "Interlocked{fun_str}(")?;
2529 self.write_expr(module, image, func_ctx)?;
2530 write!(self.out, "[")?;
2531 self.write_texture_coordinates(
2532 "int",
2533 coordinate,
2534 array_index,
2535 None,
2536 module,
2537 func_ctx,
2538 )?;
2539 write!(self.out, "],")?;
2540
2541 self.write_expr(module, value, func_ctx)?;
2542 writeln!(self.out, ");")?;
2543 }
2544 Statement::WorkGroupUniformLoad { pointer, result } => {
2545 self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
2546 write!(self.out, "{level}")?;
2547 let name = Baked(result).to_string();
2548 self.write_named_expr(module, pointer, name, result, func_ctx)?;
2549
2550 self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
2551 }
2552 Statement::Switch {
2553 selector,
2554 ref cases,
2555 } => {
2556 self.write_switch(module, func_ctx, level, selector, cases)?;
2557 }
2558 Statement::RayQuery { query, ref fun } => match *fun {
2559 RayQueryFunction::Initialize {
2560 acceleration_structure,
2561 descriptor,
2562 } => {
2563 write!(self.out, "{level}")?;
2564 self.write_expr(module, query, func_ctx)?;
2565 write!(self.out, ".TraceRayInline(")?;
2566 self.write_expr(module, acceleration_structure, func_ctx)?;
2567 write!(self.out, ", ")?;
2568 self.write_expr(module, descriptor, func_ctx)?;
2569 write!(self.out, ".flags, ")?;
2570 self.write_expr(module, descriptor, func_ctx)?;
2571 write!(self.out, ".cull_mask, ")?;
2572 write!(self.out, "RayDescFromRayDesc_(")?;
2573 self.write_expr(module, descriptor, func_ctx)?;
2574 writeln!(self.out, "));")?;
2575 }
2576 RayQueryFunction::Proceed { result } => {
2577 write!(self.out, "{level}")?;
2578 let name = Baked(result).to_string();
2579 write!(self.out, "const bool {name} = ")?;
2580 self.named_expressions.insert(result, name);
2581 self.write_expr(module, query, func_ctx)?;
2582 writeln!(self.out, ".Proceed();")?;
2583 }
2584 RayQueryFunction::GenerateIntersection { hit_t } => {
2585 write!(self.out, "{level}")?;
2586 self.write_expr(module, query, func_ctx)?;
2587 write!(self.out, ".CommitProceduralPrimitiveHit(")?;
2588 self.write_expr(module, hit_t, func_ctx)?;
2589 writeln!(self.out, ");")?;
2590 }
2591 RayQueryFunction::ConfirmIntersection => {
2592 write!(self.out, "{level}")?;
2593 self.write_expr(module, query, func_ctx)?;
2594 writeln!(self.out, ".CommitNonOpaqueTriangleHit();")?;
2595 }
2596 RayQueryFunction::Terminate => {
2597 write!(self.out, "{level}")?;
2598 self.write_expr(module, query, func_ctx)?;
2599 writeln!(self.out, ".Abort();")?;
2600 }
2601 },
2602 Statement::SubgroupBallot { result, predicate } => {
2603 write!(self.out, "{level}")?;
2604 let name = Baked(result).to_string();
2605 write!(self.out, "const uint4 {name} = ")?;
2606 self.named_expressions.insert(result, name);
2607
2608 write!(self.out, "WaveActiveBallot(")?;
2609 match predicate {
2610 Some(predicate) => self.write_expr(module, predicate, func_ctx)?,
2611 None => write!(self.out, "true")?,
2612 }
2613 writeln!(self.out, ");")?;
2614 }
2615 Statement::SubgroupCollectiveOperation {
2616 op,
2617 collective_op,
2618 argument,
2619 result,
2620 } => {
2621 write!(self.out, "{level}")?;
2622 write!(self.out, "const ")?;
2623 let name = Baked(result).to_string();
2624 match func_ctx.info[result].ty {
2625 proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2626 proc::TypeResolution::Value(ref value) => {
2627 self.write_value_type(module, value)?
2628 }
2629 };
2630 write!(self.out, " {name} = ")?;
2631 self.named_expressions.insert(result, name);
2632
2633 match (collective_op, op) {
2634 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
2635 write!(self.out, "WaveActiveAllTrue(")?
2636 }
2637 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
2638 write!(self.out, "WaveActiveAnyTrue(")?
2639 }
2640 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
2641 write!(self.out, "WaveActiveSum(")?
2642 }
2643 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
2644 write!(self.out, "WaveActiveProduct(")?
2645 }
2646 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
2647 write!(self.out, "WaveActiveMax(")?
2648 }
2649 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
2650 write!(self.out, "WaveActiveMin(")?
2651 }
2652 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
2653 write!(self.out, "WaveActiveBitAnd(")?
2654 }
2655 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
2656 write!(self.out, "WaveActiveBitOr(")?
2657 }
2658 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
2659 write!(self.out, "WaveActiveBitXor(")?
2660 }
2661 (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
2662 write!(self.out, "WavePrefixSum(")?
2663 }
2664 (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
2665 write!(self.out, "WavePrefixProduct(")?
2666 }
2667 (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
2668 self.write_expr(module, argument, func_ctx)?;
2669 write!(self.out, " + WavePrefixSum(")?;
2670 }
2671 (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
2672 self.write_expr(module, argument, func_ctx)?;
2673 write!(self.out, " * WavePrefixProduct(")?;
2674 }
2675 _ => unimplemented!(),
2676 }
2677 self.write_expr(module, argument, func_ctx)?;
2678 writeln!(self.out, ");")?;
2679 }
2680 Statement::SubgroupGather {
2681 mode,
2682 argument,
2683 result,
2684 } => {
2685 write!(self.out, "{level}")?;
2686 write!(self.out, "const ")?;
2687 let name = Baked(result).to_string();
2688 match func_ctx.info[result].ty {
2689 proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2690 proc::TypeResolution::Value(ref value) => {
2691 self.write_value_type(module, value)?
2692 }
2693 };
2694 write!(self.out, " {name} = ")?;
2695 self.named_expressions.insert(result, name);
2696 match mode {
2697 crate::GatherMode::BroadcastFirst => {
2698 write!(self.out, "WaveReadLaneFirst(")?;
2699 self.write_expr(module, argument, func_ctx)?;
2700 }
2701 crate::GatherMode::QuadBroadcast(index) => {
2702 write!(self.out, "QuadReadLaneAt(")?;
2703 self.write_expr(module, argument, func_ctx)?;
2704 write!(self.out, ", ")?;
2705 self.write_expr(module, index, func_ctx)?;
2706 }
2707 crate::GatherMode::QuadSwap(direction) => {
2708 match direction {
2709 crate::Direction::X => {
2710 write!(self.out, "QuadReadAcrossX(")?;
2711 }
2712 crate::Direction::Y => {
2713 write!(self.out, "QuadReadAcrossY(")?;
2714 }
2715 crate::Direction::Diagonal => {
2716 write!(self.out, "QuadReadAcrossDiagonal(")?;
2717 }
2718 }
2719 self.write_expr(module, argument, func_ctx)?;
2720 }
2721 _ => {
2722 write!(self.out, "WaveReadLaneAt(")?;
2723 self.write_expr(module, argument, func_ctx)?;
2724 write!(self.out, ", ")?;
2725 match mode {
2726 crate::GatherMode::BroadcastFirst => unreachable!(),
2727 crate::GatherMode::Broadcast(index)
2728 | crate::GatherMode::Shuffle(index) => {
2729 self.write_expr(module, index, func_ctx)?;
2730 }
2731 crate::GatherMode::ShuffleDown(index) => {
2732 write!(self.out, "WaveGetLaneIndex() + ")?;
2733 self.write_expr(module, index, func_ctx)?;
2734 }
2735 crate::GatherMode::ShuffleUp(index) => {
2736 write!(self.out, "WaveGetLaneIndex() - ")?;
2737 self.write_expr(module, index, func_ctx)?;
2738 }
2739 crate::GatherMode::ShuffleXor(index) => {
2740 write!(self.out, "WaveGetLaneIndex() ^ ")?;
2741 self.write_expr(module, index, func_ctx)?;
2742 }
2743 crate::GatherMode::QuadBroadcast(_) => unreachable!(),
2744 crate::GatherMode::QuadSwap(_) => unreachable!(),
2745 }
2746 }
2747 }
2748 writeln!(self.out, ");")?;
2749 }
2750 }
2751
2752 Ok(())
2753 }
2754
2755 fn write_const_expression(
2756 &mut self,
2757 module: &Module,
2758 expr: Handle<crate::Expression>,
2759 arena: &crate::Arena<crate::Expression>,
2760 ) -> BackendResult {
2761 self.write_possibly_const_expression(module, expr, arena, |writer, expr| {
2762 writer.write_const_expression(module, expr, arena)
2763 })
2764 }
2765
2766 pub(super) fn write_literal(&mut self, literal: crate::Literal) -> BackendResult {
2767 match literal {
2768 crate::Literal::F64(value) => write!(self.out, "{value:?}L")?,
2769 crate::Literal::F32(value) => write!(self.out, "{value:?}")?,
2770 crate::Literal::F16(value) => write!(self.out, "{value:?}h")?,
2771 crate::Literal::U32(value) => write!(self.out, "{value}u")?,
2772 crate::Literal::I32(value) if value == i32::MIN => {
2778 write!(self.out, "int({} - 1)", value + 1)?
2779 }
2780 crate::Literal::I32(value) => write!(self.out, "int({value})")?,
2784 crate::Literal::U64(value) => write!(self.out, "{value}uL")?,
2785 crate::Literal::I64(value) if value == i64::MIN => {
2787 write!(self.out, "({}L - 1L)", value + 1)?;
2788 }
2789 crate::Literal::I64(value) => write!(self.out, "{value}L")?,
2790 crate::Literal::Bool(value) => write!(self.out, "{value}")?,
2791 crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
2792 return Err(Error::Custom(
2793 "Abstract types should not appear in IR presented to backends".into(),
2794 ));
2795 }
2796 }
2797 Ok(())
2798 }
2799
2800 fn write_possibly_const_expression<E>(
2801 &mut self,
2802 module: &Module,
2803 expr: Handle<crate::Expression>,
2804 expressions: &crate::Arena<crate::Expression>,
2805 write_expression: E,
2806 ) -> BackendResult
2807 where
2808 E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult,
2809 {
2810 use crate::Expression;
2811
2812 match expressions[expr] {
2813 Expression::Literal(literal) => {
2814 self.write_literal(literal)?;
2815 }
2816 Expression::Constant(handle) => {
2817 let constant = &module.constants[handle];
2818 if constant.name.is_some() {
2819 write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
2820 } else {
2821 self.write_const_expression(module, constant.init, &module.global_expressions)?;
2822 }
2823 }
2824 Expression::ZeroValue(ty) => {
2825 self.write_wrapped_zero_value_function_name(module, WrappedZeroValue { ty })?;
2826 write!(self.out, "()")?;
2827 }
2828 Expression::Compose { ty, ref components } => {
2829 match module.types[ty].inner {
2830 TypeInner::Struct { .. } | TypeInner::Array { .. } => {
2831 self.write_wrapped_constructor_function_name(
2832 module,
2833 WrappedConstructor { ty },
2834 )?;
2835 }
2836 _ => {
2837 self.write_type(module, ty)?;
2838 }
2839 };
2840 write!(self.out, "(")?;
2841 for (index, component) in components.iter().enumerate() {
2842 if index != 0 {
2843 write!(self.out, ", ")?;
2844 }
2845 write_expression(self, *component)?;
2846 }
2847 write!(self.out, ")")?;
2848 }
2849 Expression::Splat { size, value } => {
2850 let number_of_components = match size {
2854 crate::VectorSize::Bi => "xx",
2855 crate::VectorSize::Tri => "xxx",
2856 crate::VectorSize::Quad => "xxxx",
2857 };
2858 write!(self.out, "(")?;
2859 write_expression(self, value)?;
2860 write!(self.out, ").{number_of_components}")?
2861 }
2862 _ => {
2863 return Err(Error::Override);
2864 }
2865 }
2866
2867 Ok(())
2868 }
2869
2870 pub(super) fn write_expr(
2875 &mut self,
2876 module: &Module,
2877 expr: Handle<crate::Expression>,
2878 func_ctx: &back::FunctionCtx<'_>,
2879 ) -> BackendResult {
2880 use crate::Expression;
2881
2882 let ff_input = if self.options.special_constants_binding.is_some() {
2884 func_ctx.is_fixed_function_input(expr, module)
2885 } else {
2886 None
2887 };
2888 let closing_bracket = match ff_input {
2889 Some(crate::BuiltIn::VertexIndex) => {
2890 write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX} + ")?;
2891 ")"
2892 }
2893 Some(crate::BuiltIn::InstanceIndex) => {
2894 write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE} + ",)?;
2895 ")"
2896 }
2897 Some(crate::BuiltIn::NumWorkGroups) => {
2898 write!(
2902 self.out,
2903 "uint3({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX}, {SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE}, {SPECIAL_CBUF_VAR}.{SPECIAL_OTHER})",
2904 )?;
2905 return Ok(());
2906 }
2907 _ => "",
2908 };
2909
2910 if let Some(name) = self.named_expressions.get(&expr) {
2911 write!(self.out, "{name}{closing_bracket}")?;
2912 return Ok(());
2913 }
2914
2915 let expression = &func_ctx.expressions[expr];
2916
2917 match *expression {
2918 Expression::Literal(_)
2919 | Expression::Constant(_)
2920 | Expression::ZeroValue(_)
2921 | Expression::Compose { .. }
2922 | Expression::Splat { .. } => {
2923 self.write_possibly_const_expression(
2924 module,
2925 expr,
2926 func_ctx.expressions,
2927 |writer, expr| writer.write_expr(module, expr, func_ctx),
2928 )?;
2929 }
2930 Expression::Override(_) => return Err(Error::Override),
2931 Expression::Binary {
2938 op:
2939 op @ crate::BinaryOperator::Add
2940 | op @ crate::BinaryOperator::Subtract
2941 | op @ crate::BinaryOperator::Multiply,
2942 left,
2943 right,
2944 } if matches!(
2945 func_ctx.resolve_type(expr, &module.types).scalar(),
2946 Some(Scalar::I32)
2947 ) =>
2948 {
2949 write!(self.out, "asint(asuint(",)?;
2950 self.write_expr(module, left, func_ctx)?;
2951 write!(self.out, ") {} asuint(", back::binary_operation_str(op))?;
2952 self.write_expr(module, right, func_ctx)?;
2953 write!(self.out, "))")?;
2954 }
2955 Expression::Binary {
2958 op: crate::BinaryOperator::Multiply,
2959 left,
2960 right,
2961 } if func_ctx.resolve_type(left, &module.types).is_matrix()
2962 || func_ctx.resolve_type(right, &module.types).is_matrix() =>
2963 {
2964 write!(self.out, "mul(")?;
2966 self.write_expr(module, right, func_ctx)?;
2967 write!(self.out, ", ")?;
2968 self.write_expr(module, left, func_ctx)?;
2969 write!(self.out, ")")?;
2970 }
2971
2972 Expression::Binary {
2984 op: crate::BinaryOperator::Divide,
2985 left,
2986 right,
2987 } if matches!(
2988 func_ctx.resolve_type(expr, &module.types).scalar_kind(),
2989 Some(ScalarKind::Sint | ScalarKind::Uint)
2990 ) =>
2991 {
2992 write!(self.out, "{DIV_FUNCTION}(")?;
2993 self.write_expr(module, left, func_ctx)?;
2994 write!(self.out, ", ")?;
2995 self.write_expr(module, right, func_ctx)?;
2996 write!(self.out, ")")?;
2997 }
2998
2999 Expression::Binary {
3000 op: crate::BinaryOperator::Modulo,
3001 left,
3002 right,
3003 } if matches!(
3004 func_ctx.resolve_type(expr, &module.types).scalar_kind(),
3005 Some(ScalarKind::Sint | ScalarKind::Uint | ScalarKind::Float)
3006 ) =>
3007 {
3008 write!(self.out, "{MOD_FUNCTION}(")?;
3009 self.write_expr(module, left, func_ctx)?;
3010 write!(self.out, ", ")?;
3011 self.write_expr(module, right, func_ctx)?;
3012 write!(self.out, ")")?;
3013 }
3014
3015 Expression::Binary { op, left, right } => {
3016 write!(self.out, "(")?;
3017 self.write_expr(module, left, func_ctx)?;
3018 write!(self.out, " {} ", back::binary_operation_str(op))?;
3019 self.write_expr(module, right, func_ctx)?;
3020 write!(self.out, ")")?;
3021 }
3022 Expression::Access { base, index } => {
3023 if let Some(crate::AddressSpace::Storage { .. }) =
3024 func_ctx.resolve_type(expr, &module.types).pointer_space()
3025 {
3026 } else {
3028 if let Some(MatrixType {
3035 columns,
3036 rows: crate::VectorSize::Bi,
3037 width: 4,
3038 }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
3039 .or_else(|| get_global_uniform_matrix(module, base, func_ctx))
3040 {
3041 write!(self.out, "__get_col_of_mat{}x2(", columns as u8)?;
3042 self.write_expr(module, base, func_ctx)?;
3043 write!(self.out, ", ")?;
3044 self.write_expr(module, index, func_ctx)?;
3045 write!(self.out, ")")?;
3046 return Ok(());
3047 }
3048
3049 let resolved = func_ctx.resolve_type(base, &module.types);
3050
3051 let (indexing_binding_array, non_uniform_qualifier) = match *resolved {
3052 TypeInner::BindingArray { .. } => {
3053 let uniformity = &func_ctx.info[index].uniformity;
3054
3055 (true, uniformity.non_uniform_result.is_some())
3056 }
3057 _ => (false, false),
3058 };
3059
3060 self.write_expr(module, base, func_ctx)?;
3061
3062 let array_sampler_info = self.sampler_binding_array_info_from_expression(
3063 module, func_ctx, base, resolved,
3064 );
3065
3066 if let Some(ref info) = array_sampler_info {
3067 write!(self.out, "{}[", info.sampler_heap_name)?;
3068 } else {
3069 write!(self.out, "[")?;
3070 }
3071
3072 let needs_bound_check = self.options.restrict_indexing
3073 && !indexing_binding_array
3074 && match resolved.pointer_space() {
3075 Some(
3076 crate::AddressSpace::Function
3077 | crate::AddressSpace::Private
3078 | crate::AddressSpace::WorkGroup
3079 | crate::AddressSpace::PushConstant,
3080 )
3081 | None => true,
3082 Some(crate::AddressSpace::Uniform) => {
3083 let var_handle = self.fill_access_chain(module, base, func_ctx)?;
3085 let bind_target = self
3086 .options
3087 .resolve_resource_binding(
3088 module.global_variables[var_handle]
3089 .binding
3090 .as_ref()
3091 .unwrap(),
3092 )
3093 .unwrap();
3094 bind_target.restrict_indexing
3095 }
3096 Some(
3097 crate::AddressSpace::Handle | crate::AddressSpace::Storage { .. },
3098 ) => unreachable!(),
3099 };
3100 let restriction_needed = if needs_bound_check {
3102 index::access_needs_check(
3103 base,
3104 index::GuardedIndex::Expression(index),
3105 module,
3106 func_ctx.expressions,
3107 func_ctx.info,
3108 )
3109 } else {
3110 None
3111 };
3112 if let Some(limit) = restriction_needed {
3113 write!(self.out, "min(uint(")?;
3114 self.write_expr(module, index, func_ctx)?;
3115 write!(self.out, "), ")?;
3116 match limit {
3117 index::IndexableLength::Known(limit) => {
3118 write!(self.out, "{}u", limit - 1)?;
3119 }
3120 index::IndexableLength::Dynamic => unreachable!(),
3121 }
3122 write!(self.out, ")")?;
3123 } else {
3124 if non_uniform_qualifier {
3125 write!(self.out, "NonUniformResourceIndex(")?;
3126 }
3127 if let Some(ref info) = array_sampler_info {
3128 write!(
3129 self.out,
3130 "{}[{} + ",
3131 info.sampler_index_buffer_name, info.binding_array_base_index_name,
3132 )?;
3133 }
3134 self.write_expr(module, index, func_ctx)?;
3135 if array_sampler_info.is_some() {
3136 write!(self.out, "]")?;
3137 }
3138 if non_uniform_qualifier {
3139 write!(self.out, ")")?;
3140 }
3141 }
3142
3143 write!(self.out, "]")?;
3144 }
3145 }
3146 Expression::AccessIndex { base, index } => {
3147 if let Some(crate::AddressSpace::Storage { .. }) =
3148 func_ctx.resolve_type(expr, &module.types).pointer_space()
3149 {
3150 } else {
3152 if let Some(MatrixType {
3156 rows: crate::VectorSize::Bi,
3157 width: 4,
3158 ..
3159 }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
3160 .or_else(|| get_global_uniform_matrix(module, base, func_ctx))
3161 {
3162 self.write_expr(module, base, func_ctx)?;
3163 write!(self.out, "._{index}")?;
3164 return Ok(());
3165 }
3166
3167 let base_ty_res = &func_ctx.info[base].ty;
3168 let mut resolved = base_ty_res.inner_with(&module.types);
3169 let base_ty_handle = match *resolved {
3170 TypeInner::Pointer { base, .. } => {
3171 resolved = &module.types[base].inner;
3172 Some(base)
3173 }
3174 _ => base_ty_res.handle(),
3175 };
3176
3177 if let TypeInner::Struct { ref members, .. } = *resolved {
3183 let member = &members[index as usize];
3184
3185 match module.types[member.ty].inner {
3186 TypeInner::Matrix {
3187 rows: crate::VectorSize::Bi,
3188 ..
3189 } if member.binding.is_none() => {
3190 let ty = base_ty_handle.unwrap();
3191 self.write_wrapped_struct_matrix_get_function_name(
3192 WrappedStructMatrixAccess { ty, index },
3193 )?;
3194 write!(self.out, "(")?;
3195 self.write_expr(module, base, func_ctx)?;
3196 write!(self.out, ")")?;
3197 return Ok(());
3198 }
3199 _ => {}
3200 }
3201 }
3202
3203 let array_sampler_info = self.sampler_binding_array_info_from_expression(
3204 module, func_ctx, base, resolved,
3205 );
3206
3207 if let Some(ref info) = array_sampler_info {
3208 write!(
3209 self.out,
3210 "{}[{}",
3211 info.sampler_heap_name, info.sampler_index_buffer_name
3212 )?;
3213 }
3214
3215 self.write_expr(module, base, func_ctx)?;
3216
3217 match *resolved {
3218 TypeInner::Vector { .. } | TypeInner::ValuePointer { .. } => {
3224 write!(self.out, ".{}", back::COMPONENTS[index as usize])?
3226 }
3227 TypeInner::Matrix { .. }
3228 | TypeInner::Array { .. }
3229 | TypeInner::BindingArray { .. } => {
3230 if let Some(ref info) = array_sampler_info {
3231 write!(
3232 self.out,
3233 "[{} + {index}]",
3234 info.binding_array_base_index_name
3235 )?;
3236 } else {
3237 write!(self.out, "[{index}]")?;
3238 }
3239 }
3240 TypeInner::Struct { .. } => {
3241 let ty = base_ty_handle.unwrap();
3244
3245 write!(
3246 self.out,
3247 ".{}",
3248 &self.names[&NameKey::StructMember(ty, index)]
3249 )?
3250 }
3251 ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))),
3252 }
3253
3254 if array_sampler_info.is_some() {
3255 write!(self.out, "]")?;
3256 }
3257 }
3258 }
3259 Expression::FunctionArgument(pos) => {
3260 let ty = func_ctx.resolve_type(expr, &module.types);
3261
3262 if let TypeInner::Image {
3268 class: crate::ImageClass::External,
3269 ..
3270 } = *ty
3271 {
3272 let plane_names = [0, 1, 2].map(|i| {
3273 &self.names[&func_ctx
3274 .external_texture_argument_key(pos, ExternalTextureNameKey::Plane(i))]
3275 });
3276 let params_name = &self.names[&func_ctx
3277 .external_texture_argument_key(pos, ExternalTextureNameKey::Params)];
3278 write!(
3279 self.out,
3280 "{}, {}, {}, {}",
3281 plane_names[0], plane_names[1], plane_names[2], params_name
3282 )?;
3283 } else {
3284 let key = func_ctx.argument_key(pos);
3285 let name = &self.names[&key];
3286 write!(self.out, "{name}")?;
3287 }
3288 }
3289 Expression::ImageSample {
3290 coordinate,
3291 image,
3292 sampler,
3293 clamp_to_edge: true,
3294 gather: None,
3295 array_index: None,
3296 offset: None,
3297 level: crate::SampleLevel::Zero,
3298 depth_ref: None,
3299 } => {
3300 write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?;
3301 self.write_expr(module, image, func_ctx)?;
3302 write!(self.out, ", ")?;
3303 self.write_expr(module, sampler, func_ctx)?;
3304 write!(self.out, ", ")?;
3305 self.write_expr(module, coordinate, func_ctx)?;
3306 write!(self.out, ")")?;
3307 }
3308 Expression::ImageSample {
3309 image,
3310 sampler,
3311 gather,
3312 coordinate,
3313 array_index,
3314 offset,
3315 level,
3316 depth_ref,
3317 clamp_to_edge,
3318 } => {
3319 if clamp_to_edge {
3320 return Err(Error::Custom(
3321 "ImageSample::clamp_to_edge should have been validated out".to_string(),
3322 ));
3323 }
3324
3325 use crate::SampleLevel as Sl;
3326 const COMPONENTS: [&str; 4] = ["", "Green", "Blue", "Alpha"];
3327
3328 let (base_str, component_str) = match gather {
3329 Some(component) => ("Gather", COMPONENTS[component as usize]),
3330 None => ("Sample", ""),
3331 };
3332 let cmp_str = match depth_ref {
3333 Some(_) => "Cmp",
3334 None => "",
3335 };
3336 let level_str = match level {
3337 Sl::Zero if gather.is_none() => "LevelZero",
3338 Sl::Auto | Sl::Zero => "",
3339 Sl::Exact(_) => "Level",
3340 Sl::Bias(_) => "Bias",
3341 Sl::Gradient { .. } => "Grad",
3342 };
3343
3344 self.write_expr(module, image, func_ctx)?;
3345 write!(self.out, ".{base_str}{cmp_str}{component_str}{level_str}(")?;
3346 self.write_expr(module, sampler, func_ctx)?;
3347 write!(self.out, ", ")?;
3348 self.write_texture_coordinates(
3349 "float",
3350 coordinate,
3351 array_index,
3352 None,
3353 module,
3354 func_ctx,
3355 )?;
3356
3357 if let Some(depth_ref) = depth_ref {
3358 write!(self.out, ", ")?;
3359 self.write_expr(module, depth_ref, func_ctx)?;
3360 }
3361
3362 match level {
3363 Sl::Auto | Sl::Zero => {}
3364 Sl::Exact(expr) => {
3365 write!(self.out, ", ")?;
3366 self.write_expr(module, expr, func_ctx)?;
3367 }
3368 Sl::Bias(expr) => {
3369 write!(self.out, ", ")?;
3370 self.write_expr(module, expr, func_ctx)?;
3371 }
3372 Sl::Gradient { x, y } => {
3373 write!(self.out, ", ")?;
3374 self.write_expr(module, x, func_ctx)?;
3375 write!(self.out, ", ")?;
3376 self.write_expr(module, y, func_ctx)?;
3377 }
3378 }
3379
3380 if let Some(offset) = offset {
3381 write!(self.out, ", ")?;
3382 write!(self.out, "int2(")?; self.write_const_expression(module, offset, func_ctx.expressions)?;
3384 write!(self.out, ")")?;
3385 }
3386
3387 write!(self.out, ")")?;
3388 }
3389 Expression::ImageQuery { image, query } => {
3390 if let TypeInner::Image {
3392 dim,
3393 arrayed,
3394 class,
3395 } = *func_ctx.resolve_type(image, &module.types)
3396 {
3397 let wrapped_image_query = WrappedImageQuery {
3398 dim,
3399 arrayed,
3400 class,
3401 query: query.into(),
3402 };
3403
3404 self.write_wrapped_image_query_function_name(wrapped_image_query)?;
3405 write!(self.out, "(")?;
3406 self.write_expr(module, image, func_ctx)?;
3408 if let crate::ImageQuery::Size { level: Some(level) } = query {
3409 write!(self.out, ", ")?;
3410 self.write_expr(module, level, func_ctx)?;
3411 }
3412 write!(self.out, ")")?;
3413 }
3414 }
3415 Expression::ImageLoad {
3416 image,
3417 coordinate,
3418 array_index,
3419 sample,
3420 level,
3421 } => self.write_image_load(
3422 &module,
3423 expr,
3424 func_ctx,
3425 image,
3426 coordinate,
3427 array_index,
3428 sample,
3429 level,
3430 )?,
3431 Expression::GlobalVariable(handle) => {
3432 let global_variable = &module.global_variables[handle];
3433 let ty = &module.types[global_variable.ty].inner;
3434
3435 let is_binding_array_of_samplers = match *ty {
3440 TypeInner::BindingArray { base, .. } => {
3441 let base_ty = &module.types[base].inner;
3442 matches!(*base_ty, TypeInner::Sampler { .. })
3443 }
3444 _ => false,
3445 };
3446
3447 let is_storage_space =
3448 matches!(global_variable.space, crate::AddressSpace::Storage { .. });
3449
3450 if let TypeInner::Image {
3458 class: crate::ImageClass::External,
3459 ..
3460 } = *ty
3461 {
3462 let plane_names = [0, 1, 2].map(|i| {
3463 &self.names[&NameKey::ExternalTextureGlobalVariable(
3464 handle,
3465 ExternalTextureNameKey::Plane(i),
3466 )]
3467 });
3468 let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
3469 handle,
3470 ExternalTextureNameKey::Params,
3471 )];
3472 write!(
3473 self.out,
3474 "{}, {}, {}, {}",
3475 plane_names[0], plane_names[1], plane_names[2], params_name
3476 )?;
3477 } else if !is_binding_array_of_samplers && !is_storage_space {
3478 let name = &self.names[&NameKey::GlobalVariable(handle)];
3479 write!(self.out, "{name}")?;
3480 }
3481 }
3482 Expression::LocalVariable(handle) => {
3483 write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])?
3484 }
3485 Expression::Load { pointer } => {
3486 match func_ctx
3487 .resolve_type(pointer, &module.types)
3488 .pointer_space()
3489 {
3490 Some(crate::AddressSpace::Storage { .. }) => {
3491 let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
3492 let result_ty = func_ctx.info[expr].ty.clone();
3493 self.write_storage_load(module, var_handle, result_ty, func_ctx)?;
3494 }
3495 _ => {
3496 let mut close_paren = false;
3497
3498 if let Some(MatrixType {
3503 rows: crate::VectorSize::Bi,
3504 width: 4,
3505 ..
3506 }) = get_inner_matrix_of_struct_array_member(
3507 module, pointer, func_ctx, false,
3508 )
3509 .or_else(|| get_inner_matrix_of_global_uniform(module, pointer, func_ctx))
3510 {
3511 let mut resolved = func_ctx.resolve_type(pointer, &module.types);
3512 let ptr_tr = resolved.pointer_base_type();
3513 if let Some(ptr_ty) =
3514 ptr_tr.as_ref().map(|tr| tr.inner_with(&module.types))
3515 {
3516 resolved = ptr_ty;
3517 }
3518
3519 write!(self.out, "((")?;
3520 if let TypeInner::Array { base, size, .. } = *resolved {
3521 self.write_type(module, base)?;
3522 self.write_array_size(module, base, size)?;
3523 } else {
3524 self.write_value_type(module, resolved)?;
3525 }
3526 write!(self.out, ")")?;
3527 close_paren = true;
3528 }
3529
3530 self.write_expr(module, pointer, func_ctx)?;
3531
3532 if close_paren {
3533 write!(self.out, ")")?;
3534 }
3535 }
3536 }
3537 }
3538 Expression::Unary { op, expr } => {
3539 let op_str = match op {
3541 crate::UnaryOperator::Negate => {
3542 match func_ctx.resolve_type(expr, &module.types).scalar() {
3543 Some(Scalar::I32) => NEG_FUNCTION,
3544 _ => "-",
3545 }
3546 }
3547 crate::UnaryOperator::LogicalNot => "!",
3548 crate::UnaryOperator::BitwiseNot => "~",
3549 };
3550 write!(self.out, "{op_str}(")?;
3551 self.write_expr(module, expr, func_ctx)?;
3552 write!(self.out, ")")?;
3553 }
3554 Expression::As {
3555 expr,
3556 kind,
3557 convert,
3558 } => {
3559 let inner = func_ctx.resolve_type(expr, &module.types);
3560 if inner.scalar_kind() == Some(ScalarKind::Float)
3561 && (kind == ScalarKind::Sint || kind == ScalarKind::Uint)
3562 && convert.is_some()
3563 {
3564 let fun_name = match (kind, convert) {
3568 (ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
3569 (ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
3570 (ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
3571 (ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
3572 _ => unreachable!(),
3573 };
3574 write!(self.out, "{fun_name}(")?;
3575 self.write_expr(module, expr, func_ctx)?;
3576 write!(self.out, ")")?;
3577 } else {
3578 let close_paren = match convert {
3579 Some(dst_width) => {
3580 let scalar = Scalar {
3581 kind,
3582 width: dst_width,
3583 };
3584 match *inner {
3585 TypeInner::Vector { size, .. } => {
3586 write!(
3587 self.out,
3588 "{}{}(",
3589 scalar.to_hlsl_str()?,
3590 common::vector_size_str(size)
3591 )?;
3592 }
3593 TypeInner::Scalar(_) => {
3594 write!(self.out, "{}(", scalar.to_hlsl_str()?,)?;
3595 }
3596 TypeInner::Matrix { columns, rows, .. } => {
3597 write!(
3598 self.out,
3599 "{}{}x{}(",
3600 scalar.to_hlsl_str()?,
3601 common::vector_size_str(columns),
3602 common::vector_size_str(rows)
3603 )?;
3604 }
3605 _ => {
3606 return Err(Error::Unimplemented(format!(
3607 "write_expr expression::as {inner:?}"
3608 )));
3609 }
3610 };
3611 true
3612 }
3613 None => {
3614 if inner.scalar_width() == Some(8) {
3615 false
3616 } else {
3617 write!(self.out, "{}(", kind.to_hlsl_cast(),)?;
3618 true
3619 }
3620 }
3621 };
3622 self.write_expr(module, expr, func_ctx)?;
3623 if close_paren {
3624 write!(self.out, ")")?;
3625 }
3626 }
3627 }
3628 Expression::Math {
3629 fun,
3630 arg,
3631 arg1,
3632 arg2,
3633 arg3,
3634 } => {
3635 use crate::MathFunction as Mf;
3636
3637 enum Function {
3638 Asincosh { is_sin: bool },
3639 Atanh,
3640 Pack2x16float,
3641 Pack2x16snorm,
3642 Pack2x16unorm,
3643 Pack4x8snorm,
3644 Pack4x8unorm,
3645 Pack4xI8,
3646 Pack4xU8,
3647 Pack4xI8Clamp,
3648 Pack4xU8Clamp,
3649 Unpack2x16float,
3650 Unpack2x16snorm,
3651 Unpack2x16unorm,
3652 Unpack4x8snorm,
3653 Unpack4x8unorm,
3654 Unpack4xI8,
3655 Unpack4xU8,
3656 Dot4I8Packed,
3657 Dot4U8Packed,
3658 QuantizeToF16,
3659 Regular(&'static str),
3660 MissingIntOverload(&'static str),
3661 MissingIntReturnType(&'static str),
3662 CountTrailingZeros,
3663 CountLeadingZeros,
3664 }
3665
3666 let fun = match fun {
3667 Mf::Abs => match func_ctx.resolve_type(arg, &module.types).scalar() {
3669 Some(Scalar::I32) => Function::Regular(ABS_FUNCTION),
3670 _ => Function::Regular("abs"),
3671 },
3672 Mf::Min => Function::Regular("min"),
3673 Mf::Max => Function::Regular("max"),
3674 Mf::Clamp => Function::Regular("clamp"),
3675 Mf::Saturate => Function::Regular("saturate"),
3676 Mf::Cos => Function::Regular("cos"),
3678 Mf::Cosh => Function::Regular("cosh"),
3679 Mf::Sin => Function::Regular("sin"),
3680 Mf::Sinh => Function::Regular("sinh"),
3681 Mf::Tan => Function::Regular("tan"),
3682 Mf::Tanh => Function::Regular("tanh"),
3683 Mf::Acos => Function::Regular("acos"),
3684 Mf::Asin => Function::Regular("asin"),
3685 Mf::Atan => Function::Regular("atan"),
3686 Mf::Atan2 => Function::Regular("atan2"),
3687 Mf::Asinh => Function::Asincosh { is_sin: true },
3688 Mf::Acosh => Function::Asincosh { is_sin: false },
3689 Mf::Atanh => Function::Atanh,
3690 Mf::Radians => Function::Regular("radians"),
3691 Mf::Degrees => Function::Regular("degrees"),
3692 Mf::Ceil => Function::Regular("ceil"),
3694 Mf::Floor => Function::Regular("floor"),
3695 Mf::Round => Function::Regular("round"),
3696 Mf::Fract => Function::Regular("frac"),
3697 Mf::Trunc => Function::Regular("trunc"),
3698 Mf::Modf => Function::Regular(MODF_FUNCTION),
3699 Mf::Frexp => Function::Regular(FREXP_FUNCTION),
3700 Mf::Ldexp => Function::Regular("ldexp"),
3701 Mf::Exp => Function::Regular("exp"),
3703 Mf::Exp2 => Function::Regular("exp2"),
3704 Mf::Log => Function::Regular("log"),
3705 Mf::Log2 => Function::Regular("log2"),
3706 Mf::Pow => Function::Regular("pow"),
3707 Mf::Dot => Function::Regular("dot"),
3709 Mf::Dot4I8Packed => Function::Dot4I8Packed,
3710 Mf::Dot4U8Packed => Function::Dot4U8Packed,
3711 Mf::Cross => Function::Regular("cross"),
3713 Mf::Distance => Function::Regular("distance"),
3714 Mf::Length => Function::Regular("length"),
3715 Mf::Normalize => Function::Regular("normalize"),
3716 Mf::FaceForward => Function::Regular("faceforward"),
3717 Mf::Reflect => Function::Regular("reflect"),
3718 Mf::Refract => Function::Regular("refract"),
3719 Mf::Sign => Function::Regular("sign"),
3721 Mf::Fma => Function::Regular("mad"),
3722 Mf::Mix => Function::Regular("lerp"),
3723 Mf::Step => Function::Regular("step"),
3724 Mf::SmoothStep => Function::Regular("smoothstep"),
3725 Mf::Sqrt => Function::Regular("sqrt"),
3726 Mf::InverseSqrt => Function::Regular("rsqrt"),
3727 Mf::Transpose => Function::Regular("transpose"),
3729 Mf::Determinant => Function::Regular("determinant"),
3730 Mf::QuantizeToF16 => Function::QuantizeToF16,
3731 Mf::CountTrailingZeros => Function::CountTrailingZeros,
3733 Mf::CountLeadingZeros => Function::CountLeadingZeros,
3734 Mf::CountOneBits => Function::MissingIntOverload("countbits"),
3735 Mf::ReverseBits => Function::MissingIntOverload("reversebits"),
3736 Mf::FirstTrailingBit => Function::MissingIntReturnType("firstbitlow"),
3737 Mf::FirstLeadingBit => Function::MissingIntReturnType("firstbithigh"),
3738 Mf::ExtractBits => Function::Regular(EXTRACT_BITS_FUNCTION),
3739 Mf::InsertBits => Function::Regular(INSERT_BITS_FUNCTION),
3740 Mf::Pack2x16float => Function::Pack2x16float,
3742 Mf::Pack2x16snorm => Function::Pack2x16snorm,
3743 Mf::Pack2x16unorm => Function::Pack2x16unorm,
3744 Mf::Pack4x8snorm => Function::Pack4x8snorm,
3745 Mf::Pack4x8unorm => Function::Pack4x8unorm,
3746 Mf::Pack4xI8 => Function::Pack4xI8,
3747 Mf::Pack4xU8 => Function::Pack4xU8,
3748 Mf::Pack4xI8Clamp => Function::Pack4xI8Clamp,
3749 Mf::Pack4xU8Clamp => Function::Pack4xU8Clamp,
3750 Mf::Unpack2x16float => Function::Unpack2x16float,
3752 Mf::Unpack2x16snorm => Function::Unpack2x16snorm,
3753 Mf::Unpack2x16unorm => Function::Unpack2x16unorm,
3754 Mf::Unpack4x8snorm => Function::Unpack4x8snorm,
3755 Mf::Unpack4x8unorm => Function::Unpack4x8unorm,
3756 Mf::Unpack4xI8 => Function::Unpack4xI8,
3757 Mf::Unpack4xU8 => Function::Unpack4xU8,
3758 _ => return Err(Error::Unimplemented(format!("write_expr_math {fun:?}"))),
3759 };
3760
3761 match fun {
3762 Function::Asincosh { is_sin } => {
3763 write!(self.out, "log(")?;
3764 self.write_expr(module, arg, func_ctx)?;
3765 write!(self.out, " + sqrt(")?;
3766 self.write_expr(module, arg, func_ctx)?;
3767 write!(self.out, " * ")?;
3768 self.write_expr(module, arg, func_ctx)?;
3769 match is_sin {
3770 true => write!(self.out, " + 1.0))")?,
3771 false => write!(self.out, " - 1.0))")?,
3772 }
3773 }
3774 Function::Atanh => {
3775 write!(self.out, "0.5 * log((1.0 + ")?;
3776 self.write_expr(module, arg, func_ctx)?;
3777 write!(self.out, ") / (1.0 - ")?;
3778 self.write_expr(module, arg, func_ctx)?;
3779 write!(self.out, "))")?;
3780 }
3781 Function::Pack2x16float => {
3782 write!(self.out, "(f32tof16(")?;
3783 self.write_expr(module, arg, func_ctx)?;
3784 write!(self.out, "[0]) | f32tof16(")?;
3785 self.write_expr(module, arg, func_ctx)?;
3786 write!(self.out, "[1]) << 16)")?;
3787 }
3788 Function::Pack2x16snorm => {
3789 let scale = 32767;
3790
3791 write!(self.out, "uint((int(round(clamp(")?;
3792 self.write_expr(module, arg, func_ctx)?;
3793 write!(
3794 self.out,
3795 "[0], -1.0, 1.0) * {scale}.0)) & 0xFFFF) | ((int(round(clamp("
3796 )?;
3797 self.write_expr(module, arg, func_ctx)?;
3798 write!(self.out, "[1], -1.0, 1.0) * {scale}.0)) & 0xFFFF) << 16))",)?;
3799 }
3800 Function::Pack2x16unorm => {
3801 let scale = 65535;
3802
3803 write!(self.out, "(uint(round(clamp(")?;
3804 self.write_expr(module, arg, func_ctx)?;
3805 write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
3806 self.write_expr(module, arg, func_ctx)?;
3807 write!(self.out, "[1], 0.0, 1.0) * {scale}.0)) << 16)")?;
3808 }
3809 Function::Pack4x8snorm => {
3810 let scale = 127;
3811
3812 write!(self.out, "uint((int(round(clamp(")?;
3813 self.write_expr(module, arg, func_ctx)?;
3814 write!(
3815 self.out,
3816 "[0], -1.0, 1.0) * {scale}.0)) & 0xFF) | ((int(round(clamp("
3817 )?;
3818 self.write_expr(module, arg, func_ctx)?;
3819 write!(
3820 self.out,
3821 "[1], -1.0, 1.0) * {scale}.0)) & 0xFF) << 8) | ((int(round(clamp("
3822 )?;
3823 self.write_expr(module, arg, func_ctx)?;
3824 write!(
3825 self.out,
3826 "[2], -1.0, 1.0) * {scale}.0)) & 0xFF) << 16) | ((int(round(clamp("
3827 )?;
3828 self.write_expr(module, arg, func_ctx)?;
3829 write!(self.out, "[3], -1.0, 1.0) * {scale}.0)) & 0xFF) << 24))",)?;
3830 }
3831 Function::Pack4x8unorm => {
3832 let scale = 255;
3833
3834 write!(self.out, "(uint(round(clamp(")?;
3835 self.write_expr(module, arg, func_ctx)?;
3836 write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
3837 self.write_expr(module, arg, func_ctx)?;
3838 write!(
3839 self.out,
3840 "[1], 0.0, 1.0) * {scale}.0)) << 8 | uint(round(clamp("
3841 )?;
3842 self.write_expr(module, arg, func_ctx)?;
3843 write!(
3844 self.out,
3845 "[2], 0.0, 1.0) * {scale}.0)) << 16 | uint(round(clamp("
3846 )?;
3847 self.write_expr(module, arg, func_ctx)?;
3848 write!(self.out, "[3], 0.0, 1.0) * {scale}.0)) << 24)")?;
3849 }
3850 fun @ (Function::Pack4xI8
3851 | Function::Pack4xU8
3852 | Function::Pack4xI8Clamp
3853 | Function::Pack4xU8Clamp) => {
3854 let was_signed =
3855 matches!(fun, Function::Pack4xI8 | Function::Pack4xI8Clamp);
3856 let clamp_bounds = match fun {
3857 Function::Pack4xI8Clamp => Some(("-128", "127")),
3858 Function::Pack4xU8Clamp => Some(("0", "255")),
3859 _ => None,
3860 };
3861 if was_signed {
3862 write!(self.out, "uint(")?;
3863 }
3864 let write_arg = |this: &mut Self| -> BackendResult {
3865 if let Some((min, max)) = clamp_bounds {
3866 write!(this.out, "clamp(")?;
3867 this.write_expr(module, arg, func_ctx)?;
3868 write!(this.out, ", {min}, {max})")?;
3869 } else {
3870 this.write_expr(module, arg, func_ctx)?;
3871 }
3872 Ok(())
3873 };
3874 write!(self.out, "(")?;
3875 write_arg(self)?;
3876 write!(self.out, "[0] & 0xFF) | ((")?;
3877 write_arg(self)?;
3878 write!(self.out, "[1] & 0xFF) << 8) | ((")?;
3879 write_arg(self)?;
3880 write!(self.out, "[2] & 0xFF) << 16) | ((")?;
3881 write_arg(self)?;
3882 write!(self.out, "[3] & 0xFF) << 24)")?;
3883 if was_signed {
3884 write!(self.out, ")")?;
3885 }
3886 }
3887
3888 Function::Unpack2x16float => {
3889 write!(self.out, "float2(f16tof32(")?;
3890 self.write_expr(module, arg, func_ctx)?;
3891 write!(self.out, "), f16tof32((")?;
3892 self.write_expr(module, arg, func_ctx)?;
3893 write!(self.out, ") >> 16))")?;
3894 }
3895 Function::Unpack2x16snorm => {
3896 let scale = 32767;
3897
3898 write!(self.out, "(float2(int2(")?;
3899 self.write_expr(module, arg, func_ctx)?;
3900 write!(self.out, " << 16, ")?;
3901 self.write_expr(module, arg, func_ctx)?;
3902 write!(self.out, ") >> 16) / {scale}.0)")?;
3903 }
3904 Function::Unpack2x16unorm => {
3905 let scale = 65535;
3906
3907 write!(self.out, "(float2(")?;
3908 self.write_expr(module, arg, func_ctx)?;
3909 write!(self.out, " & 0xFFFF, ")?;
3910 self.write_expr(module, arg, func_ctx)?;
3911 write!(self.out, " >> 16) / {scale}.0)")?;
3912 }
3913 Function::Unpack4x8snorm => {
3914 let scale = 127;
3915
3916 write!(self.out, "(float4(int4(")?;
3917 self.write_expr(module, arg, func_ctx)?;
3918 write!(self.out, " << 24, ")?;
3919 self.write_expr(module, arg, func_ctx)?;
3920 write!(self.out, " << 16, ")?;
3921 self.write_expr(module, arg, func_ctx)?;
3922 write!(self.out, " << 8, ")?;
3923 self.write_expr(module, arg, func_ctx)?;
3924 write!(self.out, ") >> 24) / {scale}.0)")?;
3925 }
3926 Function::Unpack4x8unorm => {
3927 let scale = 255;
3928
3929 write!(self.out, "(float4(")?;
3930 self.write_expr(module, arg, func_ctx)?;
3931 write!(self.out, " & 0xFF, ")?;
3932 self.write_expr(module, arg, func_ctx)?;
3933 write!(self.out, " >> 8 & 0xFF, ")?;
3934 self.write_expr(module, arg, func_ctx)?;
3935 write!(self.out, " >> 16 & 0xFF, ")?;
3936 self.write_expr(module, arg, func_ctx)?;
3937 write!(self.out, " >> 24) / {scale}.0)")?;
3938 }
3939 fun @ (Function::Unpack4xI8 | Function::Unpack4xU8) => {
3940 write!(self.out, "(")?;
3941 if matches!(fun, Function::Unpack4xU8) {
3942 write!(self.out, "u")?;
3943 }
3944 write!(self.out, "int4(")?;
3945 self.write_expr(module, arg, func_ctx)?;
3946 write!(self.out, ", ")?;
3947 self.write_expr(module, arg, func_ctx)?;
3948 write!(self.out, " >> 8, ")?;
3949 self.write_expr(module, arg, func_ctx)?;
3950 write!(self.out, " >> 16, ")?;
3951 self.write_expr(module, arg, func_ctx)?;
3952 write!(self.out, " >> 24) << 24 >> 24)")?;
3953 }
3954 fun @ (Function::Dot4I8Packed | Function::Dot4U8Packed) => {
3955 let arg1 = arg1.unwrap();
3956
3957 if self.options.shader_model >= ShaderModel::V6_4 {
3958 let function_name = match fun {
3960 Function::Dot4I8Packed => "dot4add_i8packed",
3961 Function::Dot4U8Packed => "dot4add_u8packed",
3962 _ => unreachable!(),
3963 };
3964 write!(self.out, "{function_name}(")?;
3965 self.write_expr(module, arg, func_ctx)?;
3966 write!(self.out, ", ")?;
3967 self.write_expr(module, arg1, func_ctx)?;
3968 write!(self.out, ", 0)")?;
3969 } else {
3970 write!(self.out, "dot(")?;
3972
3973 if matches!(fun, Function::Dot4U8Packed) {
3974 write!(self.out, "u")?;
3975 }
3976 write!(self.out, "int4(")?;
3977 self.write_expr(module, arg, func_ctx)?;
3978 write!(self.out, ", ")?;
3979 self.write_expr(module, arg, func_ctx)?;
3980 write!(self.out, " >> 8, ")?;
3981 self.write_expr(module, arg, func_ctx)?;
3982 write!(self.out, " >> 16, ")?;
3983 self.write_expr(module, arg, func_ctx)?;
3984 write!(self.out, " >> 24) << 24 >> 24, ")?;
3985
3986 if matches!(fun, Function::Dot4U8Packed) {
3987 write!(self.out, "u")?;
3988 }
3989 write!(self.out, "int4(")?;
3990 self.write_expr(module, arg1, func_ctx)?;
3991 write!(self.out, ", ")?;
3992 self.write_expr(module, arg1, func_ctx)?;
3993 write!(self.out, " >> 8, ")?;
3994 self.write_expr(module, arg1, func_ctx)?;
3995 write!(self.out, " >> 16, ")?;
3996 self.write_expr(module, arg1, func_ctx)?;
3997 write!(self.out, " >> 24) << 24 >> 24)")?;
3998 }
3999 }
4000 Function::QuantizeToF16 => {
4001 write!(self.out, "f16tof32(f32tof16(")?;
4002 self.write_expr(module, arg, func_ctx)?;
4003 write!(self.out, "))")?;
4004 }
4005 Function::Regular(fun_name) => {
4006 write!(self.out, "{fun_name}(")?;
4007 self.write_expr(module, arg, func_ctx)?;
4008 if let Some(arg) = arg1 {
4009 write!(self.out, ", ")?;
4010 self.write_expr(module, arg, func_ctx)?;
4011 }
4012 if let Some(arg) = arg2 {
4013 write!(self.out, ", ")?;
4014 self.write_expr(module, arg, func_ctx)?;
4015 }
4016 if let Some(arg) = arg3 {
4017 write!(self.out, ", ")?;
4018 self.write_expr(module, arg, func_ctx)?;
4019 }
4020 write!(self.out, ")")?
4021 }
4022 Function::MissingIntOverload(fun_name) => {
4025 let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
4026 if let Some(Scalar::I32) = scalar_kind {
4027 write!(self.out, "asint({fun_name}(asuint(")?;
4028 self.write_expr(module, arg, func_ctx)?;
4029 write!(self.out, ")))")?;
4030 } else {
4031 write!(self.out, "{fun_name}(")?;
4032 self.write_expr(module, arg, func_ctx)?;
4033 write!(self.out, ")")?;
4034 }
4035 }
4036 Function::MissingIntReturnType(fun_name) => {
4039 let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
4040 if let Some(Scalar::I32) = scalar_kind {
4041 write!(self.out, "asint({fun_name}(")?;
4042 self.write_expr(module, arg, func_ctx)?;
4043 write!(self.out, "))")?;
4044 } else {
4045 write!(self.out, "{fun_name}(")?;
4046 self.write_expr(module, arg, func_ctx)?;
4047 write!(self.out, ")")?;
4048 }
4049 }
4050 Function::CountTrailingZeros => {
4051 match *func_ctx.resolve_type(arg, &module.types) {
4052 TypeInner::Vector { size, scalar } => {
4053 let s = match size {
4054 crate::VectorSize::Bi => ".xx",
4055 crate::VectorSize::Tri => ".xxx",
4056 crate::VectorSize::Quad => ".xxxx",
4057 };
4058
4059 let scalar_width_bits = scalar.width * 8;
4060
4061 if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
4062 write!(
4063 self.out,
4064 "min(({scalar_width_bits}u){s}, firstbitlow("
4065 )?;
4066 self.write_expr(module, arg, func_ctx)?;
4067 write!(self.out, "))")?;
4068 } else {
4069 write!(
4071 self.out,
4072 "asint(min(({scalar_width_bits}u){s}, firstbitlow("
4073 )?;
4074 self.write_expr(module, arg, func_ctx)?;
4075 write!(self.out, ")))")?;
4076 }
4077 }
4078 TypeInner::Scalar(scalar) => {
4079 let scalar_width_bits = scalar.width * 8;
4080
4081 if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
4082 write!(self.out, "min({scalar_width_bits}u, firstbitlow(")?;
4083 self.write_expr(module, arg, func_ctx)?;
4084 write!(self.out, "))")?;
4085 } else {
4086 write!(
4088 self.out,
4089 "asint(min({scalar_width_bits}u, firstbitlow("
4090 )?;
4091 self.write_expr(module, arg, func_ctx)?;
4092 write!(self.out, ")))")?;
4093 }
4094 }
4095 _ => unreachable!(),
4096 }
4097
4098 return Ok(());
4099 }
4100 Function::CountLeadingZeros => {
4101 match *func_ctx.resolve_type(arg, &module.types) {
4102 TypeInner::Vector { size, scalar } => {
4103 let s = match size {
4104 crate::VectorSize::Bi => ".xx",
4105 crate::VectorSize::Tri => ".xxx",
4106 crate::VectorSize::Quad => ".xxxx",
4107 };
4108
4109 let constant = scalar.width * 8 - 1;
4111
4112 if scalar.kind == ScalarKind::Uint {
4113 write!(self.out, "(({constant}u){s} - firstbithigh(")?;
4114 self.write_expr(module, arg, func_ctx)?;
4115 write!(self.out, "))")?;
4116 } else {
4117 let conversion_func = match scalar.width {
4118 4 => "asint",
4119 _ => "",
4120 };
4121 write!(self.out, "(")?;
4122 self.write_expr(module, arg, func_ctx)?;
4123 write!(
4124 self.out,
4125 " < (0){s} ? (0){s} : ({constant}){s} - {conversion_func}(firstbithigh("
4126 )?;
4127 self.write_expr(module, arg, func_ctx)?;
4128 write!(self.out, ")))")?;
4129 }
4130 }
4131 TypeInner::Scalar(scalar) => {
4132 let constant = scalar.width * 8 - 1;
4134
4135 if let ScalarKind::Uint = scalar.kind {
4136 write!(self.out, "({constant}u - firstbithigh(")?;
4137 self.write_expr(module, arg, func_ctx)?;
4138 write!(self.out, "))")?;
4139 } else {
4140 let conversion_func = match scalar.width {
4141 4 => "asint",
4142 _ => "",
4143 };
4144 write!(self.out, "(")?;
4145 self.write_expr(module, arg, func_ctx)?;
4146 write!(
4147 self.out,
4148 " < 0 ? 0 : {constant} - {conversion_func}(firstbithigh("
4149 )?;
4150 self.write_expr(module, arg, func_ctx)?;
4151 write!(self.out, ")))")?;
4152 }
4153 }
4154 _ => unreachable!(),
4155 }
4156
4157 return Ok(());
4158 }
4159 }
4160 }
4161 Expression::Swizzle {
4162 size,
4163 vector,
4164 pattern,
4165 } => {
4166 self.write_expr(module, vector, func_ctx)?;
4167 write!(self.out, ".")?;
4168 for &sc in pattern[..size as usize].iter() {
4169 self.out.write_char(back::COMPONENTS[sc as usize])?;
4170 }
4171 }
4172 Expression::ArrayLength(expr) => {
4173 let var_handle = match func_ctx.expressions[expr] {
4174 Expression::AccessIndex { base, index: _ } => {
4175 match func_ctx.expressions[base] {
4176 Expression::GlobalVariable(handle) => handle,
4177 _ => unreachable!(),
4178 }
4179 }
4180 Expression::GlobalVariable(handle) => handle,
4181 _ => unreachable!(),
4182 };
4183
4184 let var = &module.global_variables[var_handle];
4185 let (offset, stride) = match module.types[var.ty].inner {
4186 TypeInner::Array { stride, .. } => (0, stride),
4187 TypeInner::Struct { ref members, .. } => {
4188 let last = members.last().unwrap();
4189 let stride = match module.types[last.ty].inner {
4190 TypeInner::Array { stride, .. } => stride,
4191 _ => unreachable!(),
4192 };
4193 (last.offset, stride)
4194 }
4195 _ => unreachable!(),
4196 };
4197
4198 let storage_access = match var.space {
4199 crate::AddressSpace::Storage { access } => access,
4200 _ => crate::StorageAccess::default(),
4201 };
4202 let wrapped_array_length = WrappedArrayLength {
4203 writable: storage_access.contains(crate::StorageAccess::STORE),
4204 };
4205
4206 write!(self.out, "((")?;
4207 self.write_wrapped_array_length_function_name(wrapped_array_length)?;
4208 let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
4209 write!(self.out, "({var_name}) - {offset}) / {stride})")?
4210 }
4211 Expression::Derivative { axis, ctrl, expr } => {
4212 use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
4213 if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
4214 let tail = match ctrl {
4215 Ctrl::Coarse => "coarse",
4216 Ctrl::Fine => "fine",
4217 Ctrl::None => unreachable!(),
4218 };
4219 write!(self.out, "abs(ddx_{tail}(")?;
4220 self.write_expr(module, expr, func_ctx)?;
4221 write!(self.out, ")) + abs(ddy_{tail}(")?;
4222 self.write_expr(module, expr, func_ctx)?;
4223 write!(self.out, "))")?
4224 } else {
4225 let fun_str = match (axis, ctrl) {
4226 (Axis::X, Ctrl::Coarse) => "ddx_coarse",
4227 (Axis::X, Ctrl::Fine) => "ddx_fine",
4228 (Axis::X, Ctrl::None) => "ddx",
4229 (Axis::Y, Ctrl::Coarse) => "ddy_coarse",
4230 (Axis::Y, Ctrl::Fine) => "ddy_fine",
4231 (Axis::Y, Ctrl::None) => "ddy",
4232 (Axis::Width, Ctrl::Coarse | Ctrl::Fine) => unreachable!(),
4233 (Axis::Width, Ctrl::None) => "fwidth",
4234 };
4235 write!(self.out, "{fun_str}(")?;
4236 self.write_expr(module, expr, func_ctx)?;
4237 write!(self.out, ")")?
4238 }
4239 }
4240 Expression::Relational { fun, argument } => {
4241 use crate::RelationalFunction as Rf;
4242
4243 let fun_str = match fun {
4244 Rf::All => "all",
4245 Rf::Any => "any",
4246 Rf::IsNan => "isnan",
4247 Rf::IsInf => "isinf",
4248 };
4249 write!(self.out, "{fun_str}(")?;
4250 self.write_expr(module, argument, func_ctx)?;
4251 write!(self.out, ")")?
4252 }
4253 Expression::Select {
4254 condition,
4255 accept,
4256 reject,
4257 } => {
4258 write!(self.out, "(")?;
4259 self.write_expr(module, condition, func_ctx)?;
4260 write!(self.out, " ? ")?;
4261 self.write_expr(module, accept, func_ctx)?;
4262 write!(self.out, " : ")?;
4263 self.write_expr(module, reject, func_ctx)?;
4264 write!(self.out, ")")?
4265 }
4266 Expression::RayQueryGetIntersection { query, committed } => {
4267 if committed {
4268 write!(self.out, "GetCommittedIntersection(")?;
4269 self.write_expr(module, query, func_ctx)?;
4270 write!(self.out, ")")?;
4271 } else {
4272 write!(self.out, "GetCandidateIntersection(")?;
4273 self.write_expr(module, query, func_ctx)?;
4274 write!(self.out, ")")?;
4275 }
4276 }
4277 Expression::RayQueryVertexPositions { .. } => unreachable!(),
4279 Expression::CallResult(_)
4281 | Expression::AtomicResult { .. }
4282 | Expression::WorkGroupUniformLoadResult { .. }
4283 | Expression::RayQueryProceedResult
4284 | Expression::SubgroupBallotResult
4285 | Expression::SubgroupOperationResult { .. } => {}
4286 }
4287
4288 if !closing_bracket.is_empty() {
4289 write!(self.out, "{closing_bracket}")?;
4290 }
4291 Ok(())
4292 }
4293
4294 #[allow(clippy::too_many_arguments)]
4295 fn write_image_load(
4296 &mut self,
4297 module: &&Module,
4298 expr: Handle<crate::Expression>,
4299 func_ctx: &back::FunctionCtx,
4300 image: Handle<crate::Expression>,
4301 coordinate: Handle<crate::Expression>,
4302 array_index: Option<Handle<crate::Expression>>,
4303 sample: Option<Handle<crate::Expression>>,
4304 level: Option<Handle<crate::Expression>>,
4305 ) -> Result<(), Error> {
4306 let mut wrapping_type = None;
4307 match *func_ctx.resolve_type(image, &module.types) {
4308 TypeInner::Image {
4309 class: crate::ImageClass::External,
4310 ..
4311 } => {
4312 write!(self.out, "{IMAGE_LOAD_EXTERNAL_FUNCTION}(")?;
4313 self.write_expr(module, image, func_ctx)?;
4314 write!(self.out, ", ")?;
4315 self.write_expr(module, coordinate, func_ctx)?;
4316 write!(self.out, ")")?;
4317 return Ok(());
4318 }
4319 TypeInner::Image {
4320 class: crate::ImageClass::Storage { format, .. },
4321 ..
4322 } => {
4323 if format.single_component() {
4324 wrapping_type = Some(Scalar::from(format));
4325 }
4326 }
4327 _ => {}
4328 }
4329 if let Some(scalar) = wrapping_type {
4330 write!(
4331 self.out,
4332 "{}{}(",
4333 help::IMAGE_STORAGE_LOAD_SCALAR_WRAPPER,
4334 scalar.to_hlsl_str()?
4335 )?;
4336 }
4337 self.write_expr(module, image, func_ctx)?;
4339 write!(self.out, ".Load(")?;
4340
4341 self.write_texture_coordinates("int", coordinate, array_index, level, module, func_ctx)?;
4342
4343 if let Some(sample) = sample {
4344 write!(self.out, ", ")?;
4345 self.write_expr(module, sample, func_ctx)?;
4346 }
4347
4348 write!(self.out, ")")?;
4350
4351 if wrapping_type.is_some() {
4352 write!(self.out, ")")?;
4353 }
4354
4355 if let TypeInner::Scalar(_) = *func_ctx.resolve_type(expr, &module.types) {
4357 write!(self.out, ".x")?;
4358 }
4359 Ok(())
4360 }
4361
4362 fn sampler_binding_array_info_from_expression(
4365 &mut self,
4366 module: &Module,
4367 func_ctx: &back::FunctionCtx<'_>,
4368 base: Handle<crate::Expression>,
4369 resolved: &TypeInner,
4370 ) -> Option<BindingArraySamplerInfo> {
4371 if let TypeInner::BindingArray {
4372 base: base_ty_handle,
4373 ..
4374 } = *resolved
4375 {
4376 let base_ty = &module.types[base_ty_handle].inner;
4377 if let TypeInner::Sampler { comparison, .. } = *base_ty {
4378 let base = &func_ctx.expressions[base];
4379
4380 if let crate::Expression::GlobalVariable(handle) = *base {
4381 let variable = &module.global_variables[handle];
4382
4383 let sampler_heap_name = match comparison {
4384 true => COMPARISON_SAMPLER_HEAP_VAR,
4385 false => SAMPLER_HEAP_VAR,
4386 };
4387
4388 return Some(BindingArraySamplerInfo {
4389 sampler_heap_name,
4390 sampler_index_buffer_name: self
4391 .wrapped
4392 .sampler_index_buffers
4393 .get(&super::SamplerIndexBufferKey {
4394 group: variable.binding.unwrap().group,
4395 })
4396 .unwrap()
4397 .clone(),
4398 binding_array_base_index_name: self.names[&NameKey::GlobalVariable(handle)]
4399 .clone(),
4400 });
4401 }
4402 }
4403 }
4404
4405 None
4406 }
4407
4408 fn write_named_expr(
4409 &mut self,
4410 module: &Module,
4411 handle: Handle<crate::Expression>,
4412 name: String,
4413 named: Handle<crate::Expression>,
4416 ctx: &back::FunctionCtx,
4417 ) -> BackendResult {
4418 match ctx.info[named].ty {
4419 proc::TypeResolution::Handle(ty_handle) => match module.types[ty_handle].inner {
4420 TypeInner::Struct { .. } => {
4421 let ty_name = &self.names[&NameKey::Type(ty_handle)];
4422 write!(self.out, "{ty_name}")?;
4423 }
4424 _ => {
4425 self.write_type(module, ty_handle)?;
4426 }
4427 },
4428 proc::TypeResolution::Value(ref inner) => {
4429 self.write_value_type(module, inner)?;
4430 }
4431 }
4432
4433 let resolved = ctx.resolve_type(named, &module.types);
4434
4435 write!(self.out, " {name}")?;
4436 if let TypeInner::Array { base, size, .. } = *resolved {
4438 self.write_array_size(module, base, size)?;
4439 }
4440 write!(self.out, " = ")?;
4441 self.write_expr(module, handle, ctx)?;
4442 writeln!(self.out, ";")?;
4443 self.named_expressions.insert(named, name);
4444
4445 Ok(())
4446 }
4447
4448 pub(super) fn write_default_init(
4450 &mut self,
4451 module: &Module,
4452 ty: Handle<crate::Type>,
4453 ) -> BackendResult {
4454 write!(self.out, "(")?;
4455 self.write_type(module, ty)?;
4456 if let TypeInner::Array { base, size, .. } = module.types[ty].inner {
4457 self.write_array_size(module, base, size)?;
4458 }
4459 write!(self.out, ")0")?;
4460 Ok(())
4461 }
4462
4463 fn write_control_barrier(
4464 &mut self,
4465 barrier: crate::Barrier,
4466 level: back::Level,
4467 ) -> BackendResult {
4468 if barrier.contains(crate::Barrier::STORAGE) {
4469 writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
4470 }
4471 if barrier.contains(crate::Barrier::WORK_GROUP) {
4472 writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?;
4473 }
4474 if barrier.contains(crate::Barrier::SUB_GROUP) {
4475 }
4477 if barrier.contains(crate::Barrier::TEXTURE) {
4478 writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
4479 }
4480 Ok(())
4481 }
4482
4483 fn write_memory_barrier(
4484 &mut self,
4485 barrier: crate::Barrier,
4486 level: back::Level,
4487 ) -> BackendResult {
4488 if barrier.contains(crate::Barrier::STORAGE) {
4489 writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
4490 }
4491 if barrier.contains(crate::Barrier::WORK_GROUP) {
4492 writeln!(self.out, "{level}GroupMemoryBarrier();")?;
4493 }
4494 if barrier.contains(crate::Barrier::SUB_GROUP) {
4495 }
4497 if barrier.contains(crate::Barrier::TEXTURE) {
4498 writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
4499 }
4500 Ok(())
4501 }
4502
4503 fn emit_hlsl_atomic_tail(
4505 &mut self,
4506 module: &Module,
4507 func_ctx: &back::FunctionCtx<'_>,
4508 fun: &crate::AtomicFunction,
4509 compare_expr: Option<Handle<crate::Expression>>,
4510 value: Handle<crate::Expression>,
4511 res_var_info: &Option<(Handle<crate::Expression>, String)>,
4512 ) -> BackendResult {
4513 if let Some(cmp) = compare_expr {
4514 write!(self.out, ", ")?;
4515 self.write_expr(module, cmp, func_ctx)?;
4516 }
4517 write!(self.out, ", ")?;
4518 if let crate::AtomicFunction::Subtract = *fun {
4519 write!(self.out, "-")?;
4521 }
4522 self.write_expr(module, value, func_ctx)?;
4523 if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
4524 write!(self.out, ", ")?;
4525 if compare_expr.is_some() {
4526 write!(self.out, "{res_name}.old_value")?;
4527 } else {
4528 write!(self.out, "{res_name}")?;
4529 }
4530 }
4531 writeln!(self.out, ");")?;
4532 Ok(())
4533 }
4534}
4535
4536pub(super) struct MatrixType {
4537 pub(super) columns: crate::VectorSize,
4538 pub(super) rows: crate::VectorSize,
4539 pub(super) width: crate::Bytes,
4540}
4541
4542pub(super) fn get_inner_matrix_data(
4543 module: &Module,
4544 handle: Handle<crate::Type>,
4545) -> Option<MatrixType> {
4546 match module.types[handle].inner {
4547 TypeInner::Matrix {
4548 columns,
4549 rows,
4550 scalar,
4551 } => Some(MatrixType {
4552 columns,
4553 rows,
4554 width: scalar.width,
4555 }),
4556 TypeInner::Array { base, .. } => get_inner_matrix_data(module, base),
4557 _ => None,
4558 }
4559}
4560
4561fn find_matrix_in_access_chain(
4565 module: &Module,
4566 base: Handle<crate::Expression>,
4567 func_ctx: &back::FunctionCtx<'_>,
4568) -> Option<(Handle<crate::Expression>, Option<Index>, Option<Index>)> {
4569 let mut current_base = base;
4570 let mut vector = None;
4571 let mut scalar = None;
4572 loop {
4573 let resolved_tr = func_ctx
4574 .resolve_type(current_base, &module.types)
4575 .pointer_base_type();
4576 let resolved = resolved_tr.as_ref()?.inner_with(&module.types);
4577
4578 match *resolved {
4579 TypeInner::Matrix { .. } => return Some((current_base, vector, scalar)),
4580 TypeInner::Scalar(_) | TypeInner::Vector { .. } => {}
4581 _ => return None,
4582 }
4583
4584 let index;
4585 (current_base, index) = match func_ctx.expressions[current_base] {
4586 crate::Expression::Access { base, index } => (base, Index::Expression(index)),
4587 crate::Expression::AccessIndex { base, index } => (base, Index::Static(index)),
4588 _ => return None,
4589 };
4590
4591 match *resolved {
4592 TypeInner::Scalar(_) => scalar = Some(index),
4593 TypeInner::Vector { .. } => vector = Some(index),
4594 _ => unreachable!(),
4595 }
4596 }
4597}
4598
4599pub(super) fn get_inner_matrix_of_struct_array_member(
4604 module: &Module,
4605 base: Handle<crate::Expression>,
4606 func_ctx: &back::FunctionCtx<'_>,
4607 direct: bool,
4608) -> Option<MatrixType> {
4609 let mut mat_data = None;
4610 let mut array_base = None;
4611
4612 let mut current_base = base;
4613 loop {
4614 let mut resolved = func_ctx.resolve_type(current_base, &module.types);
4615 if let TypeInner::Pointer { base, .. } = *resolved {
4616 resolved = &module.types[base].inner;
4617 };
4618
4619 match *resolved {
4620 TypeInner::Matrix {
4621 columns,
4622 rows,
4623 scalar,
4624 } => {
4625 mat_data = Some(MatrixType {
4626 columns,
4627 rows,
4628 width: scalar.width,
4629 })
4630 }
4631 TypeInner::Array { base, .. } => {
4632 array_base = Some(base);
4633 }
4634 TypeInner::Struct { .. } => {
4635 if let Some(array_base) = array_base {
4636 if direct {
4637 return mat_data;
4638 } else {
4639 return get_inner_matrix_data(module, array_base);
4640 }
4641 }
4642
4643 break;
4644 }
4645 _ => break,
4646 }
4647
4648 current_base = match func_ctx.expressions[current_base] {
4649 crate::Expression::Access { base, .. } => base,
4650 crate::Expression::AccessIndex { base, .. } => base,
4651 _ => break,
4652 };
4653 }
4654 None
4655}
4656
4657fn get_global_uniform_matrix(
4660 module: &Module,
4661 base: Handle<crate::Expression>,
4662 func_ctx: &back::FunctionCtx<'_>,
4663) -> Option<MatrixType> {
4664 let base_tr = func_ctx
4665 .resolve_type(base, &module.types)
4666 .pointer_base_type();
4667 let base_ty = base_tr.as_ref().map(|tr| tr.inner_with(&module.types));
4668 match (&func_ctx.expressions[base], base_ty) {
4669 (
4670 &crate::Expression::GlobalVariable(handle),
4671 Some(&TypeInner::Matrix {
4672 columns,
4673 rows,
4674 scalar,
4675 }),
4676 ) if module.global_variables[handle].space == crate::AddressSpace::Uniform => {
4677 Some(MatrixType {
4678 columns,
4679 rows,
4680 width: scalar.width,
4681 })
4682 }
4683 _ => None,
4684 }
4685}
4686
4687fn get_inner_matrix_of_global_uniform(
4692 module: &Module,
4693 base: Handle<crate::Expression>,
4694 func_ctx: &back::FunctionCtx<'_>,
4695) -> Option<MatrixType> {
4696 let mut mat_data = None;
4697 let mut array_base = None;
4698
4699 let mut current_base = base;
4700 loop {
4701 let mut resolved = func_ctx.resolve_type(current_base, &module.types);
4702 if let TypeInner::Pointer { base, .. } = *resolved {
4703 resolved = &module.types[base].inner;
4704 };
4705
4706 match *resolved {
4707 TypeInner::Matrix {
4708 columns,
4709 rows,
4710 scalar,
4711 } => {
4712 mat_data = Some(MatrixType {
4713 columns,
4714 rows,
4715 width: scalar.width,
4716 })
4717 }
4718 TypeInner::Array { base, .. } => {
4719 array_base = Some(base);
4720 }
4721 _ => break,
4722 }
4723
4724 current_base = match func_ctx.expressions[current_base] {
4725 crate::Expression::Access { base, .. } => base,
4726 crate::Expression::AccessIndex { base, .. } => base,
4727 crate::Expression::GlobalVariable(handle)
4728 if module.global_variables[handle].space == crate::AddressSpace::Uniform =>
4729 {
4730 return mat_data.or_else(|| {
4731 array_base.and_then(|array_base| get_inner_matrix_data(module, array_base))
4732 })
4733 }
4734 _ => break,
4735 };
4736 }
4737 None
4738}