1use alloc::{
2 format,
3 string::{String, ToString},
4 vec::Vec,
5};
6use core::{fmt, mem};
7
8use super::{
9 help,
10 help::{
11 WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess,
12 WrappedZeroValue,
13 },
14 storage::StoreValue,
15 BackendResult, Error, FragmentEntryPoint, Options, PipelineOptions, ShaderModel,
16};
17use crate::{
18 back::{self, get_entry_points, Baked},
19 common,
20 proc::{self, index, ExternalTextureNameKey, NameKey},
21 valid, Handle, Module, RayQueryFunction, Scalar, ScalarKind, ShaderStage, TypeInner,
22};
23
24const LOCATION_SEMANTIC: &str = "LOC";
25const SPECIAL_CBUF_TYPE: &str = "NagaConstants";
26const SPECIAL_CBUF_VAR: &str = "_NagaConstants";
27const SPECIAL_FIRST_VERTEX: &str = "first_vertex";
28const SPECIAL_FIRST_INSTANCE: &str = "first_instance";
29const SPECIAL_OTHER: &str = "other";
30
31pub(crate) const MODF_FUNCTION: &str = "naga_modf";
32pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
33pub(crate) const EXTRACT_BITS_FUNCTION: &str = "naga_extractBits";
34pub(crate) const INSERT_BITS_FUNCTION: &str = "naga_insertBits";
35pub(crate) const SAMPLER_HEAP_VAR: &str = "nagaSamplerHeap";
36pub(crate) const COMPARISON_SAMPLER_HEAP_VAR: &str = "nagaComparisonSamplerHeap";
37pub(crate) const SAMPLE_EXTERNAL_TEXTURE_FUNCTION: &str = "nagaSampleExternalTexture";
38pub(crate) const ABS_FUNCTION: &str = "naga_abs";
39pub(crate) const DIV_FUNCTION: &str = "naga_div";
40pub(crate) const MOD_FUNCTION: &str = "naga_mod";
41pub(crate) const NEG_FUNCTION: &str = "naga_neg";
42pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
43pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
44pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
45pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
46pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
47 "nagaTextureSampleBaseClampToEdge";
48pub(crate) const IMAGE_LOAD_EXTERNAL_FUNCTION: &str = "nagaTextureLoadExternal";
49pub(crate) const RAY_QUERY_TRACKER_VARIABLE_PREFIX: &str = "naga_query_init_tracker_for_";
50pub(crate) const INTERNAL_PREFIX: &str = "naga_";
52
53enum Index {
54 Expression(Handle<crate::Expression>),
55 Static(u32),
56}
57
58struct EpStructMember {
59 name: String,
60 ty: Handle<crate::Type>,
61 binding: Option<crate::Binding>,
64 index: u32,
65}
66
67struct EntryPointBinding {
70 arg_name: String,
73 ty_name: String,
75 members: Vec<EpStructMember>,
77}
78
79pub(super) struct EntryPointInterface {
80 input: Option<EntryPointBinding>,
85 output: Option<EntryPointBinding>,
89}
90
91#[derive(Clone, Eq, PartialEq, PartialOrd, Ord)]
92enum InterfaceKey {
93 Location(u32),
94 BuiltIn(crate::BuiltIn),
95 Other,
96}
97
98impl InterfaceKey {
99 const fn new(binding: Option<&crate::Binding>) -> Self {
100 match binding {
101 Some(&crate::Binding::Location { location, .. }) => Self::Location(location),
102 Some(&crate::Binding::BuiltIn(built_in)) => Self::BuiltIn(built_in),
103 None => Self::Other,
104 }
105 }
106}
107
108#[derive(Copy, Clone, PartialEq)]
109enum Io {
110 Input,
111 Output,
112}
113
114const fn is_subgroup_builtin_binding(binding: &Option<crate::Binding>) -> bool {
115 let &Some(crate::Binding::BuiltIn(builtin)) = binding else {
116 return false;
117 };
118 matches!(
119 builtin,
120 crate::BuiltIn::SubgroupSize
121 | crate::BuiltIn::SubgroupInvocationId
122 | crate::BuiltIn::NumSubgroups
123 | crate::BuiltIn::SubgroupId
124 )
125}
126
127struct BindingArraySamplerInfo {
129 sampler_heap_name: &'static str,
131 sampler_index_buffer_name: String,
133 binding_array_base_index_name: String,
135}
136
137impl<'a, W: fmt::Write> super::Writer<'a, W> {
138 pub fn new(out: W, options: &'a Options, pipeline_options: &'a PipelineOptions) -> Self {
139 Self {
140 out,
141 names: crate::FastHashMap::default(),
142 namer: proc::Namer::default(),
143 options,
144 pipeline_options,
145 entry_point_io: crate::FastHashMap::default(),
146 named_expressions: crate::NamedExpressions::default(),
147 wrapped: super::Wrapped::default(),
148 written_committed_intersection: false,
149 written_candidate_intersection: false,
150 continue_ctx: back::continue_forward::ContinueCtx::default(),
151 temp_access_chain: Vec::new(),
152 need_bake_expressions: Default::default(),
153 }
154 }
155
156 fn reset(&mut self, module: &Module) {
157 self.names.clear();
158 self.namer.reset(
159 module,
160 &super::keywords::RESERVED_SET,
161 &super::keywords::RESERVED_CASE_INSENSITIVE_SET,
162 super::keywords::RESERVED_PREFIXES,
163 &mut self.names,
164 );
165 self.entry_point_io.clear();
166 self.named_expressions.clear();
167 self.wrapped.clear();
168 self.written_committed_intersection = false;
169 self.written_candidate_intersection = false;
170 self.continue_ctx.clear();
171 self.need_bake_expressions.clear();
172 }
173
174 fn gen_force_bounded_loop_statements(
182 &mut self,
183 level: back::Level,
184 ) -> Option<(String, String)> {
185 if !self.options.force_loop_bounding {
186 return None;
187 }
188
189 let loop_bound_name = self.namer.call("loop_bound");
190 let max = u32::MAX;
191 let decl = format!("{level}uint2 {loop_bound_name} = uint2({max}u, {max}u);");
194 let level = level.next();
195 let break_and_inc = format!(
196 "{level}if (all({loop_bound_name} == uint2(0u, 0u))) {{ break; }}
197{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
198 );
199
200 Some((decl, break_and_inc))
201 }
202
203 fn update_expressions_to_bake(
208 &mut self,
209 module: &Module,
210 func: &crate::Function,
211 info: &valid::FunctionInfo,
212 ) {
213 use crate::Expression;
214 self.need_bake_expressions.clear();
215 for (exp_handle, expr) in func.expressions.iter() {
216 let expr_info = &info[exp_handle];
217 let min_ref_count = func.expressions[exp_handle].bake_ref_count();
218 if min_ref_count <= expr_info.ref_count {
219 self.need_bake_expressions.insert(exp_handle);
220 }
221
222 if let Expression::Math { fun, arg, arg1, .. } = *expr {
223 match fun {
224 crate::MathFunction::Asinh
225 | crate::MathFunction::Acosh
226 | crate::MathFunction::Atanh
227 | crate::MathFunction::Unpack2x16float
228 | crate::MathFunction::Unpack2x16snorm
229 | crate::MathFunction::Unpack2x16unorm
230 | crate::MathFunction::Unpack4x8snorm
231 | crate::MathFunction::Unpack4x8unorm
232 | crate::MathFunction::Unpack4xI8
233 | crate::MathFunction::Unpack4xU8
234 | crate::MathFunction::Pack2x16float
235 | crate::MathFunction::Pack2x16snorm
236 | crate::MathFunction::Pack2x16unorm
237 | crate::MathFunction::Pack4x8snorm
238 | crate::MathFunction::Pack4x8unorm
239 | crate::MathFunction::Pack4xI8
240 | crate::MathFunction::Pack4xU8
241 | crate::MathFunction::Pack4xI8Clamp
242 | crate::MathFunction::Pack4xU8Clamp => {
243 self.need_bake_expressions.insert(arg);
244 }
245 crate::MathFunction::CountLeadingZeros => {
246 let inner = info[exp_handle].ty.inner_with(&module.types);
247 if let Some(ScalarKind::Sint) = inner.scalar_kind() {
248 self.need_bake_expressions.insert(arg);
249 }
250 }
251 crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
252 self.need_bake_expressions.insert(arg);
253 self.need_bake_expressions.insert(arg1.unwrap());
254 }
255 _ => {}
256 }
257 }
258
259 if let Expression::Derivative { axis, ctrl, expr } = *expr {
260 use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
261 if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
262 self.need_bake_expressions.insert(expr);
263 }
264 }
265
266 if let Expression::GlobalVariable(_) = *expr {
267 let inner = info[exp_handle].ty.inner_with(&module.types);
268
269 if let TypeInner::Sampler { .. } = *inner {
270 self.need_bake_expressions.insert(exp_handle);
271 }
272 }
273 }
274 for statement in func.body.iter() {
275 match *statement {
276 crate::Statement::SubgroupCollectiveOperation {
277 op: _,
278 collective_op: crate::CollectiveOperation::InclusiveScan,
279 argument,
280 result: _,
281 } => {
282 self.need_bake_expressions.insert(argument);
283 }
284 crate::Statement::Atomic {
285 fun: crate::AtomicFunction::Exchange { compare: Some(cmp) },
286 ..
287 } => {
288 self.need_bake_expressions.insert(cmp);
289 }
290 _ => {}
291 }
292 }
293 }
294
295 pub fn write(
296 &mut self,
297 module: &Module,
298 module_info: &valid::ModuleInfo,
299 fragment_entry_point: Option<&FragmentEntryPoint<'_>>,
300 ) -> Result<super::ReflectionInfo, Error> {
301 self.reset(module);
302
303 if let Some(ref bt) = self.options.special_constants_binding {
305 writeln!(self.out, "struct {SPECIAL_CBUF_TYPE} {{")?;
306 writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_VERTEX)?;
307 writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_INSTANCE)?;
308 writeln!(self.out, "{}uint {};", back::INDENT, SPECIAL_OTHER)?;
309 writeln!(self.out, "}};")?;
310 write!(
311 self.out,
312 "ConstantBuffer<{}> {}: register(b{}",
313 SPECIAL_CBUF_TYPE, SPECIAL_CBUF_VAR, bt.register
314 )?;
315 if bt.space != 0 {
316 write!(self.out, ", space{}", bt.space)?;
317 }
318 writeln!(self.out, ");")?;
319
320 writeln!(self.out)?;
322 }
323
324 for (group, bt) in self.options.dynamic_storage_buffer_offsets_targets.iter() {
325 writeln!(self.out, "struct __dynamic_buffer_offsetsTy{group} {{")?;
326 for i in 0..bt.size {
327 writeln!(self.out, "{}uint _{};", back::INDENT, i)?;
328 }
329 writeln!(self.out, "}};")?;
330 writeln!(
331 self.out,
332 "ConstantBuffer<__dynamic_buffer_offsetsTy{}> __dynamic_buffer_offsets{}: register(b{}, space{});",
333 group, group, bt.register, bt.space
334 )?;
335
336 writeln!(self.out)?;
338 }
339
340 let ep_results = module
342 .entry_points
343 .iter()
344 .map(|ep| (ep.stage, ep.function.result.clone()))
345 .collect::<Vec<(ShaderStage, Option<crate::FunctionResult>)>>();
346
347 self.write_all_mat_cx2_typedefs_and_functions(module)?;
348
349 for (handle, ty) in module.types.iter() {
351 if let TypeInner::Struct { ref members, span } = ty.inner {
352 if module.types[members.last().unwrap().ty]
353 .inner
354 .is_dynamically_sized(&module.types)
355 {
356 continue;
359 }
360
361 let ep_result = ep_results.iter().find(|e| {
362 if let Some(ref result) = e.1 {
363 result.ty == handle
364 } else {
365 false
366 }
367 });
368
369 self.write_struct(
370 module,
371 handle,
372 members,
373 span,
374 ep_result.map(|r| (r.0, Io::Output)),
375 )?;
376 writeln!(self.out)?;
377 }
378 }
379
380 self.write_special_functions(module)?;
381
382 self.write_wrapped_expression_functions(module, &module.global_expressions, None)?;
383 self.write_wrapped_zero_value_functions(module, &module.global_expressions)?;
384
385 let mut constants = module
387 .constants
388 .iter()
389 .filter(|&(_, c)| c.name.is_some())
390 .peekable();
391 while let Some((handle, _)) = constants.next() {
392 self.write_global_constant(module, handle)?;
393 if constants.peek().is_none() {
395 writeln!(self.out)?;
396 }
397 }
398
399 for (global, _) in module.global_variables.iter() {
401 self.write_global(module, global)?;
402 }
403
404 if !module.global_variables.is_empty() {
405 writeln!(self.out)?;
407 }
408
409 let ep_range = get_entry_points(module, self.pipeline_options.entry_point.as_ref())
410 .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
411
412 for index in ep_range.clone() {
414 let ep = &module.entry_points[index];
415 let ep_name = self.names[&NameKey::EntryPoint(index as u16)].clone();
416 let ep_io = self.write_ep_interface(
417 module,
418 &ep.function,
419 ep.stage,
420 &ep_name,
421 fragment_entry_point,
422 )?;
423 self.entry_point_io.insert(index, ep_io);
424 }
425
426 for (handle, function) in module.functions.iter() {
428 let info = &module_info[handle];
429
430 if !self.options.fake_missing_bindings {
432 if let Some((var_handle, _)) =
433 module
434 .global_variables
435 .iter()
436 .find(|&(var_handle, var)| match var.binding {
437 Some(ref binding) if !info[var_handle].is_empty() => {
438 self.options.resolve_resource_binding(binding).is_err()
439 && self
440 .options
441 .resolve_external_texture_resource_binding(binding)
442 .is_err()
443 }
444 _ => false,
445 })
446 {
447 log::debug!(
448 "Skipping function {:?} (name {:?}) because global {:?} is inaccessible",
449 handle,
450 function.name,
451 var_handle
452 );
453 continue;
454 }
455 }
456
457 let ctx = back::FunctionCtx {
458 ty: back::FunctionType::Function(handle),
459 info,
460 expressions: &function.expressions,
461 named_expressions: &function.named_expressions,
462 };
463 let name = self.names[&NameKey::Function(handle)].clone();
464
465 self.write_wrapped_functions(module, &ctx)?;
466
467 self.write_function(module, name.as_str(), function, &ctx, info)?;
468
469 writeln!(self.out)?;
470 }
471
472 let mut translated_ep_names = Vec::with_capacity(ep_range.len());
473
474 for index in ep_range {
476 let ep = &module.entry_points[index];
477 let info = module_info.get_entry_point(index);
478
479 if !self.options.fake_missing_bindings {
480 let mut ep_error = None;
481 for (var_handle, var) in module.global_variables.iter() {
482 match var.binding {
483 Some(ref binding) if !info[var_handle].is_empty() => {
484 if let Err(err) = self.options.resolve_resource_binding(binding) {
485 if self
486 .options
487 .resolve_external_texture_resource_binding(binding)
488 .is_err()
489 {
490 ep_error = Some(err);
491 break;
492 }
493 }
494 }
495 _ => {}
496 }
497 }
498 if let Some(err) = ep_error {
499 translated_ep_names.push(Err(err));
500 continue;
501 }
502 }
503
504 let ctx = back::FunctionCtx {
505 ty: back::FunctionType::EntryPoint(index as u16),
506 info,
507 expressions: &ep.function.expressions,
508 named_expressions: &ep.function.named_expressions,
509 };
510
511 self.write_wrapped_functions(module, &ctx)?;
512
513 if ep.stage.compute_like() {
514 let num_threads = ep.workgroup_size;
516 writeln!(
517 self.out,
518 "[numthreads({}, {}, {})]",
519 num_threads[0], num_threads[1], num_threads[2]
520 )?;
521 }
522
523 let name = self.names[&NameKey::EntryPoint(index as u16)].clone();
524 self.write_function(module, &name, &ep.function, &ctx, info)?;
525
526 if index < module.entry_points.len() - 1 {
527 writeln!(self.out)?;
528 }
529
530 translated_ep_names.push(Ok(name));
531 }
532
533 Ok(super::ReflectionInfo {
534 entry_point_names: translated_ep_names,
535 })
536 }
537
538 fn write_modifier(&mut self, binding: &crate::Binding) -> BackendResult {
539 match *binding {
540 crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }) => {
541 write!(self.out, "precise ")?;
542 }
543 crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { perspective: false }) => {
544 write!(self.out, "noperspective ")?;
545 }
546 crate::Binding::Location {
547 interpolation,
548 sampling,
549 ..
550 } => {
551 if let Some(interpolation) = interpolation {
552 if let Some(string) = interpolation.to_hlsl_str() {
553 write!(self.out, "{string} ")?
554 }
555 }
556
557 if let Some(sampling) = sampling {
558 if let Some(string) = sampling.to_hlsl_str() {
559 write!(self.out, "{string} ")?
560 }
561 }
562 }
563 crate::Binding::BuiltIn(_) => {}
564 }
565
566 Ok(())
567 }
568
569 fn write_semantic(
572 &mut self,
573 binding: &Option<crate::Binding>,
574 stage: Option<(ShaderStage, Io)>,
575 ) -> BackendResult {
576 match *binding {
577 Some(crate::Binding::BuiltIn(builtin)) if !is_subgroup_builtin_binding(binding) => {
578 if builtin == crate::BuiltIn::ViewIndex
579 && self.options.shader_model < ShaderModel::V6_1
580 {
581 return Err(Error::ShaderModelTooLow(
582 "used @builtin(view_index) or SV_ViewID".to_string(),
583 ShaderModel::V6_1,
584 ));
585 }
586 let builtin_str = builtin.to_hlsl_str()?;
587 write!(self.out, " : {builtin_str}")?;
588 }
589 Some(crate::Binding::Location {
590 blend_src: Some(1), ..
591 }) => {
592 write!(self.out, " : SV_Target1")?;
593 }
594 Some(crate::Binding::Location { location, .. }) => {
595 if stage == Some((ShaderStage::Fragment, Io::Output)) {
596 write!(self.out, " : SV_Target{location}")?;
597 } else {
598 write!(self.out, " : {LOCATION_SEMANTIC}{location}")?;
599 }
600 }
601 _ => {}
602 }
603
604 Ok(())
605 }
606
607 fn write_interface_struct(
608 &mut self,
609 module: &Module,
610 shader_stage: (ShaderStage, Io),
611 struct_name: String,
612 mut members: Vec<EpStructMember>,
613 ) -> Result<EntryPointBinding, Error> {
614 members.sort_by_key(|m| InterfaceKey::new(m.binding.as_ref()));
618
619 write!(self.out, "struct {struct_name}")?;
620 writeln!(self.out, " {{")?;
621 for m in members.iter() {
622 debug_assert!(m.binding.is_some());
625
626 if is_subgroup_builtin_binding(&m.binding) {
627 continue;
628 }
629 write!(self.out, "{}", back::INDENT)?;
630 if let Some(ref binding) = m.binding {
631 self.write_modifier(binding)?;
632 }
633 self.write_type(module, m.ty)?;
634 write!(self.out, " {}", &m.name)?;
635 self.write_semantic(&m.binding, Some(shader_stage))?;
636 writeln!(self.out, ";")?;
637 }
638 if members.iter().any(|arg| {
639 matches!(
640 arg.binding,
641 Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId))
642 )
643 }) {
644 writeln!(
645 self.out,
646 "{}uint __local_invocation_index : SV_GroupIndex;",
647 back::INDENT
648 )?;
649 }
650 writeln!(self.out, "}};")?;
651 writeln!(self.out)?;
652
653 match shader_stage.1 {
655 Io::Input => {
656 members.sort_by_key(|m| m.index);
658 }
659 Io::Output => {
660 }
662 }
663
664 Ok(EntryPointBinding {
665 arg_name: self.namer.call(struct_name.to_lowercase().as_str()),
666 ty_name: struct_name,
667 members,
668 })
669 }
670
671 fn write_ep_input_struct(
675 &mut self,
676 module: &Module,
677 func: &crate::Function,
678 stage: ShaderStage,
679 entry_point_name: &str,
680 ) -> Result<EntryPointBinding, Error> {
681 let struct_name = format!("{stage:?}Input_{entry_point_name}");
682
683 let mut fake_members = Vec::new();
684 for arg in func.arguments.iter() {
685 match module.types[arg.ty].inner {
690 TypeInner::Struct { ref members, .. } => {
691 for member in members.iter() {
692 let name = self.namer.call_or(&member.name, "member");
693 let index = fake_members.len() as u32;
694 fake_members.push(EpStructMember {
695 name,
696 ty: member.ty,
697 binding: member.binding.clone(),
698 index,
699 });
700 }
701 }
702 _ => {
703 let member_name = self.namer.call_or(&arg.name, "member");
704 let index = fake_members.len() as u32;
705 fake_members.push(EpStructMember {
706 name: member_name,
707 ty: arg.ty,
708 binding: arg.binding.clone(),
709 index,
710 });
711 }
712 }
713 }
714
715 self.write_interface_struct(module, (stage, Io::Input), struct_name, fake_members)
716 }
717
718 fn write_ep_output_struct(
722 &mut self,
723 module: &Module,
724 result: &crate::FunctionResult,
725 stage: ShaderStage,
726 entry_point_name: &str,
727 frag_ep: Option<&FragmentEntryPoint<'_>>,
728 ) -> Result<EntryPointBinding, Error> {
729 let struct_name = format!("{stage:?}Output_{entry_point_name}");
730
731 let empty = [];
732 let members = match module.types[result.ty].inner {
733 TypeInner::Struct { ref members, .. } => members,
734 ref other => {
735 log::error!("Unexpected {other:?} output type without a binding");
736 &empty[..]
737 }
738 };
739
740 let fs_input_locs = if let (Some(frag_ep), ShaderStage::Vertex) = (frag_ep, stage) {
745 let mut fs_input_locs = Vec::new();
746 for arg in frag_ep.func.arguments.iter() {
747 let mut push_if_location = |binding: &Option<crate::Binding>| match *binding {
748 Some(crate::Binding::Location { location, .. }) => fs_input_locs.push(location),
749 Some(crate::Binding::BuiltIn(_)) | None => {}
750 };
751
752 match frag_ep.module.types[arg.ty].inner {
755 TypeInner::Struct { ref members, .. } => {
756 for member in members.iter() {
757 push_if_location(&member.binding);
758 }
759 }
760 _ => push_if_location(&arg.binding),
761 }
762 }
763 fs_input_locs.sort();
764 Some(fs_input_locs)
765 } else {
766 None
767 };
768
769 let mut fake_members = Vec::new();
770 for (index, member) in members.iter().enumerate() {
771 if let Some(ref fs_input_locs) = fs_input_locs {
772 match member.binding {
773 Some(crate::Binding::Location { location, .. }) => {
774 if fs_input_locs.binary_search(&location).is_err() {
775 continue;
776 }
777 }
778 Some(crate::Binding::BuiltIn(_)) | None => {}
779 }
780 }
781
782 let member_name = self.namer.call_or(&member.name, "member");
783 fake_members.push(EpStructMember {
784 name: member_name,
785 ty: member.ty,
786 binding: member.binding.clone(),
787 index: index as u32,
788 });
789 }
790
791 self.write_interface_struct(module, (stage, Io::Output), struct_name, fake_members)
792 }
793
794 fn write_ep_interface(
798 &mut self,
799 module: &Module,
800 func: &crate::Function,
801 stage: ShaderStage,
802 ep_name: &str,
803 frag_ep: Option<&FragmentEntryPoint<'_>>,
804 ) -> Result<EntryPointInterface, Error> {
805 Ok(EntryPointInterface {
806 input: if !func.arguments.is_empty()
807 && (stage == ShaderStage::Fragment
808 || func
809 .arguments
810 .iter()
811 .any(|arg| is_subgroup_builtin_binding(&arg.binding)))
812 {
813 Some(self.write_ep_input_struct(module, func, stage, ep_name)?)
814 } else {
815 None
816 },
817 output: match func.result {
818 Some(ref fr) if fr.binding.is_none() && stage == ShaderStage::Vertex => {
819 Some(self.write_ep_output_struct(module, fr, stage, ep_name, frag_ep)?)
820 }
821 _ => None,
822 },
823 })
824 }
825
826 fn write_ep_argument_initialization(
827 &mut self,
828 ep: &crate::EntryPoint,
829 ep_input: &EntryPointBinding,
830 fake_member: &EpStructMember,
831 ) -> BackendResult {
832 match fake_member.binding {
833 Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupSize)) => {
834 write!(self.out, "WaveGetLaneCount()")?
835 }
836 Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupInvocationId)) => {
837 write!(self.out, "WaveGetLaneIndex()")?
838 }
839 Some(crate::Binding::BuiltIn(crate::BuiltIn::NumSubgroups)) => write!(
840 self.out,
841 "({}u + WaveGetLaneCount() - 1u) / WaveGetLaneCount()",
842 ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2]
843 )?,
844 Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => {
845 write!(
846 self.out,
847 "{}.__local_invocation_index / WaveGetLaneCount()",
848 ep_input.arg_name
849 )?;
850 }
851 _ => {
852 write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?;
853 }
854 }
855 Ok(())
856 }
857
858 fn write_ep_arguments_initialization(
860 &mut self,
861 module: &Module,
862 func: &crate::Function,
863 ep_index: u16,
864 ) -> BackendResult {
865 let ep = &module.entry_points[ep_index as usize];
866 let ep_input = match self
867 .entry_point_io
868 .get_mut(&(ep_index as usize))
869 .unwrap()
870 .input
871 .take()
872 {
873 Some(ep_input) => ep_input,
874 None => return Ok(()),
875 };
876 let mut fake_iter = ep_input.members.iter();
877 for (arg_index, arg) in func.arguments.iter().enumerate() {
878 write!(self.out, "{}", back::INDENT)?;
879 self.write_type(module, arg.ty)?;
880 let arg_name = &self.names[&NameKey::EntryPointArgument(ep_index, arg_index as u32)];
881 write!(self.out, " {arg_name}")?;
882 match module.types[arg.ty].inner {
883 TypeInner::Array { base, size, .. } => {
884 self.write_array_size(module, base, size)?;
885 write!(self.out, " = ")?;
886 self.write_ep_argument_initialization(
887 ep,
888 &ep_input,
889 fake_iter.next().unwrap(),
890 )?;
891 writeln!(self.out, ";")?;
892 }
893 TypeInner::Struct { ref members, .. } => {
894 write!(self.out, " = {{ ")?;
895 for index in 0..members.len() {
896 if index != 0 {
897 write!(self.out, ", ")?;
898 }
899 self.write_ep_argument_initialization(
900 ep,
901 &ep_input,
902 fake_iter.next().unwrap(),
903 )?;
904 }
905 writeln!(self.out, " }};")?;
906 }
907 _ => {
908 write!(self.out, " = ")?;
909 self.write_ep_argument_initialization(
910 ep,
911 &ep_input,
912 fake_iter.next().unwrap(),
913 )?;
914 writeln!(self.out, ";")?;
915 }
916 }
917 }
918 assert!(fake_iter.next().is_none());
919 Ok(())
920 }
921
922 fn write_global(
926 &mut self,
927 module: &Module,
928 handle: Handle<crate::GlobalVariable>,
929 ) -> BackendResult {
930 let global = &module.global_variables[handle];
931 let inner = &module.types[global.ty].inner;
932
933 let handle_ty = match *inner {
934 TypeInner::BindingArray { ref base, .. } => &module.types[*base].inner,
935 _ => inner,
936 };
937
938 let is_external_texture = matches!(
942 *handle_ty,
943 TypeInner::Image {
944 class: crate::ImageClass::External,
945 ..
946 }
947 );
948 if is_external_texture {
949 return self.write_global_external_texture(module, handle, global);
950 }
951
952 if let Some(ref binding) = global.binding {
953 if let Err(err) = self.options.resolve_resource_binding(binding) {
954 log::debug!(
955 "Skipping global {:?} (name {:?}) for being inaccessible: {}",
956 handle,
957 global.name,
958 err,
959 );
960 return Ok(());
961 }
962 }
963
964 let is_sampler = matches!(*handle_ty, TypeInner::Sampler { .. });
966
967 if is_sampler {
968 return self.write_global_sampler(module, handle, global);
969 }
970
971 let register_ty = match global.space {
973 crate::AddressSpace::Function => unreachable!("Function address space"),
974 crate::AddressSpace::Private => {
975 write!(self.out, "static ")?;
976 self.write_type(module, global.ty)?;
977 ""
978 }
979 crate::AddressSpace::WorkGroup => {
980 write!(self.out, "groupshared ")?;
981 self.write_type(module, global.ty)?;
982 ""
983 }
984 crate::AddressSpace::TaskPayload => unimplemented!(),
985 crate::AddressSpace::Uniform => {
986 write!(self.out, "cbuffer")?;
989 "b"
990 }
991 crate::AddressSpace::Storage { access } => {
992 let (prefix, register) = if access.contains(crate::StorageAccess::STORE) {
993 ("RW", "u")
994 } else {
995 ("", "t")
996 };
997 write!(self.out, "{prefix}ByteAddressBuffer")?;
998 register
999 }
1000 crate::AddressSpace::Handle => {
1001 let register = match *handle_ty {
1002 TypeInner::Image {
1004 class: crate::ImageClass::Storage { .. },
1005 ..
1006 } => "u",
1007 _ => "t",
1008 };
1009 self.write_type(module, global.ty)?;
1010 register
1011 }
1012 crate::AddressSpace::Immediate => {
1013 write!(self.out, "ConstantBuffer<")?;
1015 "b"
1016 }
1017 };
1018
1019 if global.space == crate::AddressSpace::Immediate {
1022 self.write_global_type(module, global.ty)?;
1023
1024 if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1026 self.write_array_size(module, base, size)?;
1027 }
1028
1029 write!(self.out, ">")?;
1031 }
1032
1033 let name = &self.names[&NameKey::GlobalVariable(handle)];
1034 write!(self.out, " {name}")?;
1035
1036 if global.space == crate::AddressSpace::Immediate {
1039 match module.types[global.ty].inner {
1040 TypeInner::Struct { .. } => {}
1041 _ => {
1042 return Err(Error::Unimplemented(format!(
1043 "push-constant '{name}' has non-struct type; tracked by: https://github.com/gfx-rs/wgpu/issues/5683"
1044 )));
1045 }
1046 }
1047
1048 let target = self
1049 .options
1050 .immediates_target
1051 .as_ref()
1052 .expect("No bind target was defined for the immediates block");
1053 write!(self.out, ": register(b{}", target.register)?;
1054 if target.space != 0 {
1055 write!(self.out, ", space{}", target.space)?;
1056 }
1057 write!(self.out, ")")?;
1058 }
1059
1060 if let Some(ref binding) = global.binding {
1061 let bt = self.options.resolve_resource_binding(binding).unwrap();
1063
1064 if let TypeInner::BindingArray { base, size, .. } = module.types[global.ty].inner {
1066 if let Some(overridden_size) = bt.binding_array_size {
1067 write!(self.out, "[{overridden_size}]")?;
1068 } else {
1069 self.write_array_size(module, base, size)?;
1070 }
1071 }
1072
1073 write!(self.out, " : register({}{}", register_ty, bt.register)?;
1074 if bt.space != 0 {
1075 write!(self.out, ", space{}", bt.space)?;
1076 }
1077 write!(self.out, ")")?;
1078 } else {
1079 if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1081 self.write_array_size(module, base, size)?;
1082 }
1083 if global.space == crate::AddressSpace::Private {
1084 write!(self.out, " = ")?;
1085 if let Some(init) = global.init {
1086 self.write_const_expression(module, init, &module.global_expressions)?;
1087 } else {
1088 self.write_default_init(module, global.ty)?;
1089 }
1090 }
1091 }
1092
1093 if global.space == crate::AddressSpace::Uniform {
1094 write!(self.out, " {{ ")?;
1095
1096 self.write_global_type(module, global.ty)?;
1097
1098 write!(
1099 self.out,
1100 " {}",
1101 &self.names[&NameKey::GlobalVariable(handle)]
1102 )?;
1103
1104 if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1106 self.write_array_size(module, base, size)?;
1107 }
1108
1109 writeln!(self.out, "; }}")?;
1110 } else {
1111 writeln!(self.out, ";")?;
1112 }
1113
1114 Ok(())
1115 }
1116
1117 fn write_global_sampler(
1118 &mut self,
1119 module: &Module,
1120 handle: Handle<crate::GlobalVariable>,
1121 global: &crate::GlobalVariable,
1122 ) -> BackendResult {
1123 let binding = *global.binding.as_ref().unwrap();
1124
1125 let key = super::SamplerIndexBufferKey {
1126 group: binding.group,
1127 };
1128 self.write_wrapped_sampler_buffer(key)?;
1129
1130 let bt = self.options.resolve_resource_binding(&binding).unwrap();
1132
1133 match module.types[global.ty].inner {
1134 TypeInner::Sampler { comparison } => {
1135 write!(self.out, "static const ")?;
1142 self.write_type(module, global.ty)?;
1143
1144 let heap_var = if comparison {
1145 COMPARISON_SAMPLER_HEAP_VAR
1146 } else {
1147 SAMPLER_HEAP_VAR
1148 };
1149
1150 let index_buffer_name = &self.wrapped.sampler_index_buffers[&key];
1151 let name = &self.names[&NameKey::GlobalVariable(handle)];
1152 writeln!(
1153 self.out,
1154 " {name} = {heap_var}[{index_buffer_name}[{register}]];",
1155 register = bt.register
1156 )?;
1157 }
1158 TypeInner::BindingArray { .. } => {
1159 let name = &self.names[&NameKey::GlobalVariable(handle)];
1165 writeln!(
1166 self.out,
1167 "static const uint {name} = {register};",
1168 register = bt.register
1169 )?;
1170 }
1171 _ => unreachable!(),
1172 };
1173
1174 Ok(())
1175 }
1176
1177 fn write_global_external_texture(
1181 &mut self,
1182 module: &Module,
1183 handle: Handle<crate::GlobalVariable>,
1184 global: &crate::GlobalVariable,
1185 ) -> BackendResult {
1186 let res_binding = global
1187 .binding
1188 .as_ref()
1189 .expect("External texture global variables must have a resource binding");
1190 let ext_tex_bindings = match self
1191 .options
1192 .resolve_external_texture_resource_binding(res_binding)
1193 {
1194 Ok(bindings) => bindings,
1195 Err(err) => {
1196 log::debug!(
1197 "Skipping global {:?} (name {:?}) for being inaccessible: {}",
1198 handle,
1199 global.name,
1200 err,
1201 );
1202 return Ok(());
1203 }
1204 };
1205
1206 let mut write_plane = |bt: &super::BindTarget, name| -> BackendResult {
1207 write!(
1208 self.out,
1209 "Texture2D<float4> {}: register(t{}",
1210 name, bt.register
1211 )?;
1212 if bt.space != 0 {
1213 write!(self.out, ", space{}", bt.space)?;
1214 }
1215 writeln!(self.out, ");")?;
1216 Ok(())
1217 };
1218 for (i, bt) in ext_tex_bindings.planes.iter().enumerate() {
1219 let plane_name = &self.names
1220 [&NameKey::ExternalTextureGlobalVariable(handle, ExternalTextureNameKey::Plane(i))];
1221 write_plane(bt, plane_name)?;
1222 }
1223
1224 let params_name = &self.names
1225 [&NameKey::ExternalTextureGlobalVariable(handle, ExternalTextureNameKey::Params)];
1226 let params_ty_name =
1227 &self.names[&NameKey::Type(module.special_types.external_texture_params.unwrap())];
1228 write!(
1229 self.out,
1230 "cbuffer {}: register(b{}",
1231 params_name, ext_tex_bindings.params.register
1232 )?;
1233 if ext_tex_bindings.params.space != 0 {
1234 write!(self.out, ", space{}", ext_tex_bindings.params.space)?;
1235 }
1236 writeln!(self.out, ") {{ {params_ty_name} {params_name}; }};")?;
1237
1238 Ok(())
1239 }
1240
1241 fn write_global_constant(
1246 &mut self,
1247 module: &Module,
1248 handle: Handle<crate::Constant>,
1249 ) -> BackendResult {
1250 write!(self.out, "static const ")?;
1251 let constant = &module.constants[handle];
1252 self.write_type(module, constant.ty)?;
1253 let name = &self.names[&NameKey::Constant(handle)];
1254 write!(self.out, " {name}")?;
1255 if let TypeInner::Array { base, size, .. } = module.types[constant.ty].inner {
1257 self.write_array_size(module, base, size)?;
1258 }
1259 write!(self.out, " = ")?;
1260 self.write_const_expression(module, constant.init, &module.global_expressions)?;
1261 writeln!(self.out, ";")?;
1262 Ok(())
1263 }
1264
1265 pub(super) fn write_array_size(
1266 &mut self,
1267 module: &Module,
1268 base: Handle<crate::Type>,
1269 size: crate::ArraySize,
1270 ) -> BackendResult {
1271 write!(self.out, "[")?;
1272
1273 match size.resolve(module.to_ctx())? {
1274 proc::IndexableLength::Known(size) => {
1275 write!(self.out, "{size}")?;
1276 }
1277 proc::IndexableLength::Dynamic => unreachable!(),
1278 }
1279
1280 write!(self.out, "]")?;
1281
1282 if let TypeInner::Array {
1283 base: next_base,
1284 size: next_size,
1285 ..
1286 } = module.types[base].inner
1287 {
1288 self.write_array_size(module, next_base, next_size)?;
1289 }
1290
1291 Ok(())
1292 }
1293
1294 fn write_struct(
1299 &mut self,
1300 module: &Module,
1301 handle: Handle<crate::Type>,
1302 members: &[crate::StructMember],
1303 span: u32,
1304 shader_stage: Option<(ShaderStage, Io)>,
1305 ) -> BackendResult {
1306 let struct_name = &self.names[&NameKey::Type(handle)];
1308 writeln!(self.out, "struct {struct_name} {{")?;
1309
1310 let mut last_offset = 0;
1311 for (index, member) in members.iter().enumerate() {
1312 if member.binding.is_none() && member.offset > last_offset {
1313 let padding = (member.offset - last_offset) / 4;
1317 for i in 0..padding {
1318 writeln!(self.out, "{}int _pad{}_{};", back::INDENT, index, i)?;
1319 }
1320 }
1321 let ty_inner = &module.types[member.ty].inner;
1322 last_offset = member.offset + ty_inner.size_hlsl(module.to_ctx())?;
1323
1324 write!(self.out, "{}", back::INDENT)?;
1326
1327 match module.types[member.ty].inner {
1328 TypeInner::Array { base, size, .. } => {
1329 self.write_global_type(module, member.ty)?;
1332
1333 write!(
1335 self.out,
1336 " {}",
1337 &self.names[&NameKey::StructMember(handle, index as u32)]
1338 )?;
1339 self.write_array_size(module, base, size)?;
1341 }
1342 TypeInner::Matrix {
1345 rows,
1346 columns,
1347 scalar,
1348 } if member.binding.is_none() && rows == crate::VectorSize::Bi => {
1349 let vec_ty = TypeInner::Vector { size: rows, scalar };
1350 let field_name_key = NameKey::StructMember(handle, index as u32);
1351
1352 for i in 0..columns as u8 {
1353 if i != 0 {
1354 write!(self.out, "; ")?;
1355 }
1356 self.write_value_type(module, &vec_ty)?;
1357 write!(self.out, " {}_{}", &self.names[&field_name_key], i)?;
1358 }
1359 }
1360 _ => {
1361 if let Some(ref binding) = member.binding {
1363 self.write_modifier(binding)?;
1364 }
1365
1366 if let TypeInner::Matrix { .. } = module.types[member.ty].inner {
1370 write!(self.out, "row_major ")?;
1371 }
1372
1373 self.write_type(module, member.ty)?;
1375 write!(
1376 self.out,
1377 " {}",
1378 &self.names[&NameKey::StructMember(handle, index as u32)]
1379 )?;
1380 }
1381 }
1382
1383 self.write_semantic(&member.binding, shader_stage)?;
1384 writeln!(self.out, ";")?;
1385 }
1386
1387 if members.last().unwrap().binding.is_none() && span > last_offset {
1389 let padding = (span - last_offset) / 4;
1390 for i in 0..padding {
1391 writeln!(self.out, "{}int _end_pad_{};", back::INDENT, i)?;
1392 }
1393 }
1394
1395 writeln!(self.out, "}};")?;
1396 Ok(())
1397 }
1398
1399 pub(super) fn write_global_type(
1404 &mut self,
1405 module: &Module,
1406 ty: Handle<crate::Type>,
1407 ) -> BackendResult {
1408 let matrix_data = get_inner_matrix_data(module, ty);
1409
1410 if let Some(MatrixType {
1413 columns,
1414 rows: crate::VectorSize::Bi,
1415 width: 4,
1416 }) = matrix_data
1417 {
1418 write!(self.out, "__mat{}x2", columns as u8)?;
1419 } else {
1420 if matrix_data.is_some() {
1424 write!(self.out, "row_major ")?;
1425 }
1426
1427 self.write_type(module, ty)?;
1428 }
1429
1430 Ok(())
1431 }
1432
1433 pub(super) fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
1438 let inner = &module.types[ty].inner;
1439 match *inner {
1440 TypeInner::Struct { .. } => write!(self.out, "{}", self.names[&NameKey::Type(ty)])?,
1441 TypeInner::Array { base, .. } | TypeInner::BindingArray { base, .. } => {
1443 self.write_type(module, base)?
1444 }
1445 ref other => self.write_value_type(module, other)?,
1446 }
1447
1448 Ok(())
1449 }
1450
1451 pub(super) fn write_value_type(&mut self, module: &Module, inner: &TypeInner) -> BackendResult {
1456 match *inner {
1457 TypeInner::Scalar(scalar) | TypeInner::Atomic(scalar) => {
1458 write!(self.out, "{}", scalar.to_hlsl_str()?)?;
1459 }
1460 TypeInner::Vector { size, scalar } => {
1461 write!(
1462 self.out,
1463 "{}{}",
1464 scalar.to_hlsl_str()?,
1465 common::vector_size_str(size)
1466 )?;
1467 }
1468 TypeInner::Matrix {
1469 columns,
1470 rows,
1471 scalar,
1472 } => {
1473 write!(
1478 self.out,
1479 "{}{}x{}",
1480 scalar.to_hlsl_str()?,
1481 common::vector_size_str(columns),
1482 common::vector_size_str(rows),
1483 )?;
1484 }
1485 TypeInner::Image {
1486 dim,
1487 arrayed,
1488 class,
1489 } => {
1490 self.write_image_type(dim, arrayed, class)?;
1491 }
1492 TypeInner::Sampler { comparison } => {
1493 let sampler = if comparison {
1494 "SamplerComparisonState"
1495 } else {
1496 "SamplerState"
1497 };
1498 write!(self.out, "{sampler}")?;
1499 }
1500 TypeInner::Array { base, size, .. } | TypeInner::BindingArray { base, size } => {
1504 self.write_array_size(module, base, size)?;
1505 }
1506 TypeInner::AccelerationStructure { .. } => {
1507 write!(self.out, "RaytracingAccelerationStructure")?;
1508 }
1509 TypeInner::RayQuery { .. } => {
1510 write!(self.out, "RayQuery<RAY_FLAG_NONE>")?;
1512 }
1513 _ => return Err(Error::Unimplemented(format!("write_value_type {inner:?}"))),
1514 }
1515
1516 Ok(())
1517 }
1518
1519 fn write_function(
1523 &mut self,
1524 module: &Module,
1525 name: &str,
1526 func: &crate::Function,
1527 func_ctx: &back::FunctionCtx<'_>,
1528 info: &valid::FunctionInfo,
1529 ) -> BackendResult {
1530 self.update_expressions_to_bake(module, func, info);
1533
1534 if let Some(ref result) = func.result {
1535 let array_return_type = match module.types[result.ty].inner {
1537 TypeInner::Array { base, size, .. } => {
1538 let array_return_type = self.namer.call(&format!("ret_{name}"));
1539 write!(self.out, "typedef ")?;
1540 self.write_type(module, result.ty)?;
1541 write!(self.out, " {array_return_type}")?;
1542 self.write_array_size(module, base, size)?;
1543 writeln!(self.out, ";")?;
1544 Some(array_return_type)
1545 }
1546 _ => None,
1547 };
1548
1549 if let Some(
1551 ref binding @ crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }),
1552 ) = result.binding
1553 {
1554 self.write_modifier(binding)?;
1555 }
1556
1557 match func_ctx.ty {
1559 back::FunctionType::Function(_) => {
1560 if let Some(array_return_type) = array_return_type {
1561 write!(self.out, "{array_return_type}")?;
1562 } else {
1563 self.write_type(module, result.ty)?;
1564 }
1565 }
1566 back::FunctionType::EntryPoint(index) => {
1567 if let Some(ref ep_output) =
1568 self.entry_point_io.get(&(index as usize)).unwrap().output
1569 {
1570 write!(self.out, "{}", ep_output.ty_name)?;
1571 } else {
1572 self.write_type(module, result.ty)?;
1573 }
1574 }
1575 }
1576 } else {
1577 write!(self.out, "void")?;
1578 }
1579
1580 write!(self.out, " {name}(")?;
1582
1583 let need_workgroup_variables_initialization =
1584 self.need_workgroup_variables_initialization(func_ctx, module);
1585
1586 match func_ctx.ty {
1588 back::FunctionType::Function(handle) => {
1589 for (index, arg) in func.arguments.iter().enumerate() {
1590 if index != 0 {
1591 write!(self.out, ", ")?;
1592 }
1593
1594 self.write_function_argument(module, handle, arg, index)?;
1595 }
1596 }
1597 back::FunctionType::EntryPoint(ep_index) => {
1598 if let Some(ref ep_input) =
1599 self.entry_point_io.get(&(ep_index as usize)).unwrap().input
1600 {
1601 write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name)?;
1602 } else {
1603 let stage = module.entry_points[ep_index as usize].stage;
1604 for (index, arg) in func.arguments.iter().enumerate() {
1605 if index != 0 {
1606 write!(self.out, ", ")?;
1607 }
1608 self.write_type(module, arg.ty)?;
1609
1610 let argument_name =
1611 &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)];
1612
1613 write!(self.out, " {argument_name}")?;
1614 if let TypeInner::Array { base, size, .. } = module.types[arg.ty].inner {
1615 self.write_array_size(module, base, size)?;
1616 }
1617
1618 self.write_semantic(&arg.binding, Some((stage, Io::Input)))?;
1619 }
1620 }
1621 if need_workgroup_variables_initialization {
1622 if self
1623 .entry_point_io
1624 .get(&(ep_index as usize))
1625 .unwrap()
1626 .input
1627 .is_some()
1628 || !func.arguments.is_empty()
1629 {
1630 write!(self.out, ", ")?;
1631 }
1632 write!(self.out, "uint3 __local_invocation_id : SV_GroupThreadID")?;
1633 }
1634 }
1635 }
1636 write!(self.out, ")")?;
1638
1639 if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1641 let stage = module.entry_points[index as usize].stage;
1642 if let Some(crate::FunctionResult { ref binding, .. }) = func.result {
1643 self.write_semantic(binding, Some((stage, Io::Output)))?;
1644 }
1645 }
1646
1647 writeln!(self.out)?;
1649 writeln!(self.out, "{{")?;
1650
1651 if need_workgroup_variables_initialization {
1652 self.write_workgroup_variables_initialization(func_ctx, module)?;
1653 }
1654
1655 if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1656 self.write_ep_arguments_initialization(module, func, index)?;
1657 }
1658
1659 for (handle, local) in func.local_variables.iter() {
1661 write!(self.out, "{}", back::INDENT)?;
1663
1664 self.write_type(module, local.ty)?;
1667 write!(self.out, " {}", self.names[&func_ctx.name_key(handle)])?;
1668 if let TypeInner::Array { base, size, .. } = module.types[local.ty].inner {
1670 self.write_array_size(module, base, size)?;
1671 }
1672
1673 let is_ray_query = match module.types[local.ty].inner {
1674 TypeInner::RayQuery { .. } => true,
1676 _ => {
1677 write!(self.out, " = ")?;
1678 if let Some(init) = local.init {
1680 self.write_expr(module, init, func_ctx)?;
1681 } else {
1682 self.write_default_init(module, local.ty)?;
1684 }
1685 false
1686 }
1687 };
1688 writeln!(self.out, ";")?;
1690 if is_ray_query {
1692 write!(self.out, "{}", back::INDENT)?;
1693 self.write_value_type(module, &TypeInner::Scalar(Scalar::U32))?;
1694 writeln!(
1695 self.out,
1696 " {RAY_QUERY_TRACKER_VARIABLE_PREFIX}{} = 0;",
1697 self.names[&func_ctx.name_key(handle)]
1698 )?;
1699 }
1700 }
1701
1702 if !func.local_variables.is_empty() {
1703 writeln!(self.out)?;
1704 }
1705
1706 for sta in func.body.iter() {
1708 self.write_stmt(module, sta, func_ctx, back::Level(1))?;
1710 }
1711
1712 writeln!(self.out, "}}")?;
1713
1714 self.named_expressions.clear();
1715
1716 Ok(())
1717 }
1718
1719 fn write_function_argument(
1720 &mut self,
1721 module: &Module,
1722 handle: Handle<crate::Function>,
1723 arg: &crate::FunctionArgument,
1724 index: usize,
1725 ) -> BackendResult {
1726 if let TypeInner::Image {
1729 class: crate::ImageClass::External,
1730 ..
1731 } = module.types[arg.ty].inner
1732 {
1733 return self.write_function_external_texture_argument(module, handle, index);
1734 }
1735
1736 let arg_ty = match module.types[arg.ty].inner {
1738 TypeInner::Pointer { base, .. } => {
1740 write!(self.out, "inout ")?;
1742 base
1743 }
1744 _ => arg.ty,
1745 };
1746 self.write_type(module, arg_ty)?;
1747
1748 let argument_name = &self.names[&NameKey::FunctionArgument(handle, index as u32)];
1749
1750 write!(self.out, " {argument_name}")?;
1752 if let TypeInner::Array { base, size, .. } = module.types[arg_ty].inner {
1753 self.write_array_size(module, base, size)?;
1754 }
1755
1756 Ok(())
1757 }
1758
1759 fn write_function_external_texture_argument(
1760 &mut self,
1761 module: &Module,
1762 handle: Handle<crate::Function>,
1763 index: usize,
1764 ) -> BackendResult {
1765 let plane_names = [0, 1, 2].map(|i| {
1766 &self.names[&NameKey::ExternalTextureFunctionArgument(
1767 handle,
1768 index as u32,
1769 ExternalTextureNameKey::Plane(i),
1770 )]
1771 });
1772 let params_name = &self.names[&NameKey::ExternalTextureFunctionArgument(
1773 handle,
1774 index as u32,
1775 ExternalTextureNameKey::Params,
1776 )];
1777 let params_ty_name =
1778 &self.names[&NameKey::Type(module.special_types.external_texture_params.unwrap())];
1779 write!(
1780 self.out,
1781 "Texture2D<float4> {}, Texture2D<float4> {}, Texture2D<float4> {}, {params_ty_name} {params_name}",
1782 plane_names[0], plane_names[1], plane_names[2],
1783 )?;
1784 Ok(())
1785 }
1786
1787 fn need_workgroup_variables_initialization(
1788 &mut self,
1789 func_ctx: &back::FunctionCtx,
1790 module: &Module,
1791 ) -> bool {
1792 self.options.zero_initialize_workgroup_memory
1793 && func_ctx.ty.is_compute_like_entry_point(module)
1794 && module.global_variables.iter().any(|(handle, var)| {
1795 !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
1796 })
1797 }
1798
1799 fn write_workgroup_variables_initialization(
1800 &mut self,
1801 func_ctx: &back::FunctionCtx,
1802 module: &Module,
1803 ) -> BackendResult {
1804 let level = back::Level(1);
1805
1806 writeln!(
1807 self.out,
1808 "{level}if (all(__local_invocation_id == uint3(0u, 0u, 0u))) {{"
1809 )?;
1810
1811 let vars = module.global_variables.iter().filter(|&(handle, var)| {
1812 !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
1813 });
1814
1815 for (handle, var) in vars {
1816 let name = &self.names[&NameKey::GlobalVariable(handle)];
1817 write!(self.out, "{}{} = ", level.next(), name)?;
1818 self.write_default_init(module, var.ty)?;
1819 writeln!(self.out, ";")?;
1820 }
1821
1822 writeln!(self.out, "{level}}}")?;
1823 self.write_control_barrier(crate::Barrier::WORK_GROUP, level)
1824 }
1825
1826 fn write_switch(
1828 &mut self,
1829 module: &Module,
1830 func_ctx: &back::FunctionCtx<'_>,
1831 level: back::Level,
1832 selector: Handle<crate::Expression>,
1833 cases: &[crate::SwitchCase],
1834 ) -> BackendResult {
1835 let indent_level_1 = level.next();
1837 let indent_level_2 = indent_level_1.next();
1838
1839 if let Some(variable) = self.continue_ctx.enter_switch(&mut self.namer) {
1841 writeln!(self.out, "{level}bool {variable} = false;",)?;
1842 };
1843
1844 let one_body = cases
1849 .iter()
1850 .rev()
1851 .skip(1)
1852 .all(|case| case.fall_through && case.body.is_empty());
1853 if one_body {
1854 writeln!(self.out, "{level}do {{")?;
1856 if let Some(case) = cases.last() {
1860 for sta in case.body.iter() {
1861 self.write_stmt(module, sta, func_ctx, indent_level_1)?;
1862 }
1863 }
1864 writeln!(self.out, "{level}}} while(false);")?;
1866 } else {
1867 write!(self.out, "{level}")?;
1869 write!(self.out, "switch(")?;
1870 self.write_expr(module, selector, func_ctx)?;
1871 writeln!(self.out, ") {{")?;
1872
1873 for (i, case) in cases.iter().enumerate() {
1874 match case.value {
1875 crate::SwitchValue::I32(value) => {
1876 write!(self.out, "{indent_level_1}case {value}:")?
1877 }
1878 crate::SwitchValue::U32(value) => {
1879 write!(self.out, "{indent_level_1}case {value}u:")?
1880 }
1881 crate::SwitchValue::Default => write!(self.out, "{indent_level_1}default:")?,
1882 }
1883
1884 let write_block_braces = !(case.fall_through && case.body.is_empty());
1891 if write_block_braces {
1892 writeln!(self.out, " {{")?;
1893 } else {
1894 writeln!(self.out)?;
1895 }
1896
1897 if case.fall_through && !case.body.is_empty() {
1915 let curr_len = i + 1;
1916 let end_case_idx = curr_len
1917 + cases
1918 .iter()
1919 .skip(curr_len)
1920 .position(|case| !case.fall_through)
1921 .unwrap();
1922 let indent_level_3 = indent_level_2.next();
1923 for case in &cases[i..=end_case_idx] {
1924 writeln!(self.out, "{indent_level_2}{{")?;
1925 let prev_len = self.named_expressions.len();
1926 for sta in case.body.iter() {
1927 self.write_stmt(module, sta, func_ctx, indent_level_3)?;
1928 }
1929 self.named_expressions.truncate(prev_len);
1931 writeln!(self.out, "{indent_level_2}}}")?;
1932 }
1933
1934 let last_case = &cases[end_case_idx];
1935 if last_case.body.last().is_none_or(|s| !s.is_terminator()) {
1936 writeln!(self.out, "{indent_level_2}break;")?;
1937 }
1938 } else {
1939 for sta in case.body.iter() {
1940 self.write_stmt(module, sta, func_ctx, indent_level_2)?;
1941 }
1942 if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator()) {
1943 writeln!(self.out, "{indent_level_2}break;")?;
1944 }
1945 }
1946
1947 if write_block_braces {
1948 writeln!(self.out, "{indent_level_1}}}")?;
1949 }
1950 }
1951
1952 writeln!(self.out, "{level}}}")?;
1953 }
1954
1955 use back::continue_forward::ExitControlFlow;
1957 let op = match self.continue_ctx.exit_switch() {
1958 ExitControlFlow::None => None,
1959 ExitControlFlow::Continue { variable } => Some(("continue", variable)),
1960 ExitControlFlow::Break { variable } => Some(("break", variable)),
1961 };
1962 if let Some((control_flow, variable)) = op {
1963 writeln!(self.out, "{level}if ({variable}) {{")?;
1964 writeln!(self.out, "{indent_level_1}{control_flow};")?;
1965 writeln!(self.out, "{level}}}")?;
1966 }
1967
1968 Ok(())
1969 }
1970
1971 fn write_index(
1972 &mut self,
1973 module: &Module,
1974 index: Index,
1975 func_ctx: &back::FunctionCtx<'_>,
1976 ) -> BackendResult {
1977 match index {
1978 Index::Static(index) => {
1979 write!(self.out, "{index}")?;
1980 }
1981 Index::Expression(index) => {
1982 self.write_expr(module, index, func_ctx)?;
1983 }
1984 }
1985 Ok(())
1986 }
1987
1988 fn write_stmt(
1993 &mut self,
1994 module: &Module,
1995 stmt: &crate::Statement,
1996 func_ctx: &back::FunctionCtx<'_>,
1997 level: back::Level,
1998 ) -> BackendResult {
1999 use crate::Statement;
2000
2001 match *stmt {
2002 Statement::Emit(ref range) => {
2003 for handle in range.clone() {
2004 let ptr_class = func_ctx.resolve_type(handle, &module.types).pointer_space();
2005 let expr_name = if ptr_class.is_some() {
2006 None
2010 } else if let Some(name) = func_ctx.named_expressions.get(&handle) {
2011 Some(self.namer.call(name))
2016 } else if self.need_bake_expressions.contains(&handle) {
2017 Some(Baked(handle).to_string())
2018 } else {
2019 None
2020 };
2021
2022 if let Some(name) = expr_name {
2023 write!(self.out, "{level}")?;
2024 self.write_named_expr(module, handle, name, handle, func_ctx)?;
2025 }
2026 }
2027 }
2028 Statement::Block(ref block) => {
2030 write!(self.out, "{level}")?;
2031 writeln!(self.out, "{{")?;
2032 for sta in block.iter() {
2033 self.write_stmt(module, sta, func_ctx, level.next())?
2035 }
2036 writeln!(self.out, "{level}}}")?
2037 }
2038 Statement::If {
2040 condition,
2041 ref accept,
2042 ref reject,
2043 } => {
2044 write!(self.out, "{level}")?;
2045 write!(self.out, "if (")?;
2046 self.write_expr(module, condition, func_ctx)?;
2047 writeln!(self.out, ") {{")?;
2048
2049 let l2 = level.next();
2050 for sta in accept {
2051 self.write_stmt(module, sta, func_ctx, l2)?;
2053 }
2054
2055 if !reject.is_empty() {
2058 writeln!(self.out, "{level}}} else {{")?;
2059
2060 for sta in reject {
2061 self.write_stmt(module, sta, func_ctx, l2)?;
2063 }
2064 }
2065
2066 writeln!(self.out, "{level}}}")?
2067 }
2068 Statement::Kill => writeln!(self.out, "{level}discard;")?,
2070 Statement::Return { value: None } => {
2071 writeln!(self.out, "{level}return;")?;
2072 }
2073 Statement::Return { value: Some(expr) } => {
2074 let base_ty_res = &func_ctx.info[expr].ty;
2075 let mut resolved = base_ty_res.inner_with(&module.types);
2076 if let TypeInner::Pointer { base, space: _ } = *resolved {
2077 resolved = &module.types[base].inner;
2078 }
2079
2080 if let TypeInner::Struct { .. } = *resolved {
2081 let ty = base_ty_res.handle().unwrap();
2083 let struct_name = &self.names[&NameKey::Type(ty)];
2084 let variable_name = self.namer.call(&struct_name.to_lowercase());
2085 write!(self.out, "{level}const {struct_name} {variable_name} = ",)?;
2086 self.write_expr(module, expr, func_ctx)?;
2087 writeln!(self.out, ";")?;
2088
2089 let ep_output = match func_ctx.ty {
2091 back::FunctionType::Function(_) => None,
2092 back::FunctionType::EntryPoint(index) => self
2093 .entry_point_io
2094 .get(&(index as usize))
2095 .unwrap()
2096 .output
2097 .as_ref(),
2098 };
2099 let final_name = match ep_output {
2100 Some(ep_output) => {
2101 let final_name = self.namer.call(&variable_name);
2102 write!(
2103 self.out,
2104 "{}const {} {} = {{ ",
2105 level, ep_output.ty_name, final_name,
2106 )?;
2107 for (index, m) in ep_output.members.iter().enumerate() {
2108 if index != 0 {
2109 write!(self.out, ", ")?;
2110 }
2111 let member_name = &self.names[&NameKey::StructMember(ty, m.index)];
2112 write!(self.out, "{variable_name}.{member_name}")?;
2113 }
2114 writeln!(self.out, " }};")?;
2115 final_name
2116 }
2117 None => variable_name,
2118 };
2119 writeln!(self.out, "{level}return {final_name};")?;
2120 } else {
2121 write!(self.out, "{level}return ")?;
2122 self.write_expr(module, expr, func_ctx)?;
2123 writeln!(self.out, ";")?
2124 }
2125 }
2126 Statement::Store { pointer, value } => {
2127 let ty_inner = func_ctx.resolve_type(pointer, &module.types);
2128 if let Some(crate::AddressSpace::Storage { .. }) = ty_inner.pointer_space() {
2129 let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2130 self.write_storage_store(
2131 module,
2132 var_handle,
2133 StoreValue::Expression(value),
2134 func_ctx,
2135 level,
2136 None,
2137 )?;
2138 } else {
2139 enum MatrixAccess {
2145 Direct {
2146 base: Handle<crate::Expression>,
2147 index: u32,
2148 },
2149 Struct {
2150 columns: crate::VectorSize,
2151 base: Handle<crate::Expression>,
2152 },
2153 }
2154
2155 let get_members = |expr: Handle<crate::Expression>| {
2156 let resolved = func_ctx.resolve_type(expr, &module.types);
2157 match *resolved {
2158 TypeInner::Pointer { base, .. } => match module.types[base].inner {
2159 TypeInner::Struct { ref members, .. } => Some(members),
2160 _ => None,
2161 },
2162 _ => None,
2163 }
2164 };
2165
2166 write!(self.out, "{level}")?;
2167
2168 let matrix_access_on_lhs =
2169 find_matrix_in_access_chain(module, pointer, func_ctx).and_then(
2170 |(matrix_expr, vector, scalar)| match (
2171 func_ctx.resolve_type(matrix_expr, &module.types),
2172 &func_ctx.expressions[matrix_expr],
2173 ) {
2174 (
2175 &TypeInner::Pointer { base: ty, .. },
2176 &crate::Expression::AccessIndex { base, index },
2177 ) if matches!(
2178 module.types[ty].inner,
2179 TypeInner::Matrix {
2180 rows: crate::VectorSize::Bi,
2181 ..
2182 }
2183 ) && get_members(base)
2184 .map(|members| members[index as usize].binding.is_none())
2185 == Some(true) =>
2186 {
2187 Some((MatrixAccess::Direct { base, index }, vector, scalar))
2188 }
2189 _ => {
2190 if let Some(MatrixType {
2191 columns,
2192 rows: crate::VectorSize::Bi,
2193 width: 4,
2194 }) = get_inner_matrix_of_struct_array_member(
2195 module,
2196 matrix_expr,
2197 func_ctx,
2198 true,
2199 ) {
2200 Some((
2201 MatrixAccess::Struct {
2202 columns,
2203 base: matrix_expr,
2204 },
2205 vector,
2206 scalar,
2207 ))
2208 } else {
2209 None
2210 }
2211 }
2212 },
2213 );
2214
2215 match matrix_access_on_lhs {
2216 Some((MatrixAccess::Direct { index, base }, vector, scalar)) => {
2217 let base_ty_res = &func_ctx.info[base].ty;
2218 let resolved = base_ty_res.inner_with(&module.types);
2219 let ty = match *resolved {
2220 TypeInner::Pointer { base, .. } => base,
2221 _ => base_ty_res.handle().unwrap(),
2222 };
2223
2224 if let Some(Index::Static(vec_index)) = vector {
2225 self.write_expr(module, base, func_ctx)?;
2226 write!(
2227 self.out,
2228 ".{}_{}",
2229 &self.names[&NameKey::StructMember(ty, index)],
2230 vec_index
2231 )?;
2232
2233 if let Some(scalar_index) = scalar {
2234 write!(self.out, "[")?;
2235 self.write_index(module, scalar_index, func_ctx)?;
2236 write!(self.out, "]")?;
2237 }
2238
2239 write!(self.out, " = ")?;
2240 self.write_expr(module, value, func_ctx)?;
2241 writeln!(self.out, ";")?;
2242 } else {
2243 let access = WrappedStructMatrixAccess { ty, index };
2244 match (&vector, &scalar) {
2245 (&Some(_), &Some(_)) => {
2246 self.write_wrapped_struct_matrix_set_scalar_function_name(
2247 access,
2248 )?;
2249 }
2250 (&Some(_), &None) => {
2251 self.write_wrapped_struct_matrix_set_vec_function_name(
2252 access,
2253 )?;
2254 }
2255 (&None, _) => {
2256 self.write_wrapped_struct_matrix_set_function_name(access)?;
2257 }
2258 }
2259
2260 write!(self.out, "(")?;
2261 self.write_expr(module, base, func_ctx)?;
2262 write!(self.out, ", ")?;
2263 self.write_expr(module, value, func_ctx)?;
2264
2265 if let Some(Index::Expression(vec_index)) = vector {
2266 write!(self.out, ", ")?;
2267 self.write_expr(module, vec_index, func_ctx)?;
2268
2269 if let Some(scalar_index) = scalar {
2270 write!(self.out, ", ")?;
2271 self.write_index(module, scalar_index, func_ctx)?;
2272 }
2273 }
2274 writeln!(self.out, ");")?;
2275 }
2276 }
2277 Some((
2278 MatrixAccess::Struct { columns, base },
2279 Some(Index::Expression(vec_index)),
2280 scalar,
2281 )) => {
2282 if scalar.is_some() {
2286 write!(self.out, "__set_el_of_mat{}x2", columns as u8)?;
2287 } else {
2288 write!(self.out, "__set_col_of_mat{}x2", columns as u8)?;
2289 }
2290 write!(self.out, "(")?;
2291 self.write_expr(module, base, func_ctx)?;
2292 write!(self.out, ", ")?;
2293 self.write_expr(module, vec_index, func_ctx)?;
2294
2295 if let Some(scalar_index) = scalar {
2296 write!(self.out, ", ")?;
2297 self.write_index(module, scalar_index, func_ctx)?;
2298 }
2299
2300 write!(self.out, ", ")?;
2301 self.write_expr(module, value, func_ctx)?;
2302
2303 writeln!(self.out, ");")?;
2304 }
2305 Some((MatrixAccess::Struct { .. }, Some(Index::Static(_)), _))
2306 | Some((MatrixAccess::Struct { .. }, None, _))
2307 | None => {
2308 self.write_expr(module, pointer, func_ctx)?;
2309 write!(self.out, " = ")?;
2310
2311 if let Some(MatrixType {
2316 columns,
2317 rows: crate::VectorSize::Bi,
2318 width: 4,
2319 }) = get_inner_matrix_of_struct_array_member(
2320 module, pointer, func_ctx, false,
2321 ) {
2322 let mut resolved = func_ctx.resolve_type(pointer, &module.types);
2323 if let TypeInner::Pointer { base, .. } = *resolved {
2324 resolved = &module.types[base].inner;
2325 }
2326
2327 write!(self.out, "(__mat{}x2", columns as u8)?;
2328 if let TypeInner::Array { base, size, .. } = *resolved {
2329 self.write_array_size(module, base, size)?;
2330 }
2331 write!(self.out, ")")?;
2332 }
2333
2334 self.write_expr(module, value, func_ctx)?;
2335 writeln!(self.out, ";")?
2336 }
2337 }
2338 }
2339 }
2340 Statement::Loop {
2341 ref body,
2342 ref continuing,
2343 break_if,
2344 } => {
2345 let force_loop_bound_statements = self.gen_force_bounded_loop_statements(level);
2346 let gate_name = (!continuing.is_empty() || break_if.is_some())
2347 .then(|| self.namer.call("loop_init"));
2348
2349 if let Some((ref decl, _)) = force_loop_bound_statements {
2350 writeln!(self.out, "{decl}")?;
2351 }
2352 if let Some(ref gate_name) = gate_name {
2353 writeln!(self.out, "{level}bool {gate_name} = true;")?;
2354 }
2355
2356 self.continue_ctx.enter_loop();
2357 writeln!(self.out, "{level}while(true) {{")?;
2358 if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
2359 writeln!(self.out, "{break_and_inc}")?;
2360 }
2361 let l2 = level.next();
2362 if let Some(gate_name) = gate_name {
2363 writeln!(self.out, "{l2}if (!{gate_name}) {{")?;
2364 let l3 = l2.next();
2365 for sta in continuing.iter() {
2366 self.write_stmt(module, sta, func_ctx, l3)?;
2367 }
2368 if let Some(condition) = break_if {
2369 write!(self.out, "{l3}if (")?;
2370 self.write_expr(module, condition, func_ctx)?;
2371 writeln!(self.out, ") {{")?;
2372 writeln!(self.out, "{}break;", l3.next())?;
2373 writeln!(self.out, "{l3}}}")?;
2374 }
2375 writeln!(self.out, "{l2}}}")?;
2376 writeln!(self.out, "{l2}{gate_name} = false;")?;
2377 }
2378
2379 for sta in body.iter() {
2380 self.write_stmt(module, sta, func_ctx, l2)?;
2381 }
2382
2383 writeln!(self.out, "{level}}}")?;
2384 self.continue_ctx.exit_loop();
2385 }
2386 Statement::Break => writeln!(self.out, "{level}break;")?,
2387 Statement::Continue => {
2388 if let Some(variable) = self.continue_ctx.continue_encountered() {
2389 writeln!(self.out, "{level}{variable} = true;")?;
2390 writeln!(self.out, "{level}break;")?
2391 } else {
2392 writeln!(self.out, "{level}continue;")?
2393 }
2394 }
2395 Statement::ControlBarrier(barrier) => {
2396 self.write_control_barrier(barrier, level)?;
2397 }
2398 Statement::MemoryBarrier(barrier) => {
2399 self.write_memory_barrier(barrier, level)?;
2400 }
2401 Statement::ImageStore {
2402 image,
2403 coordinate,
2404 array_index,
2405 value,
2406 } => {
2407 write!(self.out, "{level}")?;
2408 self.write_expr(module, image, func_ctx)?;
2409
2410 write!(self.out, "[")?;
2411 if let Some(index) = array_index {
2412 write!(self.out, "int3(")?;
2414 self.write_expr(module, coordinate, func_ctx)?;
2415 write!(self.out, ", ")?;
2416 self.write_expr(module, index, func_ctx)?;
2417 write!(self.out, ")")?;
2418 } else {
2419 self.write_expr(module, coordinate, func_ctx)?;
2420 }
2421 write!(self.out, "]")?;
2422
2423 write!(self.out, " = ")?;
2424 self.write_expr(module, value, func_ctx)?;
2425 writeln!(self.out, ";")?;
2426 }
2427 Statement::Call {
2428 function,
2429 ref arguments,
2430 result,
2431 } => {
2432 write!(self.out, "{level}")?;
2433 if let Some(expr) = result {
2434 write!(self.out, "const ")?;
2435 let name = Baked(expr).to_string();
2436 let expr_ty = &func_ctx.info[expr].ty;
2437 let ty_inner = match *expr_ty {
2438 proc::TypeResolution::Handle(handle) => {
2439 self.write_type(module, handle)?;
2440 &module.types[handle].inner
2441 }
2442 proc::TypeResolution::Value(ref value) => {
2443 self.write_value_type(module, value)?;
2444 value
2445 }
2446 };
2447 write!(self.out, " {name}")?;
2448 if let TypeInner::Array { base, size, .. } = *ty_inner {
2449 self.write_array_size(module, base, size)?;
2450 }
2451 write!(self.out, " = ")?;
2452 self.named_expressions.insert(expr, name);
2453 }
2454 let func_name = &self.names[&NameKey::Function(function)];
2455 write!(self.out, "{func_name}(")?;
2456 for (index, argument) in arguments.iter().enumerate() {
2457 if index != 0 {
2458 write!(self.out, ", ")?;
2459 }
2460 self.write_expr(module, *argument, func_ctx)?;
2461 }
2462 writeln!(self.out, ");")?
2463 }
2464 Statement::Atomic {
2465 pointer,
2466 ref fun,
2467 value,
2468 result,
2469 } => {
2470 write!(self.out, "{level}")?;
2471 let res_var_info = if let Some(res_handle) = result {
2472 let name = Baked(res_handle).to_string();
2473 match func_ctx.info[res_handle].ty {
2474 proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2475 proc::TypeResolution::Value(ref value) => {
2476 self.write_value_type(module, value)?
2477 }
2478 };
2479 write!(self.out, " {name}; ")?;
2480 self.named_expressions.insert(res_handle, name.clone());
2481 Some((res_handle, name))
2482 } else {
2483 None
2484 };
2485 let pointer_space = func_ctx
2486 .resolve_type(pointer, &module.types)
2487 .pointer_space()
2488 .unwrap();
2489 let fun_str = fun.to_hlsl_suffix();
2490 let compare_expr = match *fun {
2491 crate::AtomicFunction::Exchange { compare: Some(cmp) } => Some(cmp),
2492 _ => None,
2493 };
2494 match pointer_space {
2495 crate::AddressSpace::WorkGroup => {
2496 write!(self.out, "Interlocked{fun_str}(")?;
2497 self.write_expr(module, pointer, func_ctx)?;
2498 self.emit_hlsl_atomic_tail(
2499 module,
2500 func_ctx,
2501 fun,
2502 compare_expr,
2503 value,
2504 &res_var_info,
2505 )?;
2506 }
2507 crate::AddressSpace::Storage { .. } => {
2508 let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2509 let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
2510 let width = match func_ctx.resolve_type(value, &module.types) {
2511 &TypeInner::Scalar(Scalar { width: 8, .. }) => "64",
2512 _ => "",
2513 };
2514 write!(self.out, "{var_name}.Interlocked{fun_str}{width}(")?;
2515 let chain = mem::take(&mut self.temp_access_chain);
2516 self.write_storage_address(module, &chain, func_ctx)?;
2517 self.temp_access_chain = chain;
2518 self.emit_hlsl_atomic_tail(
2519 module,
2520 func_ctx,
2521 fun,
2522 compare_expr,
2523 value,
2524 &res_var_info,
2525 )?;
2526 }
2527 ref other => {
2528 return Err(Error::Custom(format!(
2529 "invalid address space {other:?} for atomic statement"
2530 )))
2531 }
2532 }
2533 if let Some(cmp) = compare_expr {
2534 if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
2535 write!(
2536 self.out,
2537 "{level}{res_name}.exchanged = ({res_name}.old_value == "
2538 )?;
2539 self.write_expr(module, cmp, func_ctx)?;
2540 writeln!(self.out, ");")?;
2541 }
2542 }
2543 }
2544 Statement::ImageAtomic {
2545 image,
2546 coordinate,
2547 array_index,
2548 fun,
2549 value,
2550 } => {
2551 write!(self.out, "{level}")?;
2552
2553 let fun_str = fun.to_hlsl_suffix();
2554 write!(self.out, "Interlocked{fun_str}(")?;
2555 self.write_expr(module, image, func_ctx)?;
2556 write!(self.out, "[")?;
2557 self.write_texture_coordinates(
2558 "int",
2559 coordinate,
2560 array_index,
2561 None,
2562 module,
2563 func_ctx,
2564 )?;
2565 write!(self.out, "],")?;
2566
2567 self.write_expr(module, value, func_ctx)?;
2568 writeln!(self.out, ");")?;
2569 }
2570 Statement::WorkGroupUniformLoad { pointer, result } => {
2571 self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
2572 write!(self.out, "{level}")?;
2573 let name = Baked(result).to_string();
2574 self.write_named_expr(module, pointer, name, result, func_ctx)?;
2575
2576 self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
2577 }
2578 Statement::Switch {
2579 selector,
2580 ref cases,
2581 } => {
2582 self.write_switch(module, func_ctx, level, selector, cases)?;
2583 }
2584 Statement::RayQuery { query, ref fun } => {
2585 let crate::Expression::LocalVariable(query_var) = func_ctx.expressions[query]
2597 else {
2598 unreachable!()
2599 };
2600
2601 let tracker_expr_name = format!(
2602 "{RAY_QUERY_TRACKER_VARIABLE_PREFIX}{}",
2603 self.names[&func_ctx.name_key(query_var)]
2604 );
2605
2606 match *fun {
2607 RayQueryFunction::Initialize {
2608 acceleration_structure,
2609 descriptor,
2610 } => {
2611 self.write_initialize_function(
2612 module,
2613 level,
2614 query,
2615 acceleration_structure,
2616 descriptor,
2617 &tracker_expr_name,
2618 func_ctx,
2619 )?;
2620 }
2621 RayQueryFunction::Proceed { result } => {
2622 self.write_proceed(
2623 module,
2624 level,
2625 query,
2626 result,
2627 &tracker_expr_name,
2628 func_ctx,
2629 )?;
2630 }
2631 RayQueryFunction::GenerateIntersection { hit_t } => {
2632 self.write_generate_intersection(
2633 module,
2634 level,
2635 query,
2636 hit_t,
2637 &tracker_expr_name,
2638 func_ctx,
2639 )?;
2640 }
2641 RayQueryFunction::ConfirmIntersection => {
2642 self.write_confirm_intersection(
2643 module,
2644 level,
2645 query,
2646 &tracker_expr_name,
2647 func_ctx,
2648 )?;
2649 }
2650 RayQueryFunction::Terminate => {
2651 self.write_terminate(module, level, query, &tracker_expr_name, func_ctx)?;
2652 }
2653 }
2654 }
2655 Statement::SubgroupBallot { result, predicate } => {
2656 write!(self.out, "{level}")?;
2657 let name = Baked(result).to_string();
2658 write!(self.out, "const uint4 {name} = ")?;
2659 self.named_expressions.insert(result, name);
2660
2661 write!(self.out, "WaveActiveBallot(")?;
2662 match predicate {
2663 Some(predicate) => self.write_expr(module, predicate, func_ctx)?,
2664 None => write!(self.out, "true")?,
2665 }
2666 writeln!(self.out, ");")?;
2667 }
2668 Statement::SubgroupCollectiveOperation {
2669 op,
2670 collective_op,
2671 argument,
2672 result,
2673 } => {
2674 write!(self.out, "{level}")?;
2675 write!(self.out, "const ")?;
2676 let name = Baked(result).to_string();
2677 match func_ctx.info[result].ty {
2678 proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2679 proc::TypeResolution::Value(ref value) => {
2680 self.write_value_type(module, value)?
2681 }
2682 };
2683 write!(self.out, " {name} = ")?;
2684 self.named_expressions.insert(result, name);
2685
2686 match (collective_op, op) {
2687 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
2688 write!(self.out, "WaveActiveAllTrue(")?
2689 }
2690 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
2691 write!(self.out, "WaveActiveAnyTrue(")?
2692 }
2693 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
2694 write!(self.out, "WaveActiveSum(")?
2695 }
2696 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
2697 write!(self.out, "WaveActiveProduct(")?
2698 }
2699 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
2700 write!(self.out, "WaveActiveMax(")?
2701 }
2702 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
2703 write!(self.out, "WaveActiveMin(")?
2704 }
2705 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
2706 write!(self.out, "WaveActiveBitAnd(")?
2707 }
2708 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
2709 write!(self.out, "WaveActiveBitOr(")?
2710 }
2711 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
2712 write!(self.out, "WaveActiveBitXor(")?
2713 }
2714 (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
2715 write!(self.out, "WavePrefixSum(")?
2716 }
2717 (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
2718 write!(self.out, "WavePrefixProduct(")?
2719 }
2720 (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
2721 self.write_expr(module, argument, func_ctx)?;
2722 write!(self.out, " + WavePrefixSum(")?;
2723 }
2724 (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
2725 self.write_expr(module, argument, func_ctx)?;
2726 write!(self.out, " * WavePrefixProduct(")?;
2727 }
2728 _ => unimplemented!(),
2729 }
2730 self.write_expr(module, argument, func_ctx)?;
2731 writeln!(self.out, ");")?;
2732 }
2733 Statement::SubgroupGather {
2734 mode,
2735 argument,
2736 result,
2737 } => {
2738 write!(self.out, "{level}")?;
2739 write!(self.out, "const ")?;
2740 let name = Baked(result).to_string();
2741 match func_ctx.info[result].ty {
2742 proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2743 proc::TypeResolution::Value(ref value) => {
2744 self.write_value_type(module, value)?
2745 }
2746 };
2747 write!(self.out, " {name} = ")?;
2748 self.named_expressions.insert(result, name);
2749 match mode {
2750 crate::GatherMode::BroadcastFirst => {
2751 write!(self.out, "WaveReadLaneFirst(")?;
2752 self.write_expr(module, argument, func_ctx)?;
2753 }
2754 crate::GatherMode::QuadBroadcast(index) => {
2755 write!(self.out, "QuadReadLaneAt(")?;
2756 self.write_expr(module, argument, func_ctx)?;
2757 write!(self.out, ", ")?;
2758 self.write_expr(module, index, func_ctx)?;
2759 }
2760 crate::GatherMode::QuadSwap(direction) => {
2761 match direction {
2762 crate::Direction::X => {
2763 write!(self.out, "QuadReadAcrossX(")?;
2764 }
2765 crate::Direction::Y => {
2766 write!(self.out, "QuadReadAcrossY(")?;
2767 }
2768 crate::Direction::Diagonal => {
2769 write!(self.out, "QuadReadAcrossDiagonal(")?;
2770 }
2771 }
2772 self.write_expr(module, argument, func_ctx)?;
2773 }
2774 _ => {
2775 write!(self.out, "WaveReadLaneAt(")?;
2776 self.write_expr(module, argument, func_ctx)?;
2777 write!(self.out, ", ")?;
2778 match mode {
2779 crate::GatherMode::BroadcastFirst => unreachable!(),
2780 crate::GatherMode::Broadcast(index)
2781 | crate::GatherMode::Shuffle(index) => {
2782 self.write_expr(module, index, func_ctx)?;
2783 }
2784 crate::GatherMode::ShuffleDown(index) => {
2785 write!(self.out, "WaveGetLaneIndex() + ")?;
2786 self.write_expr(module, index, func_ctx)?;
2787 }
2788 crate::GatherMode::ShuffleUp(index) => {
2789 write!(self.out, "WaveGetLaneIndex() - ")?;
2790 self.write_expr(module, index, func_ctx)?;
2791 }
2792 crate::GatherMode::ShuffleXor(index) => {
2793 write!(self.out, "WaveGetLaneIndex() ^ ")?;
2794 self.write_expr(module, index, func_ctx)?;
2795 }
2796 crate::GatherMode::QuadBroadcast(_) => unreachable!(),
2797 crate::GatherMode::QuadSwap(_) => unreachable!(),
2798 }
2799 }
2800 }
2801 writeln!(self.out, ");")?;
2802 }
2803 Statement::CooperativeStore { .. } => unimplemented!(),
2804 }
2805
2806 Ok(())
2807 }
2808
2809 fn write_const_expression(
2810 &mut self,
2811 module: &Module,
2812 expr: Handle<crate::Expression>,
2813 arena: &crate::Arena<crate::Expression>,
2814 ) -> BackendResult {
2815 self.write_possibly_const_expression(module, expr, arena, |writer, expr| {
2816 writer.write_const_expression(module, expr, arena)
2817 })
2818 }
2819
2820 pub(super) fn write_literal(&mut self, literal: crate::Literal) -> BackendResult {
2821 match literal {
2822 crate::Literal::F64(value) => write!(self.out, "{value:?}L")?,
2823 crate::Literal::F32(value) => write!(self.out, "{value:?}")?,
2824 crate::Literal::F16(value) => write!(self.out, "{value:?}h")?,
2825 crate::Literal::U32(value) => write!(self.out, "{value}u")?,
2826 crate::Literal::I32(value) if value == i32::MIN => {
2832 write!(self.out, "int({} - 1)", value + 1)?
2833 }
2834 crate::Literal::I32(value) => write!(self.out, "int({value})")?,
2838 crate::Literal::U64(value) => write!(self.out, "{value}uL")?,
2839 crate::Literal::I64(value) if value == i64::MIN => {
2841 write!(self.out, "({}L - 1L)", value + 1)?;
2842 }
2843 crate::Literal::I64(value) => write!(self.out, "{value}L")?,
2844 crate::Literal::Bool(value) => write!(self.out, "{value}")?,
2845 crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
2846 return Err(Error::Custom(
2847 "Abstract types should not appear in IR presented to backends".into(),
2848 ));
2849 }
2850 }
2851 Ok(())
2852 }
2853
2854 fn write_possibly_const_expression<E>(
2855 &mut self,
2856 module: &Module,
2857 expr: Handle<crate::Expression>,
2858 expressions: &crate::Arena<crate::Expression>,
2859 write_expression: E,
2860 ) -> BackendResult
2861 where
2862 E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult,
2863 {
2864 use crate::Expression;
2865
2866 match expressions[expr] {
2867 Expression::Literal(literal) => {
2868 self.write_literal(literal)?;
2869 }
2870 Expression::Constant(handle) => {
2871 let constant = &module.constants[handle];
2872 if constant.name.is_some() {
2873 write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
2874 } else {
2875 self.write_const_expression(module, constant.init, &module.global_expressions)?;
2876 }
2877 }
2878 Expression::ZeroValue(ty) => {
2879 self.write_wrapped_zero_value_function_name(module, WrappedZeroValue { ty })?;
2880 write!(self.out, "()")?;
2881 }
2882 Expression::Compose { ty, ref components } => {
2883 match module.types[ty].inner {
2884 TypeInner::Struct { .. } | TypeInner::Array { .. } => {
2885 self.write_wrapped_constructor_function_name(
2886 module,
2887 WrappedConstructor { ty },
2888 )?;
2889 }
2890 _ => {
2891 self.write_type(module, ty)?;
2892 }
2893 };
2894 write!(self.out, "(")?;
2895 for (index, component) in components.iter().enumerate() {
2896 if index != 0 {
2897 write!(self.out, ", ")?;
2898 }
2899 write_expression(self, *component)?;
2900 }
2901 write!(self.out, ")")?;
2902 }
2903 Expression::Splat { size, value } => {
2904 let number_of_components = match size {
2908 crate::VectorSize::Bi => "xx",
2909 crate::VectorSize::Tri => "xxx",
2910 crate::VectorSize::Quad => "xxxx",
2911 };
2912 write!(self.out, "(")?;
2913 write_expression(self, value)?;
2914 write!(self.out, ").{number_of_components}")?
2915 }
2916 _ => {
2917 return Err(Error::Override);
2918 }
2919 }
2920
2921 Ok(())
2922 }
2923
2924 pub(super) fn write_expr(
2929 &mut self,
2930 module: &Module,
2931 expr: Handle<crate::Expression>,
2932 func_ctx: &back::FunctionCtx<'_>,
2933 ) -> BackendResult {
2934 use crate::Expression;
2935
2936 let ff_input = if self.options.special_constants_binding.is_some() {
2938 func_ctx.is_fixed_function_input(expr, module)
2939 } else {
2940 None
2941 };
2942 let closing_bracket = match ff_input {
2943 Some(crate::BuiltIn::VertexIndex) => {
2944 write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX} + ")?;
2945 ")"
2946 }
2947 Some(crate::BuiltIn::InstanceIndex) => {
2948 write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE} + ",)?;
2949 ")"
2950 }
2951 Some(crate::BuiltIn::NumWorkGroups) => {
2952 write!(
2956 self.out,
2957 "uint3({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX}, {SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE}, {SPECIAL_CBUF_VAR}.{SPECIAL_OTHER})",
2958 )?;
2959 return Ok(());
2960 }
2961 _ => "",
2962 };
2963
2964 if let Some(name) = self.named_expressions.get(&expr) {
2965 write!(self.out, "{name}{closing_bracket}")?;
2966 return Ok(());
2967 }
2968
2969 let expression = &func_ctx.expressions[expr];
2970
2971 match *expression {
2972 Expression::Literal(_)
2973 | Expression::Constant(_)
2974 | Expression::ZeroValue(_)
2975 | Expression::Compose { .. }
2976 | Expression::Splat { .. } => {
2977 self.write_possibly_const_expression(
2978 module,
2979 expr,
2980 func_ctx.expressions,
2981 |writer, expr| writer.write_expr(module, expr, func_ctx),
2982 )?;
2983 }
2984 Expression::Override(_) => return Err(Error::Override),
2985 Expression::Binary {
2992 op:
2993 op @ crate::BinaryOperator::Add
2994 | op @ crate::BinaryOperator::Subtract
2995 | op @ crate::BinaryOperator::Multiply,
2996 left,
2997 right,
2998 } if matches!(
2999 func_ctx.resolve_type(expr, &module.types).scalar(),
3000 Some(Scalar::I32)
3001 ) =>
3002 {
3003 write!(self.out, "asint(asuint(",)?;
3004 self.write_expr(module, left, func_ctx)?;
3005 write!(self.out, ") {} asuint(", back::binary_operation_str(op))?;
3006 self.write_expr(module, right, func_ctx)?;
3007 write!(self.out, "))")?;
3008 }
3009 Expression::Binary {
3012 op: crate::BinaryOperator::Multiply,
3013 left,
3014 right,
3015 } if func_ctx.resolve_type(left, &module.types).is_matrix()
3016 || func_ctx.resolve_type(right, &module.types).is_matrix() =>
3017 {
3018 write!(self.out, "mul(")?;
3020 self.write_expr(module, right, func_ctx)?;
3021 write!(self.out, ", ")?;
3022 self.write_expr(module, left, func_ctx)?;
3023 write!(self.out, ")")?;
3024 }
3025
3026 Expression::Binary {
3038 op: crate::BinaryOperator::Divide,
3039 left,
3040 right,
3041 } if matches!(
3042 func_ctx.resolve_type(expr, &module.types).scalar_kind(),
3043 Some(ScalarKind::Sint | ScalarKind::Uint)
3044 ) =>
3045 {
3046 write!(self.out, "{DIV_FUNCTION}(")?;
3047 self.write_expr(module, left, func_ctx)?;
3048 write!(self.out, ", ")?;
3049 self.write_expr(module, right, func_ctx)?;
3050 write!(self.out, ")")?;
3051 }
3052
3053 Expression::Binary {
3054 op: crate::BinaryOperator::Modulo,
3055 left,
3056 right,
3057 } if matches!(
3058 func_ctx.resolve_type(expr, &module.types).scalar_kind(),
3059 Some(ScalarKind::Sint | ScalarKind::Uint | ScalarKind::Float)
3060 ) =>
3061 {
3062 write!(self.out, "{MOD_FUNCTION}(")?;
3063 self.write_expr(module, left, func_ctx)?;
3064 write!(self.out, ", ")?;
3065 self.write_expr(module, right, func_ctx)?;
3066 write!(self.out, ")")?;
3067 }
3068
3069 Expression::Binary { op, left, right } => {
3070 write!(self.out, "(")?;
3071 self.write_expr(module, left, func_ctx)?;
3072 write!(self.out, " {} ", back::binary_operation_str(op))?;
3073 self.write_expr(module, right, func_ctx)?;
3074 write!(self.out, ")")?;
3075 }
3076 Expression::Access { base, index } => {
3077 if let Some(crate::AddressSpace::Storage { .. }) =
3078 func_ctx.resolve_type(expr, &module.types).pointer_space()
3079 {
3080 } else {
3082 if let Some(MatrixType {
3089 columns,
3090 rows: crate::VectorSize::Bi,
3091 width: 4,
3092 }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
3093 .or_else(|| get_global_uniform_matrix(module, base, func_ctx))
3094 {
3095 write!(self.out, "__get_col_of_mat{}x2(", columns as u8)?;
3096 self.write_expr(module, base, func_ctx)?;
3097 write!(self.out, ", ")?;
3098 self.write_expr(module, index, func_ctx)?;
3099 write!(self.out, ")")?;
3100 return Ok(());
3101 }
3102
3103 let resolved = func_ctx.resolve_type(base, &module.types);
3104
3105 let (indexing_binding_array, non_uniform_qualifier) = match *resolved {
3106 TypeInner::BindingArray { .. } => {
3107 let uniformity = &func_ctx.info[index].uniformity;
3108
3109 (true, uniformity.non_uniform_result.is_some())
3110 }
3111 _ => (false, false),
3112 };
3113
3114 self.write_expr(module, base, func_ctx)?;
3115
3116 let array_sampler_info = self.sampler_binding_array_info_from_expression(
3117 module, func_ctx, base, resolved,
3118 );
3119
3120 if let Some(ref info) = array_sampler_info {
3121 write!(self.out, "{}[", info.sampler_heap_name)?;
3122 } else {
3123 write!(self.out, "[")?;
3124 }
3125
3126 let needs_bound_check = self.options.restrict_indexing
3127 && !indexing_binding_array
3128 && match resolved.pointer_space() {
3129 Some(
3130 crate::AddressSpace::Function
3131 | crate::AddressSpace::Private
3132 | crate::AddressSpace::WorkGroup
3133 | crate::AddressSpace::Immediate
3134 | crate::AddressSpace::TaskPayload,
3135 )
3136 | None => true,
3137 Some(crate::AddressSpace::Uniform) => {
3138 let var_handle = self.fill_access_chain(module, base, func_ctx)?;
3140 let bind_target = self
3141 .options
3142 .resolve_resource_binding(
3143 module.global_variables[var_handle]
3144 .binding
3145 .as_ref()
3146 .unwrap(),
3147 )
3148 .unwrap();
3149 bind_target.restrict_indexing
3150 }
3151 Some(
3152 crate::AddressSpace::Handle | crate::AddressSpace::Storage { .. },
3153 ) => unreachable!(),
3154 };
3155 let restriction_needed = if needs_bound_check {
3157 index::access_needs_check(
3158 base,
3159 index::GuardedIndex::Expression(index),
3160 module,
3161 func_ctx.expressions,
3162 func_ctx.info,
3163 )
3164 } else {
3165 None
3166 };
3167 if let Some(limit) = restriction_needed {
3168 write!(self.out, "min(uint(")?;
3169 self.write_expr(module, index, func_ctx)?;
3170 write!(self.out, "), ")?;
3171 match limit {
3172 index::IndexableLength::Known(limit) => {
3173 write!(self.out, "{}u", limit - 1)?;
3174 }
3175 index::IndexableLength::Dynamic => unreachable!(),
3176 }
3177 write!(self.out, ")")?;
3178 } else {
3179 if non_uniform_qualifier {
3180 write!(self.out, "NonUniformResourceIndex(")?;
3181 }
3182 if let Some(ref info) = array_sampler_info {
3183 write!(
3184 self.out,
3185 "{}[{} + ",
3186 info.sampler_index_buffer_name, info.binding_array_base_index_name,
3187 )?;
3188 }
3189 self.write_expr(module, index, func_ctx)?;
3190 if array_sampler_info.is_some() {
3191 write!(self.out, "]")?;
3192 }
3193 if non_uniform_qualifier {
3194 write!(self.out, ")")?;
3195 }
3196 }
3197
3198 write!(self.out, "]")?;
3199 }
3200 }
3201 Expression::AccessIndex { base, index } => {
3202 if let Some(crate::AddressSpace::Storage { .. }) =
3203 func_ctx.resolve_type(expr, &module.types).pointer_space()
3204 {
3205 } else {
3207 if let Some(MatrixType {
3211 rows: crate::VectorSize::Bi,
3212 width: 4,
3213 ..
3214 }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
3215 .or_else(|| get_global_uniform_matrix(module, base, func_ctx))
3216 {
3217 self.write_expr(module, base, func_ctx)?;
3218 write!(self.out, "._{index}")?;
3219 return Ok(());
3220 }
3221
3222 let base_ty_res = &func_ctx.info[base].ty;
3223 let mut resolved = base_ty_res.inner_with(&module.types);
3224 let base_ty_handle = match *resolved {
3225 TypeInner::Pointer { base, .. } => {
3226 resolved = &module.types[base].inner;
3227 Some(base)
3228 }
3229 _ => base_ty_res.handle(),
3230 };
3231
3232 if let TypeInner::Struct { ref members, .. } = *resolved {
3238 let member = &members[index as usize];
3239
3240 match module.types[member.ty].inner {
3241 TypeInner::Matrix {
3242 rows: crate::VectorSize::Bi,
3243 ..
3244 } if member.binding.is_none() => {
3245 let ty = base_ty_handle.unwrap();
3246 self.write_wrapped_struct_matrix_get_function_name(
3247 WrappedStructMatrixAccess { ty, index },
3248 )?;
3249 write!(self.out, "(")?;
3250 self.write_expr(module, base, func_ctx)?;
3251 write!(self.out, ")")?;
3252 return Ok(());
3253 }
3254 _ => {}
3255 }
3256 }
3257
3258 let array_sampler_info = self.sampler_binding_array_info_from_expression(
3259 module, func_ctx, base, resolved,
3260 );
3261
3262 if let Some(ref info) = array_sampler_info {
3263 write!(
3264 self.out,
3265 "{}[{}",
3266 info.sampler_heap_name, info.sampler_index_buffer_name
3267 )?;
3268 }
3269
3270 self.write_expr(module, base, func_ctx)?;
3271
3272 match *resolved {
3273 TypeInner::Vector { .. } | TypeInner::ValuePointer { .. } => {
3279 write!(self.out, ".{}", back::COMPONENTS[index as usize])?
3281 }
3282 TypeInner::Matrix { .. }
3283 | TypeInner::Array { .. }
3284 | TypeInner::BindingArray { .. } => {
3285 if let Some(ref info) = array_sampler_info {
3286 write!(
3287 self.out,
3288 "[{} + {index}]",
3289 info.binding_array_base_index_name
3290 )?;
3291 } else {
3292 write!(self.out, "[{index}]")?;
3293 }
3294 }
3295 TypeInner::Struct { .. } => {
3296 let ty = base_ty_handle.unwrap();
3299
3300 write!(
3301 self.out,
3302 ".{}",
3303 &self.names[&NameKey::StructMember(ty, index)]
3304 )?
3305 }
3306 ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))),
3307 }
3308
3309 if array_sampler_info.is_some() {
3310 write!(self.out, "]")?;
3311 }
3312 }
3313 }
3314 Expression::FunctionArgument(pos) => {
3315 let ty = func_ctx.resolve_type(expr, &module.types);
3316
3317 if let TypeInner::Image {
3323 class: crate::ImageClass::External,
3324 ..
3325 } = *ty
3326 {
3327 let plane_names = [0, 1, 2].map(|i| {
3328 &self.names[&func_ctx
3329 .external_texture_argument_key(pos, ExternalTextureNameKey::Plane(i))]
3330 });
3331 let params_name = &self.names[&func_ctx
3332 .external_texture_argument_key(pos, ExternalTextureNameKey::Params)];
3333 write!(
3334 self.out,
3335 "{}, {}, {}, {}",
3336 plane_names[0], plane_names[1], plane_names[2], params_name
3337 )?;
3338 } else {
3339 let key = func_ctx.argument_key(pos);
3340 let name = &self.names[&key];
3341 write!(self.out, "{name}")?;
3342 }
3343 }
3344 Expression::ImageSample {
3345 coordinate,
3346 image,
3347 sampler,
3348 clamp_to_edge: true,
3349 gather: None,
3350 array_index: None,
3351 offset: None,
3352 level: crate::SampleLevel::Zero,
3353 depth_ref: None,
3354 } => {
3355 write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?;
3356 self.write_expr(module, image, func_ctx)?;
3357 write!(self.out, ", ")?;
3358 self.write_expr(module, sampler, func_ctx)?;
3359 write!(self.out, ", ")?;
3360 self.write_expr(module, coordinate, func_ctx)?;
3361 write!(self.out, ")")?;
3362 }
3363 Expression::ImageSample {
3364 image,
3365 sampler,
3366 gather,
3367 coordinate,
3368 array_index,
3369 offset,
3370 level,
3371 depth_ref,
3372 clamp_to_edge,
3373 } => {
3374 if clamp_to_edge {
3375 return Err(Error::Custom(
3376 "ImageSample::clamp_to_edge should have been validated out".to_string(),
3377 ));
3378 }
3379
3380 use crate::SampleLevel as Sl;
3381 const COMPONENTS: [&str; 4] = ["", "Green", "Blue", "Alpha"];
3382
3383 let (base_str, component_str) = match gather {
3384 Some(component) => ("Gather", COMPONENTS[component as usize]),
3385 None => ("Sample", ""),
3386 };
3387 let cmp_str = match depth_ref {
3388 Some(_) => "Cmp",
3389 None => "",
3390 };
3391 let level_str = match level {
3392 Sl::Zero if gather.is_none() => "LevelZero",
3393 Sl::Auto | Sl::Zero => "",
3394 Sl::Exact(_) => "Level",
3395 Sl::Bias(_) => "Bias",
3396 Sl::Gradient { .. } => "Grad",
3397 };
3398
3399 self.write_expr(module, image, func_ctx)?;
3400 write!(self.out, ".{base_str}{cmp_str}{component_str}{level_str}(")?;
3401 self.write_expr(module, sampler, func_ctx)?;
3402 write!(self.out, ", ")?;
3403 self.write_texture_coordinates(
3404 "float",
3405 coordinate,
3406 array_index,
3407 None,
3408 module,
3409 func_ctx,
3410 )?;
3411
3412 if let Some(depth_ref) = depth_ref {
3413 write!(self.out, ", ")?;
3414 self.write_expr(module, depth_ref, func_ctx)?;
3415 }
3416
3417 match level {
3418 Sl::Auto | Sl::Zero => {}
3419 Sl::Exact(expr) => {
3420 write!(self.out, ", ")?;
3421 self.write_expr(module, expr, func_ctx)?;
3422 }
3423 Sl::Bias(expr) => {
3424 write!(self.out, ", ")?;
3425 self.write_expr(module, expr, func_ctx)?;
3426 }
3427 Sl::Gradient { x, y } => {
3428 write!(self.out, ", ")?;
3429 self.write_expr(module, x, func_ctx)?;
3430 write!(self.out, ", ")?;
3431 self.write_expr(module, y, func_ctx)?;
3432 }
3433 }
3434
3435 if let Some(offset) = offset {
3436 write!(self.out, ", ")?;
3437 write!(self.out, "int2(")?; self.write_const_expression(module, offset, func_ctx.expressions)?;
3439 write!(self.out, ")")?;
3440 }
3441
3442 write!(self.out, ")")?;
3443 }
3444 Expression::ImageQuery { image, query } => {
3445 if let TypeInner::Image {
3447 dim,
3448 arrayed,
3449 class,
3450 } = *func_ctx.resolve_type(image, &module.types)
3451 {
3452 let wrapped_image_query = WrappedImageQuery {
3453 dim,
3454 arrayed,
3455 class,
3456 query: query.into(),
3457 };
3458
3459 self.write_wrapped_image_query_function_name(wrapped_image_query)?;
3460 write!(self.out, "(")?;
3461 self.write_expr(module, image, func_ctx)?;
3463 if let crate::ImageQuery::Size { level: Some(level) } = query {
3464 write!(self.out, ", ")?;
3465 self.write_expr(module, level, func_ctx)?;
3466 }
3467 write!(self.out, ")")?;
3468 }
3469 }
3470 Expression::ImageLoad {
3471 image,
3472 coordinate,
3473 array_index,
3474 sample,
3475 level,
3476 } => self.write_image_load(
3477 &module,
3478 expr,
3479 func_ctx,
3480 image,
3481 coordinate,
3482 array_index,
3483 sample,
3484 level,
3485 )?,
3486 Expression::GlobalVariable(handle) => {
3487 let global_variable = &module.global_variables[handle];
3488 let ty = &module.types[global_variable.ty].inner;
3489
3490 let is_binding_array_of_samplers = match *ty {
3495 TypeInner::BindingArray { base, .. } => {
3496 let base_ty = &module.types[base].inner;
3497 matches!(*base_ty, TypeInner::Sampler { .. })
3498 }
3499 _ => false,
3500 };
3501
3502 let is_storage_space =
3503 matches!(global_variable.space, crate::AddressSpace::Storage { .. });
3504
3505 if let TypeInner::Image {
3513 class: crate::ImageClass::External,
3514 ..
3515 } = *ty
3516 {
3517 let plane_names = [0, 1, 2].map(|i| {
3518 &self.names[&NameKey::ExternalTextureGlobalVariable(
3519 handle,
3520 ExternalTextureNameKey::Plane(i),
3521 )]
3522 });
3523 let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
3524 handle,
3525 ExternalTextureNameKey::Params,
3526 )];
3527 write!(
3528 self.out,
3529 "{}, {}, {}, {}",
3530 plane_names[0], plane_names[1], plane_names[2], params_name
3531 )?;
3532 } else if !is_binding_array_of_samplers && !is_storage_space {
3533 let name = &self.names[&NameKey::GlobalVariable(handle)];
3534 write!(self.out, "{name}")?;
3535 }
3536 }
3537 Expression::LocalVariable(handle) => {
3538 write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])?
3539 }
3540 Expression::Load { pointer } => {
3541 match func_ctx
3542 .resolve_type(pointer, &module.types)
3543 .pointer_space()
3544 {
3545 Some(crate::AddressSpace::Storage { .. }) => {
3546 let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
3547 let result_ty = func_ctx.info[expr].ty.clone();
3548 self.write_storage_load(module, var_handle, result_ty, func_ctx)?;
3549 }
3550 _ => {
3551 let mut close_paren = false;
3552
3553 if let Some(MatrixType {
3558 rows: crate::VectorSize::Bi,
3559 width: 4,
3560 ..
3561 }) = get_inner_matrix_of_struct_array_member(
3562 module, pointer, func_ctx, false,
3563 )
3564 .or_else(|| get_inner_matrix_of_global_uniform(module, pointer, func_ctx))
3565 {
3566 let mut resolved = func_ctx.resolve_type(pointer, &module.types);
3567 let ptr_tr = resolved.pointer_base_type();
3568 if let Some(ptr_ty) =
3569 ptr_tr.as_ref().map(|tr| tr.inner_with(&module.types))
3570 {
3571 resolved = ptr_ty;
3572 }
3573
3574 write!(self.out, "((")?;
3575 if let TypeInner::Array { base, size, .. } = *resolved {
3576 self.write_type(module, base)?;
3577 self.write_array_size(module, base, size)?;
3578 } else {
3579 self.write_value_type(module, resolved)?;
3580 }
3581 write!(self.out, ")")?;
3582 close_paren = true;
3583 }
3584
3585 self.write_expr(module, pointer, func_ctx)?;
3586
3587 if close_paren {
3588 write!(self.out, ")")?;
3589 }
3590 }
3591 }
3592 }
3593 Expression::Unary { op, expr } => {
3594 let op_str = match op {
3596 crate::UnaryOperator::Negate => {
3597 match func_ctx.resolve_type(expr, &module.types).scalar() {
3598 Some(Scalar::I32) => NEG_FUNCTION,
3599 _ => "-",
3600 }
3601 }
3602 crate::UnaryOperator::LogicalNot => "!",
3603 crate::UnaryOperator::BitwiseNot => "~",
3604 };
3605 write!(self.out, "{op_str}(")?;
3606 self.write_expr(module, expr, func_ctx)?;
3607 write!(self.out, ")")?;
3608 }
3609 Expression::As {
3610 expr,
3611 kind,
3612 convert,
3613 } => {
3614 let inner = func_ctx.resolve_type(expr, &module.types);
3615 if inner.scalar_kind() == Some(ScalarKind::Float)
3616 && (kind == ScalarKind::Sint || kind == ScalarKind::Uint)
3617 && convert.is_some()
3618 {
3619 let fun_name = match (kind, convert) {
3623 (ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
3624 (ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
3625 (ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
3626 (ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
3627 _ => unreachable!(),
3628 };
3629 write!(self.out, "{fun_name}(")?;
3630 self.write_expr(module, expr, func_ctx)?;
3631 write!(self.out, ")")?;
3632 } else {
3633 let close_paren = match convert {
3634 Some(dst_width) => {
3635 let scalar = Scalar {
3636 kind,
3637 width: dst_width,
3638 };
3639 match *inner {
3640 TypeInner::Vector { size, .. } => {
3641 write!(
3642 self.out,
3643 "{}{}(",
3644 scalar.to_hlsl_str()?,
3645 common::vector_size_str(size)
3646 )?;
3647 }
3648 TypeInner::Scalar(_) => {
3649 write!(self.out, "{}(", scalar.to_hlsl_str()?,)?;
3650 }
3651 TypeInner::Matrix { columns, rows, .. } => {
3652 write!(
3653 self.out,
3654 "{}{}x{}(",
3655 scalar.to_hlsl_str()?,
3656 common::vector_size_str(columns),
3657 common::vector_size_str(rows)
3658 )?;
3659 }
3660 _ => {
3661 return Err(Error::Unimplemented(format!(
3662 "write_expr expression::as {inner:?}"
3663 )));
3664 }
3665 };
3666 true
3667 }
3668 None => {
3669 if inner.scalar_width() == Some(8) {
3670 false
3671 } else {
3672 write!(self.out, "{}(", kind.to_hlsl_cast(),)?;
3673 true
3674 }
3675 }
3676 };
3677 self.write_expr(module, expr, func_ctx)?;
3678 if close_paren {
3679 write!(self.out, ")")?;
3680 }
3681 }
3682 }
3683 Expression::Math {
3684 fun,
3685 arg,
3686 arg1,
3687 arg2,
3688 arg3,
3689 } => {
3690 use crate::MathFunction as Mf;
3691
3692 enum Function {
3693 Asincosh { is_sin: bool },
3694 Atanh,
3695 Pack2x16float,
3696 Pack2x16snorm,
3697 Pack2x16unorm,
3698 Pack4x8snorm,
3699 Pack4x8unorm,
3700 Pack4xI8,
3701 Pack4xU8,
3702 Pack4xI8Clamp,
3703 Pack4xU8Clamp,
3704 Unpack2x16float,
3705 Unpack2x16snorm,
3706 Unpack2x16unorm,
3707 Unpack4x8snorm,
3708 Unpack4x8unorm,
3709 Unpack4xI8,
3710 Unpack4xU8,
3711 Dot4I8Packed,
3712 Dot4U8Packed,
3713 QuantizeToF16,
3714 Regular(&'static str),
3715 MissingIntOverload(&'static str),
3716 MissingIntReturnType(&'static str),
3717 CountTrailingZeros,
3718 CountLeadingZeros,
3719 }
3720
3721 let fun = match fun {
3722 Mf::Abs => match func_ctx.resolve_type(arg, &module.types).scalar() {
3724 Some(Scalar::I32) => Function::Regular(ABS_FUNCTION),
3725 _ => Function::Regular("abs"),
3726 },
3727 Mf::Min => Function::Regular("min"),
3728 Mf::Max => Function::Regular("max"),
3729 Mf::Clamp => Function::Regular("clamp"),
3730 Mf::Saturate => Function::Regular("saturate"),
3731 Mf::Cos => Function::Regular("cos"),
3733 Mf::Cosh => Function::Regular("cosh"),
3734 Mf::Sin => Function::Regular("sin"),
3735 Mf::Sinh => Function::Regular("sinh"),
3736 Mf::Tan => Function::Regular("tan"),
3737 Mf::Tanh => Function::Regular("tanh"),
3738 Mf::Acos => Function::Regular("acos"),
3739 Mf::Asin => Function::Regular("asin"),
3740 Mf::Atan => Function::Regular("atan"),
3741 Mf::Atan2 => Function::Regular("atan2"),
3742 Mf::Asinh => Function::Asincosh { is_sin: true },
3743 Mf::Acosh => Function::Asincosh { is_sin: false },
3744 Mf::Atanh => Function::Atanh,
3745 Mf::Radians => Function::Regular("radians"),
3746 Mf::Degrees => Function::Regular("degrees"),
3747 Mf::Ceil => Function::Regular("ceil"),
3749 Mf::Floor => Function::Regular("floor"),
3750 Mf::Round => Function::Regular("round"),
3751 Mf::Fract => Function::Regular("frac"),
3752 Mf::Trunc => Function::Regular("trunc"),
3753 Mf::Modf => Function::Regular(MODF_FUNCTION),
3754 Mf::Frexp => Function::Regular(FREXP_FUNCTION),
3755 Mf::Ldexp => Function::Regular("ldexp"),
3756 Mf::Exp => Function::Regular("exp"),
3758 Mf::Exp2 => Function::Regular("exp2"),
3759 Mf::Log => Function::Regular("log"),
3760 Mf::Log2 => Function::Regular("log2"),
3761 Mf::Pow => Function::Regular("pow"),
3762 Mf::Dot => Function::Regular("dot"),
3764 Mf::Dot4I8Packed => Function::Dot4I8Packed,
3765 Mf::Dot4U8Packed => Function::Dot4U8Packed,
3766 Mf::Cross => Function::Regular("cross"),
3768 Mf::Distance => Function::Regular("distance"),
3769 Mf::Length => Function::Regular("length"),
3770 Mf::Normalize => Function::Regular("normalize"),
3771 Mf::FaceForward => Function::Regular("faceforward"),
3772 Mf::Reflect => Function::Regular("reflect"),
3773 Mf::Refract => Function::Regular("refract"),
3774 Mf::Sign => Function::Regular("sign"),
3776 Mf::Fma => Function::Regular("mad"),
3777 Mf::Mix => Function::Regular("lerp"),
3778 Mf::Step => Function::Regular("step"),
3779 Mf::SmoothStep => Function::Regular("smoothstep"),
3780 Mf::Sqrt => Function::Regular("sqrt"),
3781 Mf::InverseSqrt => Function::Regular("rsqrt"),
3782 Mf::Transpose => Function::Regular("transpose"),
3784 Mf::Determinant => Function::Regular("determinant"),
3785 Mf::QuantizeToF16 => Function::QuantizeToF16,
3786 Mf::CountTrailingZeros => Function::CountTrailingZeros,
3788 Mf::CountLeadingZeros => Function::CountLeadingZeros,
3789 Mf::CountOneBits => Function::MissingIntOverload("countbits"),
3790 Mf::ReverseBits => Function::MissingIntOverload("reversebits"),
3791 Mf::FirstTrailingBit => Function::MissingIntReturnType("firstbitlow"),
3792 Mf::FirstLeadingBit => Function::MissingIntReturnType("firstbithigh"),
3793 Mf::ExtractBits => Function::Regular(EXTRACT_BITS_FUNCTION),
3794 Mf::InsertBits => Function::Regular(INSERT_BITS_FUNCTION),
3795 Mf::Pack2x16float => Function::Pack2x16float,
3797 Mf::Pack2x16snorm => Function::Pack2x16snorm,
3798 Mf::Pack2x16unorm => Function::Pack2x16unorm,
3799 Mf::Pack4x8snorm => Function::Pack4x8snorm,
3800 Mf::Pack4x8unorm => Function::Pack4x8unorm,
3801 Mf::Pack4xI8 => Function::Pack4xI8,
3802 Mf::Pack4xU8 => Function::Pack4xU8,
3803 Mf::Pack4xI8Clamp => Function::Pack4xI8Clamp,
3804 Mf::Pack4xU8Clamp => Function::Pack4xU8Clamp,
3805 Mf::Unpack2x16float => Function::Unpack2x16float,
3807 Mf::Unpack2x16snorm => Function::Unpack2x16snorm,
3808 Mf::Unpack2x16unorm => Function::Unpack2x16unorm,
3809 Mf::Unpack4x8snorm => Function::Unpack4x8snorm,
3810 Mf::Unpack4x8unorm => Function::Unpack4x8unorm,
3811 Mf::Unpack4xI8 => Function::Unpack4xI8,
3812 Mf::Unpack4xU8 => Function::Unpack4xU8,
3813 _ => return Err(Error::Unimplemented(format!("write_expr_math {fun:?}"))),
3814 };
3815
3816 match fun {
3817 Function::Asincosh { is_sin } => {
3818 write!(self.out, "log(")?;
3819 self.write_expr(module, arg, func_ctx)?;
3820 write!(self.out, " + sqrt(")?;
3821 self.write_expr(module, arg, func_ctx)?;
3822 write!(self.out, " * ")?;
3823 self.write_expr(module, arg, func_ctx)?;
3824 match is_sin {
3825 true => write!(self.out, " + 1.0))")?,
3826 false => write!(self.out, " - 1.0))")?,
3827 }
3828 }
3829 Function::Atanh => {
3830 write!(self.out, "0.5 * log((1.0 + ")?;
3831 self.write_expr(module, arg, func_ctx)?;
3832 write!(self.out, ") / (1.0 - ")?;
3833 self.write_expr(module, arg, func_ctx)?;
3834 write!(self.out, "))")?;
3835 }
3836 Function::Pack2x16float => {
3837 write!(self.out, "(f32tof16(")?;
3838 self.write_expr(module, arg, func_ctx)?;
3839 write!(self.out, "[0]) | f32tof16(")?;
3840 self.write_expr(module, arg, func_ctx)?;
3841 write!(self.out, "[1]) << 16)")?;
3842 }
3843 Function::Pack2x16snorm => {
3844 let scale = 32767;
3845
3846 write!(self.out, "uint((int(round(clamp(")?;
3847 self.write_expr(module, arg, func_ctx)?;
3848 write!(
3849 self.out,
3850 "[0], -1.0, 1.0) * {scale}.0)) & 0xFFFF) | ((int(round(clamp("
3851 )?;
3852 self.write_expr(module, arg, func_ctx)?;
3853 write!(self.out, "[1], -1.0, 1.0) * {scale}.0)) & 0xFFFF) << 16))",)?;
3854 }
3855 Function::Pack2x16unorm => {
3856 let scale = 65535;
3857
3858 write!(self.out, "(uint(round(clamp(")?;
3859 self.write_expr(module, arg, func_ctx)?;
3860 write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
3861 self.write_expr(module, arg, func_ctx)?;
3862 write!(self.out, "[1], 0.0, 1.0) * {scale}.0)) << 16)")?;
3863 }
3864 Function::Pack4x8snorm => {
3865 let scale = 127;
3866
3867 write!(self.out, "uint((int(round(clamp(")?;
3868 self.write_expr(module, arg, func_ctx)?;
3869 write!(
3870 self.out,
3871 "[0], -1.0, 1.0) * {scale}.0)) & 0xFF) | ((int(round(clamp("
3872 )?;
3873 self.write_expr(module, arg, func_ctx)?;
3874 write!(
3875 self.out,
3876 "[1], -1.0, 1.0) * {scale}.0)) & 0xFF) << 8) | ((int(round(clamp("
3877 )?;
3878 self.write_expr(module, arg, func_ctx)?;
3879 write!(
3880 self.out,
3881 "[2], -1.0, 1.0) * {scale}.0)) & 0xFF) << 16) | ((int(round(clamp("
3882 )?;
3883 self.write_expr(module, arg, func_ctx)?;
3884 write!(self.out, "[3], -1.0, 1.0) * {scale}.0)) & 0xFF) << 24))",)?;
3885 }
3886 Function::Pack4x8unorm => {
3887 let scale = 255;
3888
3889 write!(self.out, "(uint(round(clamp(")?;
3890 self.write_expr(module, arg, func_ctx)?;
3891 write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
3892 self.write_expr(module, arg, func_ctx)?;
3893 write!(
3894 self.out,
3895 "[1], 0.0, 1.0) * {scale}.0)) << 8 | uint(round(clamp("
3896 )?;
3897 self.write_expr(module, arg, func_ctx)?;
3898 write!(
3899 self.out,
3900 "[2], 0.0, 1.0) * {scale}.0)) << 16 | uint(round(clamp("
3901 )?;
3902 self.write_expr(module, arg, func_ctx)?;
3903 write!(self.out, "[3], 0.0, 1.0) * {scale}.0)) << 24)")?;
3904 }
3905 fun @ (Function::Pack4xI8
3906 | Function::Pack4xU8
3907 | Function::Pack4xI8Clamp
3908 | Function::Pack4xU8Clamp) => {
3909 let was_signed =
3910 matches!(fun, Function::Pack4xI8 | Function::Pack4xI8Clamp);
3911 let clamp_bounds = match fun {
3912 Function::Pack4xI8Clamp => Some(("-128", "127")),
3913 Function::Pack4xU8Clamp => Some(("0", "255")),
3914 _ => None,
3915 };
3916 if was_signed {
3917 write!(self.out, "uint(")?;
3918 }
3919 let write_arg = |this: &mut Self| -> BackendResult {
3920 if let Some((min, max)) = clamp_bounds {
3921 write!(this.out, "clamp(")?;
3922 this.write_expr(module, arg, func_ctx)?;
3923 write!(this.out, ", {min}, {max})")?;
3924 } else {
3925 this.write_expr(module, arg, func_ctx)?;
3926 }
3927 Ok(())
3928 };
3929 write!(self.out, "(")?;
3930 write_arg(self)?;
3931 write!(self.out, "[0] & 0xFF) | ((")?;
3932 write_arg(self)?;
3933 write!(self.out, "[1] & 0xFF) << 8) | ((")?;
3934 write_arg(self)?;
3935 write!(self.out, "[2] & 0xFF) << 16) | ((")?;
3936 write_arg(self)?;
3937 write!(self.out, "[3] & 0xFF) << 24)")?;
3938 if was_signed {
3939 write!(self.out, ")")?;
3940 }
3941 }
3942
3943 Function::Unpack2x16float => {
3944 write!(self.out, "float2(f16tof32(")?;
3945 self.write_expr(module, arg, func_ctx)?;
3946 write!(self.out, "), f16tof32((")?;
3947 self.write_expr(module, arg, func_ctx)?;
3948 write!(self.out, ") >> 16))")?;
3949 }
3950 Function::Unpack2x16snorm => {
3951 let scale = 32767;
3952
3953 write!(self.out, "(float2(int2(")?;
3954 self.write_expr(module, arg, func_ctx)?;
3955 write!(self.out, " << 16, ")?;
3956 self.write_expr(module, arg, func_ctx)?;
3957 write!(self.out, ") >> 16) / {scale}.0)")?;
3958 }
3959 Function::Unpack2x16unorm => {
3960 let scale = 65535;
3961
3962 write!(self.out, "(float2(")?;
3963 self.write_expr(module, arg, func_ctx)?;
3964 write!(self.out, " & 0xFFFF, ")?;
3965 self.write_expr(module, arg, func_ctx)?;
3966 write!(self.out, " >> 16) / {scale}.0)")?;
3967 }
3968 Function::Unpack4x8snorm => {
3969 let scale = 127;
3970
3971 write!(self.out, "(float4(int4(")?;
3972 self.write_expr(module, arg, func_ctx)?;
3973 write!(self.out, " << 24, ")?;
3974 self.write_expr(module, arg, func_ctx)?;
3975 write!(self.out, " << 16, ")?;
3976 self.write_expr(module, arg, func_ctx)?;
3977 write!(self.out, " << 8, ")?;
3978 self.write_expr(module, arg, func_ctx)?;
3979 write!(self.out, ") >> 24) / {scale}.0)")?;
3980 }
3981 Function::Unpack4x8unorm => {
3982 let scale = 255;
3983
3984 write!(self.out, "(float4(")?;
3985 self.write_expr(module, arg, func_ctx)?;
3986 write!(self.out, " & 0xFF, ")?;
3987 self.write_expr(module, arg, func_ctx)?;
3988 write!(self.out, " >> 8 & 0xFF, ")?;
3989 self.write_expr(module, arg, func_ctx)?;
3990 write!(self.out, " >> 16 & 0xFF, ")?;
3991 self.write_expr(module, arg, func_ctx)?;
3992 write!(self.out, " >> 24) / {scale}.0)")?;
3993 }
3994 fun @ (Function::Unpack4xI8 | Function::Unpack4xU8) => {
3995 write!(self.out, "(")?;
3996 if matches!(fun, Function::Unpack4xU8) {
3997 write!(self.out, "u")?;
3998 }
3999 write!(self.out, "int4(")?;
4000 self.write_expr(module, arg, func_ctx)?;
4001 write!(self.out, ", ")?;
4002 self.write_expr(module, arg, func_ctx)?;
4003 write!(self.out, " >> 8, ")?;
4004 self.write_expr(module, arg, func_ctx)?;
4005 write!(self.out, " >> 16, ")?;
4006 self.write_expr(module, arg, func_ctx)?;
4007 write!(self.out, " >> 24) << 24 >> 24)")?;
4008 }
4009 fun @ (Function::Dot4I8Packed | Function::Dot4U8Packed) => {
4010 let arg1 = arg1.unwrap();
4011
4012 if self.options.shader_model >= ShaderModel::V6_4 {
4013 let function_name = match fun {
4015 Function::Dot4I8Packed => "dot4add_i8packed",
4016 Function::Dot4U8Packed => "dot4add_u8packed",
4017 _ => unreachable!(),
4018 };
4019 write!(self.out, "{function_name}(")?;
4020 self.write_expr(module, arg, func_ctx)?;
4021 write!(self.out, ", ")?;
4022 self.write_expr(module, arg1, func_ctx)?;
4023 write!(self.out, ", 0)")?;
4024 } else {
4025 write!(self.out, "dot(")?;
4027
4028 if matches!(fun, Function::Dot4U8Packed) {
4029 write!(self.out, "u")?;
4030 }
4031 write!(self.out, "int4(")?;
4032 self.write_expr(module, arg, func_ctx)?;
4033 write!(self.out, ", ")?;
4034 self.write_expr(module, arg, func_ctx)?;
4035 write!(self.out, " >> 8, ")?;
4036 self.write_expr(module, arg, func_ctx)?;
4037 write!(self.out, " >> 16, ")?;
4038 self.write_expr(module, arg, func_ctx)?;
4039 write!(self.out, " >> 24) << 24 >> 24, ")?;
4040
4041 if matches!(fun, Function::Dot4U8Packed) {
4042 write!(self.out, "u")?;
4043 }
4044 write!(self.out, "int4(")?;
4045 self.write_expr(module, arg1, func_ctx)?;
4046 write!(self.out, ", ")?;
4047 self.write_expr(module, arg1, func_ctx)?;
4048 write!(self.out, " >> 8, ")?;
4049 self.write_expr(module, arg1, func_ctx)?;
4050 write!(self.out, " >> 16, ")?;
4051 self.write_expr(module, arg1, func_ctx)?;
4052 write!(self.out, " >> 24) << 24 >> 24)")?;
4053 }
4054 }
4055 Function::QuantizeToF16 => {
4056 write!(self.out, "f16tof32(f32tof16(")?;
4057 self.write_expr(module, arg, func_ctx)?;
4058 write!(self.out, "))")?;
4059 }
4060 Function::Regular(fun_name) => {
4061 write!(self.out, "{fun_name}(")?;
4062 self.write_expr(module, arg, func_ctx)?;
4063 if let Some(arg) = arg1 {
4064 write!(self.out, ", ")?;
4065 self.write_expr(module, arg, func_ctx)?;
4066 }
4067 if let Some(arg) = arg2 {
4068 write!(self.out, ", ")?;
4069 self.write_expr(module, arg, func_ctx)?;
4070 }
4071 if let Some(arg) = arg3 {
4072 write!(self.out, ", ")?;
4073 self.write_expr(module, arg, func_ctx)?;
4074 }
4075 write!(self.out, ")")?
4076 }
4077 Function::MissingIntOverload(fun_name) => {
4080 let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
4081 if let Some(Scalar::I32) = scalar_kind {
4082 write!(self.out, "asint({fun_name}(asuint(")?;
4083 self.write_expr(module, arg, func_ctx)?;
4084 write!(self.out, ")))")?;
4085 } else {
4086 write!(self.out, "{fun_name}(")?;
4087 self.write_expr(module, arg, func_ctx)?;
4088 write!(self.out, ")")?;
4089 }
4090 }
4091 Function::MissingIntReturnType(fun_name) => {
4094 let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
4095 if let Some(Scalar::I32) = scalar_kind {
4096 write!(self.out, "asint({fun_name}(")?;
4097 self.write_expr(module, arg, func_ctx)?;
4098 write!(self.out, "))")?;
4099 } else {
4100 write!(self.out, "{fun_name}(")?;
4101 self.write_expr(module, arg, func_ctx)?;
4102 write!(self.out, ")")?;
4103 }
4104 }
4105 Function::CountTrailingZeros => {
4106 match *func_ctx.resolve_type(arg, &module.types) {
4107 TypeInner::Vector { size, scalar } => {
4108 let s = match size {
4109 crate::VectorSize::Bi => ".xx",
4110 crate::VectorSize::Tri => ".xxx",
4111 crate::VectorSize::Quad => ".xxxx",
4112 };
4113
4114 let scalar_width_bits = scalar.width * 8;
4115
4116 if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
4117 write!(
4118 self.out,
4119 "min(({scalar_width_bits}u){s}, firstbitlow("
4120 )?;
4121 self.write_expr(module, arg, func_ctx)?;
4122 write!(self.out, "))")?;
4123 } else {
4124 write!(
4126 self.out,
4127 "asint(min(({scalar_width_bits}u){s}, firstbitlow("
4128 )?;
4129 self.write_expr(module, arg, func_ctx)?;
4130 write!(self.out, ")))")?;
4131 }
4132 }
4133 TypeInner::Scalar(scalar) => {
4134 let scalar_width_bits = scalar.width * 8;
4135
4136 if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
4137 write!(self.out, "min({scalar_width_bits}u, firstbitlow(")?;
4138 self.write_expr(module, arg, func_ctx)?;
4139 write!(self.out, "))")?;
4140 } else {
4141 write!(
4143 self.out,
4144 "asint(min({scalar_width_bits}u, firstbitlow("
4145 )?;
4146 self.write_expr(module, arg, func_ctx)?;
4147 write!(self.out, ")))")?;
4148 }
4149 }
4150 _ => unreachable!(),
4151 }
4152
4153 return Ok(());
4154 }
4155 Function::CountLeadingZeros => {
4156 match *func_ctx.resolve_type(arg, &module.types) {
4157 TypeInner::Vector { size, scalar } => {
4158 let s = match size {
4159 crate::VectorSize::Bi => ".xx",
4160 crate::VectorSize::Tri => ".xxx",
4161 crate::VectorSize::Quad => ".xxxx",
4162 };
4163
4164 let constant = scalar.width * 8 - 1;
4166
4167 if scalar.kind == ScalarKind::Uint {
4168 write!(self.out, "(({constant}u){s} - firstbithigh(")?;
4169 self.write_expr(module, arg, func_ctx)?;
4170 write!(self.out, "))")?;
4171 } else {
4172 let conversion_func = match scalar.width {
4173 4 => "asint",
4174 _ => "",
4175 };
4176 write!(self.out, "(")?;
4177 self.write_expr(module, arg, func_ctx)?;
4178 write!(
4179 self.out,
4180 " < (0){s} ? (0){s} : ({constant}){s} - {conversion_func}(firstbithigh("
4181 )?;
4182 self.write_expr(module, arg, func_ctx)?;
4183 write!(self.out, ")))")?;
4184 }
4185 }
4186 TypeInner::Scalar(scalar) => {
4187 let constant = scalar.width * 8 - 1;
4189
4190 if let ScalarKind::Uint = scalar.kind {
4191 write!(self.out, "({constant}u - firstbithigh(")?;
4192 self.write_expr(module, arg, func_ctx)?;
4193 write!(self.out, "))")?;
4194 } else {
4195 let conversion_func = match scalar.width {
4196 4 => "asint",
4197 _ => "",
4198 };
4199 write!(self.out, "(")?;
4200 self.write_expr(module, arg, func_ctx)?;
4201 write!(
4202 self.out,
4203 " < 0 ? 0 : {constant} - {conversion_func}(firstbithigh("
4204 )?;
4205 self.write_expr(module, arg, func_ctx)?;
4206 write!(self.out, ")))")?;
4207 }
4208 }
4209 _ => unreachable!(),
4210 }
4211
4212 return Ok(());
4213 }
4214 }
4215 }
4216 Expression::Swizzle {
4217 size,
4218 vector,
4219 pattern,
4220 } => {
4221 self.write_expr(module, vector, func_ctx)?;
4222 write!(self.out, ".")?;
4223 for &sc in pattern[..size as usize].iter() {
4224 self.out.write_char(back::COMPONENTS[sc as usize])?;
4225 }
4226 }
4227 Expression::ArrayLength(expr) => {
4228 let var_handle = match func_ctx.expressions[expr] {
4229 Expression::AccessIndex { base, index: _ } => {
4230 match func_ctx.expressions[base] {
4231 Expression::GlobalVariable(handle) => handle,
4232 _ => unreachable!(),
4233 }
4234 }
4235 Expression::GlobalVariable(handle) => handle,
4236 _ => unreachable!(),
4237 };
4238
4239 let var = &module.global_variables[var_handle];
4240 let (offset, stride) = match module.types[var.ty].inner {
4241 TypeInner::Array { stride, .. } => (0, stride),
4242 TypeInner::Struct { ref members, .. } => {
4243 let last = members.last().unwrap();
4244 let stride = match module.types[last.ty].inner {
4245 TypeInner::Array { stride, .. } => stride,
4246 _ => unreachable!(),
4247 };
4248 (last.offset, stride)
4249 }
4250 _ => unreachable!(),
4251 };
4252
4253 let storage_access = match var.space {
4254 crate::AddressSpace::Storage { access } => access,
4255 _ => crate::StorageAccess::default(),
4256 };
4257 let wrapped_array_length = WrappedArrayLength {
4258 writable: storage_access.contains(crate::StorageAccess::STORE),
4259 };
4260
4261 write!(self.out, "((")?;
4262 self.write_wrapped_array_length_function_name(wrapped_array_length)?;
4263 let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
4264 write!(self.out, "({var_name}) - {offset}) / {stride})")?
4265 }
4266 Expression::Derivative { axis, ctrl, expr } => {
4267 use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
4268 if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
4269 let tail = match ctrl {
4270 Ctrl::Coarse => "coarse",
4271 Ctrl::Fine => "fine",
4272 Ctrl::None => unreachable!(),
4273 };
4274 write!(self.out, "abs(ddx_{tail}(")?;
4275 self.write_expr(module, expr, func_ctx)?;
4276 write!(self.out, ")) + abs(ddy_{tail}(")?;
4277 self.write_expr(module, expr, func_ctx)?;
4278 write!(self.out, "))")?
4279 } else {
4280 let fun_str = match (axis, ctrl) {
4281 (Axis::X, Ctrl::Coarse) => "ddx_coarse",
4282 (Axis::X, Ctrl::Fine) => "ddx_fine",
4283 (Axis::X, Ctrl::None) => "ddx",
4284 (Axis::Y, Ctrl::Coarse) => "ddy_coarse",
4285 (Axis::Y, Ctrl::Fine) => "ddy_fine",
4286 (Axis::Y, Ctrl::None) => "ddy",
4287 (Axis::Width, Ctrl::Coarse | Ctrl::Fine) => unreachable!(),
4288 (Axis::Width, Ctrl::None) => "fwidth",
4289 };
4290 write!(self.out, "{fun_str}(")?;
4291 self.write_expr(module, expr, func_ctx)?;
4292 write!(self.out, ")")?
4293 }
4294 }
4295 Expression::Relational { fun, argument } => {
4296 use crate::RelationalFunction as Rf;
4297
4298 let fun_str = match fun {
4299 Rf::All => "all",
4300 Rf::Any => "any",
4301 Rf::IsNan => "isnan",
4302 Rf::IsInf => "isinf",
4303 };
4304 write!(self.out, "{fun_str}(")?;
4305 self.write_expr(module, argument, func_ctx)?;
4306 write!(self.out, ")")?
4307 }
4308 Expression::Select {
4309 condition,
4310 accept,
4311 reject,
4312 } => {
4313 write!(self.out, "(")?;
4314 self.write_expr(module, condition, func_ctx)?;
4315 write!(self.out, " ? ")?;
4316 self.write_expr(module, accept, func_ctx)?;
4317 write!(self.out, " : ")?;
4318 self.write_expr(module, reject, func_ctx)?;
4319 write!(self.out, ")")?
4320 }
4321 Expression::RayQueryGetIntersection { query, committed } => {
4322 let Expression::LocalVariable(query_var) = func_ctx.expressions[query] else {
4324 unreachable!()
4325 };
4326
4327 let tracker_expr_name = format!(
4328 "{RAY_QUERY_TRACKER_VARIABLE_PREFIX}{}",
4329 self.names[&func_ctx.name_key(query_var)]
4330 );
4331
4332 if committed {
4333 write!(self.out, "GetCommittedIntersection(")?;
4334 self.write_expr(module, query, func_ctx)?;
4335 write!(self.out, ", {tracker_expr_name})")?;
4336 } else {
4337 write!(self.out, "GetCandidateIntersection(")?;
4338 self.write_expr(module, query, func_ctx)?;
4339 write!(self.out, ", {tracker_expr_name})")?;
4340 }
4341 }
4342 Expression::RayQueryVertexPositions { .. }
4344 | Expression::CooperativeLoad { .. }
4345 | Expression::CooperativeMultiplyAdd { .. } => {
4346 unreachable!()
4347 }
4348 Expression::CallResult(_)
4350 | Expression::AtomicResult { .. }
4351 | Expression::WorkGroupUniformLoadResult { .. }
4352 | Expression::RayQueryProceedResult
4353 | Expression::SubgroupBallotResult
4354 | Expression::SubgroupOperationResult { .. } => {}
4355 }
4356
4357 if !closing_bracket.is_empty() {
4358 write!(self.out, "{closing_bracket}")?;
4359 }
4360 Ok(())
4361 }
4362
4363 #[allow(clippy::too_many_arguments)]
4364 fn write_image_load(
4365 &mut self,
4366 module: &&Module,
4367 expr: Handle<crate::Expression>,
4368 func_ctx: &back::FunctionCtx,
4369 image: Handle<crate::Expression>,
4370 coordinate: Handle<crate::Expression>,
4371 array_index: Option<Handle<crate::Expression>>,
4372 sample: Option<Handle<crate::Expression>>,
4373 level: Option<Handle<crate::Expression>>,
4374 ) -> Result<(), Error> {
4375 let mut wrapping_type = None;
4376 match *func_ctx.resolve_type(image, &module.types) {
4377 TypeInner::Image {
4378 class: crate::ImageClass::External,
4379 ..
4380 } => {
4381 write!(self.out, "{IMAGE_LOAD_EXTERNAL_FUNCTION}(")?;
4382 self.write_expr(module, image, func_ctx)?;
4383 write!(self.out, ", ")?;
4384 self.write_expr(module, coordinate, func_ctx)?;
4385 write!(self.out, ")")?;
4386 return Ok(());
4387 }
4388 TypeInner::Image {
4389 class: crate::ImageClass::Storage { format, .. },
4390 ..
4391 } => {
4392 if format.single_component() {
4393 wrapping_type = Some(Scalar::from(format));
4394 }
4395 }
4396 _ => {}
4397 }
4398 if let Some(scalar) = wrapping_type {
4399 write!(
4400 self.out,
4401 "{}{}(",
4402 help::IMAGE_STORAGE_LOAD_SCALAR_WRAPPER,
4403 scalar.to_hlsl_str()?
4404 )?;
4405 }
4406 self.write_expr(module, image, func_ctx)?;
4408 write!(self.out, ".Load(")?;
4409
4410 self.write_texture_coordinates("int", coordinate, array_index, level, module, func_ctx)?;
4411
4412 if let Some(sample) = sample {
4413 write!(self.out, ", ")?;
4414 self.write_expr(module, sample, func_ctx)?;
4415 }
4416
4417 write!(self.out, ")")?;
4419
4420 if wrapping_type.is_some() {
4421 write!(self.out, ")")?;
4422 }
4423
4424 if let TypeInner::Scalar(_) = *func_ctx.resolve_type(expr, &module.types) {
4426 write!(self.out, ".x")?;
4427 }
4428 Ok(())
4429 }
4430
4431 fn sampler_binding_array_info_from_expression(
4434 &mut self,
4435 module: &Module,
4436 func_ctx: &back::FunctionCtx<'_>,
4437 base: Handle<crate::Expression>,
4438 resolved: &TypeInner,
4439 ) -> Option<BindingArraySamplerInfo> {
4440 if let TypeInner::BindingArray {
4441 base: base_ty_handle,
4442 ..
4443 } = *resolved
4444 {
4445 let base_ty = &module.types[base_ty_handle].inner;
4446 if let TypeInner::Sampler { comparison, .. } = *base_ty {
4447 let base = &func_ctx.expressions[base];
4448
4449 if let crate::Expression::GlobalVariable(handle) = *base {
4450 let variable = &module.global_variables[handle];
4451
4452 let sampler_heap_name = match comparison {
4453 true => COMPARISON_SAMPLER_HEAP_VAR,
4454 false => SAMPLER_HEAP_VAR,
4455 };
4456
4457 return Some(BindingArraySamplerInfo {
4458 sampler_heap_name,
4459 sampler_index_buffer_name: self
4460 .wrapped
4461 .sampler_index_buffers
4462 .get(&super::SamplerIndexBufferKey {
4463 group: variable.binding.unwrap().group,
4464 })
4465 .unwrap()
4466 .clone(),
4467 binding_array_base_index_name: self.names[&NameKey::GlobalVariable(handle)]
4468 .clone(),
4469 });
4470 }
4471 }
4472 }
4473
4474 None
4475 }
4476
4477 fn write_named_expr(
4478 &mut self,
4479 module: &Module,
4480 handle: Handle<crate::Expression>,
4481 name: String,
4482 named: Handle<crate::Expression>,
4485 ctx: &back::FunctionCtx,
4486 ) -> BackendResult {
4487 match ctx.info[named].ty {
4488 proc::TypeResolution::Handle(ty_handle) => match module.types[ty_handle].inner {
4489 TypeInner::Struct { .. } => {
4490 let ty_name = &self.names[&NameKey::Type(ty_handle)];
4491 write!(self.out, "{ty_name}")?;
4492 }
4493 _ => {
4494 self.write_type(module, ty_handle)?;
4495 }
4496 },
4497 proc::TypeResolution::Value(ref inner) => {
4498 self.write_value_type(module, inner)?;
4499 }
4500 }
4501
4502 let resolved = ctx.resolve_type(named, &module.types);
4503
4504 write!(self.out, " {name}")?;
4505 if let TypeInner::Array { base, size, .. } = *resolved {
4507 self.write_array_size(module, base, size)?;
4508 }
4509 write!(self.out, " = ")?;
4510 self.write_expr(module, handle, ctx)?;
4511 writeln!(self.out, ";")?;
4512 self.named_expressions.insert(named, name);
4513
4514 Ok(())
4515 }
4516
4517 pub(super) fn write_default_init(
4519 &mut self,
4520 module: &Module,
4521 ty: Handle<crate::Type>,
4522 ) -> BackendResult {
4523 write!(self.out, "(")?;
4524 self.write_type(module, ty)?;
4525 if let TypeInner::Array { base, size, .. } = module.types[ty].inner {
4526 self.write_array_size(module, base, size)?;
4527 }
4528 write!(self.out, ")0")?;
4529 Ok(())
4530 }
4531
4532 fn write_control_barrier(
4533 &mut self,
4534 barrier: crate::Barrier,
4535 level: back::Level,
4536 ) -> BackendResult {
4537 if barrier.contains(crate::Barrier::STORAGE) {
4538 writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
4539 }
4540 if barrier.contains(crate::Barrier::WORK_GROUP) {
4541 writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?;
4542 }
4543 if barrier.contains(crate::Barrier::SUB_GROUP) {
4544 }
4546 if barrier.contains(crate::Barrier::TEXTURE) {
4547 writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
4548 }
4549 Ok(())
4550 }
4551
4552 fn write_memory_barrier(
4553 &mut self,
4554 barrier: crate::Barrier,
4555 level: back::Level,
4556 ) -> BackendResult {
4557 if barrier.contains(crate::Barrier::STORAGE) {
4558 writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
4559 }
4560 if barrier.contains(crate::Barrier::WORK_GROUP) {
4561 writeln!(self.out, "{level}GroupMemoryBarrier();")?;
4562 }
4563 if barrier.contains(crate::Barrier::SUB_GROUP) {
4564 }
4566 if barrier.contains(crate::Barrier::TEXTURE) {
4567 writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
4568 }
4569 Ok(())
4570 }
4571
4572 fn emit_hlsl_atomic_tail(
4574 &mut self,
4575 module: &Module,
4576 func_ctx: &back::FunctionCtx<'_>,
4577 fun: &crate::AtomicFunction,
4578 compare_expr: Option<Handle<crate::Expression>>,
4579 value: Handle<crate::Expression>,
4580 res_var_info: &Option<(Handle<crate::Expression>, String)>,
4581 ) -> BackendResult {
4582 if let Some(cmp) = compare_expr {
4583 write!(self.out, ", ")?;
4584 self.write_expr(module, cmp, func_ctx)?;
4585 }
4586 write!(self.out, ", ")?;
4587 if let crate::AtomicFunction::Subtract = *fun {
4588 write!(self.out, "-")?;
4590 }
4591 self.write_expr(module, value, func_ctx)?;
4592 if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
4593 write!(self.out, ", ")?;
4594 if compare_expr.is_some() {
4595 write!(self.out, "{res_name}.old_value")?;
4596 } else {
4597 write!(self.out, "{res_name}")?;
4598 }
4599 }
4600 writeln!(self.out, ");")?;
4601 Ok(())
4602 }
4603}
4604
4605pub(super) struct MatrixType {
4606 pub(super) columns: crate::VectorSize,
4607 pub(super) rows: crate::VectorSize,
4608 pub(super) width: crate::Bytes,
4609}
4610
4611pub(super) fn get_inner_matrix_data(
4612 module: &Module,
4613 handle: Handle<crate::Type>,
4614) -> Option<MatrixType> {
4615 match module.types[handle].inner {
4616 TypeInner::Matrix {
4617 columns,
4618 rows,
4619 scalar,
4620 } => Some(MatrixType {
4621 columns,
4622 rows,
4623 width: scalar.width,
4624 }),
4625 TypeInner::Array { base, .. } => get_inner_matrix_data(module, base),
4626 _ => None,
4627 }
4628}
4629
4630fn find_matrix_in_access_chain(
4634 module: &Module,
4635 base: Handle<crate::Expression>,
4636 func_ctx: &back::FunctionCtx<'_>,
4637) -> Option<(Handle<crate::Expression>, Option<Index>, Option<Index>)> {
4638 let mut current_base = base;
4639 let mut vector = None;
4640 let mut scalar = None;
4641 loop {
4642 let resolved_tr = func_ctx
4643 .resolve_type(current_base, &module.types)
4644 .pointer_base_type();
4645 let resolved = resolved_tr.as_ref()?.inner_with(&module.types);
4646
4647 match *resolved {
4648 TypeInner::Matrix { .. } => return Some((current_base, vector, scalar)),
4649 TypeInner::Scalar(_) | TypeInner::Vector { .. } => {}
4650 _ => return None,
4651 }
4652
4653 let index;
4654 (current_base, index) = match func_ctx.expressions[current_base] {
4655 crate::Expression::Access { base, index } => (base, Index::Expression(index)),
4656 crate::Expression::AccessIndex { base, index } => (base, Index::Static(index)),
4657 _ => return None,
4658 };
4659
4660 match *resolved {
4661 TypeInner::Scalar(_) => scalar = Some(index),
4662 TypeInner::Vector { .. } => vector = Some(index),
4663 _ => unreachable!(),
4664 }
4665 }
4666}
4667
4668pub(super) fn get_inner_matrix_of_struct_array_member(
4673 module: &Module,
4674 base: Handle<crate::Expression>,
4675 func_ctx: &back::FunctionCtx<'_>,
4676 direct: bool,
4677) -> Option<MatrixType> {
4678 let mut mat_data = None;
4679 let mut array_base = None;
4680
4681 let mut current_base = base;
4682 loop {
4683 let mut resolved = func_ctx.resolve_type(current_base, &module.types);
4684 if let TypeInner::Pointer { base, .. } = *resolved {
4685 resolved = &module.types[base].inner;
4686 };
4687
4688 match *resolved {
4689 TypeInner::Matrix {
4690 columns,
4691 rows,
4692 scalar,
4693 } => {
4694 mat_data = Some(MatrixType {
4695 columns,
4696 rows,
4697 width: scalar.width,
4698 })
4699 }
4700 TypeInner::Array { base, .. } => {
4701 array_base = Some(base);
4702 }
4703 TypeInner::Struct { .. } => {
4704 if let Some(array_base) = array_base {
4705 if direct {
4706 return mat_data;
4707 } else {
4708 return get_inner_matrix_data(module, array_base);
4709 }
4710 }
4711
4712 break;
4713 }
4714 _ => break,
4715 }
4716
4717 current_base = match func_ctx.expressions[current_base] {
4718 crate::Expression::Access { base, .. } => base,
4719 crate::Expression::AccessIndex { base, .. } => base,
4720 _ => break,
4721 };
4722 }
4723 None
4724}
4725
4726fn get_global_uniform_matrix(
4729 module: &Module,
4730 base: Handle<crate::Expression>,
4731 func_ctx: &back::FunctionCtx<'_>,
4732) -> Option<MatrixType> {
4733 let base_tr = func_ctx
4734 .resolve_type(base, &module.types)
4735 .pointer_base_type();
4736 let base_ty = base_tr.as_ref().map(|tr| tr.inner_with(&module.types));
4737 match (&func_ctx.expressions[base], base_ty) {
4738 (
4739 &crate::Expression::GlobalVariable(handle),
4740 Some(&TypeInner::Matrix {
4741 columns,
4742 rows,
4743 scalar,
4744 }),
4745 ) if module.global_variables[handle].space == crate::AddressSpace::Uniform => {
4746 Some(MatrixType {
4747 columns,
4748 rows,
4749 width: scalar.width,
4750 })
4751 }
4752 _ => None,
4753 }
4754}
4755
4756fn get_inner_matrix_of_global_uniform(
4761 module: &Module,
4762 base: Handle<crate::Expression>,
4763 func_ctx: &back::FunctionCtx<'_>,
4764) -> Option<MatrixType> {
4765 let mut mat_data = None;
4766 let mut array_base = None;
4767
4768 let mut current_base = base;
4769 loop {
4770 let mut resolved = func_ctx.resolve_type(current_base, &module.types);
4771 if let TypeInner::Pointer { base, .. } = *resolved {
4772 resolved = &module.types[base].inner;
4773 };
4774
4775 match *resolved {
4776 TypeInner::Matrix {
4777 columns,
4778 rows,
4779 scalar,
4780 } => {
4781 mat_data = Some(MatrixType {
4782 columns,
4783 rows,
4784 width: scalar.width,
4785 })
4786 }
4787 TypeInner::Array { base, .. } => {
4788 array_base = Some(base);
4789 }
4790 _ => break,
4791 }
4792
4793 current_base = match func_ctx.expressions[current_base] {
4794 crate::Expression::Access { base, .. } => base,
4795 crate::Expression::AccessIndex { base, .. } => base,
4796 crate::Expression::GlobalVariable(handle)
4797 if module.global_variables[handle].space == crate::AddressSpace::Uniform =>
4798 {
4799 return mat_data.or_else(|| {
4800 array_base.and_then(|array_base| get_inner_matrix_data(module, array_base))
4801 })
4802 }
4803 _ => break,
4804 };
4805 }
4806 None
4807}