1use alloc::{
2 format,
3 string::{String, ToString},
4 vec::Vec,
5};
6use core::{fmt, mem};
7
8use super::{
9 help,
10 help::{
11 WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess,
12 WrappedZeroValue,
13 },
14 storage::StoreValue,
15 BackendResult, Error, FragmentEntryPoint, Options, PipelineOptions, ShaderModel,
16};
17use crate::{
18 back::{self, get_entry_points, Baked},
19 common,
20 proc::{self, index, NameKey},
21 valid, Handle, Module, RayQueryFunction, Scalar, ScalarKind, ShaderStage, TypeInner,
22};
23
24const LOCATION_SEMANTIC: &str = "LOC";
25const SPECIAL_CBUF_TYPE: &str = "NagaConstants";
26const SPECIAL_CBUF_VAR: &str = "_NagaConstants";
27const SPECIAL_FIRST_VERTEX: &str = "first_vertex";
28const SPECIAL_FIRST_INSTANCE: &str = "first_instance";
29const SPECIAL_OTHER: &str = "other";
30
31pub(crate) const MODF_FUNCTION: &str = "naga_modf";
32pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
33pub(crate) const EXTRACT_BITS_FUNCTION: &str = "naga_extractBits";
34pub(crate) const INSERT_BITS_FUNCTION: &str = "naga_insertBits";
35pub(crate) const SAMPLER_HEAP_VAR: &str = "nagaSamplerHeap";
36pub(crate) const COMPARISON_SAMPLER_HEAP_VAR: &str = "nagaComparisonSamplerHeap";
37pub(crate) const ABS_FUNCTION: &str = "naga_abs";
38pub(crate) const DIV_FUNCTION: &str = "naga_div";
39pub(crate) const MOD_FUNCTION: &str = "naga_mod";
40pub(crate) const NEG_FUNCTION: &str = "naga_neg";
41pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
42pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
43pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
44pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
45pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
46 "nagaTextureSampleBaseClampToEdge";
47
48enum Index {
49 Expression(Handle<crate::Expression>),
50 Static(u32),
51}
52
53struct EpStructMember {
54 name: String,
55 ty: Handle<crate::Type>,
56 binding: Option<crate::Binding>,
59 index: u32,
60}
61
62struct EntryPointBinding {
65 arg_name: String,
68 ty_name: String,
70 members: Vec<EpStructMember>,
72}
73
74pub(super) struct EntryPointInterface {
75 input: Option<EntryPointBinding>,
80 output: Option<EntryPointBinding>,
84}
85
86#[derive(Clone, Eq, PartialEq, PartialOrd, Ord)]
87enum InterfaceKey {
88 Location(u32),
89 BuiltIn(crate::BuiltIn),
90 Other,
91}
92
93impl InterfaceKey {
94 const fn new(binding: Option<&crate::Binding>) -> Self {
95 match binding {
96 Some(&crate::Binding::Location { location, .. }) => Self::Location(location),
97 Some(&crate::Binding::BuiltIn(built_in)) => Self::BuiltIn(built_in),
98 None => Self::Other,
99 }
100 }
101}
102
103#[derive(Copy, Clone, PartialEq)]
104enum Io {
105 Input,
106 Output,
107}
108
109const fn is_subgroup_builtin_binding(binding: &Option<crate::Binding>) -> bool {
110 let &Some(crate::Binding::BuiltIn(builtin)) = binding else {
111 return false;
112 };
113 matches!(
114 builtin,
115 crate::BuiltIn::SubgroupSize
116 | crate::BuiltIn::SubgroupInvocationId
117 | crate::BuiltIn::NumSubgroups
118 | crate::BuiltIn::SubgroupId
119 )
120}
121
122struct BindingArraySamplerInfo {
124 sampler_heap_name: &'static str,
126 sampler_index_buffer_name: String,
128 binding_array_base_index_name: String,
130}
131
132impl<'a, W: fmt::Write> super::Writer<'a, W> {
133 pub fn new(out: W, options: &'a Options, pipeline_options: &'a PipelineOptions) -> Self {
134 Self {
135 out,
136 names: crate::FastHashMap::default(),
137 namer: proc::Namer::default(),
138 options,
139 pipeline_options,
140 entry_point_io: crate::FastHashMap::default(),
141 named_expressions: crate::NamedExpressions::default(),
142 wrapped: super::Wrapped::default(),
143 written_committed_intersection: false,
144 written_candidate_intersection: false,
145 continue_ctx: back::continue_forward::ContinueCtx::default(),
146 temp_access_chain: Vec::new(),
147 need_bake_expressions: Default::default(),
148 }
149 }
150
151 fn reset(&mut self, module: &Module) {
152 self.names.clear();
153 self.namer.reset(
154 module,
155 &super::keywords::RESERVED_SET,
156 super::keywords::RESERVED_CASE_INSENSITIVE,
157 super::keywords::RESERVED_PREFIXES,
158 &mut self.names,
159 );
160 self.entry_point_io.clear();
161 self.named_expressions.clear();
162 self.wrapped.clear();
163 self.written_committed_intersection = false;
164 self.written_candidate_intersection = false;
165 self.continue_ctx.clear();
166 self.need_bake_expressions.clear();
167 }
168
169 fn gen_force_bounded_loop_statements(
177 &mut self,
178 level: back::Level,
179 ) -> Option<(String, String)> {
180 if !self.options.force_loop_bounding {
181 return None;
182 }
183
184 let loop_bound_name = self.namer.call("loop_bound");
185 let max = u32::MAX;
186 let decl = format!("{level}uint2 {loop_bound_name} = uint2({max}u, {max}u);");
189 let level = level.next();
190 let break_and_inc = format!(
191 "{level}if (all({loop_bound_name} == uint2(0u, 0u))) {{ break; }}
192{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
193 );
194
195 Some((decl, break_and_inc))
196 }
197
198 fn update_expressions_to_bake(
203 &mut self,
204 module: &Module,
205 func: &crate::Function,
206 info: &valid::FunctionInfo,
207 ) {
208 use crate::Expression;
209 self.need_bake_expressions.clear();
210 for (exp_handle, expr) in func.expressions.iter() {
211 let expr_info = &info[exp_handle];
212 let min_ref_count = func.expressions[exp_handle].bake_ref_count();
213 if min_ref_count <= expr_info.ref_count {
214 self.need_bake_expressions.insert(exp_handle);
215 }
216
217 if let Expression::Math { fun, arg, arg1, .. } = *expr {
218 match fun {
219 crate::MathFunction::Asinh
220 | crate::MathFunction::Acosh
221 | crate::MathFunction::Atanh
222 | crate::MathFunction::Unpack2x16float
223 | crate::MathFunction::Unpack2x16snorm
224 | crate::MathFunction::Unpack2x16unorm
225 | crate::MathFunction::Unpack4x8snorm
226 | crate::MathFunction::Unpack4x8unorm
227 | crate::MathFunction::Unpack4xI8
228 | crate::MathFunction::Unpack4xU8
229 | crate::MathFunction::Pack2x16float
230 | crate::MathFunction::Pack2x16snorm
231 | crate::MathFunction::Pack2x16unorm
232 | crate::MathFunction::Pack4x8snorm
233 | crate::MathFunction::Pack4x8unorm
234 | crate::MathFunction::Pack4xI8
235 | crate::MathFunction::Pack4xU8
236 | crate::MathFunction::Pack4xI8Clamp
237 | crate::MathFunction::Pack4xU8Clamp => {
238 self.need_bake_expressions.insert(arg);
239 }
240 crate::MathFunction::CountLeadingZeros => {
241 let inner = info[exp_handle].ty.inner_with(&module.types);
242 if let Some(ScalarKind::Sint) = inner.scalar_kind() {
243 self.need_bake_expressions.insert(arg);
244 }
245 }
246 crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
247 self.need_bake_expressions.insert(arg);
248 self.need_bake_expressions.insert(arg1.unwrap());
249 }
250 _ => {}
251 }
252 }
253
254 if let Expression::Derivative { axis, ctrl, expr } = *expr {
255 use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
256 if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
257 self.need_bake_expressions.insert(expr);
258 }
259 }
260
261 if let Expression::GlobalVariable(_) = *expr {
262 let inner = info[exp_handle].ty.inner_with(&module.types);
263
264 if let TypeInner::Sampler { .. } = *inner {
265 self.need_bake_expressions.insert(exp_handle);
266 }
267 }
268 }
269 for statement in func.body.iter() {
270 match *statement {
271 crate::Statement::SubgroupCollectiveOperation {
272 op: _,
273 collective_op: crate::CollectiveOperation::InclusiveScan,
274 argument,
275 result: _,
276 } => {
277 self.need_bake_expressions.insert(argument);
278 }
279 crate::Statement::Atomic {
280 fun: crate::AtomicFunction::Exchange { compare: Some(cmp) },
281 ..
282 } => {
283 self.need_bake_expressions.insert(cmp);
284 }
285 _ => {}
286 }
287 }
288 }
289
290 pub fn write(
291 &mut self,
292 module: &Module,
293 module_info: &valid::ModuleInfo,
294 fragment_entry_point: Option<&FragmentEntryPoint<'_>>,
295 ) -> Result<super::ReflectionInfo, Error> {
296 self.reset(module);
297
298 if let Some(ref bt) = self.options.special_constants_binding {
300 writeln!(self.out, "struct {SPECIAL_CBUF_TYPE} {{")?;
301 writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_VERTEX)?;
302 writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_INSTANCE)?;
303 writeln!(self.out, "{}uint {};", back::INDENT, SPECIAL_OTHER)?;
304 writeln!(self.out, "}};")?;
305 write!(
306 self.out,
307 "ConstantBuffer<{}> {}: register(b{}",
308 SPECIAL_CBUF_TYPE, SPECIAL_CBUF_VAR, bt.register
309 )?;
310 if bt.space != 0 {
311 write!(self.out, ", space{}", bt.space)?;
312 }
313 writeln!(self.out, ");")?;
314
315 writeln!(self.out)?;
317 }
318
319 for (group, bt) in self.options.dynamic_storage_buffer_offsets_targets.iter() {
320 writeln!(self.out, "struct __dynamic_buffer_offsetsTy{group} {{")?;
321 for i in 0..bt.size {
322 writeln!(self.out, "{}uint _{};", back::INDENT, i)?;
323 }
324 writeln!(self.out, "}};")?;
325 writeln!(
326 self.out,
327 "ConstantBuffer<__dynamic_buffer_offsetsTy{}> __dynamic_buffer_offsets{}: register(b{}, space{});",
328 group, group, bt.register, bt.space
329 )?;
330
331 writeln!(self.out)?;
333 }
334
335 let ep_results = module
337 .entry_points
338 .iter()
339 .map(|ep| (ep.stage, ep.function.result.clone()))
340 .collect::<Vec<(ShaderStage, Option<crate::FunctionResult>)>>();
341
342 self.write_all_mat_cx2_typedefs_and_functions(module)?;
343
344 for (handle, ty) in module.types.iter() {
346 if let TypeInner::Struct { ref members, span } = ty.inner {
347 if module.types[members.last().unwrap().ty]
348 .inner
349 .is_dynamically_sized(&module.types)
350 {
351 continue;
354 }
355
356 let ep_result = ep_results.iter().find(|e| {
357 if let Some(ref result) = e.1 {
358 result.ty == handle
359 } else {
360 false
361 }
362 });
363
364 self.write_struct(
365 module,
366 handle,
367 members,
368 span,
369 ep_result.map(|r| (r.0, Io::Output)),
370 )?;
371 writeln!(self.out)?;
372 }
373 }
374
375 self.write_special_functions(module)?;
376
377 self.write_wrapped_expression_functions(module, &module.global_expressions, None)?;
378 self.write_wrapped_zero_value_functions(module, &module.global_expressions)?;
379
380 let mut constants = module
382 .constants
383 .iter()
384 .filter(|&(_, c)| c.name.is_some())
385 .peekable();
386 while let Some((handle, _)) = constants.next() {
387 self.write_global_constant(module, handle)?;
388 if constants.peek().is_none() {
390 writeln!(self.out)?;
391 }
392 }
393
394 for (global, _) in module.global_variables.iter() {
396 self.write_global(module, global)?;
397 }
398
399 if !module.global_variables.is_empty() {
400 writeln!(self.out)?;
402 }
403
404 let ep_range = get_entry_points(module, self.pipeline_options.entry_point.as_ref())
405 .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
406
407 for index in ep_range.clone() {
409 let ep = &module.entry_points[index];
410 let ep_name = self.names[&NameKey::EntryPoint(index as u16)].clone();
411 let ep_io = self.write_ep_interface(
412 module,
413 &ep.function,
414 ep.stage,
415 &ep_name,
416 fragment_entry_point,
417 )?;
418 self.entry_point_io.insert(index, ep_io);
419 }
420
421 for (handle, function) in module.functions.iter() {
423 let info = &module_info[handle];
424
425 if !self.options.fake_missing_bindings {
427 if let Some((var_handle, _)) =
428 module
429 .global_variables
430 .iter()
431 .find(|&(var_handle, var)| match var.binding {
432 Some(ref binding) if !info[var_handle].is_empty() => {
433 self.options.resolve_resource_binding(binding).is_err()
434 }
435 _ => false,
436 })
437 {
438 log::info!(
439 "Skipping function {:?} (name {:?}) because global {:?} is inaccessible",
440 handle,
441 function.name,
442 var_handle
443 );
444 continue;
445 }
446 }
447
448 let ctx = back::FunctionCtx {
449 ty: back::FunctionType::Function(handle),
450 info,
451 expressions: &function.expressions,
452 named_expressions: &function.named_expressions,
453 };
454 let name = self.names[&NameKey::Function(handle)].clone();
455
456 self.write_wrapped_functions(module, &ctx)?;
457
458 self.write_function(module, name.as_str(), function, &ctx, info)?;
459
460 writeln!(self.out)?;
461 }
462
463 let mut translated_ep_names = Vec::with_capacity(ep_range.len());
464
465 for index in ep_range {
467 let ep = &module.entry_points[index];
468 let info = module_info.get_entry_point(index);
469
470 if !self.options.fake_missing_bindings {
471 let mut ep_error = None;
472 for (var_handle, var) in module.global_variables.iter() {
473 match var.binding {
474 Some(ref binding) if !info[var_handle].is_empty() => {
475 if let Err(err) = self.options.resolve_resource_binding(binding) {
476 ep_error = Some(err);
477 break;
478 }
479 }
480 _ => {}
481 }
482 }
483 if let Some(err) = ep_error {
484 translated_ep_names.push(Err(err));
485 continue;
486 }
487 }
488
489 let ctx = back::FunctionCtx {
490 ty: back::FunctionType::EntryPoint(index as u16),
491 info,
492 expressions: &ep.function.expressions,
493 named_expressions: &ep.function.named_expressions,
494 };
495
496 self.write_wrapped_functions(module, &ctx)?;
497
498 if ep.stage == ShaderStage::Compute {
499 let num_threads = ep.workgroup_size;
501 writeln!(
502 self.out,
503 "[numthreads({}, {}, {})]",
504 num_threads[0], num_threads[1], num_threads[2]
505 )?;
506 }
507
508 let name = self.names[&NameKey::EntryPoint(index as u16)].clone();
509 self.write_function(module, &name, &ep.function, &ctx, info)?;
510
511 if index < module.entry_points.len() - 1 {
512 writeln!(self.out)?;
513 }
514
515 translated_ep_names.push(Ok(name));
516 }
517
518 Ok(super::ReflectionInfo {
519 entry_point_names: translated_ep_names,
520 })
521 }
522
523 fn write_modifier(&mut self, binding: &crate::Binding) -> BackendResult {
524 match *binding {
525 crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }) => {
526 write!(self.out, "precise ")?;
527 }
528 crate::Binding::Location {
529 interpolation,
530 sampling,
531 ..
532 } => {
533 if let Some(interpolation) = interpolation {
534 if let Some(string) = interpolation.to_hlsl_str() {
535 write!(self.out, "{string} ")?
536 }
537 }
538
539 if let Some(sampling) = sampling {
540 if let Some(string) = sampling.to_hlsl_str() {
541 write!(self.out, "{string} ")?
542 }
543 }
544 }
545 crate::Binding::BuiltIn(_) => {}
546 }
547
548 Ok(())
549 }
550
551 fn write_semantic(
554 &mut self,
555 binding: &Option<crate::Binding>,
556 stage: Option<(ShaderStage, Io)>,
557 ) -> BackendResult {
558 match *binding {
559 Some(crate::Binding::BuiltIn(builtin)) if !is_subgroup_builtin_binding(binding) => {
560 let builtin_str = builtin.to_hlsl_str()?;
561 write!(self.out, " : {builtin_str}")?;
562 }
563 Some(crate::Binding::Location {
564 blend_src: Some(1), ..
565 }) => {
566 write!(self.out, " : SV_Target1")?;
567 }
568 Some(crate::Binding::Location { location, .. }) => {
569 if stage == Some((ShaderStage::Fragment, Io::Output)) {
570 write!(self.out, " : SV_Target{location}")?;
571 } else {
572 write!(self.out, " : {LOCATION_SEMANTIC}{location}")?;
573 }
574 }
575 _ => {}
576 }
577
578 Ok(())
579 }
580
581 fn write_interface_struct(
582 &mut self,
583 module: &Module,
584 shader_stage: (ShaderStage, Io),
585 struct_name: String,
586 mut members: Vec<EpStructMember>,
587 ) -> Result<EntryPointBinding, Error> {
588 members.sort_by_key(|m| InterfaceKey::new(m.binding.as_ref()));
592
593 write!(self.out, "struct {struct_name}")?;
594 writeln!(self.out, " {{")?;
595 for m in members.iter() {
596 debug_assert!(m.binding.is_some());
599
600 if is_subgroup_builtin_binding(&m.binding) {
601 continue;
602 }
603 write!(self.out, "{}", back::INDENT)?;
604 if let Some(ref binding) = m.binding {
605 self.write_modifier(binding)?;
606 }
607 self.write_type(module, m.ty)?;
608 write!(self.out, " {}", &m.name)?;
609 self.write_semantic(&m.binding, Some(shader_stage))?;
610 writeln!(self.out, ";")?;
611 }
612 if members.iter().any(|arg| {
613 matches!(
614 arg.binding,
615 Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId))
616 )
617 }) {
618 writeln!(
619 self.out,
620 "{}uint __local_invocation_index : SV_GroupIndex;",
621 back::INDENT
622 )?;
623 }
624 writeln!(self.out, "}};")?;
625 writeln!(self.out)?;
626
627 match shader_stage.1 {
629 Io::Input => {
630 members.sort_by_key(|m| m.index);
632 }
633 Io::Output => {
634 }
636 }
637
638 Ok(EntryPointBinding {
639 arg_name: self.namer.call(struct_name.to_lowercase().as_str()),
640 ty_name: struct_name,
641 members,
642 })
643 }
644
645 fn write_ep_input_struct(
649 &mut self,
650 module: &Module,
651 func: &crate::Function,
652 stage: ShaderStage,
653 entry_point_name: &str,
654 ) -> Result<EntryPointBinding, Error> {
655 let struct_name = format!("{stage:?}Input_{entry_point_name}");
656
657 let mut fake_members = Vec::new();
658 for arg in func.arguments.iter() {
659 match module.types[arg.ty].inner {
664 TypeInner::Struct { ref members, .. } => {
665 for member in members.iter() {
666 let name = self.namer.call_or(&member.name, "member");
667 let index = fake_members.len() as u32;
668 fake_members.push(EpStructMember {
669 name,
670 ty: member.ty,
671 binding: member.binding.clone(),
672 index,
673 });
674 }
675 }
676 _ => {
677 let member_name = self.namer.call_or(&arg.name, "member");
678 let index = fake_members.len() as u32;
679 fake_members.push(EpStructMember {
680 name: member_name,
681 ty: arg.ty,
682 binding: arg.binding.clone(),
683 index,
684 });
685 }
686 }
687 }
688
689 self.write_interface_struct(module, (stage, Io::Input), struct_name, fake_members)
690 }
691
692 fn write_ep_output_struct(
696 &mut self,
697 module: &Module,
698 result: &crate::FunctionResult,
699 stage: ShaderStage,
700 entry_point_name: &str,
701 frag_ep: Option<&FragmentEntryPoint<'_>>,
702 ) -> Result<EntryPointBinding, Error> {
703 let struct_name = format!("{stage:?}Output_{entry_point_name}");
704
705 let empty = [];
706 let members = match module.types[result.ty].inner {
707 TypeInner::Struct { ref members, .. } => members,
708 ref other => {
709 log::error!("Unexpected {other:?} output type without a binding");
710 &empty[..]
711 }
712 };
713
714 let fs_input_locs = if let (Some(frag_ep), ShaderStage::Vertex) = (frag_ep, stage) {
719 let mut fs_input_locs = Vec::new();
720 for arg in frag_ep.func.arguments.iter() {
721 let mut push_if_location = |binding: &Option<crate::Binding>| match *binding {
722 Some(crate::Binding::Location { location, .. }) => fs_input_locs.push(location),
723 Some(crate::Binding::BuiltIn(_)) | None => {}
724 };
725
726 match frag_ep.module.types[arg.ty].inner {
729 TypeInner::Struct { ref members, .. } => {
730 for member in members.iter() {
731 push_if_location(&member.binding);
732 }
733 }
734 _ => push_if_location(&arg.binding),
735 }
736 }
737 fs_input_locs.sort();
738 Some(fs_input_locs)
739 } else {
740 None
741 };
742
743 let mut fake_members = Vec::new();
744 for (index, member) in members.iter().enumerate() {
745 if let Some(ref fs_input_locs) = fs_input_locs {
746 match member.binding {
747 Some(crate::Binding::Location { location, .. }) => {
748 if fs_input_locs.binary_search(&location).is_err() {
749 continue;
750 }
751 }
752 Some(crate::Binding::BuiltIn(_)) | None => {}
753 }
754 }
755
756 let member_name = self.namer.call_or(&member.name, "member");
757 fake_members.push(EpStructMember {
758 name: member_name,
759 ty: member.ty,
760 binding: member.binding.clone(),
761 index: index as u32,
762 });
763 }
764
765 self.write_interface_struct(module, (stage, Io::Output), struct_name, fake_members)
766 }
767
768 fn write_ep_interface(
772 &mut self,
773 module: &Module,
774 func: &crate::Function,
775 stage: ShaderStage,
776 ep_name: &str,
777 frag_ep: Option<&FragmentEntryPoint<'_>>,
778 ) -> Result<EntryPointInterface, Error> {
779 Ok(EntryPointInterface {
780 input: if !func.arguments.is_empty()
781 && (stage == ShaderStage::Fragment
782 || func
783 .arguments
784 .iter()
785 .any(|arg| is_subgroup_builtin_binding(&arg.binding)))
786 {
787 Some(self.write_ep_input_struct(module, func, stage, ep_name)?)
788 } else {
789 None
790 },
791 output: match func.result {
792 Some(ref fr) if fr.binding.is_none() && stage == ShaderStage::Vertex => {
793 Some(self.write_ep_output_struct(module, fr, stage, ep_name, frag_ep)?)
794 }
795 _ => None,
796 },
797 })
798 }
799
800 fn write_ep_argument_initialization(
801 &mut self,
802 ep: &crate::EntryPoint,
803 ep_input: &EntryPointBinding,
804 fake_member: &EpStructMember,
805 ) -> BackendResult {
806 match fake_member.binding {
807 Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupSize)) => {
808 write!(self.out, "WaveGetLaneCount()")?
809 }
810 Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupInvocationId)) => {
811 write!(self.out, "WaveGetLaneIndex()")?
812 }
813 Some(crate::Binding::BuiltIn(crate::BuiltIn::NumSubgroups)) => write!(
814 self.out,
815 "({}u + WaveGetLaneCount() - 1u) / WaveGetLaneCount()",
816 ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2]
817 )?,
818 Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => {
819 write!(
820 self.out,
821 "{}.__local_invocation_index / WaveGetLaneCount()",
822 ep_input.arg_name
823 )?;
824 }
825 _ => {
826 write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?;
827 }
828 }
829 Ok(())
830 }
831
832 fn write_ep_arguments_initialization(
834 &mut self,
835 module: &Module,
836 func: &crate::Function,
837 ep_index: u16,
838 ) -> BackendResult {
839 let ep = &module.entry_points[ep_index as usize];
840 let ep_input = match self
841 .entry_point_io
842 .get_mut(&(ep_index as usize))
843 .unwrap()
844 .input
845 .take()
846 {
847 Some(ep_input) => ep_input,
848 None => return Ok(()),
849 };
850 let mut fake_iter = ep_input.members.iter();
851 for (arg_index, arg) in func.arguments.iter().enumerate() {
852 write!(self.out, "{}", back::INDENT)?;
853 self.write_type(module, arg.ty)?;
854 let arg_name = &self.names[&NameKey::EntryPointArgument(ep_index, arg_index as u32)];
855 write!(self.out, " {arg_name}")?;
856 match module.types[arg.ty].inner {
857 TypeInner::Array { base, size, .. } => {
858 self.write_array_size(module, base, size)?;
859 write!(self.out, " = ")?;
860 self.write_ep_argument_initialization(
861 ep,
862 &ep_input,
863 fake_iter.next().unwrap(),
864 )?;
865 writeln!(self.out, ";")?;
866 }
867 TypeInner::Struct { ref members, .. } => {
868 write!(self.out, " = {{ ")?;
869 for index in 0..members.len() {
870 if index != 0 {
871 write!(self.out, ", ")?;
872 }
873 self.write_ep_argument_initialization(
874 ep,
875 &ep_input,
876 fake_iter.next().unwrap(),
877 )?;
878 }
879 writeln!(self.out, " }};")?;
880 }
881 _ => {
882 write!(self.out, " = ")?;
883 self.write_ep_argument_initialization(
884 ep,
885 &ep_input,
886 fake_iter.next().unwrap(),
887 )?;
888 writeln!(self.out, ";")?;
889 }
890 }
891 }
892 assert!(fake_iter.next().is_none());
893 Ok(())
894 }
895
896 fn write_global(
900 &mut self,
901 module: &Module,
902 handle: Handle<crate::GlobalVariable>,
903 ) -> BackendResult {
904 let global = &module.global_variables[handle];
905 let inner = &module.types[global.ty].inner;
906
907 if let Some(ref binding) = global.binding {
908 if let Err(err) = self.options.resolve_resource_binding(binding) {
909 log::info!(
910 "Skipping global {:?} (name {:?}) for being inaccessible: {}",
911 handle,
912 global.name,
913 err,
914 );
915 return Ok(());
916 }
917 }
918
919 let handle_ty = match *inner {
920 TypeInner::BindingArray { ref base, .. } => &module.types[*base].inner,
921 _ => inner,
922 };
923
924 let is_sampler = matches!(*handle_ty, TypeInner::Sampler { .. });
926
927 if is_sampler {
928 return self.write_global_sampler(module, handle, global);
929 }
930
931 let register_ty = match global.space {
933 crate::AddressSpace::Function => unreachable!("Function address space"),
934 crate::AddressSpace::Private => {
935 write!(self.out, "static ")?;
936 self.write_type(module, global.ty)?;
937 ""
938 }
939 crate::AddressSpace::WorkGroup => {
940 write!(self.out, "groupshared ")?;
941 self.write_type(module, global.ty)?;
942 ""
943 }
944 crate::AddressSpace::Uniform => {
945 write!(self.out, "cbuffer")?;
948 "b"
949 }
950 crate::AddressSpace::Storage { access } => {
951 let (prefix, register) = if access.contains(crate::StorageAccess::STORE) {
952 ("RW", "u")
953 } else {
954 ("", "t")
955 };
956 write!(self.out, "{prefix}ByteAddressBuffer")?;
957 register
958 }
959 crate::AddressSpace::Handle => {
960 let register = match *handle_ty {
961 TypeInner::Image {
963 class: crate::ImageClass::Storage { .. },
964 ..
965 } => "u",
966 _ => "t",
967 };
968 self.write_type(module, global.ty)?;
969 register
970 }
971 crate::AddressSpace::PushConstant => {
972 write!(self.out, "ConstantBuffer<")?;
974 "b"
975 }
976 };
977
978 if global.space == crate::AddressSpace::PushConstant {
981 self.write_global_type(module, global.ty)?;
982
983 if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
985 self.write_array_size(module, base, size)?;
986 }
987
988 write!(self.out, ">")?;
990 }
991
992 let name = &self.names[&NameKey::GlobalVariable(handle)];
993 write!(self.out, " {name}")?;
994
995 if global.space == crate::AddressSpace::PushConstant {
998 match module.types[global.ty].inner {
999 TypeInner::Struct { .. } => {}
1000 _ => {
1001 return Err(Error::Unimplemented(format!(
1002 "push-constant '{name}' has non-struct type; tracked by: https://github.com/gfx-rs/wgpu/issues/5683"
1003 )));
1004 }
1005 }
1006
1007 let target = self
1008 .options
1009 .push_constants_target
1010 .as_ref()
1011 .expect("No bind target was defined for the push constants block");
1012 write!(self.out, ": register(b{}", target.register)?;
1013 if target.space != 0 {
1014 write!(self.out, ", space{}", target.space)?;
1015 }
1016 write!(self.out, ")")?;
1017 }
1018
1019 if let Some(ref binding) = global.binding {
1020 let bt = self.options.resolve_resource_binding(binding).unwrap();
1022
1023 if let TypeInner::BindingArray { base, size, .. } = module.types[global.ty].inner {
1025 if let Some(overridden_size) = bt.binding_array_size {
1026 write!(self.out, "[{overridden_size}]")?;
1027 } else {
1028 self.write_array_size(module, base, size)?;
1029 }
1030 }
1031
1032 write!(self.out, " : register({}{}", register_ty, bt.register)?;
1033 if bt.space != 0 {
1034 write!(self.out, ", space{}", bt.space)?;
1035 }
1036 write!(self.out, ")")?;
1037 } else {
1038 if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1040 self.write_array_size(module, base, size)?;
1041 }
1042 if global.space == crate::AddressSpace::Private {
1043 write!(self.out, " = ")?;
1044 if let Some(init) = global.init {
1045 self.write_const_expression(module, init, &module.global_expressions)?;
1046 } else {
1047 self.write_default_init(module, global.ty)?;
1048 }
1049 }
1050 }
1051
1052 if global.space == crate::AddressSpace::Uniform {
1053 write!(self.out, " {{ ")?;
1054
1055 self.write_global_type(module, global.ty)?;
1056
1057 write!(
1058 self.out,
1059 " {}",
1060 &self.names[&NameKey::GlobalVariable(handle)]
1061 )?;
1062
1063 if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1065 self.write_array_size(module, base, size)?;
1066 }
1067
1068 writeln!(self.out, "; }}")?;
1069 } else {
1070 writeln!(self.out, ";")?;
1071 }
1072
1073 Ok(())
1074 }
1075
1076 fn write_global_sampler(
1077 &mut self,
1078 module: &Module,
1079 handle: Handle<crate::GlobalVariable>,
1080 global: &crate::GlobalVariable,
1081 ) -> BackendResult {
1082 let binding = *global.binding.as_ref().unwrap();
1083
1084 let key = super::SamplerIndexBufferKey {
1085 group: binding.group,
1086 };
1087 self.write_wrapped_sampler_buffer(key)?;
1088
1089 let bt = self.options.resolve_resource_binding(&binding).unwrap();
1091
1092 match module.types[global.ty].inner {
1093 TypeInner::Sampler { comparison } => {
1094 write!(self.out, "static const ")?;
1101 self.write_type(module, global.ty)?;
1102
1103 let heap_var = if comparison {
1104 COMPARISON_SAMPLER_HEAP_VAR
1105 } else {
1106 SAMPLER_HEAP_VAR
1107 };
1108
1109 let index_buffer_name = &self.wrapped.sampler_index_buffers[&key];
1110 let name = &self.names[&NameKey::GlobalVariable(handle)];
1111 writeln!(
1112 self.out,
1113 " {name} = {heap_var}[{index_buffer_name}[{register}]];",
1114 register = bt.register
1115 )?;
1116 }
1117 TypeInner::BindingArray { .. } => {
1118 let name = &self.names[&NameKey::GlobalVariable(handle)];
1124 writeln!(
1125 self.out,
1126 "static const uint {name} = {register};",
1127 register = bt.register
1128 )?;
1129 }
1130 _ => unreachable!(),
1131 };
1132
1133 Ok(())
1134 }
1135
1136 fn write_global_constant(
1141 &mut self,
1142 module: &Module,
1143 handle: Handle<crate::Constant>,
1144 ) -> BackendResult {
1145 write!(self.out, "static const ")?;
1146 let constant = &module.constants[handle];
1147 self.write_type(module, constant.ty)?;
1148 let name = &self.names[&NameKey::Constant(handle)];
1149 write!(self.out, " {name}")?;
1150 if let TypeInner::Array { base, size, .. } = module.types[constant.ty].inner {
1152 self.write_array_size(module, base, size)?;
1153 }
1154 write!(self.out, " = ")?;
1155 self.write_const_expression(module, constant.init, &module.global_expressions)?;
1156 writeln!(self.out, ";")?;
1157 Ok(())
1158 }
1159
1160 pub(super) fn write_array_size(
1161 &mut self,
1162 module: &Module,
1163 base: Handle<crate::Type>,
1164 size: crate::ArraySize,
1165 ) -> BackendResult {
1166 write!(self.out, "[")?;
1167
1168 match size.resolve(module.to_ctx())? {
1169 proc::IndexableLength::Known(size) => {
1170 write!(self.out, "{size}")?;
1171 }
1172 proc::IndexableLength::Dynamic => unreachable!(),
1173 }
1174
1175 write!(self.out, "]")?;
1176
1177 if let TypeInner::Array {
1178 base: next_base,
1179 size: next_size,
1180 ..
1181 } = module.types[base].inner
1182 {
1183 self.write_array_size(module, next_base, next_size)?;
1184 }
1185
1186 Ok(())
1187 }
1188
1189 fn write_struct(
1194 &mut self,
1195 module: &Module,
1196 handle: Handle<crate::Type>,
1197 members: &[crate::StructMember],
1198 span: u32,
1199 shader_stage: Option<(ShaderStage, Io)>,
1200 ) -> BackendResult {
1201 let struct_name = &self.names[&NameKey::Type(handle)];
1203 writeln!(self.out, "struct {struct_name} {{")?;
1204
1205 let mut last_offset = 0;
1206 for (index, member) in members.iter().enumerate() {
1207 if member.binding.is_none() && member.offset > last_offset {
1208 let padding = (member.offset - last_offset) / 4;
1212 for i in 0..padding {
1213 writeln!(self.out, "{}int _pad{}_{};", back::INDENT, index, i)?;
1214 }
1215 }
1216 let ty_inner = &module.types[member.ty].inner;
1217 last_offset = member.offset + ty_inner.size_hlsl(module.to_ctx())?;
1218
1219 write!(self.out, "{}", back::INDENT)?;
1221
1222 match module.types[member.ty].inner {
1223 TypeInner::Array { base, size, .. } => {
1224 self.write_global_type(module, member.ty)?;
1227
1228 write!(
1230 self.out,
1231 " {}",
1232 &self.names[&NameKey::StructMember(handle, index as u32)]
1233 )?;
1234 self.write_array_size(module, base, size)?;
1236 }
1237 TypeInner::Matrix {
1240 rows,
1241 columns,
1242 scalar,
1243 } if member.binding.is_none() && rows == crate::VectorSize::Bi => {
1244 let vec_ty = TypeInner::Vector { size: rows, scalar };
1245 let field_name_key = NameKey::StructMember(handle, index as u32);
1246
1247 for i in 0..columns as u8 {
1248 if i != 0 {
1249 write!(self.out, "; ")?;
1250 }
1251 self.write_value_type(module, &vec_ty)?;
1252 write!(self.out, " {}_{}", &self.names[&field_name_key], i)?;
1253 }
1254 }
1255 _ => {
1256 if let Some(ref binding) = member.binding {
1258 self.write_modifier(binding)?;
1259 }
1260
1261 if let TypeInner::Matrix { .. } = module.types[member.ty].inner {
1265 write!(self.out, "row_major ")?;
1266 }
1267
1268 self.write_type(module, member.ty)?;
1270 write!(
1271 self.out,
1272 " {}",
1273 &self.names[&NameKey::StructMember(handle, index as u32)]
1274 )?;
1275 }
1276 }
1277
1278 self.write_semantic(&member.binding, shader_stage)?;
1279 writeln!(self.out, ";")?;
1280 }
1281
1282 if members.last().unwrap().binding.is_none() && span > last_offset {
1284 let padding = (span - last_offset) / 4;
1285 for i in 0..padding {
1286 writeln!(self.out, "{}int _end_pad_{};", back::INDENT, i)?;
1287 }
1288 }
1289
1290 writeln!(self.out, "}};")?;
1291 Ok(())
1292 }
1293
1294 pub(super) fn write_global_type(
1299 &mut self,
1300 module: &Module,
1301 ty: Handle<crate::Type>,
1302 ) -> BackendResult {
1303 let matrix_data = get_inner_matrix_data(module, ty);
1304
1305 if let Some(MatrixType {
1308 columns,
1309 rows: crate::VectorSize::Bi,
1310 width: 4,
1311 }) = matrix_data
1312 {
1313 write!(self.out, "__mat{}x2", columns as u8)?;
1314 } else {
1315 if matrix_data.is_some() {
1319 write!(self.out, "row_major ")?;
1320 }
1321
1322 self.write_type(module, ty)?;
1323 }
1324
1325 Ok(())
1326 }
1327
1328 pub(super) fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
1333 let inner = &module.types[ty].inner;
1334 match *inner {
1335 TypeInner::Struct { .. } => write!(self.out, "{}", self.names[&NameKey::Type(ty)])?,
1336 TypeInner::Array { base, .. } | TypeInner::BindingArray { base, .. } => {
1338 self.write_type(module, base)?
1339 }
1340 ref other => self.write_value_type(module, other)?,
1341 }
1342
1343 Ok(())
1344 }
1345
1346 pub(super) fn write_value_type(&mut self, module: &Module, inner: &TypeInner) -> BackendResult {
1351 match *inner {
1352 TypeInner::Scalar(scalar) | TypeInner::Atomic(scalar) => {
1353 write!(self.out, "{}", scalar.to_hlsl_str()?)?;
1354 }
1355 TypeInner::Vector { size, scalar } => {
1356 write!(
1357 self.out,
1358 "{}{}",
1359 scalar.to_hlsl_str()?,
1360 common::vector_size_str(size)
1361 )?;
1362 }
1363 TypeInner::Matrix {
1364 columns,
1365 rows,
1366 scalar,
1367 } => {
1368 write!(
1373 self.out,
1374 "{}{}x{}",
1375 scalar.to_hlsl_str()?,
1376 common::vector_size_str(columns),
1377 common::vector_size_str(rows),
1378 )?;
1379 }
1380 TypeInner::Image {
1381 dim,
1382 arrayed,
1383 class,
1384 } => {
1385 self.write_image_type(dim, arrayed, class)?;
1386 }
1387 TypeInner::Sampler { comparison } => {
1388 let sampler = if comparison {
1389 "SamplerComparisonState"
1390 } else {
1391 "SamplerState"
1392 };
1393 write!(self.out, "{sampler}")?;
1394 }
1395 TypeInner::Array { base, size, .. } | TypeInner::BindingArray { base, size } => {
1399 self.write_array_size(module, base, size)?;
1400 }
1401 TypeInner::AccelerationStructure { .. } => {
1402 write!(self.out, "RaytracingAccelerationStructure")?;
1403 }
1404 TypeInner::RayQuery { .. } => {
1405 write!(self.out, "RayQuery<RAY_FLAG_NONE>")?;
1407 }
1408 _ => return Err(Error::Unimplemented(format!("write_value_type {inner:?}"))),
1409 }
1410
1411 Ok(())
1412 }
1413
1414 fn write_function(
1418 &mut self,
1419 module: &Module,
1420 name: &str,
1421 func: &crate::Function,
1422 func_ctx: &back::FunctionCtx<'_>,
1423 info: &valid::FunctionInfo,
1424 ) -> BackendResult {
1425 self.update_expressions_to_bake(module, func, info);
1428
1429 if let Some(ref result) = func.result {
1430 let array_return_type = match module.types[result.ty].inner {
1432 TypeInner::Array { base, size, .. } => {
1433 let array_return_type = self.namer.call(&format!("ret_{name}"));
1434 write!(self.out, "typedef ")?;
1435 self.write_type(module, result.ty)?;
1436 write!(self.out, " {array_return_type}")?;
1437 self.write_array_size(module, base, size)?;
1438 writeln!(self.out, ";")?;
1439 Some(array_return_type)
1440 }
1441 _ => None,
1442 };
1443
1444 if let Some(
1446 ref binding @ crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }),
1447 ) = result.binding
1448 {
1449 self.write_modifier(binding)?;
1450 }
1451
1452 match func_ctx.ty {
1454 back::FunctionType::Function(_) => {
1455 if let Some(array_return_type) = array_return_type {
1456 write!(self.out, "{array_return_type}")?;
1457 } else {
1458 self.write_type(module, result.ty)?;
1459 }
1460 }
1461 back::FunctionType::EntryPoint(index) => {
1462 if let Some(ref ep_output) =
1463 self.entry_point_io.get(&(index as usize)).unwrap().output
1464 {
1465 write!(self.out, "{}", ep_output.ty_name)?;
1466 } else {
1467 self.write_type(module, result.ty)?;
1468 }
1469 }
1470 }
1471 } else {
1472 write!(self.out, "void")?;
1473 }
1474
1475 write!(self.out, " {name}(")?;
1477
1478 let need_workgroup_variables_initialization =
1479 self.need_workgroup_variables_initialization(func_ctx, module);
1480
1481 match func_ctx.ty {
1483 back::FunctionType::Function(handle) => {
1484 for (index, arg) in func.arguments.iter().enumerate() {
1485 if index != 0 {
1486 write!(self.out, ", ")?;
1487 }
1488 let arg_ty = match module.types[arg.ty].inner {
1490 TypeInner::Pointer { base, .. } => {
1492 write!(self.out, "inout ")?;
1494 base
1495 }
1496 _ => arg.ty,
1497 };
1498 self.write_type(module, arg_ty)?;
1499
1500 let argument_name =
1501 &self.names[&NameKey::FunctionArgument(handle, index as u32)];
1502
1503 write!(self.out, " {argument_name}")?;
1505 if let TypeInner::Array { base, size, .. } = module.types[arg_ty].inner {
1506 self.write_array_size(module, base, size)?;
1507 }
1508 }
1509 }
1510 back::FunctionType::EntryPoint(ep_index) => {
1511 if let Some(ref ep_input) =
1512 self.entry_point_io.get(&(ep_index as usize)).unwrap().input
1513 {
1514 write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name)?;
1515 } else {
1516 let stage = module.entry_points[ep_index as usize].stage;
1517 for (index, arg) in func.arguments.iter().enumerate() {
1518 if index != 0 {
1519 write!(self.out, ", ")?;
1520 }
1521 self.write_type(module, arg.ty)?;
1522
1523 let argument_name =
1524 &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)];
1525
1526 write!(self.out, " {argument_name}")?;
1527 if let TypeInner::Array { base, size, .. } = module.types[arg.ty].inner {
1528 self.write_array_size(module, base, size)?;
1529 }
1530
1531 self.write_semantic(&arg.binding, Some((stage, Io::Input)))?;
1532 }
1533 }
1534 if need_workgroup_variables_initialization {
1535 if self
1536 .entry_point_io
1537 .get(&(ep_index as usize))
1538 .unwrap()
1539 .input
1540 .is_some()
1541 || !func.arguments.is_empty()
1542 {
1543 write!(self.out, ", ")?;
1544 }
1545 write!(self.out, "uint3 __local_invocation_id : SV_GroupThreadID")?;
1546 }
1547 }
1548 }
1549 write!(self.out, ")")?;
1551
1552 if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1554 let stage = module.entry_points[index as usize].stage;
1555 if let Some(crate::FunctionResult { ref binding, .. }) = func.result {
1556 self.write_semantic(binding, Some((stage, Io::Output)))?;
1557 }
1558 }
1559
1560 writeln!(self.out)?;
1562 writeln!(self.out, "{{")?;
1563
1564 if need_workgroup_variables_initialization {
1565 self.write_workgroup_variables_initialization(func_ctx, module)?;
1566 }
1567
1568 if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1569 self.write_ep_arguments_initialization(module, func, index)?;
1570 }
1571
1572 for (handle, local) in func.local_variables.iter() {
1574 write!(self.out, "{}", back::INDENT)?;
1576
1577 self.write_type(module, local.ty)?;
1580 write!(self.out, " {}", self.names[&func_ctx.name_key(handle)])?;
1581 if let TypeInner::Array { base, size, .. } = module.types[local.ty].inner {
1583 self.write_array_size(module, base, size)?;
1584 }
1585
1586 match module.types[local.ty].inner {
1587 TypeInner::RayQuery { .. } => {}
1589 _ => {
1590 write!(self.out, " = ")?;
1591 if let Some(init) = local.init {
1593 self.write_expr(module, init, func_ctx)?;
1594 } else {
1595 self.write_default_init(module, local.ty)?;
1597 }
1598 }
1599 }
1600 writeln!(self.out, ";")?
1602 }
1603
1604 if !func.local_variables.is_empty() {
1605 writeln!(self.out)?;
1606 }
1607
1608 for sta in func.body.iter() {
1610 self.write_stmt(module, sta, func_ctx, back::Level(1))?;
1612 }
1613
1614 writeln!(self.out, "}}")?;
1615
1616 self.named_expressions.clear();
1617
1618 Ok(())
1619 }
1620
1621 fn need_workgroup_variables_initialization(
1622 &mut self,
1623 func_ctx: &back::FunctionCtx,
1624 module: &Module,
1625 ) -> bool {
1626 self.options.zero_initialize_workgroup_memory
1627 && func_ctx.ty.is_compute_entry_point(module)
1628 && module.global_variables.iter().any(|(handle, var)| {
1629 !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
1630 })
1631 }
1632
1633 fn write_workgroup_variables_initialization(
1634 &mut self,
1635 func_ctx: &back::FunctionCtx,
1636 module: &Module,
1637 ) -> BackendResult {
1638 let level = back::Level(1);
1639
1640 writeln!(
1641 self.out,
1642 "{level}if (all(__local_invocation_id == uint3(0u, 0u, 0u))) {{"
1643 )?;
1644
1645 let vars = module.global_variables.iter().filter(|&(handle, var)| {
1646 !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
1647 });
1648
1649 for (handle, var) in vars {
1650 let name = &self.names[&NameKey::GlobalVariable(handle)];
1651 write!(self.out, "{}{} = ", level.next(), name)?;
1652 self.write_default_init(module, var.ty)?;
1653 writeln!(self.out, ";")?;
1654 }
1655
1656 writeln!(self.out, "{level}}}")?;
1657 self.write_control_barrier(crate::Barrier::WORK_GROUP, level)
1658 }
1659
1660 fn write_switch(
1662 &mut self,
1663 module: &Module,
1664 func_ctx: &back::FunctionCtx<'_>,
1665 level: back::Level,
1666 selector: Handle<crate::Expression>,
1667 cases: &[crate::SwitchCase],
1668 ) -> BackendResult {
1669 let indent_level_1 = level.next();
1671 let indent_level_2 = indent_level_1.next();
1672
1673 if let Some(variable) = self.continue_ctx.enter_switch(&mut self.namer) {
1675 writeln!(self.out, "{level}bool {variable} = false;",)?;
1676 };
1677
1678 let one_body = cases
1683 .iter()
1684 .rev()
1685 .skip(1)
1686 .all(|case| case.fall_through && case.body.is_empty());
1687 if one_body {
1688 writeln!(self.out, "{level}do {{")?;
1690 if let Some(case) = cases.last() {
1694 for sta in case.body.iter() {
1695 self.write_stmt(module, sta, func_ctx, indent_level_1)?;
1696 }
1697 }
1698 writeln!(self.out, "{level}}} while(false);")?;
1700 } else {
1701 write!(self.out, "{level}")?;
1703 write!(self.out, "switch(")?;
1704 self.write_expr(module, selector, func_ctx)?;
1705 writeln!(self.out, ") {{")?;
1706
1707 for (i, case) in cases.iter().enumerate() {
1708 match case.value {
1709 crate::SwitchValue::I32(value) => {
1710 write!(self.out, "{indent_level_1}case {value}:")?
1711 }
1712 crate::SwitchValue::U32(value) => {
1713 write!(self.out, "{indent_level_1}case {value}u:")?
1714 }
1715 crate::SwitchValue::Default => write!(self.out, "{indent_level_1}default:")?,
1716 }
1717
1718 let write_block_braces = !(case.fall_through && case.body.is_empty());
1725 if write_block_braces {
1726 writeln!(self.out, " {{")?;
1727 } else {
1728 writeln!(self.out)?;
1729 }
1730
1731 if case.fall_through && !case.body.is_empty() {
1749 let curr_len = i + 1;
1750 let end_case_idx = curr_len
1751 + cases
1752 .iter()
1753 .skip(curr_len)
1754 .position(|case| !case.fall_through)
1755 .unwrap();
1756 let indent_level_3 = indent_level_2.next();
1757 for case in &cases[i..=end_case_idx] {
1758 writeln!(self.out, "{indent_level_2}{{")?;
1759 let prev_len = self.named_expressions.len();
1760 for sta in case.body.iter() {
1761 self.write_stmt(module, sta, func_ctx, indent_level_3)?;
1762 }
1763 self.named_expressions.truncate(prev_len);
1765 writeln!(self.out, "{indent_level_2}}}")?;
1766 }
1767
1768 let last_case = &cases[end_case_idx];
1769 if last_case.body.last().is_none_or(|s| !s.is_terminator()) {
1770 writeln!(self.out, "{indent_level_2}break;")?;
1771 }
1772 } else {
1773 for sta in case.body.iter() {
1774 self.write_stmt(module, sta, func_ctx, indent_level_2)?;
1775 }
1776 if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator()) {
1777 writeln!(self.out, "{indent_level_2}break;")?;
1778 }
1779 }
1780
1781 if write_block_braces {
1782 writeln!(self.out, "{indent_level_1}}}")?;
1783 }
1784 }
1785
1786 writeln!(self.out, "{level}}}")?;
1787 }
1788
1789 use back::continue_forward::ExitControlFlow;
1791 let op = match self.continue_ctx.exit_switch() {
1792 ExitControlFlow::None => None,
1793 ExitControlFlow::Continue { variable } => Some(("continue", variable)),
1794 ExitControlFlow::Break { variable } => Some(("break", variable)),
1795 };
1796 if let Some((control_flow, variable)) = op {
1797 writeln!(self.out, "{level}if ({variable}) {{")?;
1798 writeln!(self.out, "{indent_level_1}{control_flow};")?;
1799 writeln!(self.out, "{level}}}")?;
1800 }
1801
1802 Ok(())
1803 }
1804
1805 fn write_index(
1806 &mut self,
1807 module: &Module,
1808 index: Index,
1809 func_ctx: &back::FunctionCtx<'_>,
1810 ) -> BackendResult {
1811 match index {
1812 Index::Static(index) => {
1813 write!(self.out, "{index}")?;
1814 }
1815 Index::Expression(index) => {
1816 self.write_expr(module, index, func_ctx)?;
1817 }
1818 }
1819 Ok(())
1820 }
1821
1822 fn write_stmt(
1827 &mut self,
1828 module: &Module,
1829 stmt: &crate::Statement,
1830 func_ctx: &back::FunctionCtx<'_>,
1831 level: back::Level,
1832 ) -> BackendResult {
1833 use crate::Statement;
1834
1835 match *stmt {
1836 Statement::Emit(ref range) => {
1837 for handle in range.clone() {
1838 let ptr_class = func_ctx.resolve_type(handle, &module.types).pointer_space();
1839 let expr_name = if ptr_class.is_some() {
1840 None
1844 } else if let Some(name) = func_ctx.named_expressions.get(&handle) {
1845 Some(self.namer.call(name))
1850 } else if self.need_bake_expressions.contains(&handle) {
1851 Some(Baked(handle).to_string())
1852 } else {
1853 None
1854 };
1855
1856 if let Some(name) = expr_name {
1857 write!(self.out, "{level}")?;
1858 self.write_named_expr(module, handle, name, handle, func_ctx)?;
1859 }
1860 }
1861 }
1862 Statement::Block(ref block) => {
1864 write!(self.out, "{level}")?;
1865 writeln!(self.out, "{{")?;
1866 for sta in block.iter() {
1867 self.write_stmt(module, sta, func_ctx, level.next())?
1869 }
1870 writeln!(self.out, "{level}}}")?
1871 }
1872 Statement::If {
1874 condition,
1875 ref accept,
1876 ref reject,
1877 } => {
1878 write!(self.out, "{level}")?;
1879 write!(self.out, "if (")?;
1880 self.write_expr(module, condition, func_ctx)?;
1881 writeln!(self.out, ") {{")?;
1882
1883 let l2 = level.next();
1884 for sta in accept {
1885 self.write_stmt(module, sta, func_ctx, l2)?;
1887 }
1888
1889 if !reject.is_empty() {
1892 writeln!(self.out, "{level}}} else {{")?;
1893
1894 for sta in reject {
1895 self.write_stmt(module, sta, func_ctx, l2)?;
1897 }
1898 }
1899
1900 writeln!(self.out, "{level}}}")?
1901 }
1902 Statement::Kill => writeln!(self.out, "{level}discard;")?,
1904 Statement::Return { value: None } => {
1905 writeln!(self.out, "{level}return;")?;
1906 }
1907 Statement::Return { value: Some(expr) } => {
1908 let base_ty_res = &func_ctx.info[expr].ty;
1909 let mut resolved = base_ty_res.inner_with(&module.types);
1910 if let TypeInner::Pointer { base, space: _ } = *resolved {
1911 resolved = &module.types[base].inner;
1912 }
1913
1914 if let TypeInner::Struct { .. } = *resolved {
1915 let ty = base_ty_res.handle().unwrap();
1917 let struct_name = &self.names[&NameKey::Type(ty)];
1918 let variable_name = self.namer.call(&struct_name.to_lowercase());
1919 write!(self.out, "{level}const {struct_name} {variable_name} = ",)?;
1920 self.write_expr(module, expr, func_ctx)?;
1921 writeln!(self.out, ";")?;
1922
1923 let ep_output = match func_ctx.ty {
1925 back::FunctionType::Function(_) => None,
1926 back::FunctionType::EntryPoint(index) => self
1927 .entry_point_io
1928 .get(&(index as usize))
1929 .unwrap()
1930 .output
1931 .as_ref(),
1932 };
1933 let final_name = match ep_output {
1934 Some(ep_output) => {
1935 let final_name = self.namer.call(&variable_name);
1936 write!(
1937 self.out,
1938 "{}const {} {} = {{ ",
1939 level, ep_output.ty_name, final_name,
1940 )?;
1941 for (index, m) in ep_output.members.iter().enumerate() {
1942 if index != 0 {
1943 write!(self.out, ", ")?;
1944 }
1945 let member_name = &self.names[&NameKey::StructMember(ty, m.index)];
1946 write!(self.out, "{variable_name}.{member_name}")?;
1947 }
1948 writeln!(self.out, " }};")?;
1949 final_name
1950 }
1951 None => variable_name,
1952 };
1953 writeln!(self.out, "{level}return {final_name};")?;
1954 } else {
1955 write!(self.out, "{level}return ")?;
1956 self.write_expr(module, expr, func_ctx)?;
1957 writeln!(self.out, ";")?
1958 }
1959 }
1960 Statement::Store { pointer, value } => {
1961 let ty_inner = func_ctx.resolve_type(pointer, &module.types);
1962 if let Some(crate::AddressSpace::Storage { .. }) = ty_inner.pointer_space() {
1963 let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
1964 self.write_storage_store(
1965 module,
1966 var_handle,
1967 StoreValue::Expression(value),
1968 func_ctx,
1969 level,
1970 None,
1971 )?;
1972 } else {
1973 enum MatrixAccess {
1979 Direct {
1980 base: Handle<crate::Expression>,
1981 index: u32,
1982 },
1983 Struct {
1984 columns: crate::VectorSize,
1985 base: Handle<crate::Expression>,
1986 },
1987 }
1988
1989 let get_members = |expr: Handle<crate::Expression>| {
1990 let resolved = func_ctx.resolve_type(expr, &module.types);
1991 match *resolved {
1992 TypeInner::Pointer { base, .. } => match module.types[base].inner {
1993 TypeInner::Struct { ref members, .. } => Some(members),
1994 _ => None,
1995 },
1996 _ => None,
1997 }
1998 };
1999
2000 write!(self.out, "{level}")?;
2001
2002 let matrix_access_on_lhs =
2003 find_matrix_in_access_chain(module, pointer, func_ctx).and_then(
2004 |(matrix_expr, vector, scalar)| match (
2005 func_ctx.resolve_type(matrix_expr, &module.types),
2006 &func_ctx.expressions[matrix_expr],
2007 ) {
2008 (
2009 &TypeInner::Pointer { base: ty, .. },
2010 &crate::Expression::AccessIndex { base, index },
2011 ) if matches!(
2012 module.types[ty].inner,
2013 TypeInner::Matrix {
2014 rows: crate::VectorSize::Bi,
2015 ..
2016 }
2017 ) && get_members(base)
2018 .map(|members| members[index as usize].binding.is_none())
2019 == Some(true) =>
2020 {
2021 Some((MatrixAccess::Direct { base, index }, vector, scalar))
2022 }
2023 _ => {
2024 if let Some(MatrixType {
2025 columns,
2026 rows: crate::VectorSize::Bi,
2027 width: 4,
2028 }) = get_inner_matrix_of_struct_array_member(
2029 module,
2030 matrix_expr,
2031 func_ctx,
2032 true,
2033 ) {
2034 Some((
2035 MatrixAccess::Struct {
2036 columns,
2037 base: matrix_expr,
2038 },
2039 vector,
2040 scalar,
2041 ))
2042 } else {
2043 None
2044 }
2045 }
2046 },
2047 );
2048
2049 match matrix_access_on_lhs {
2050 Some((MatrixAccess::Direct { index, base }, vector, scalar)) => {
2051 let base_ty_res = &func_ctx.info[base].ty;
2052 let resolved = base_ty_res.inner_with(&module.types);
2053 let ty = match *resolved {
2054 TypeInner::Pointer { base, .. } => base,
2055 _ => base_ty_res.handle().unwrap(),
2056 };
2057
2058 if let Some(Index::Static(vec_index)) = vector {
2059 self.write_expr(module, base, func_ctx)?;
2060 write!(
2061 self.out,
2062 ".{}_{}",
2063 &self.names[&NameKey::StructMember(ty, index)],
2064 vec_index
2065 )?;
2066
2067 if let Some(scalar_index) = scalar {
2068 write!(self.out, "[")?;
2069 self.write_index(module, scalar_index, func_ctx)?;
2070 write!(self.out, "]")?;
2071 }
2072
2073 write!(self.out, " = ")?;
2074 self.write_expr(module, value, func_ctx)?;
2075 writeln!(self.out, ";")?;
2076 } else {
2077 let access = WrappedStructMatrixAccess { ty, index };
2078 match (&vector, &scalar) {
2079 (&Some(_), &Some(_)) => {
2080 self.write_wrapped_struct_matrix_set_scalar_function_name(
2081 access,
2082 )?;
2083 }
2084 (&Some(_), &None) => {
2085 self.write_wrapped_struct_matrix_set_vec_function_name(
2086 access,
2087 )?;
2088 }
2089 (&None, _) => {
2090 self.write_wrapped_struct_matrix_set_function_name(access)?;
2091 }
2092 }
2093
2094 write!(self.out, "(")?;
2095 self.write_expr(module, base, func_ctx)?;
2096 write!(self.out, ", ")?;
2097 self.write_expr(module, value, func_ctx)?;
2098
2099 if let Some(Index::Expression(vec_index)) = vector {
2100 write!(self.out, ", ")?;
2101 self.write_expr(module, vec_index, func_ctx)?;
2102
2103 if let Some(scalar_index) = scalar {
2104 write!(self.out, ", ")?;
2105 self.write_index(module, scalar_index, func_ctx)?;
2106 }
2107 }
2108 writeln!(self.out, ");")?;
2109 }
2110 }
2111 Some((
2112 MatrixAccess::Struct { columns, base },
2113 Some(Index::Expression(vec_index)),
2114 scalar,
2115 )) => {
2116 if scalar.is_some() {
2120 write!(self.out, "__set_el_of_mat{}x2", columns as u8)?;
2121 } else {
2122 write!(self.out, "__set_col_of_mat{}x2", columns as u8)?;
2123 }
2124 write!(self.out, "(")?;
2125 self.write_expr(module, base, func_ctx)?;
2126 write!(self.out, ", ")?;
2127 self.write_expr(module, vec_index, func_ctx)?;
2128
2129 if let Some(scalar_index) = scalar {
2130 write!(self.out, ", ")?;
2131 self.write_index(module, scalar_index, func_ctx)?;
2132 }
2133
2134 write!(self.out, ", ")?;
2135 self.write_expr(module, value, func_ctx)?;
2136
2137 writeln!(self.out, ");")?;
2138 }
2139 Some((MatrixAccess::Struct { .. }, Some(Index::Static(_)), _))
2140 | Some((MatrixAccess::Struct { .. }, None, _))
2141 | None => {
2142 self.write_expr(module, pointer, func_ctx)?;
2143 write!(self.out, " = ")?;
2144
2145 if let Some(MatrixType {
2150 columns,
2151 rows: crate::VectorSize::Bi,
2152 width: 4,
2153 }) = get_inner_matrix_of_struct_array_member(
2154 module, pointer, func_ctx, false,
2155 ) {
2156 let mut resolved = func_ctx.resolve_type(pointer, &module.types);
2157 if let TypeInner::Pointer { base, .. } = *resolved {
2158 resolved = &module.types[base].inner;
2159 }
2160
2161 write!(self.out, "(__mat{}x2", columns as u8)?;
2162 if let TypeInner::Array { base, size, .. } = *resolved {
2163 self.write_array_size(module, base, size)?;
2164 }
2165 write!(self.out, ")")?;
2166 }
2167
2168 self.write_expr(module, value, func_ctx)?;
2169 writeln!(self.out, ";")?
2170 }
2171 }
2172 }
2173 }
2174 Statement::Loop {
2175 ref body,
2176 ref continuing,
2177 break_if,
2178 } => {
2179 let force_loop_bound_statements = self.gen_force_bounded_loop_statements(level);
2180 let gate_name = (!continuing.is_empty() || break_if.is_some())
2181 .then(|| self.namer.call("loop_init"));
2182
2183 if let Some((ref decl, _)) = force_loop_bound_statements {
2184 writeln!(self.out, "{decl}")?;
2185 }
2186 if let Some(ref gate_name) = gate_name {
2187 writeln!(self.out, "{level}bool {gate_name} = true;")?;
2188 }
2189
2190 self.continue_ctx.enter_loop();
2191 writeln!(self.out, "{level}while(true) {{")?;
2192 if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
2193 writeln!(self.out, "{break_and_inc}")?;
2194 }
2195 let l2 = level.next();
2196 if let Some(gate_name) = gate_name {
2197 writeln!(self.out, "{l2}if (!{gate_name}) {{")?;
2198 let l3 = l2.next();
2199 for sta in continuing.iter() {
2200 self.write_stmt(module, sta, func_ctx, l3)?;
2201 }
2202 if let Some(condition) = break_if {
2203 write!(self.out, "{l3}if (")?;
2204 self.write_expr(module, condition, func_ctx)?;
2205 writeln!(self.out, ") {{")?;
2206 writeln!(self.out, "{}break;", l3.next())?;
2207 writeln!(self.out, "{l3}}}")?;
2208 }
2209 writeln!(self.out, "{l2}}}")?;
2210 writeln!(self.out, "{l2}{gate_name} = false;")?;
2211 }
2212
2213 for sta in body.iter() {
2214 self.write_stmt(module, sta, func_ctx, l2)?;
2215 }
2216
2217 writeln!(self.out, "{level}}}")?;
2218 self.continue_ctx.exit_loop();
2219 }
2220 Statement::Break => writeln!(self.out, "{level}break;")?,
2221 Statement::Continue => {
2222 if let Some(variable) = self.continue_ctx.continue_encountered() {
2223 writeln!(self.out, "{level}{variable} = true;")?;
2224 writeln!(self.out, "{level}break;")?
2225 } else {
2226 writeln!(self.out, "{level}continue;")?
2227 }
2228 }
2229 Statement::ControlBarrier(barrier) => {
2230 self.write_control_barrier(barrier, level)?;
2231 }
2232 Statement::MemoryBarrier(barrier) => {
2233 self.write_memory_barrier(barrier, level)?;
2234 }
2235 Statement::ImageStore {
2236 image,
2237 coordinate,
2238 array_index,
2239 value,
2240 } => {
2241 write!(self.out, "{level}")?;
2242 self.write_expr(module, image, func_ctx)?;
2243
2244 write!(self.out, "[")?;
2245 if let Some(index) = array_index {
2246 write!(self.out, "int3(")?;
2248 self.write_expr(module, coordinate, func_ctx)?;
2249 write!(self.out, ", ")?;
2250 self.write_expr(module, index, func_ctx)?;
2251 write!(self.out, ")")?;
2252 } else {
2253 self.write_expr(module, coordinate, func_ctx)?;
2254 }
2255 write!(self.out, "]")?;
2256
2257 write!(self.out, " = ")?;
2258 self.write_expr(module, value, func_ctx)?;
2259 writeln!(self.out, ";")?;
2260 }
2261 Statement::Call {
2262 function,
2263 ref arguments,
2264 result,
2265 } => {
2266 write!(self.out, "{level}")?;
2267 if let Some(expr) = result {
2268 write!(self.out, "const ")?;
2269 let name = Baked(expr).to_string();
2270 let expr_ty = &func_ctx.info[expr].ty;
2271 let ty_inner = match *expr_ty {
2272 proc::TypeResolution::Handle(handle) => {
2273 self.write_type(module, handle)?;
2274 &module.types[handle].inner
2275 }
2276 proc::TypeResolution::Value(ref value) => {
2277 self.write_value_type(module, value)?;
2278 value
2279 }
2280 };
2281 write!(self.out, " {name}")?;
2282 if let TypeInner::Array { base, size, .. } = *ty_inner {
2283 self.write_array_size(module, base, size)?;
2284 }
2285 write!(self.out, " = ")?;
2286 self.named_expressions.insert(expr, name);
2287 }
2288 let func_name = &self.names[&NameKey::Function(function)];
2289 write!(self.out, "{func_name}(")?;
2290 for (index, argument) in arguments.iter().enumerate() {
2291 if index != 0 {
2292 write!(self.out, ", ")?;
2293 }
2294 self.write_expr(module, *argument, func_ctx)?;
2295 }
2296 writeln!(self.out, ");")?
2297 }
2298 Statement::Atomic {
2299 pointer,
2300 ref fun,
2301 value,
2302 result,
2303 } => {
2304 write!(self.out, "{level}")?;
2305 let res_var_info = if let Some(res_handle) = result {
2306 let name = Baked(res_handle).to_string();
2307 match func_ctx.info[res_handle].ty {
2308 proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2309 proc::TypeResolution::Value(ref value) => {
2310 self.write_value_type(module, value)?
2311 }
2312 };
2313 write!(self.out, " {name}; ")?;
2314 self.named_expressions.insert(res_handle, name.clone());
2315 Some((res_handle, name))
2316 } else {
2317 None
2318 };
2319 let pointer_space = func_ctx
2320 .resolve_type(pointer, &module.types)
2321 .pointer_space()
2322 .unwrap();
2323 let fun_str = fun.to_hlsl_suffix();
2324 let compare_expr = match *fun {
2325 crate::AtomicFunction::Exchange { compare: Some(cmp) } => Some(cmp),
2326 _ => None,
2327 };
2328 match pointer_space {
2329 crate::AddressSpace::WorkGroup => {
2330 write!(self.out, "Interlocked{fun_str}(")?;
2331 self.write_expr(module, pointer, func_ctx)?;
2332 self.emit_hlsl_atomic_tail(
2333 module,
2334 func_ctx,
2335 fun,
2336 compare_expr,
2337 value,
2338 &res_var_info,
2339 )?;
2340 }
2341 crate::AddressSpace::Storage { .. } => {
2342 let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2343 let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
2344 let width = match func_ctx.resolve_type(value, &module.types) {
2345 &TypeInner::Scalar(Scalar { width: 8, .. }) => "64",
2346 _ => "",
2347 };
2348 write!(self.out, "{var_name}.Interlocked{fun_str}{width}(")?;
2349 let chain = mem::take(&mut self.temp_access_chain);
2350 self.write_storage_address(module, &chain, func_ctx)?;
2351 self.temp_access_chain = chain;
2352 self.emit_hlsl_atomic_tail(
2353 module,
2354 func_ctx,
2355 fun,
2356 compare_expr,
2357 value,
2358 &res_var_info,
2359 )?;
2360 }
2361 ref other => {
2362 return Err(Error::Custom(format!(
2363 "invalid address space {other:?} for atomic statement"
2364 )))
2365 }
2366 }
2367 if let Some(cmp) = compare_expr {
2368 if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
2369 write!(
2370 self.out,
2371 "{level}{res_name}.exchanged = ({res_name}.old_value == "
2372 )?;
2373 self.write_expr(module, cmp, func_ctx)?;
2374 writeln!(self.out, ");")?;
2375 }
2376 }
2377 }
2378 Statement::ImageAtomic {
2379 image,
2380 coordinate,
2381 array_index,
2382 fun,
2383 value,
2384 } => {
2385 write!(self.out, "{level}")?;
2386
2387 let fun_str = fun.to_hlsl_suffix();
2388 write!(self.out, "Interlocked{fun_str}(")?;
2389 self.write_expr(module, image, func_ctx)?;
2390 write!(self.out, "[")?;
2391 self.write_texture_coordinates(
2392 "int",
2393 coordinate,
2394 array_index,
2395 None,
2396 module,
2397 func_ctx,
2398 )?;
2399 write!(self.out, "],")?;
2400
2401 self.write_expr(module, value, func_ctx)?;
2402 writeln!(self.out, ");")?;
2403 }
2404 Statement::WorkGroupUniformLoad { pointer, result } => {
2405 self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
2406 write!(self.out, "{level}")?;
2407 let name = Baked(result).to_string();
2408 self.write_named_expr(module, pointer, name, result, func_ctx)?;
2409
2410 self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
2411 }
2412 Statement::Switch {
2413 selector,
2414 ref cases,
2415 } => {
2416 self.write_switch(module, func_ctx, level, selector, cases)?;
2417 }
2418 Statement::RayQuery { query, ref fun } => match *fun {
2419 RayQueryFunction::Initialize {
2420 acceleration_structure,
2421 descriptor,
2422 } => {
2423 write!(self.out, "{level}")?;
2424 self.write_expr(module, query, func_ctx)?;
2425 write!(self.out, ".TraceRayInline(")?;
2426 self.write_expr(module, acceleration_structure, func_ctx)?;
2427 write!(self.out, ", ")?;
2428 self.write_expr(module, descriptor, func_ctx)?;
2429 write!(self.out, ".flags, ")?;
2430 self.write_expr(module, descriptor, func_ctx)?;
2431 write!(self.out, ".cull_mask, ")?;
2432 write!(self.out, "RayDescFromRayDesc_(")?;
2433 self.write_expr(module, descriptor, func_ctx)?;
2434 writeln!(self.out, "));")?;
2435 }
2436 RayQueryFunction::Proceed { result } => {
2437 write!(self.out, "{level}")?;
2438 let name = Baked(result).to_string();
2439 write!(self.out, "const bool {name} = ")?;
2440 self.named_expressions.insert(result, name);
2441 self.write_expr(module, query, func_ctx)?;
2442 writeln!(self.out, ".Proceed();")?;
2443 }
2444 RayQueryFunction::GenerateIntersection { hit_t } => {
2445 write!(self.out, "{level}")?;
2446 self.write_expr(module, query, func_ctx)?;
2447 write!(self.out, ".CommitProceduralPrimitiveHit(")?;
2448 self.write_expr(module, hit_t, func_ctx)?;
2449 writeln!(self.out, ");")?;
2450 }
2451 RayQueryFunction::ConfirmIntersection => {
2452 write!(self.out, "{level}")?;
2453 self.write_expr(module, query, func_ctx)?;
2454 writeln!(self.out, ".CommitNonOpaqueTriangleHit();")?;
2455 }
2456 RayQueryFunction::Terminate => {
2457 write!(self.out, "{level}")?;
2458 self.write_expr(module, query, func_ctx)?;
2459 writeln!(self.out, ".Abort();")?;
2460 }
2461 },
2462 Statement::SubgroupBallot { result, predicate } => {
2463 write!(self.out, "{level}")?;
2464 let name = Baked(result).to_string();
2465 write!(self.out, "const uint4 {name} = ")?;
2466 self.named_expressions.insert(result, name);
2467
2468 write!(self.out, "WaveActiveBallot(")?;
2469 match predicate {
2470 Some(predicate) => self.write_expr(module, predicate, func_ctx)?,
2471 None => write!(self.out, "true")?,
2472 }
2473 writeln!(self.out, ");")?;
2474 }
2475 Statement::SubgroupCollectiveOperation {
2476 op,
2477 collective_op,
2478 argument,
2479 result,
2480 } => {
2481 write!(self.out, "{level}")?;
2482 write!(self.out, "const ")?;
2483 let name = Baked(result).to_string();
2484 match func_ctx.info[result].ty {
2485 proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2486 proc::TypeResolution::Value(ref value) => {
2487 self.write_value_type(module, value)?
2488 }
2489 };
2490 write!(self.out, " {name} = ")?;
2491 self.named_expressions.insert(result, name);
2492
2493 match (collective_op, op) {
2494 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
2495 write!(self.out, "WaveActiveAllTrue(")?
2496 }
2497 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
2498 write!(self.out, "WaveActiveAnyTrue(")?
2499 }
2500 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
2501 write!(self.out, "WaveActiveSum(")?
2502 }
2503 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
2504 write!(self.out, "WaveActiveProduct(")?
2505 }
2506 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
2507 write!(self.out, "WaveActiveMax(")?
2508 }
2509 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
2510 write!(self.out, "WaveActiveMin(")?
2511 }
2512 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
2513 write!(self.out, "WaveActiveBitAnd(")?
2514 }
2515 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
2516 write!(self.out, "WaveActiveBitOr(")?
2517 }
2518 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
2519 write!(self.out, "WaveActiveBitXor(")?
2520 }
2521 (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
2522 write!(self.out, "WavePrefixSum(")?
2523 }
2524 (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
2525 write!(self.out, "WavePrefixProduct(")?
2526 }
2527 (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
2528 self.write_expr(module, argument, func_ctx)?;
2529 write!(self.out, " + WavePrefixSum(")?;
2530 }
2531 (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
2532 self.write_expr(module, argument, func_ctx)?;
2533 write!(self.out, " * WavePrefixProduct(")?;
2534 }
2535 _ => unimplemented!(),
2536 }
2537 self.write_expr(module, argument, func_ctx)?;
2538 writeln!(self.out, ");")?;
2539 }
2540 Statement::SubgroupGather {
2541 mode,
2542 argument,
2543 result,
2544 } => {
2545 write!(self.out, "{level}")?;
2546 write!(self.out, "const ")?;
2547 let name = Baked(result).to_string();
2548 match func_ctx.info[result].ty {
2549 proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2550 proc::TypeResolution::Value(ref value) => {
2551 self.write_value_type(module, value)?
2552 }
2553 };
2554 write!(self.out, " {name} = ")?;
2555 self.named_expressions.insert(result, name);
2556 match mode {
2557 crate::GatherMode::BroadcastFirst => {
2558 write!(self.out, "WaveReadLaneFirst(")?;
2559 self.write_expr(module, argument, func_ctx)?;
2560 }
2561 crate::GatherMode::QuadBroadcast(index) => {
2562 write!(self.out, "QuadReadLaneAt(")?;
2563 self.write_expr(module, argument, func_ctx)?;
2564 write!(self.out, ", ")?;
2565 self.write_expr(module, index, func_ctx)?;
2566 }
2567 crate::GatherMode::QuadSwap(direction) => {
2568 match direction {
2569 crate::Direction::X => {
2570 write!(self.out, "QuadReadAcrossX(")?;
2571 }
2572 crate::Direction::Y => {
2573 write!(self.out, "QuadReadAcrossY(")?;
2574 }
2575 crate::Direction::Diagonal => {
2576 write!(self.out, "QuadReadAcrossDiagonal(")?;
2577 }
2578 }
2579 self.write_expr(module, argument, func_ctx)?;
2580 }
2581 _ => {
2582 write!(self.out, "WaveReadLaneAt(")?;
2583 self.write_expr(module, argument, func_ctx)?;
2584 write!(self.out, ", ")?;
2585 match mode {
2586 crate::GatherMode::BroadcastFirst => unreachable!(),
2587 crate::GatherMode::Broadcast(index)
2588 | crate::GatherMode::Shuffle(index) => {
2589 self.write_expr(module, index, func_ctx)?;
2590 }
2591 crate::GatherMode::ShuffleDown(index) => {
2592 write!(self.out, "WaveGetLaneIndex() + ")?;
2593 self.write_expr(module, index, func_ctx)?;
2594 }
2595 crate::GatherMode::ShuffleUp(index) => {
2596 write!(self.out, "WaveGetLaneIndex() - ")?;
2597 self.write_expr(module, index, func_ctx)?;
2598 }
2599 crate::GatherMode::ShuffleXor(index) => {
2600 write!(self.out, "WaveGetLaneIndex() ^ ")?;
2601 self.write_expr(module, index, func_ctx)?;
2602 }
2603 crate::GatherMode::QuadBroadcast(_) => unreachable!(),
2604 crate::GatherMode::QuadSwap(_) => unreachable!(),
2605 }
2606 }
2607 }
2608 writeln!(self.out, ");")?;
2609 }
2610 }
2611
2612 Ok(())
2613 }
2614
2615 fn write_const_expression(
2616 &mut self,
2617 module: &Module,
2618 expr: Handle<crate::Expression>,
2619 arena: &crate::Arena<crate::Expression>,
2620 ) -> BackendResult {
2621 self.write_possibly_const_expression(module, expr, arena, |writer, expr| {
2622 writer.write_const_expression(module, expr, arena)
2623 })
2624 }
2625
2626 pub(super) fn write_literal(&mut self, literal: crate::Literal) -> BackendResult {
2627 match literal {
2628 crate::Literal::F64(value) => write!(self.out, "{value:?}L")?,
2629 crate::Literal::F32(value) => write!(self.out, "{value:?}")?,
2630 crate::Literal::F16(value) => write!(self.out, "{value:?}h")?,
2631 crate::Literal::U32(value) => write!(self.out, "{value}u")?,
2632 crate::Literal::I32(value) if value == i32::MIN => {
2638 write!(self.out, "int({} - 1)", value + 1)?
2639 }
2640 crate::Literal::I32(value) => write!(self.out, "int({value})")?,
2644 crate::Literal::U64(value) => write!(self.out, "{value}uL")?,
2645 crate::Literal::I64(value) if value == i64::MIN => {
2647 write!(self.out, "({}L - 1L)", value + 1)?;
2648 }
2649 crate::Literal::I64(value) => write!(self.out, "{value}L")?,
2650 crate::Literal::Bool(value) => write!(self.out, "{value}")?,
2651 crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
2652 return Err(Error::Custom(
2653 "Abstract types should not appear in IR presented to backends".into(),
2654 ));
2655 }
2656 }
2657 Ok(())
2658 }
2659
2660 fn write_possibly_const_expression<E>(
2661 &mut self,
2662 module: &Module,
2663 expr: Handle<crate::Expression>,
2664 expressions: &crate::Arena<crate::Expression>,
2665 write_expression: E,
2666 ) -> BackendResult
2667 where
2668 E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult,
2669 {
2670 use crate::Expression;
2671
2672 match expressions[expr] {
2673 Expression::Literal(literal) => {
2674 self.write_literal(literal)?;
2675 }
2676 Expression::Constant(handle) => {
2677 let constant = &module.constants[handle];
2678 if constant.name.is_some() {
2679 write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
2680 } else {
2681 self.write_const_expression(module, constant.init, &module.global_expressions)?;
2682 }
2683 }
2684 Expression::ZeroValue(ty) => {
2685 self.write_wrapped_zero_value_function_name(module, WrappedZeroValue { ty })?;
2686 write!(self.out, "()")?;
2687 }
2688 Expression::Compose { ty, ref components } => {
2689 match module.types[ty].inner {
2690 TypeInner::Struct { .. } | TypeInner::Array { .. } => {
2691 self.write_wrapped_constructor_function_name(
2692 module,
2693 WrappedConstructor { ty },
2694 )?;
2695 }
2696 _ => {
2697 self.write_type(module, ty)?;
2698 }
2699 };
2700 write!(self.out, "(")?;
2701 for (index, component) in components.iter().enumerate() {
2702 if index != 0 {
2703 write!(self.out, ", ")?;
2704 }
2705 write_expression(self, *component)?;
2706 }
2707 write!(self.out, ")")?;
2708 }
2709 Expression::Splat { size, value } => {
2710 let number_of_components = match size {
2714 crate::VectorSize::Bi => "xx",
2715 crate::VectorSize::Tri => "xxx",
2716 crate::VectorSize::Quad => "xxxx",
2717 };
2718 write!(self.out, "(")?;
2719 write_expression(self, value)?;
2720 write!(self.out, ").{number_of_components}")?
2721 }
2722 _ => {
2723 return Err(Error::Override);
2724 }
2725 }
2726
2727 Ok(())
2728 }
2729
2730 pub(super) fn write_expr(
2735 &mut self,
2736 module: &Module,
2737 expr: Handle<crate::Expression>,
2738 func_ctx: &back::FunctionCtx<'_>,
2739 ) -> BackendResult {
2740 use crate::Expression;
2741
2742 let ff_input = if self.options.special_constants_binding.is_some() {
2744 func_ctx.is_fixed_function_input(expr, module)
2745 } else {
2746 None
2747 };
2748 let closing_bracket = match ff_input {
2749 Some(crate::BuiltIn::VertexIndex) => {
2750 write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX} + ")?;
2751 ")"
2752 }
2753 Some(crate::BuiltIn::InstanceIndex) => {
2754 write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE} + ",)?;
2755 ")"
2756 }
2757 Some(crate::BuiltIn::NumWorkGroups) => {
2758 write!(
2762 self.out,
2763 "uint3({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX}, {SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE}, {SPECIAL_CBUF_VAR}.{SPECIAL_OTHER})",
2764 )?;
2765 return Ok(());
2766 }
2767 _ => "",
2768 };
2769
2770 if let Some(name) = self.named_expressions.get(&expr) {
2771 write!(self.out, "{name}{closing_bracket}")?;
2772 return Ok(());
2773 }
2774
2775 let expression = &func_ctx.expressions[expr];
2776
2777 match *expression {
2778 Expression::Literal(_)
2779 | Expression::Constant(_)
2780 | Expression::ZeroValue(_)
2781 | Expression::Compose { .. }
2782 | Expression::Splat { .. } => {
2783 self.write_possibly_const_expression(
2784 module,
2785 expr,
2786 func_ctx.expressions,
2787 |writer, expr| writer.write_expr(module, expr, func_ctx),
2788 )?;
2789 }
2790 Expression::Override(_) => return Err(Error::Override),
2791 Expression::Binary {
2798 op:
2799 op @ crate::BinaryOperator::Add
2800 | op @ crate::BinaryOperator::Subtract
2801 | op @ crate::BinaryOperator::Multiply,
2802 left,
2803 right,
2804 } if matches!(
2805 func_ctx.resolve_type(expr, &module.types).scalar(),
2806 Some(Scalar::I32)
2807 ) =>
2808 {
2809 write!(self.out, "asint(asuint(",)?;
2810 self.write_expr(module, left, func_ctx)?;
2811 write!(self.out, ") {} asuint(", back::binary_operation_str(op))?;
2812 self.write_expr(module, right, func_ctx)?;
2813 write!(self.out, "))")?;
2814 }
2815 Expression::Binary {
2818 op: crate::BinaryOperator::Multiply,
2819 left,
2820 right,
2821 } if func_ctx.resolve_type(left, &module.types).is_matrix()
2822 || func_ctx.resolve_type(right, &module.types).is_matrix() =>
2823 {
2824 write!(self.out, "mul(")?;
2826 self.write_expr(module, right, func_ctx)?;
2827 write!(self.out, ", ")?;
2828 self.write_expr(module, left, func_ctx)?;
2829 write!(self.out, ")")?;
2830 }
2831
2832 Expression::Binary {
2844 op: crate::BinaryOperator::Divide,
2845 left,
2846 right,
2847 } if matches!(
2848 func_ctx.resolve_type(expr, &module.types).scalar_kind(),
2849 Some(ScalarKind::Sint | ScalarKind::Uint)
2850 ) =>
2851 {
2852 write!(self.out, "{DIV_FUNCTION}(")?;
2853 self.write_expr(module, left, func_ctx)?;
2854 write!(self.out, ", ")?;
2855 self.write_expr(module, right, func_ctx)?;
2856 write!(self.out, ")")?;
2857 }
2858
2859 Expression::Binary {
2860 op: crate::BinaryOperator::Modulo,
2861 left,
2862 right,
2863 } if matches!(
2864 func_ctx.resolve_type(expr, &module.types).scalar_kind(),
2865 Some(ScalarKind::Sint | ScalarKind::Uint | ScalarKind::Float)
2866 ) =>
2867 {
2868 write!(self.out, "{MOD_FUNCTION}(")?;
2869 self.write_expr(module, left, func_ctx)?;
2870 write!(self.out, ", ")?;
2871 self.write_expr(module, right, func_ctx)?;
2872 write!(self.out, ")")?;
2873 }
2874
2875 Expression::Binary { op, left, right } => {
2876 write!(self.out, "(")?;
2877 self.write_expr(module, left, func_ctx)?;
2878 write!(self.out, " {} ", back::binary_operation_str(op))?;
2879 self.write_expr(module, right, func_ctx)?;
2880 write!(self.out, ")")?;
2881 }
2882 Expression::Access { base, index } => {
2883 if let Some(crate::AddressSpace::Storage { .. }) =
2884 func_ctx.resolve_type(expr, &module.types).pointer_space()
2885 {
2886 } else {
2888 if let Some(MatrixType {
2895 columns,
2896 rows: crate::VectorSize::Bi,
2897 width: 4,
2898 }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
2899 .or_else(|| get_global_uniform_matrix(module, base, func_ctx))
2900 {
2901 write!(self.out, "__get_col_of_mat{}x2(", columns as u8)?;
2902 self.write_expr(module, base, func_ctx)?;
2903 write!(self.out, ", ")?;
2904 self.write_expr(module, index, func_ctx)?;
2905 write!(self.out, ")")?;
2906 return Ok(());
2907 }
2908
2909 let resolved = func_ctx.resolve_type(base, &module.types);
2910
2911 let (indexing_binding_array, non_uniform_qualifier) = match *resolved {
2912 TypeInner::BindingArray { .. } => {
2913 let uniformity = &func_ctx.info[index].uniformity;
2914
2915 (true, uniformity.non_uniform_result.is_some())
2916 }
2917 _ => (false, false),
2918 };
2919
2920 self.write_expr(module, base, func_ctx)?;
2921
2922 let array_sampler_info = self.sampler_binding_array_info_from_expression(
2923 module, func_ctx, base, resolved,
2924 );
2925
2926 if let Some(ref info) = array_sampler_info {
2927 write!(self.out, "{}[", info.sampler_heap_name)?;
2928 } else {
2929 write!(self.out, "[")?;
2930 }
2931
2932 let needs_bound_check = self.options.restrict_indexing
2933 && !indexing_binding_array
2934 && match resolved.pointer_space() {
2935 Some(
2936 crate::AddressSpace::Function
2937 | crate::AddressSpace::Private
2938 | crate::AddressSpace::WorkGroup
2939 | crate::AddressSpace::PushConstant,
2940 )
2941 | None => true,
2942 Some(crate::AddressSpace::Uniform) => {
2943 let var_handle = self.fill_access_chain(module, base, func_ctx)?;
2945 let bind_target = self
2946 .options
2947 .resolve_resource_binding(
2948 module.global_variables[var_handle]
2949 .binding
2950 .as_ref()
2951 .unwrap(),
2952 )
2953 .unwrap();
2954 bind_target.restrict_indexing
2955 }
2956 Some(
2957 crate::AddressSpace::Handle | crate::AddressSpace::Storage { .. },
2958 ) => unreachable!(),
2959 };
2960 let restriction_needed = if needs_bound_check {
2962 index::access_needs_check(
2963 base,
2964 index::GuardedIndex::Expression(index),
2965 module,
2966 func_ctx.expressions,
2967 func_ctx.info,
2968 )
2969 } else {
2970 None
2971 };
2972 if let Some(limit) = restriction_needed {
2973 write!(self.out, "min(uint(")?;
2974 self.write_expr(module, index, func_ctx)?;
2975 write!(self.out, "), ")?;
2976 match limit {
2977 index::IndexableLength::Known(limit) => {
2978 write!(self.out, "{}u", limit - 1)?;
2979 }
2980 index::IndexableLength::Dynamic => unreachable!(),
2981 }
2982 write!(self.out, ")")?;
2983 } else {
2984 if non_uniform_qualifier {
2985 write!(self.out, "NonUniformResourceIndex(")?;
2986 }
2987 if let Some(ref info) = array_sampler_info {
2988 write!(
2989 self.out,
2990 "{}[{} + ",
2991 info.sampler_index_buffer_name, info.binding_array_base_index_name,
2992 )?;
2993 }
2994 self.write_expr(module, index, func_ctx)?;
2995 if array_sampler_info.is_some() {
2996 write!(self.out, "]")?;
2997 }
2998 if non_uniform_qualifier {
2999 write!(self.out, ")")?;
3000 }
3001 }
3002
3003 write!(self.out, "]")?;
3004 }
3005 }
3006 Expression::AccessIndex { base, index } => {
3007 if let Some(crate::AddressSpace::Storage { .. }) =
3008 func_ctx.resolve_type(expr, &module.types).pointer_space()
3009 {
3010 } else {
3012 if let Some(MatrixType {
3016 rows: crate::VectorSize::Bi,
3017 width: 4,
3018 ..
3019 }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
3020 .or_else(|| get_global_uniform_matrix(module, base, func_ctx))
3021 {
3022 self.write_expr(module, base, func_ctx)?;
3023 write!(self.out, "._{index}")?;
3024 return Ok(());
3025 }
3026
3027 let base_ty_res = &func_ctx.info[base].ty;
3028 let mut resolved = base_ty_res.inner_with(&module.types);
3029 let base_ty_handle = match *resolved {
3030 TypeInner::Pointer { base, .. } => {
3031 resolved = &module.types[base].inner;
3032 Some(base)
3033 }
3034 _ => base_ty_res.handle(),
3035 };
3036
3037 if let TypeInner::Struct { ref members, .. } = *resolved {
3043 let member = &members[index as usize];
3044
3045 match module.types[member.ty].inner {
3046 TypeInner::Matrix {
3047 rows: crate::VectorSize::Bi,
3048 ..
3049 } if member.binding.is_none() => {
3050 let ty = base_ty_handle.unwrap();
3051 self.write_wrapped_struct_matrix_get_function_name(
3052 WrappedStructMatrixAccess { ty, index },
3053 )?;
3054 write!(self.out, "(")?;
3055 self.write_expr(module, base, func_ctx)?;
3056 write!(self.out, ")")?;
3057 return Ok(());
3058 }
3059 _ => {}
3060 }
3061 }
3062
3063 let array_sampler_info = self.sampler_binding_array_info_from_expression(
3064 module, func_ctx, base, resolved,
3065 );
3066
3067 if let Some(ref info) = array_sampler_info {
3068 write!(
3069 self.out,
3070 "{}[{}",
3071 info.sampler_heap_name, info.sampler_index_buffer_name
3072 )?;
3073 }
3074
3075 self.write_expr(module, base, func_ctx)?;
3076
3077 match *resolved {
3078 TypeInner::Vector { .. } | TypeInner::ValuePointer { .. } => {
3084 write!(self.out, ".{}", back::COMPONENTS[index as usize])?
3086 }
3087 TypeInner::Matrix { .. }
3088 | TypeInner::Array { .. }
3089 | TypeInner::BindingArray { .. } => {
3090 if let Some(ref info) = array_sampler_info {
3091 write!(
3092 self.out,
3093 "[{} + {index}]",
3094 info.binding_array_base_index_name
3095 )?;
3096 } else {
3097 write!(self.out, "[{index}]")?;
3098 }
3099 }
3100 TypeInner::Struct { .. } => {
3101 let ty = base_ty_handle.unwrap();
3104
3105 write!(
3106 self.out,
3107 ".{}",
3108 &self.names[&NameKey::StructMember(ty, index)]
3109 )?
3110 }
3111 ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))),
3112 }
3113
3114 if array_sampler_info.is_some() {
3115 write!(self.out, "]")?;
3116 }
3117 }
3118 }
3119 Expression::FunctionArgument(pos) => {
3120 let key = func_ctx.argument_key(pos);
3121 let name = &self.names[&key];
3122 write!(self.out, "{name}")?;
3123 }
3124 Expression::ImageSample {
3125 coordinate,
3126 image,
3127 sampler,
3128 clamp_to_edge: true,
3129 gather: None,
3130 array_index: None,
3131 offset: None,
3132 level: crate::SampleLevel::Zero,
3133 depth_ref: None,
3134 } => {
3135 write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?;
3136 self.write_expr(module, image, func_ctx)?;
3137 write!(self.out, ", ")?;
3138 self.write_expr(module, sampler, func_ctx)?;
3139 write!(self.out, ", ")?;
3140 self.write_expr(module, coordinate, func_ctx)?;
3141 write!(self.out, ")")?;
3142 }
3143 Expression::ImageSample {
3144 image,
3145 sampler,
3146 gather,
3147 coordinate,
3148 array_index,
3149 offset,
3150 level,
3151 depth_ref,
3152 clamp_to_edge,
3153 } => {
3154 if clamp_to_edge {
3155 return Err(Error::Custom(
3156 "ImageSample::clamp_to_edge should have been validated out".to_string(),
3157 ));
3158 }
3159
3160 use crate::SampleLevel as Sl;
3161 const COMPONENTS: [&str; 4] = ["", "Green", "Blue", "Alpha"];
3162
3163 let (base_str, component_str) = match gather {
3164 Some(component) => ("Gather", COMPONENTS[component as usize]),
3165 None => ("Sample", ""),
3166 };
3167 let cmp_str = match depth_ref {
3168 Some(_) => "Cmp",
3169 None => "",
3170 };
3171 let level_str = match level {
3172 Sl::Zero if gather.is_none() => "LevelZero",
3173 Sl::Auto | Sl::Zero => "",
3174 Sl::Exact(_) => "Level",
3175 Sl::Bias(_) => "Bias",
3176 Sl::Gradient { .. } => "Grad",
3177 };
3178
3179 self.write_expr(module, image, func_ctx)?;
3180 write!(self.out, ".{base_str}{cmp_str}{component_str}{level_str}(")?;
3181 self.write_expr(module, sampler, func_ctx)?;
3182 write!(self.out, ", ")?;
3183 self.write_texture_coordinates(
3184 "float",
3185 coordinate,
3186 array_index,
3187 None,
3188 module,
3189 func_ctx,
3190 )?;
3191
3192 if let Some(depth_ref) = depth_ref {
3193 write!(self.out, ", ")?;
3194 self.write_expr(module, depth_ref, func_ctx)?;
3195 }
3196
3197 match level {
3198 Sl::Auto | Sl::Zero => {}
3199 Sl::Exact(expr) => {
3200 write!(self.out, ", ")?;
3201 self.write_expr(module, expr, func_ctx)?;
3202 }
3203 Sl::Bias(expr) => {
3204 write!(self.out, ", ")?;
3205 self.write_expr(module, expr, func_ctx)?;
3206 }
3207 Sl::Gradient { x, y } => {
3208 write!(self.out, ", ")?;
3209 self.write_expr(module, x, func_ctx)?;
3210 write!(self.out, ", ")?;
3211 self.write_expr(module, y, func_ctx)?;
3212 }
3213 }
3214
3215 if let Some(offset) = offset {
3216 write!(self.out, ", ")?;
3217 write!(self.out, "int2(")?; self.write_const_expression(module, offset, func_ctx.expressions)?;
3219 write!(self.out, ")")?;
3220 }
3221
3222 write!(self.out, ")")?;
3223 }
3224 Expression::ImageQuery { image, query } => {
3225 if let TypeInner::Image {
3227 dim,
3228 arrayed,
3229 class,
3230 } = *func_ctx.resolve_type(image, &module.types)
3231 {
3232 let wrapped_image_query = WrappedImageQuery {
3233 dim,
3234 arrayed,
3235 class,
3236 query: query.into(),
3237 };
3238
3239 self.write_wrapped_image_query_function_name(wrapped_image_query)?;
3240 write!(self.out, "(")?;
3241 self.write_expr(module, image, func_ctx)?;
3243 if let crate::ImageQuery::Size { level: Some(level) } = query {
3244 write!(self.out, ", ")?;
3245 self.write_expr(module, level, func_ctx)?;
3246 }
3247 write!(self.out, ")")?;
3248 }
3249 }
3250 Expression::ImageLoad {
3251 image,
3252 coordinate,
3253 array_index,
3254 sample,
3255 level,
3256 } => self.write_image_load(
3257 &module,
3258 expr,
3259 func_ctx,
3260 image,
3261 coordinate,
3262 array_index,
3263 sample,
3264 level,
3265 )?,
3266 Expression::GlobalVariable(handle) => {
3267 let global_variable = &module.global_variables[handle];
3268 let ty = &module.types[global_variable.ty].inner;
3269
3270 let is_binding_array_of_samplers = match *ty {
3275 TypeInner::BindingArray { base, .. } => {
3276 let base_ty = &module.types[base].inner;
3277 matches!(*base_ty, TypeInner::Sampler { .. })
3278 }
3279 _ => false,
3280 };
3281
3282 let is_storage_space =
3283 matches!(global_variable.space, crate::AddressSpace::Storage { .. });
3284
3285 if !is_binding_array_of_samplers && !is_storage_space {
3286 let name = &self.names[&NameKey::GlobalVariable(handle)];
3287 write!(self.out, "{name}")?;
3288 }
3289 }
3290 Expression::LocalVariable(handle) => {
3291 write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])?
3292 }
3293 Expression::Load { pointer } => {
3294 match func_ctx
3295 .resolve_type(pointer, &module.types)
3296 .pointer_space()
3297 {
3298 Some(crate::AddressSpace::Storage { .. }) => {
3299 let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
3300 let result_ty = func_ctx.info[expr].ty.clone();
3301 self.write_storage_load(module, var_handle, result_ty, func_ctx)?;
3302 }
3303 _ => {
3304 let mut close_paren = false;
3305
3306 if let Some(MatrixType {
3311 rows: crate::VectorSize::Bi,
3312 width: 4,
3313 ..
3314 }) = get_inner_matrix_of_struct_array_member(
3315 module, pointer, func_ctx, false,
3316 )
3317 .or_else(|| get_inner_matrix_of_global_uniform(module, pointer, func_ctx))
3318 {
3319 let mut resolved = func_ctx.resolve_type(pointer, &module.types);
3320 let ptr_tr = resolved.pointer_base_type();
3321 if let Some(ptr_ty) =
3322 ptr_tr.as_ref().map(|tr| tr.inner_with(&module.types))
3323 {
3324 resolved = ptr_ty;
3325 }
3326
3327 write!(self.out, "((")?;
3328 if let TypeInner::Array { base, size, .. } = *resolved {
3329 self.write_type(module, base)?;
3330 self.write_array_size(module, base, size)?;
3331 } else {
3332 self.write_value_type(module, resolved)?;
3333 }
3334 write!(self.out, ")")?;
3335 close_paren = true;
3336 }
3337
3338 self.write_expr(module, pointer, func_ctx)?;
3339
3340 if close_paren {
3341 write!(self.out, ")")?;
3342 }
3343 }
3344 }
3345 }
3346 Expression::Unary { op, expr } => {
3347 let op_str = match op {
3349 crate::UnaryOperator::Negate => {
3350 match func_ctx.resolve_type(expr, &module.types).scalar() {
3351 Some(Scalar::I32) => NEG_FUNCTION,
3352 _ => "-",
3353 }
3354 }
3355 crate::UnaryOperator::LogicalNot => "!",
3356 crate::UnaryOperator::BitwiseNot => "~",
3357 };
3358 write!(self.out, "{op_str}(")?;
3359 self.write_expr(module, expr, func_ctx)?;
3360 write!(self.out, ")")?;
3361 }
3362 Expression::As {
3363 expr,
3364 kind,
3365 convert,
3366 } => {
3367 let inner = func_ctx.resolve_type(expr, &module.types);
3368 if inner.scalar_kind() == Some(ScalarKind::Float)
3369 && (kind == ScalarKind::Sint || kind == ScalarKind::Uint)
3370 && convert.is_some()
3371 {
3372 let fun_name = match (kind, convert) {
3376 (ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
3377 (ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
3378 (ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
3379 (ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
3380 _ => unreachable!(),
3381 };
3382 write!(self.out, "{fun_name}(")?;
3383 self.write_expr(module, expr, func_ctx)?;
3384 write!(self.out, ")")?;
3385 } else {
3386 let close_paren = match convert {
3387 Some(dst_width) => {
3388 let scalar = Scalar {
3389 kind,
3390 width: dst_width,
3391 };
3392 match *inner {
3393 TypeInner::Vector { size, .. } => {
3394 write!(
3395 self.out,
3396 "{}{}(",
3397 scalar.to_hlsl_str()?,
3398 common::vector_size_str(size)
3399 )?;
3400 }
3401 TypeInner::Scalar(_) => {
3402 write!(self.out, "{}(", scalar.to_hlsl_str()?,)?;
3403 }
3404 TypeInner::Matrix { columns, rows, .. } => {
3405 write!(
3406 self.out,
3407 "{}{}x{}(",
3408 scalar.to_hlsl_str()?,
3409 common::vector_size_str(columns),
3410 common::vector_size_str(rows)
3411 )?;
3412 }
3413 _ => {
3414 return Err(Error::Unimplemented(format!(
3415 "write_expr expression::as {inner:?}"
3416 )));
3417 }
3418 };
3419 true
3420 }
3421 None => {
3422 if inner.scalar_width() == Some(8) {
3423 false
3424 } else {
3425 write!(self.out, "{}(", kind.to_hlsl_cast(),)?;
3426 true
3427 }
3428 }
3429 };
3430 self.write_expr(module, expr, func_ctx)?;
3431 if close_paren {
3432 write!(self.out, ")")?;
3433 }
3434 }
3435 }
3436 Expression::Math {
3437 fun,
3438 arg,
3439 arg1,
3440 arg2,
3441 arg3,
3442 } => {
3443 use crate::MathFunction as Mf;
3444
3445 enum Function {
3446 Asincosh { is_sin: bool },
3447 Atanh,
3448 Pack2x16float,
3449 Pack2x16snorm,
3450 Pack2x16unorm,
3451 Pack4x8snorm,
3452 Pack4x8unorm,
3453 Pack4xI8,
3454 Pack4xU8,
3455 Pack4xI8Clamp,
3456 Pack4xU8Clamp,
3457 Unpack2x16float,
3458 Unpack2x16snorm,
3459 Unpack2x16unorm,
3460 Unpack4x8snorm,
3461 Unpack4x8unorm,
3462 Unpack4xI8,
3463 Unpack4xU8,
3464 Dot4I8Packed,
3465 Dot4U8Packed,
3466 QuantizeToF16,
3467 Regular(&'static str),
3468 MissingIntOverload(&'static str),
3469 MissingIntReturnType(&'static str),
3470 CountTrailingZeros,
3471 CountLeadingZeros,
3472 }
3473
3474 let fun = match fun {
3475 Mf::Abs => match func_ctx.resolve_type(arg, &module.types).scalar() {
3477 Some(Scalar::I32) => Function::Regular(ABS_FUNCTION),
3478 _ => Function::Regular("abs"),
3479 },
3480 Mf::Min => Function::Regular("min"),
3481 Mf::Max => Function::Regular("max"),
3482 Mf::Clamp => Function::Regular("clamp"),
3483 Mf::Saturate => Function::Regular("saturate"),
3484 Mf::Cos => Function::Regular("cos"),
3486 Mf::Cosh => Function::Regular("cosh"),
3487 Mf::Sin => Function::Regular("sin"),
3488 Mf::Sinh => Function::Regular("sinh"),
3489 Mf::Tan => Function::Regular("tan"),
3490 Mf::Tanh => Function::Regular("tanh"),
3491 Mf::Acos => Function::Regular("acos"),
3492 Mf::Asin => Function::Regular("asin"),
3493 Mf::Atan => Function::Regular("atan"),
3494 Mf::Atan2 => Function::Regular("atan2"),
3495 Mf::Asinh => Function::Asincosh { is_sin: true },
3496 Mf::Acosh => Function::Asincosh { is_sin: false },
3497 Mf::Atanh => Function::Atanh,
3498 Mf::Radians => Function::Regular("radians"),
3499 Mf::Degrees => Function::Regular("degrees"),
3500 Mf::Ceil => Function::Regular("ceil"),
3502 Mf::Floor => Function::Regular("floor"),
3503 Mf::Round => Function::Regular("round"),
3504 Mf::Fract => Function::Regular("frac"),
3505 Mf::Trunc => Function::Regular("trunc"),
3506 Mf::Modf => Function::Regular(MODF_FUNCTION),
3507 Mf::Frexp => Function::Regular(FREXP_FUNCTION),
3508 Mf::Ldexp => Function::Regular("ldexp"),
3509 Mf::Exp => Function::Regular("exp"),
3511 Mf::Exp2 => Function::Regular("exp2"),
3512 Mf::Log => Function::Regular("log"),
3513 Mf::Log2 => Function::Regular("log2"),
3514 Mf::Pow => Function::Regular("pow"),
3515 Mf::Dot => Function::Regular("dot"),
3517 Mf::Dot4I8Packed => Function::Dot4I8Packed,
3518 Mf::Dot4U8Packed => Function::Dot4U8Packed,
3519 Mf::Cross => Function::Regular("cross"),
3521 Mf::Distance => Function::Regular("distance"),
3522 Mf::Length => Function::Regular("length"),
3523 Mf::Normalize => Function::Regular("normalize"),
3524 Mf::FaceForward => Function::Regular("faceforward"),
3525 Mf::Reflect => Function::Regular("reflect"),
3526 Mf::Refract => Function::Regular("refract"),
3527 Mf::Sign => Function::Regular("sign"),
3529 Mf::Fma => Function::Regular("mad"),
3530 Mf::Mix => Function::Regular("lerp"),
3531 Mf::Step => Function::Regular("step"),
3532 Mf::SmoothStep => Function::Regular("smoothstep"),
3533 Mf::Sqrt => Function::Regular("sqrt"),
3534 Mf::InverseSqrt => Function::Regular("rsqrt"),
3535 Mf::Transpose => Function::Regular("transpose"),
3537 Mf::Determinant => Function::Regular("determinant"),
3538 Mf::QuantizeToF16 => Function::QuantizeToF16,
3539 Mf::CountTrailingZeros => Function::CountTrailingZeros,
3541 Mf::CountLeadingZeros => Function::CountLeadingZeros,
3542 Mf::CountOneBits => Function::MissingIntOverload("countbits"),
3543 Mf::ReverseBits => Function::MissingIntOverload("reversebits"),
3544 Mf::FirstTrailingBit => Function::MissingIntReturnType("firstbitlow"),
3545 Mf::FirstLeadingBit => Function::MissingIntReturnType("firstbithigh"),
3546 Mf::ExtractBits => Function::Regular(EXTRACT_BITS_FUNCTION),
3547 Mf::InsertBits => Function::Regular(INSERT_BITS_FUNCTION),
3548 Mf::Pack2x16float => Function::Pack2x16float,
3550 Mf::Pack2x16snorm => Function::Pack2x16snorm,
3551 Mf::Pack2x16unorm => Function::Pack2x16unorm,
3552 Mf::Pack4x8snorm => Function::Pack4x8snorm,
3553 Mf::Pack4x8unorm => Function::Pack4x8unorm,
3554 Mf::Pack4xI8 => Function::Pack4xI8,
3555 Mf::Pack4xU8 => Function::Pack4xU8,
3556 Mf::Pack4xI8Clamp => Function::Pack4xI8Clamp,
3557 Mf::Pack4xU8Clamp => Function::Pack4xU8Clamp,
3558 Mf::Unpack2x16float => Function::Unpack2x16float,
3560 Mf::Unpack2x16snorm => Function::Unpack2x16snorm,
3561 Mf::Unpack2x16unorm => Function::Unpack2x16unorm,
3562 Mf::Unpack4x8snorm => Function::Unpack4x8snorm,
3563 Mf::Unpack4x8unorm => Function::Unpack4x8unorm,
3564 Mf::Unpack4xI8 => Function::Unpack4xI8,
3565 Mf::Unpack4xU8 => Function::Unpack4xU8,
3566 _ => return Err(Error::Unimplemented(format!("write_expr_math {fun:?}"))),
3567 };
3568
3569 match fun {
3570 Function::Asincosh { is_sin } => {
3571 write!(self.out, "log(")?;
3572 self.write_expr(module, arg, func_ctx)?;
3573 write!(self.out, " + sqrt(")?;
3574 self.write_expr(module, arg, func_ctx)?;
3575 write!(self.out, " * ")?;
3576 self.write_expr(module, arg, func_ctx)?;
3577 match is_sin {
3578 true => write!(self.out, " + 1.0))")?,
3579 false => write!(self.out, " - 1.0))")?,
3580 }
3581 }
3582 Function::Atanh => {
3583 write!(self.out, "0.5 * log((1.0 + ")?;
3584 self.write_expr(module, arg, func_ctx)?;
3585 write!(self.out, ") / (1.0 - ")?;
3586 self.write_expr(module, arg, func_ctx)?;
3587 write!(self.out, "))")?;
3588 }
3589 Function::Pack2x16float => {
3590 write!(self.out, "(f32tof16(")?;
3591 self.write_expr(module, arg, func_ctx)?;
3592 write!(self.out, "[0]) | f32tof16(")?;
3593 self.write_expr(module, arg, func_ctx)?;
3594 write!(self.out, "[1]) << 16)")?;
3595 }
3596 Function::Pack2x16snorm => {
3597 let scale = 32767;
3598
3599 write!(self.out, "uint((int(round(clamp(")?;
3600 self.write_expr(module, arg, func_ctx)?;
3601 write!(
3602 self.out,
3603 "[0], -1.0, 1.0) * {scale}.0)) & 0xFFFF) | ((int(round(clamp("
3604 )?;
3605 self.write_expr(module, arg, func_ctx)?;
3606 write!(self.out, "[1], -1.0, 1.0) * {scale}.0)) & 0xFFFF) << 16))",)?;
3607 }
3608 Function::Pack2x16unorm => {
3609 let scale = 65535;
3610
3611 write!(self.out, "(uint(round(clamp(")?;
3612 self.write_expr(module, arg, func_ctx)?;
3613 write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
3614 self.write_expr(module, arg, func_ctx)?;
3615 write!(self.out, "[1], 0.0, 1.0) * {scale}.0)) << 16)")?;
3616 }
3617 Function::Pack4x8snorm => {
3618 let scale = 127;
3619
3620 write!(self.out, "uint((int(round(clamp(")?;
3621 self.write_expr(module, arg, func_ctx)?;
3622 write!(
3623 self.out,
3624 "[0], -1.0, 1.0) * {scale}.0)) & 0xFF) | ((int(round(clamp("
3625 )?;
3626 self.write_expr(module, arg, func_ctx)?;
3627 write!(
3628 self.out,
3629 "[1], -1.0, 1.0) * {scale}.0)) & 0xFF) << 8) | ((int(round(clamp("
3630 )?;
3631 self.write_expr(module, arg, func_ctx)?;
3632 write!(
3633 self.out,
3634 "[2], -1.0, 1.0) * {scale}.0)) & 0xFF) << 16) | ((int(round(clamp("
3635 )?;
3636 self.write_expr(module, arg, func_ctx)?;
3637 write!(self.out, "[3], -1.0, 1.0) * {scale}.0)) & 0xFF) << 24))",)?;
3638 }
3639 Function::Pack4x8unorm => {
3640 let scale = 255;
3641
3642 write!(self.out, "(uint(round(clamp(")?;
3643 self.write_expr(module, arg, func_ctx)?;
3644 write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
3645 self.write_expr(module, arg, func_ctx)?;
3646 write!(
3647 self.out,
3648 "[1], 0.0, 1.0) * {scale}.0)) << 8 | uint(round(clamp("
3649 )?;
3650 self.write_expr(module, arg, func_ctx)?;
3651 write!(
3652 self.out,
3653 "[2], 0.0, 1.0) * {scale}.0)) << 16 | uint(round(clamp("
3654 )?;
3655 self.write_expr(module, arg, func_ctx)?;
3656 write!(self.out, "[3], 0.0, 1.0) * {scale}.0)) << 24)")?;
3657 }
3658 fun @ (Function::Pack4xI8
3659 | Function::Pack4xU8
3660 | Function::Pack4xI8Clamp
3661 | Function::Pack4xU8Clamp) => {
3662 let was_signed =
3663 matches!(fun, Function::Pack4xI8 | Function::Pack4xI8Clamp);
3664 let clamp_bounds = match fun {
3665 Function::Pack4xI8Clamp => Some(("-128", "127")),
3666 Function::Pack4xU8Clamp => Some(("0", "255")),
3667 _ => None,
3668 };
3669 if was_signed {
3670 write!(self.out, "uint(")?;
3671 }
3672 let write_arg = |this: &mut Self| -> BackendResult {
3673 if let Some((min, max)) = clamp_bounds {
3674 write!(this.out, "clamp(")?;
3675 this.write_expr(module, arg, func_ctx)?;
3676 write!(this.out, ", {min}, {max})")?;
3677 } else {
3678 this.write_expr(module, arg, func_ctx)?;
3679 }
3680 Ok(())
3681 };
3682 write!(self.out, "(")?;
3683 write_arg(self)?;
3684 write!(self.out, "[0] & 0xFF) | ((")?;
3685 write_arg(self)?;
3686 write!(self.out, "[1] & 0xFF) << 8) | ((")?;
3687 write_arg(self)?;
3688 write!(self.out, "[2] & 0xFF) << 16) | ((")?;
3689 write_arg(self)?;
3690 write!(self.out, "[3] & 0xFF) << 24)")?;
3691 if was_signed {
3692 write!(self.out, ")")?;
3693 }
3694 }
3695
3696 Function::Unpack2x16float => {
3697 write!(self.out, "float2(f16tof32(")?;
3698 self.write_expr(module, arg, func_ctx)?;
3699 write!(self.out, "), f16tof32((")?;
3700 self.write_expr(module, arg, func_ctx)?;
3701 write!(self.out, ") >> 16))")?;
3702 }
3703 Function::Unpack2x16snorm => {
3704 let scale = 32767;
3705
3706 write!(self.out, "(float2(int2(")?;
3707 self.write_expr(module, arg, func_ctx)?;
3708 write!(self.out, " << 16, ")?;
3709 self.write_expr(module, arg, func_ctx)?;
3710 write!(self.out, ") >> 16) / {scale}.0)")?;
3711 }
3712 Function::Unpack2x16unorm => {
3713 let scale = 65535;
3714
3715 write!(self.out, "(float2(")?;
3716 self.write_expr(module, arg, func_ctx)?;
3717 write!(self.out, " & 0xFFFF, ")?;
3718 self.write_expr(module, arg, func_ctx)?;
3719 write!(self.out, " >> 16) / {scale}.0)")?;
3720 }
3721 Function::Unpack4x8snorm => {
3722 let scale = 127;
3723
3724 write!(self.out, "(float4(int4(")?;
3725 self.write_expr(module, arg, func_ctx)?;
3726 write!(self.out, " << 24, ")?;
3727 self.write_expr(module, arg, func_ctx)?;
3728 write!(self.out, " << 16, ")?;
3729 self.write_expr(module, arg, func_ctx)?;
3730 write!(self.out, " << 8, ")?;
3731 self.write_expr(module, arg, func_ctx)?;
3732 write!(self.out, ") >> 24) / {scale}.0)")?;
3733 }
3734 Function::Unpack4x8unorm => {
3735 let scale = 255;
3736
3737 write!(self.out, "(float4(")?;
3738 self.write_expr(module, arg, func_ctx)?;
3739 write!(self.out, " & 0xFF, ")?;
3740 self.write_expr(module, arg, func_ctx)?;
3741 write!(self.out, " >> 8 & 0xFF, ")?;
3742 self.write_expr(module, arg, func_ctx)?;
3743 write!(self.out, " >> 16 & 0xFF, ")?;
3744 self.write_expr(module, arg, func_ctx)?;
3745 write!(self.out, " >> 24) / {scale}.0)")?;
3746 }
3747 fun @ (Function::Unpack4xI8 | Function::Unpack4xU8) => {
3748 write!(self.out, "(")?;
3749 if matches!(fun, Function::Unpack4xU8) {
3750 write!(self.out, "u")?;
3751 }
3752 write!(self.out, "int4(")?;
3753 self.write_expr(module, arg, func_ctx)?;
3754 write!(self.out, ", ")?;
3755 self.write_expr(module, arg, func_ctx)?;
3756 write!(self.out, " >> 8, ")?;
3757 self.write_expr(module, arg, func_ctx)?;
3758 write!(self.out, " >> 16, ")?;
3759 self.write_expr(module, arg, func_ctx)?;
3760 write!(self.out, " >> 24) << 24 >> 24)")?;
3761 }
3762 fun @ (Function::Dot4I8Packed | Function::Dot4U8Packed) => {
3763 let arg1 = arg1.unwrap();
3764
3765 if self.options.shader_model >= ShaderModel::V6_4 {
3766 let function_name = match fun {
3768 Function::Dot4I8Packed => "dot4add_i8packed",
3769 Function::Dot4U8Packed => "dot4add_u8packed",
3770 _ => unreachable!(),
3771 };
3772 write!(self.out, "{function_name}(")?;
3773 self.write_expr(module, arg, func_ctx)?;
3774 write!(self.out, ", ")?;
3775 self.write_expr(module, arg1, func_ctx)?;
3776 write!(self.out, ", 0)")?;
3777 } else {
3778 write!(self.out, "dot(")?;
3780
3781 if matches!(fun, Function::Dot4U8Packed) {
3782 write!(self.out, "u")?;
3783 }
3784 write!(self.out, "int4(")?;
3785 self.write_expr(module, arg, func_ctx)?;
3786 write!(self.out, ", ")?;
3787 self.write_expr(module, arg, func_ctx)?;
3788 write!(self.out, " >> 8, ")?;
3789 self.write_expr(module, arg, func_ctx)?;
3790 write!(self.out, " >> 16, ")?;
3791 self.write_expr(module, arg, func_ctx)?;
3792 write!(self.out, " >> 24) << 24 >> 24, ")?;
3793
3794 if matches!(fun, Function::Dot4U8Packed) {
3795 write!(self.out, "u")?;
3796 }
3797 write!(self.out, "int4(")?;
3798 self.write_expr(module, arg1, func_ctx)?;
3799 write!(self.out, ", ")?;
3800 self.write_expr(module, arg1, func_ctx)?;
3801 write!(self.out, " >> 8, ")?;
3802 self.write_expr(module, arg1, func_ctx)?;
3803 write!(self.out, " >> 16, ")?;
3804 self.write_expr(module, arg1, func_ctx)?;
3805 write!(self.out, " >> 24) << 24 >> 24)")?;
3806 }
3807 }
3808 Function::QuantizeToF16 => {
3809 write!(self.out, "f16tof32(f32tof16(")?;
3810 self.write_expr(module, arg, func_ctx)?;
3811 write!(self.out, "))")?;
3812 }
3813 Function::Regular(fun_name) => {
3814 write!(self.out, "{fun_name}(")?;
3815 self.write_expr(module, arg, func_ctx)?;
3816 if let Some(arg) = arg1 {
3817 write!(self.out, ", ")?;
3818 self.write_expr(module, arg, func_ctx)?;
3819 }
3820 if let Some(arg) = arg2 {
3821 write!(self.out, ", ")?;
3822 self.write_expr(module, arg, func_ctx)?;
3823 }
3824 if let Some(arg) = arg3 {
3825 write!(self.out, ", ")?;
3826 self.write_expr(module, arg, func_ctx)?;
3827 }
3828 write!(self.out, ")")?
3829 }
3830 Function::MissingIntOverload(fun_name) => {
3833 let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
3834 if let Some(Scalar::I32) = scalar_kind {
3835 write!(self.out, "asint({fun_name}(asuint(")?;
3836 self.write_expr(module, arg, func_ctx)?;
3837 write!(self.out, ")))")?;
3838 } else {
3839 write!(self.out, "{fun_name}(")?;
3840 self.write_expr(module, arg, func_ctx)?;
3841 write!(self.out, ")")?;
3842 }
3843 }
3844 Function::MissingIntReturnType(fun_name) => {
3847 let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
3848 if let Some(Scalar::I32) = scalar_kind {
3849 write!(self.out, "asint({fun_name}(")?;
3850 self.write_expr(module, arg, func_ctx)?;
3851 write!(self.out, "))")?;
3852 } else {
3853 write!(self.out, "{fun_name}(")?;
3854 self.write_expr(module, arg, func_ctx)?;
3855 write!(self.out, ")")?;
3856 }
3857 }
3858 Function::CountTrailingZeros => {
3859 match *func_ctx.resolve_type(arg, &module.types) {
3860 TypeInner::Vector { size, scalar } => {
3861 let s = match size {
3862 crate::VectorSize::Bi => ".xx",
3863 crate::VectorSize::Tri => ".xxx",
3864 crate::VectorSize::Quad => ".xxxx",
3865 };
3866
3867 let scalar_width_bits = scalar.width * 8;
3868
3869 if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
3870 write!(
3871 self.out,
3872 "min(({scalar_width_bits}u){s}, firstbitlow("
3873 )?;
3874 self.write_expr(module, arg, func_ctx)?;
3875 write!(self.out, "))")?;
3876 } else {
3877 write!(
3879 self.out,
3880 "asint(min(({scalar_width_bits}u){s}, firstbitlow("
3881 )?;
3882 self.write_expr(module, arg, func_ctx)?;
3883 write!(self.out, ")))")?;
3884 }
3885 }
3886 TypeInner::Scalar(scalar) => {
3887 let scalar_width_bits = scalar.width * 8;
3888
3889 if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
3890 write!(self.out, "min({scalar_width_bits}u, firstbitlow(")?;
3891 self.write_expr(module, arg, func_ctx)?;
3892 write!(self.out, "))")?;
3893 } else {
3894 write!(
3896 self.out,
3897 "asint(min({scalar_width_bits}u, firstbitlow("
3898 )?;
3899 self.write_expr(module, arg, func_ctx)?;
3900 write!(self.out, ")))")?;
3901 }
3902 }
3903 _ => unreachable!(),
3904 }
3905
3906 return Ok(());
3907 }
3908 Function::CountLeadingZeros => {
3909 match *func_ctx.resolve_type(arg, &module.types) {
3910 TypeInner::Vector { size, scalar } => {
3911 let s = match size {
3912 crate::VectorSize::Bi => ".xx",
3913 crate::VectorSize::Tri => ".xxx",
3914 crate::VectorSize::Quad => ".xxxx",
3915 };
3916
3917 let constant = scalar.width * 8 - 1;
3919
3920 if scalar.kind == ScalarKind::Uint {
3921 write!(self.out, "(({constant}u){s} - firstbithigh(")?;
3922 self.write_expr(module, arg, func_ctx)?;
3923 write!(self.out, "))")?;
3924 } else {
3925 let conversion_func = match scalar.width {
3926 4 => "asint",
3927 _ => "",
3928 };
3929 write!(self.out, "(")?;
3930 self.write_expr(module, arg, func_ctx)?;
3931 write!(
3932 self.out,
3933 " < (0){s} ? (0){s} : ({constant}){s} - {conversion_func}(firstbithigh("
3934 )?;
3935 self.write_expr(module, arg, func_ctx)?;
3936 write!(self.out, ")))")?;
3937 }
3938 }
3939 TypeInner::Scalar(scalar) => {
3940 let constant = scalar.width * 8 - 1;
3942
3943 if let ScalarKind::Uint = scalar.kind {
3944 write!(self.out, "({constant}u - firstbithigh(")?;
3945 self.write_expr(module, arg, func_ctx)?;
3946 write!(self.out, "))")?;
3947 } else {
3948 let conversion_func = match scalar.width {
3949 4 => "asint",
3950 _ => "",
3951 };
3952 write!(self.out, "(")?;
3953 self.write_expr(module, arg, func_ctx)?;
3954 write!(
3955 self.out,
3956 " < 0 ? 0 : {constant} - {conversion_func}(firstbithigh("
3957 )?;
3958 self.write_expr(module, arg, func_ctx)?;
3959 write!(self.out, ")))")?;
3960 }
3961 }
3962 _ => unreachable!(),
3963 }
3964
3965 return Ok(());
3966 }
3967 }
3968 }
3969 Expression::Swizzle {
3970 size,
3971 vector,
3972 pattern,
3973 } => {
3974 self.write_expr(module, vector, func_ctx)?;
3975 write!(self.out, ".")?;
3976 for &sc in pattern[..size as usize].iter() {
3977 self.out.write_char(back::COMPONENTS[sc as usize])?;
3978 }
3979 }
3980 Expression::ArrayLength(expr) => {
3981 let var_handle = match func_ctx.expressions[expr] {
3982 Expression::AccessIndex { base, index: _ } => {
3983 match func_ctx.expressions[base] {
3984 Expression::GlobalVariable(handle) => handle,
3985 _ => unreachable!(),
3986 }
3987 }
3988 Expression::GlobalVariable(handle) => handle,
3989 _ => unreachable!(),
3990 };
3991
3992 let var = &module.global_variables[var_handle];
3993 let (offset, stride) = match module.types[var.ty].inner {
3994 TypeInner::Array { stride, .. } => (0, stride),
3995 TypeInner::Struct { ref members, .. } => {
3996 let last = members.last().unwrap();
3997 let stride = match module.types[last.ty].inner {
3998 TypeInner::Array { stride, .. } => stride,
3999 _ => unreachable!(),
4000 };
4001 (last.offset, stride)
4002 }
4003 _ => unreachable!(),
4004 };
4005
4006 let storage_access = match var.space {
4007 crate::AddressSpace::Storage { access } => access,
4008 _ => crate::StorageAccess::default(),
4009 };
4010 let wrapped_array_length = WrappedArrayLength {
4011 writable: storage_access.contains(crate::StorageAccess::STORE),
4012 };
4013
4014 write!(self.out, "((")?;
4015 self.write_wrapped_array_length_function_name(wrapped_array_length)?;
4016 let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
4017 write!(self.out, "({var_name}) - {offset}) / {stride})")?
4018 }
4019 Expression::Derivative { axis, ctrl, expr } => {
4020 use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
4021 if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
4022 let tail = match ctrl {
4023 Ctrl::Coarse => "coarse",
4024 Ctrl::Fine => "fine",
4025 Ctrl::None => unreachable!(),
4026 };
4027 write!(self.out, "abs(ddx_{tail}(")?;
4028 self.write_expr(module, expr, func_ctx)?;
4029 write!(self.out, ")) + abs(ddy_{tail}(")?;
4030 self.write_expr(module, expr, func_ctx)?;
4031 write!(self.out, "))")?
4032 } else {
4033 let fun_str = match (axis, ctrl) {
4034 (Axis::X, Ctrl::Coarse) => "ddx_coarse",
4035 (Axis::X, Ctrl::Fine) => "ddx_fine",
4036 (Axis::X, Ctrl::None) => "ddx",
4037 (Axis::Y, Ctrl::Coarse) => "ddy_coarse",
4038 (Axis::Y, Ctrl::Fine) => "ddy_fine",
4039 (Axis::Y, Ctrl::None) => "ddy",
4040 (Axis::Width, Ctrl::Coarse | Ctrl::Fine) => unreachable!(),
4041 (Axis::Width, Ctrl::None) => "fwidth",
4042 };
4043 write!(self.out, "{fun_str}(")?;
4044 self.write_expr(module, expr, func_ctx)?;
4045 write!(self.out, ")")?
4046 }
4047 }
4048 Expression::Relational { fun, argument } => {
4049 use crate::RelationalFunction as Rf;
4050
4051 let fun_str = match fun {
4052 Rf::All => "all",
4053 Rf::Any => "any",
4054 Rf::IsNan => "isnan",
4055 Rf::IsInf => "isinf",
4056 };
4057 write!(self.out, "{fun_str}(")?;
4058 self.write_expr(module, argument, func_ctx)?;
4059 write!(self.out, ")")?
4060 }
4061 Expression::Select {
4062 condition,
4063 accept,
4064 reject,
4065 } => {
4066 write!(self.out, "(")?;
4067 self.write_expr(module, condition, func_ctx)?;
4068 write!(self.out, " ? ")?;
4069 self.write_expr(module, accept, func_ctx)?;
4070 write!(self.out, " : ")?;
4071 self.write_expr(module, reject, func_ctx)?;
4072 write!(self.out, ")")?
4073 }
4074 Expression::RayQueryGetIntersection { query, committed } => {
4075 if committed {
4076 write!(self.out, "GetCommittedIntersection(")?;
4077 self.write_expr(module, query, func_ctx)?;
4078 write!(self.out, ")")?;
4079 } else {
4080 write!(self.out, "GetCandidateIntersection(")?;
4081 self.write_expr(module, query, func_ctx)?;
4082 write!(self.out, ")")?;
4083 }
4084 }
4085 Expression::RayQueryVertexPositions { .. } => unreachable!(),
4087 Expression::CallResult(_)
4089 | Expression::AtomicResult { .. }
4090 | Expression::WorkGroupUniformLoadResult { .. }
4091 | Expression::RayQueryProceedResult
4092 | Expression::SubgroupBallotResult
4093 | Expression::SubgroupOperationResult { .. } => {}
4094 }
4095
4096 if !closing_bracket.is_empty() {
4097 write!(self.out, "{closing_bracket}")?;
4098 }
4099 Ok(())
4100 }
4101
4102 #[allow(clippy::too_many_arguments)]
4103 fn write_image_load(
4104 &mut self,
4105 module: &&Module,
4106 expr: Handle<crate::Expression>,
4107 func_ctx: &back::FunctionCtx,
4108 image: Handle<crate::Expression>,
4109 coordinate: Handle<crate::Expression>,
4110 array_index: Option<Handle<crate::Expression>>,
4111 sample: Option<Handle<crate::Expression>>,
4112 level: Option<Handle<crate::Expression>>,
4113 ) -> Result<(), Error> {
4114 let mut wrapping_type = None;
4115 match *func_ctx.resolve_type(image, &module.types) {
4116 TypeInner::Image {
4117 class: crate::ImageClass::Storage { format, .. },
4118 ..
4119 } => {
4120 if format.single_component() {
4121 wrapping_type = Some(Scalar::from(format));
4122 }
4123 }
4124 _ => {}
4125 }
4126 if let Some(scalar) = wrapping_type {
4127 write!(
4128 self.out,
4129 "{}{}(",
4130 help::IMAGE_STORAGE_LOAD_SCALAR_WRAPPER,
4131 scalar.to_hlsl_str()?
4132 )?;
4133 }
4134 self.write_expr(module, image, func_ctx)?;
4136 write!(self.out, ".Load(")?;
4137
4138 self.write_texture_coordinates("int", coordinate, array_index, level, module, func_ctx)?;
4139
4140 if let Some(sample) = sample {
4141 write!(self.out, ", ")?;
4142 self.write_expr(module, sample, func_ctx)?;
4143 }
4144
4145 write!(self.out, ")")?;
4147
4148 if wrapping_type.is_some() {
4149 write!(self.out, ")")?;
4150 }
4151
4152 if let TypeInner::Scalar(_) = *func_ctx.resolve_type(expr, &module.types) {
4154 write!(self.out, ".x")?;
4155 }
4156 Ok(())
4157 }
4158
4159 fn sampler_binding_array_info_from_expression(
4162 &mut self,
4163 module: &Module,
4164 func_ctx: &back::FunctionCtx<'_>,
4165 base: Handle<crate::Expression>,
4166 resolved: &TypeInner,
4167 ) -> Option<BindingArraySamplerInfo> {
4168 if let TypeInner::BindingArray {
4169 base: base_ty_handle,
4170 ..
4171 } = *resolved
4172 {
4173 let base_ty = &module.types[base_ty_handle].inner;
4174 if let TypeInner::Sampler { comparison, .. } = *base_ty {
4175 let base = &func_ctx.expressions[base];
4176
4177 if let crate::Expression::GlobalVariable(handle) = *base {
4178 let variable = &module.global_variables[handle];
4179
4180 let sampler_heap_name = match comparison {
4181 true => COMPARISON_SAMPLER_HEAP_VAR,
4182 false => SAMPLER_HEAP_VAR,
4183 };
4184
4185 return Some(BindingArraySamplerInfo {
4186 sampler_heap_name,
4187 sampler_index_buffer_name: self
4188 .wrapped
4189 .sampler_index_buffers
4190 .get(&super::SamplerIndexBufferKey {
4191 group: variable.binding.unwrap().group,
4192 })
4193 .unwrap()
4194 .clone(),
4195 binding_array_base_index_name: self.names[&NameKey::GlobalVariable(handle)]
4196 .clone(),
4197 });
4198 }
4199 }
4200 }
4201
4202 None
4203 }
4204
4205 fn write_named_expr(
4206 &mut self,
4207 module: &Module,
4208 handle: Handle<crate::Expression>,
4209 name: String,
4210 named: Handle<crate::Expression>,
4213 ctx: &back::FunctionCtx,
4214 ) -> BackendResult {
4215 match ctx.info[named].ty {
4216 proc::TypeResolution::Handle(ty_handle) => match module.types[ty_handle].inner {
4217 TypeInner::Struct { .. } => {
4218 let ty_name = &self.names[&NameKey::Type(ty_handle)];
4219 write!(self.out, "{ty_name}")?;
4220 }
4221 _ => {
4222 self.write_type(module, ty_handle)?;
4223 }
4224 },
4225 proc::TypeResolution::Value(ref inner) => {
4226 self.write_value_type(module, inner)?;
4227 }
4228 }
4229
4230 let resolved = ctx.resolve_type(named, &module.types);
4231
4232 write!(self.out, " {name}")?;
4233 if let TypeInner::Array { base, size, .. } = *resolved {
4235 self.write_array_size(module, base, size)?;
4236 }
4237 write!(self.out, " = ")?;
4238 self.write_expr(module, handle, ctx)?;
4239 writeln!(self.out, ";")?;
4240 self.named_expressions.insert(named, name);
4241
4242 Ok(())
4243 }
4244
4245 pub(super) fn write_default_init(
4247 &mut self,
4248 module: &Module,
4249 ty: Handle<crate::Type>,
4250 ) -> BackendResult {
4251 write!(self.out, "(")?;
4252 self.write_type(module, ty)?;
4253 if let TypeInner::Array { base, size, .. } = module.types[ty].inner {
4254 self.write_array_size(module, base, size)?;
4255 }
4256 write!(self.out, ")0")?;
4257 Ok(())
4258 }
4259
4260 fn write_control_barrier(
4261 &mut self,
4262 barrier: crate::Barrier,
4263 level: back::Level,
4264 ) -> BackendResult {
4265 if barrier.contains(crate::Barrier::STORAGE) {
4266 writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
4267 }
4268 if barrier.contains(crate::Barrier::WORK_GROUP) {
4269 writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?;
4270 }
4271 if barrier.contains(crate::Barrier::SUB_GROUP) {
4272 }
4274 if barrier.contains(crate::Barrier::TEXTURE) {
4275 writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
4276 }
4277 Ok(())
4278 }
4279
4280 fn write_memory_barrier(
4281 &mut self,
4282 barrier: crate::Barrier,
4283 level: back::Level,
4284 ) -> BackendResult {
4285 if barrier.contains(crate::Barrier::STORAGE) {
4286 writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
4287 }
4288 if barrier.contains(crate::Barrier::WORK_GROUP) {
4289 writeln!(self.out, "{level}GroupMemoryBarrier();")?;
4290 }
4291 if barrier.contains(crate::Barrier::SUB_GROUP) {
4292 }
4294 if barrier.contains(crate::Barrier::TEXTURE) {
4295 writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
4296 }
4297 Ok(())
4298 }
4299
4300 fn emit_hlsl_atomic_tail(
4302 &mut self,
4303 module: &Module,
4304 func_ctx: &back::FunctionCtx<'_>,
4305 fun: &crate::AtomicFunction,
4306 compare_expr: Option<Handle<crate::Expression>>,
4307 value: Handle<crate::Expression>,
4308 res_var_info: &Option<(Handle<crate::Expression>, String)>,
4309 ) -> BackendResult {
4310 if let Some(cmp) = compare_expr {
4311 write!(self.out, ", ")?;
4312 self.write_expr(module, cmp, func_ctx)?;
4313 }
4314 write!(self.out, ", ")?;
4315 if let crate::AtomicFunction::Subtract = *fun {
4316 write!(self.out, "-")?;
4318 }
4319 self.write_expr(module, value, func_ctx)?;
4320 if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
4321 write!(self.out, ", ")?;
4322 if compare_expr.is_some() {
4323 write!(self.out, "{res_name}.old_value")?;
4324 } else {
4325 write!(self.out, "{res_name}")?;
4326 }
4327 }
4328 writeln!(self.out, ");")?;
4329 Ok(())
4330 }
4331}
4332
4333pub(super) struct MatrixType {
4334 pub(super) columns: crate::VectorSize,
4335 pub(super) rows: crate::VectorSize,
4336 pub(super) width: crate::Bytes,
4337}
4338
4339pub(super) fn get_inner_matrix_data(
4340 module: &Module,
4341 handle: Handle<crate::Type>,
4342) -> Option<MatrixType> {
4343 match module.types[handle].inner {
4344 TypeInner::Matrix {
4345 columns,
4346 rows,
4347 scalar,
4348 } => Some(MatrixType {
4349 columns,
4350 rows,
4351 width: scalar.width,
4352 }),
4353 TypeInner::Array { base, .. } => get_inner_matrix_data(module, base),
4354 _ => None,
4355 }
4356}
4357
4358fn find_matrix_in_access_chain(
4362 module: &Module,
4363 base: Handle<crate::Expression>,
4364 func_ctx: &back::FunctionCtx<'_>,
4365) -> Option<(Handle<crate::Expression>, Option<Index>, Option<Index>)> {
4366 let mut current_base = base;
4367 let mut vector = None;
4368 let mut scalar = None;
4369 loop {
4370 let resolved_tr = func_ctx
4371 .resolve_type(current_base, &module.types)
4372 .pointer_base_type();
4373 let resolved = resolved_tr.as_ref()?.inner_with(&module.types);
4374
4375 match *resolved {
4376 TypeInner::Matrix { .. } => return Some((current_base, vector, scalar)),
4377 TypeInner::Scalar(_) | TypeInner::Vector { .. } => {}
4378 _ => return None,
4379 }
4380
4381 let index;
4382 (current_base, index) = match func_ctx.expressions[current_base] {
4383 crate::Expression::Access { base, index } => (base, Index::Expression(index)),
4384 crate::Expression::AccessIndex { base, index } => (base, Index::Static(index)),
4385 _ => return None,
4386 };
4387
4388 match *resolved {
4389 TypeInner::Scalar(_) => scalar = Some(index),
4390 TypeInner::Vector { .. } => vector = Some(index),
4391 _ => unreachable!(),
4392 }
4393 }
4394}
4395
4396pub(super) fn get_inner_matrix_of_struct_array_member(
4401 module: &Module,
4402 base: Handle<crate::Expression>,
4403 func_ctx: &back::FunctionCtx<'_>,
4404 direct: bool,
4405) -> Option<MatrixType> {
4406 let mut mat_data = None;
4407 let mut array_base = None;
4408
4409 let mut current_base = base;
4410 loop {
4411 let mut resolved = func_ctx.resolve_type(current_base, &module.types);
4412 if let TypeInner::Pointer { base, .. } = *resolved {
4413 resolved = &module.types[base].inner;
4414 };
4415
4416 match *resolved {
4417 TypeInner::Matrix {
4418 columns,
4419 rows,
4420 scalar,
4421 } => {
4422 mat_data = Some(MatrixType {
4423 columns,
4424 rows,
4425 width: scalar.width,
4426 })
4427 }
4428 TypeInner::Array { base, .. } => {
4429 array_base = Some(base);
4430 }
4431 TypeInner::Struct { .. } => {
4432 if let Some(array_base) = array_base {
4433 if direct {
4434 return mat_data;
4435 } else {
4436 return get_inner_matrix_data(module, array_base);
4437 }
4438 }
4439
4440 break;
4441 }
4442 _ => break,
4443 }
4444
4445 current_base = match func_ctx.expressions[current_base] {
4446 crate::Expression::Access { base, .. } => base,
4447 crate::Expression::AccessIndex { base, .. } => base,
4448 _ => break,
4449 };
4450 }
4451 None
4452}
4453
4454fn get_global_uniform_matrix(
4457 module: &Module,
4458 base: Handle<crate::Expression>,
4459 func_ctx: &back::FunctionCtx<'_>,
4460) -> Option<MatrixType> {
4461 let base_tr = func_ctx
4462 .resolve_type(base, &module.types)
4463 .pointer_base_type();
4464 let base_ty = base_tr.as_ref().map(|tr| tr.inner_with(&module.types));
4465 match (&func_ctx.expressions[base], base_ty) {
4466 (
4467 &crate::Expression::GlobalVariable(handle),
4468 Some(&TypeInner::Matrix {
4469 columns,
4470 rows,
4471 scalar,
4472 }),
4473 ) if module.global_variables[handle].space == crate::AddressSpace::Uniform => {
4474 Some(MatrixType {
4475 columns,
4476 rows,
4477 width: scalar.width,
4478 })
4479 }
4480 _ => None,
4481 }
4482}
4483
4484fn get_inner_matrix_of_global_uniform(
4489 module: &Module,
4490 base: Handle<crate::Expression>,
4491 func_ctx: &back::FunctionCtx<'_>,
4492) -> Option<MatrixType> {
4493 let mut mat_data = None;
4494 let mut array_base = None;
4495
4496 let mut current_base = base;
4497 loop {
4498 let mut resolved = func_ctx.resolve_type(current_base, &module.types);
4499 if let TypeInner::Pointer { base, .. } = *resolved {
4500 resolved = &module.types[base].inner;
4501 };
4502
4503 match *resolved {
4504 TypeInner::Matrix {
4505 columns,
4506 rows,
4507 scalar,
4508 } => {
4509 mat_data = Some(MatrixType {
4510 columns,
4511 rows,
4512 width: scalar.width,
4513 })
4514 }
4515 TypeInner::Array { base, .. } => {
4516 array_base = Some(base);
4517 }
4518 _ => break,
4519 }
4520
4521 current_base = match func_ctx.expressions[current_base] {
4522 crate::Expression::Access { base, .. } => base,
4523 crate::Expression::AccessIndex { base, .. } => base,
4524 crate::Expression::GlobalVariable(handle)
4525 if module.global_variables[handle].space == crate::AddressSpace::Uniform =>
4526 {
4527 return mat_data.or_else(|| {
4528 array_base.and_then(|array_base| get_inner_matrix_data(module, array_base))
4529 })
4530 }
4531 _ => break,
4532 };
4533 }
4534 None
4535}