1use alloc::{
2 format,
3 string::{String, ToString},
4 vec::Vec,
5};
6use core::{
7 fmt::{self, Write as _},
8 mem,
9};
10
11use super::{
12 help,
13 help::{
14 WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess,
15 WrappedZeroValue,
16 },
17 storage::StoreValue,
18 BackendResult, Error, FragmentEntryPoint, Options, PipelineOptions, ShaderModel,
19};
20use crate::{
21 back::{self, get_entry_points, Baked},
22 common,
23 proc::{self, index, ExternalTextureNameKey, NameKey},
24 valid, Handle, Module, RayQueryFunction, Scalar, ScalarKind, ShaderStage, TypeInner,
25};
26
27const LOCATION_SEMANTIC: &str = "LOC";
28const SPECIAL_CBUF_TYPE: &str = "NagaConstants";
29const SPECIAL_CBUF_VAR: &str = "_NagaConstants";
30const SPECIAL_FIRST_VERTEX: &str = "first_vertex";
31const SPECIAL_FIRST_INSTANCE: &str = "first_instance";
32const SPECIAL_OTHER: &str = "other";
33
34pub(crate) const MODF_FUNCTION: &str = "naga_modf";
35pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
36pub(crate) const EXTRACT_BITS_FUNCTION: &str = "naga_extractBits";
37pub(crate) const INSERT_BITS_FUNCTION: &str = "naga_insertBits";
38pub(crate) const SAMPLER_HEAP_VAR: &str = "nagaSamplerHeap";
39pub(crate) const COMPARISON_SAMPLER_HEAP_VAR: &str = "nagaComparisonSamplerHeap";
40pub(crate) const SAMPLE_EXTERNAL_TEXTURE_FUNCTION: &str = "nagaSampleExternalTexture";
41pub(crate) const ABS_FUNCTION: &str = "naga_abs";
42pub(crate) const DIV_FUNCTION: &str = "naga_div";
43pub(crate) const MOD_FUNCTION: &str = "naga_mod";
44pub(crate) const NEG_FUNCTION: &str = "naga_neg";
45pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
46pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
47pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
48pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
49pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
50 "nagaTextureSampleBaseClampToEdge";
51pub(crate) const IMAGE_LOAD_EXTERNAL_FUNCTION: &str = "nagaTextureLoadExternal";
52pub(crate) const RAY_QUERY_TRACKER_VARIABLE_PREFIX: &str = "naga_query_init_tracker_for_";
53pub(crate) const INTERNAL_PREFIX: &str = "naga_";
55
56enum Index {
57 Expression(Handle<crate::Expression>),
58 Static(u32),
59}
60
61pub(super) struct EpStructMember {
62 pub(super) name: String,
63 pub(super) ty: Handle<crate::Type>,
64 pub(super) binding: Option<crate::Binding>,
67 pub(super) index: u32,
68}
69
70pub(super) struct EntryPointBinding {
73 pub(super) arg_name: String,
76 pub(super) ty_name: String,
78 pub(super) members: Vec<EpStructMember>,
80 pub(super) local_invocation_index_name: Option<String>,
81}
82
83pub(super) struct EntryPointInterface {
84 pub(crate) input: Option<EntryPointBinding>,
89 pub(crate) output: Option<EntryPointBinding>,
93 pub(crate) mesh_vertices: Option<EntryPointBinding>,
94 pub(crate) mesh_primitives: Option<EntryPointBinding>,
95 pub(crate) mesh_indices: Option<EntryPointBinding>,
96}
97
98#[derive(Clone, Eq, PartialEq, PartialOrd, Ord)]
99enum InterfaceKey {
100 Location(u32),
101 BuiltIn(crate::BuiltIn),
102 Other,
103}
104
105impl InterfaceKey {
106 const fn new(binding: Option<&crate::Binding>) -> Self {
107 match binding {
108 Some(&crate::Binding::Location { location, .. }) => Self::Location(location),
109 Some(&crate::Binding::BuiltIn(built_in)) => Self::BuiltIn(built_in),
110 None => Self::Other,
111 }
112 }
113}
114
115#[derive(Copy, Clone, PartialEq)]
116pub(super) enum Io {
117 Input,
118 Output,
119 MeshVertices,
120 MeshPrimitives,
121}
122
123pub(super) struct NestedEntryPointArgs {
125 pub user_args: Vec<String>,
127 pub task_payload: Option<String>,
128 pub local_invocation_index: String,
129}
130
131const fn is_subgroup_builtin_binding(binding: &Option<crate::Binding>) -> bool {
132 let &Some(crate::Binding::BuiltIn(builtin)) = binding else {
133 return false;
134 };
135 matches!(
136 builtin,
137 crate::BuiltIn::SubgroupSize
138 | crate::BuiltIn::SubgroupInvocationId
139 | crate::BuiltIn::NumSubgroups
140 | crate::BuiltIn::SubgroupId
141 )
142}
143
144struct BindingArraySamplerInfo {
146 sampler_heap_name: &'static str,
148 sampler_index_buffer_name: String,
150 binding_array_base_index_name: String,
152}
153
154impl<'a, W: fmt::Write> super::Writer<'a, W> {
155 pub fn new(out: W, options: &'a Options, pipeline_options: &'a PipelineOptions) -> Self {
156 Self {
157 out,
158 names: crate::FastHashMap::default(),
159 namer: proc::Namer::default(),
160 options,
161 pipeline_options,
162 entry_point_io: crate::FastHashMap::default(),
163 named_expressions: crate::NamedExpressions::default(),
164 wrapped: super::Wrapped::default(),
165 written_committed_intersection: false,
166 written_candidate_intersection: false,
167 continue_ctx: back::continue_forward::ContinueCtx::default(),
168 temp_access_chain: Vec::new(),
169 need_bake_expressions: Default::default(),
170 function_task_payload_var: Default::default(),
171 }
172 }
173
174 fn reset(&mut self, module: &Module) {
175 self.names.clear();
176 self.namer.reset(
177 module,
178 &super::keywords::RESERVED_SET,
179 proc::KeywordSet::empty(),
180 &super::keywords::RESERVED_CASE_INSENSITIVE_SET,
181 super::keywords::RESERVED_PREFIXES,
182 &mut self.names,
183 );
184 self.entry_point_io.clear();
185 self.named_expressions.clear();
186 self.wrapped.clear();
187 self.written_committed_intersection = false;
188 self.written_candidate_intersection = false;
189 self.continue_ctx.clear();
190 self.need_bake_expressions.clear();
191 self.function_task_payload_var.clear();
192 }
193
194 fn gen_force_bounded_loop_statements(
202 &mut self,
203 level: back::Level,
204 ) -> Option<(String, String)> {
205 if !self.options.force_loop_bounding {
206 return None;
207 }
208
209 let loop_bound_name = self.namer.call("loop_bound");
210 let max = u32::MAX;
211 let decl = format!("{level}uint2 {loop_bound_name} = uint2({max}u, {max}u);");
214 let level = level.next();
215 let break_and_inc = format!(
216 "{level}if (all({loop_bound_name} == uint2(0u, 0u))) {{ break; }}
217{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
218 );
219
220 Some((decl, break_and_inc))
221 }
222
223 fn update_expressions_to_bake(
228 &mut self,
229 module: &Module,
230 func: &crate::Function,
231 info: &valid::FunctionInfo,
232 ) {
233 use crate::Expression;
234 self.need_bake_expressions.clear();
235 for (exp_handle, expr) in func.expressions.iter() {
236 let expr_info = &info[exp_handle];
237 let min_ref_count = func.expressions[exp_handle].bake_ref_count();
238 if min_ref_count <= expr_info.ref_count {
239 self.need_bake_expressions.insert(exp_handle);
240 }
241 if let Expression::Load { pointer } = *expr {
242 if info[pointer]
243 .ty
244 .inner_with(&module.types)
245 .is_atomic_pointer(&module.types)
246 {
247 self.need_bake_expressions.insert(exp_handle);
248 }
249 }
250
251 if let Expression::Math { fun, arg, arg1, .. } = *expr {
252 match fun {
253 crate::MathFunction::Asinh
254 | crate::MathFunction::Acosh
255 | crate::MathFunction::Atanh
256 | crate::MathFunction::Unpack2x16float
257 | crate::MathFunction::Unpack2x16snorm
258 | crate::MathFunction::Unpack2x16unorm
259 | crate::MathFunction::Unpack4x8snorm
260 | crate::MathFunction::Unpack4x8unorm
261 | crate::MathFunction::Unpack4xI8
262 | crate::MathFunction::Unpack4xU8
263 | crate::MathFunction::Pack2x16float
264 | crate::MathFunction::Pack2x16snorm
265 | crate::MathFunction::Pack2x16unorm
266 | crate::MathFunction::Pack4x8snorm
267 | crate::MathFunction::Pack4x8unorm
268 | crate::MathFunction::Pack4xI8
269 | crate::MathFunction::Pack4xU8
270 | crate::MathFunction::Pack4xI8Clamp
271 | crate::MathFunction::Pack4xU8Clamp => {
272 self.need_bake_expressions.insert(arg);
273 }
274 crate::MathFunction::CountLeadingZeros => {
275 let inner = info[exp_handle].ty.inner_with(&module.types);
276 if let Some(ScalarKind::Sint) = inner.scalar_kind() {
277 self.need_bake_expressions.insert(arg);
278 }
279 }
280 crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
281 self.need_bake_expressions.insert(arg);
282 self.need_bake_expressions.insert(arg1.unwrap());
283 }
284 _ => {}
285 }
286 }
287
288 if let Expression::Derivative { axis, ctrl, expr } = *expr {
289 use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
290 if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
291 self.need_bake_expressions.insert(expr);
292 }
293 }
294
295 if let Expression::GlobalVariable(_) = *expr {
296 let inner = info[exp_handle].ty.inner_with(&module.types);
297
298 if let TypeInner::Sampler { .. } = *inner {
299 self.need_bake_expressions.insert(exp_handle);
300 }
301 }
302 }
303 for statement in func.body.iter() {
304 match *statement {
305 crate::Statement::SubgroupCollectiveOperation {
306 op: _,
307 collective_op: crate::CollectiveOperation::InclusiveScan,
308 argument,
309 result: _,
310 } => {
311 self.need_bake_expressions.insert(argument);
312 }
313 crate::Statement::Atomic {
314 fun: crate::AtomicFunction::Exchange { compare: Some(cmp) },
315 ..
316 } => {
317 self.need_bake_expressions.insert(cmp);
318 }
319 _ => {}
320 }
321 }
322 }
323
324 pub fn write(
325 &mut self,
326 module: &Module,
327 module_info: &valid::ModuleInfo,
328 fragment_entry_point: Option<&FragmentEntryPoint<'_>>,
329 ) -> Result<super::ReflectionInfo, Error> {
330 self.reset(module);
331
332 if module.uses_mesh_shaders() && self.options.shader_model < ShaderModel::V6_5 {
333 return Err(Error::ShaderModelTooLow(
334 "mesh shaders".to_string(),
335 ShaderModel::V6_5,
336 ));
337 }
338
339 if let Some(ref bt) = self.options.special_constants_binding {
341 writeln!(self.out, "struct {SPECIAL_CBUF_TYPE} {{")?;
342 writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_VERTEX)?;
343 writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_INSTANCE)?;
344 writeln!(self.out, "{}uint {};", back::INDENT, SPECIAL_OTHER)?;
345 writeln!(self.out, "}};")?;
346 write!(
347 self.out,
348 "ConstantBuffer<{}> {}: register(b{}",
349 SPECIAL_CBUF_TYPE, SPECIAL_CBUF_VAR, bt.register
350 )?;
351 if bt.space != 0 {
352 write!(self.out, ", space{}", bt.space)?;
353 }
354 writeln!(self.out, ");")?;
355
356 writeln!(self.out)?;
358 }
359
360 for (group, bt) in self.options.dynamic_storage_buffer_offsets_targets.iter() {
361 writeln!(self.out, "struct __dynamic_buffer_offsetsTy{group} {{")?;
362 for i in 0..bt.size {
363 writeln!(self.out, "{}uint _{};", back::INDENT, i)?;
364 }
365 writeln!(self.out, "}};")?;
366 writeln!(
367 self.out,
368 "ConstantBuffer<__dynamic_buffer_offsetsTy{}> __dynamic_buffer_offsets{}: register(b{}, space{});",
369 group, group, bt.register, bt.space
370 )?;
371
372 writeln!(self.out)?;
374 }
375
376 let ep_results = module
378 .entry_points
379 .iter()
380 .map(|ep| (ep.stage, ep.function.result.clone()))
381 .collect::<Vec<(ShaderStage, Option<crate::FunctionResult>)>>();
382
383 self.write_all_mat_cx2_typedefs_and_functions(module)?;
384
385 for (handle, ty) in module.types.iter() {
387 if let TypeInner::Struct { ref members, span } = ty.inner {
388 if module.types[members.last().unwrap().ty]
389 .inner
390 .is_dynamically_sized(&module.types)
391 {
392 continue;
395 }
396
397 let ep_result = ep_results.iter().find(|e| {
398 if let Some(ref result) = e.1 {
399 result.ty == handle
400 } else {
401 false
402 }
403 });
404
405 self.write_struct(
406 module,
407 handle,
408 members,
409 span,
410 ep_result.map(|r| (r.0, Io::Output)),
411 )?;
412 writeln!(self.out)?;
413 }
414 }
415
416 self.write_special_functions(module)?;
417
418 self.write_wrapped_expression_functions(module, &module.global_expressions, None)?;
419 self.write_wrapped_zero_value_functions(module, &module.global_expressions)?;
420
421 let mut constants = module
423 .constants
424 .iter()
425 .filter(|&(_, c)| c.name.is_some())
426 .peekable();
427 while let Some((handle, _)) = constants.next() {
428 self.write_global_constant(module, handle)?;
429 if constants.peek().is_none() {
431 writeln!(self.out)?;
432 }
433 }
434
435 for (global, _) in module.global_variables.iter() {
437 self.write_global(module, global)?;
438 }
439
440 if !module.global_variables.is_empty() {
441 writeln!(self.out)?;
443 }
444
445 let ep_range = get_entry_points(module, self.pipeline_options.entry_point.as_ref())
446 .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
447
448 for index in ep_range.clone() {
450 let ep = &module.entry_points[index];
451 let ep_name = self.names[&NameKey::EntryPoint(index as u16)].clone();
452 let ep_io = self.write_ep_interface(module, ep, &ep_name, fragment_entry_point)?;
453 self.entry_point_io.insert(index, ep_io);
454 }
455
456 for (handle, function) in module.functions.iter() {
458 let info = &module_info[handle];
459
460 if !self.options.fake_missing_bindings {
462 if let Some((var_handle, _)) =
463 module
464 .global_variables
465 .iter()
466 .find(|&(var_handle, var)| match var.binding {
467 Some(ref binding) if !info[var_handle].is_empty() => {
468 self.options.resolve_resource_binding(binding).is_err()
469 && self
470 .options
471 .resolve_external_texture_resource_binding(binding)
472 .is_err()
473 }
474 _ => false,
475 })
476 {
477 log::debug!(
478 "Skipping function {:?} (name {:?}) because global {:?} is inaccessible",
479 handle,
480 function.name,
481 var_handle
482 );
483 continue;
484 }
485 }
486
487 let ctx = back::FunctionCtx {
488 ty: back::FunctionType::Function(handle),
489 info,
490 expressions: &function.expressions,
491 named_expressions: &function.named_expressions,
492 };
493 let name = self.names[&NameKey::Function(handle)].clone();
494
495 self.write_wrapped_functions(module, &ctx)?;
496
497 self.write_function(module, name.as_str(), function, &ctx, info, String::new())?;
498
499 writeln!(self.out)?;
500 }
501
502 let mut translated_ep_names = Vec::with_capacity(ep_range.len());
503
504 for index in ep_range {
506 let ep = &module.entry_points[index];
507 let info = module_info.get_entry_point(index);
508
509 if !self.options.fake_missing_bindings {
510 let mut ep_error = None;
511 for (var_handle, var) in module.global_variables.iter() {
512 match var.binding {
513 Some(ref binding) if !info[var_handle].is_empty() => {
514 if let Err(err) = self.options.resolve_resource_binding(binding) {
515 if self
516 .options
517 .resolve_external_texture_resource_binding(binding)
518 .is_err()
519 {
520 ep_error = Some(err);
521 break;
522 }
523 }
524 }
525 _ => {}
526 }
527 }
528 if let Some(err) = ep_error {
529 translated_ep_names.push(Err(err));
530 continue;
531 }
532 }
533
534 let ctx = back::FunctionCtx {
535 ty: back::FunctionType::EntryPoint(index as u16),
536 info,
537 expressions: &ep.function.expressions,
538 named_expressions: &ep.function.named_expressions,
539 };
540
541 self.write_wrapped_functions(module, &ctx)?;
542
543 let mut attribute_string = String::new();
546 if ep.stage.compute_like() {
547 let num_threads = ep.workgroup_size;
549 writeln!(
550 attribute_string,
551 "[numthreads({}, {}, {})]",
552 num_threads[0], num_threads[1], num_threads[2]
553 )?;
554 }
555 if let Some(ref info) = ep.mesh_info {
556 let topology_str = match info.topology {
557 crate::MeshOutputTopology::Points => unreachable!(),
558 crate::MeshOutputTopology::Lines => "line",
559 crate::MeshOutputTopology::Triangles => "triangle",
560 };
561 writeln!(attribute_string, "[outputtopology(\"{topology_str}\")]")?;
562 }
563
564 let name = self.names[&NameKey::EntryPoint(index as u16)].clone();
565 self.write_function(module, &name, &ep.function, &ctx, info, attribute_string)?;
566
567 if index < module.entry_points.len() - 1 {
568 writeln!(self.out)?;
569 }
570
571 translated_ep_names.push(Ok(name));
572 }
573
574 Ok(super::ReflectionInfo {
575 entry_point_names: translated_ep_names,
576 })
577 }
578
579 fn write_modifier(&mut self, binding: &crate::Binding) -> BackendResult {
580 match *binding {
581 crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }) => {
582 write!(self.out, "precise ")?;
583 }
584 crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { perspective: false }) => {
585 write!(self.out, "noperspective ")?;
586 }
587 crate::Binding::Location {
588 interpolation,
589 sampling,
590 ..
591 } => {
592 if let Some(interpolation) = interpolation {
593 if let Some(string) = interpolation.to_hlsl_str() {
594 write!(self.out, "{string} ")?
595 }
596 }
597
598 if let Some(sampling) = sampling {
599 if let Some(string) = sampling.to_hlsl_str() {
600 write!(self.out, "{string} ")?
601 }
602 }
603 }
604 crate::Binding::BuiltIn(_) => {}
605 }
606
607 Ok(())
608 }
609
610 pub(super) fn write_semantic(
613 &mut self,
614 binding: &Option<crate::Binding>,
615 stage: Option<(ShaderStage, Io)>,
616 ) -> BackendResult {
617 let is_per_primitive = match *binding {
618 Some(crate::Binding::BuiltIn(builtin)) if !is_subgroup_builtin_binding(binding) => {
619 if builtin == crate::BuiltIn::ViewIndex
620 && self.options.shader_model < ShaderModel::V6_1
621 {
622 return Err(Error::ShaderModelTooLow(
623 "used @builtin(view_index) or SV_ViewID".to_string(),
624 ShaderModel::V6_1,
625 ));
626 }
627 if let Some(builtin_str) = builtin.to_hlsl_str()? {
628 write!(self.out, " : {builtin_str}")?;
629 }
630 false
631 }
632 Some(crate::Binding::Location {
633 blend_src: Some(1),
634 per_primitive,
635 ..
636 }) => {
637 write!(self.out, " : SV_Target1")?;
638 per_primitive
639 }
640 Some(crate::Binding::Location {
641 location,
642 per_primitive,
643 ..
644 }) => {
645 if stage == Some((ShaderStage::Fragment, Io::Output)) {
646 write!(self.out, " : SV_Target{location}")?;
647 } else {
648 write!(self.out, " : {LOCATION_SEMANTIC}{location}")?;
649 }
650 per_primitive
651 }
652 _ => false,
653 };
654 if is_per_primitive {
655 write!(self.out, " : primitive")?;
656 }
657
658 Ok(())
659 }
660
661 pub(super) fn write_interface_struct(
662 &mut self,
663 module: &Module,
664 shader_stage: (ShaderStage, Io),
665 struct_name: String,
666 var_name: Option<&str>,
667 mut members: Vec<EpStructMember>,
668 ) -> Result<EntryPointBinding, Error> {
669 let struct_name = self.namer.call(&struct_name);
670 members.sort_by_key(|m| InterfaceKey::new(m.binding.as_ref()));
674
675 write!(self.out, "struct {struct_name}")?;
676 writeln!(self.out, " {{")?;
677 let mut local_invocation_index_name = None;
678 let mut subgroup_id_used = false;
679 for m in members.iter() {
680 debug_assert!(m.binding.is_some());
683
684 match m.binding {
685 Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => {
686 subgroup_id_used = true;
687 }
688 Some(crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationIndex)) => {
689 local_invocation_index_name = Some(m.name.clone());
690 }
691 _ => (),
692 }
693
694 if is_subgroup_builtin_binding(&m.binding) {
695 continue;
696 }
697 write!(self.out, "{}", back::INDENT)?;
698 if let Some(ref binding) = m.binding {
699 self.write_modifier(binding)?;
700 }
701 self.write_type(module, m.ty)?;
702 write!(self.out, " {}", &m.name)?;
703 self.write_semantic(&m.binding, Some(shader_stage))?;
704 writeln!(self.out, ";")?;
705 }
706 if subgroup_id_used && local_invocation_index_name.is_none() {
707 let name = self.namer.call("local_invocation_index");
708 writeln!(self.out, "{}uint {name} : SV_GroupIndex;", back::INDENT)?;
709 local_invocation_index_name = Some(name);
710 }
711 writeln!(self.out, "}};")?;
712 writeln!(self.out)?;
713
714 match shader_stage.1 {
716 Io::Input => {
717 members.sort_by_key(|m| m.index);
719 }
720 Io::Output | Io::MeshVertices | Io::MeshPrimitives => {
721 }
723 }
724
725 Ok(EntryPointBinding {
726 arg_name: self
727 .namer
728 .call(var_name.unwrap_or(struct_name.to_lowercase().as_str())),
729 ty_name: struct_name,
730 members,
731 local_invocation_index_name,
732 })
733 }
734
735 fn write_ep_input_struct(
739 &mut self,
740 module: &Module,
741 func: &crate::Function,
742 stage: ShaderStage,
743 entry_point_name: &str,
744 ) -> Result<EntryPointBinding, Error> {
745 let struct_name = format!("{stage:?}Input_{entry_point_name}");
746
747 let mut fake_members = Vec::new();
748 for arg in func.arguments.iter() {
749 match module.types[arg.ty].inner {
754 TypeInner::Struct { ref members, .. } => {
755 for member in members.iter() {
756 let name = self.namer.call_or(&member.name, "member");
757 let index = fake_members.len() as u32;
758 fake_members.push(EpStructMember {
759 name,
760 ty: member.ty,
761 binding: member.binding.clone(),
762 index,
763 });
764 }
765 }
766 _ => {
767 let member_name = self.namer.call_or(&arg.name, "member");
768 let index = fake_members.len() as u32;
769 fake_members.push(EpStructMember {
770 name: member_name,
771 ty: arg.ty,
772 binding: arg.binding.clone(),
773 index,
774 });
775 }
776 }
777 }
778
779 self.write_interface_struct(module, (stage, Io::Input), struct_name, None, fake_members)
780 }
781
782 fn write_ep_output_struct(
786 &mut self,
787 module: &Module,
788 result: &crate::FunctionResult,
789 stage: ShaderStage,
790 entry_point_name: &str,
791 frag_ep: Option<&FragmentEntryPoint<'_>>,
792 ) -> Result<EntryPointBinding, Error> {
793 let struct_name = format!("{stage:?}Output_{entry_point_name}");
794
795 let empty = [];
796 let members = match module.types[result.ty].inner {
797 TypeInner::Struct { ref members, .. } => members,
798 ref other => {
799 log::error!("Unexpected {other:?} output type without a binding");
800 &empty[..]
801 }
802 };
803
804 let fs_input_locs = if let (Some(frag_ep), ShaderStage::Vertex) = (frag_ep, stage) {
809 let mut fs_input_locs = Vec::new();
810 for arg in frag_ep.func.arguments.iter() {
811 let mut push_if_location = |binding: &Option<crate::Binding>| match *binding {
812 Some(crate::Binding::Location { location, .. }) => fs_input_locs.push(location),
813 Some(crate::Binding::BuiltIn(_)) | None => {}
814 };
815
816 match frag_ep.module.types[arg.ty].inner {
819 TypeInner::Struct { ref members, .. } => {
820 for member in members.iter() {
821 push_if_location(&member.binding);
822 }
823 }
824 _ => push_if_location(&arg.binding),
825 }
826 }
827 fs_input_locs.sort();
828 Some(fs_input_locs)
829 } else {
830 None
831 };
832
833 let mut fake_members = Vec::new();
834 for (index, member) in members.iter().enumerate() {
835 if let Some(ref fs_input_locs) = fs_input_locs {
836 match member.binding {
837 Some(crate::Binding::Location { location, .. }) => {
838 if fs_input_locs.binary_search(&location).is_err() {
839 continue;
840 }
841 }
842 Some(crate::Binding::BuiltIn(_)) | None => {}
843 }
844 }
845
846 let member_name = self.namer.call_or(&member.name, "member");
847 fake_members.push(EpStructMember {
848 name: member_name,
849 ty: member.ty,
850 binding: member.binding.clone(),
851 index: index as u32,
852 });
853 }
854
855 self.write_interface_struct(module, (stage, Io::Output), struct_name, None, fake_members)
856 }
857
858 fn write_ep_interface(
862 &mut self,
863 module: &Module,
864 ep: &crate::EntryPoint,
865 ep_name: &str,
866 frag_ep: Option<&FragmentEntryPoint<'_>>,
867 ) -> Result<EntryPointInterface, Error> {
868 let func = &ep.function;
869 let stage = ep.stage;
870 Ok(EntryPointInterface {
871 input: if !func.arguments.is_empty()
872 && (stage == ShaderStage::Fragment
873 || func
874 .arguments
875 .iter()
876 .any(|arg| is_subgroup_builtin_binding(&arg.binding)))
877 {
878 Some(self.write_ep_input_struct(module, func, stage, ep_name)?)
879 } else {
880 None
881 },
882 output: match func.result {
883 Some(ref fr) if fr.binding.is_none() && stage == ShaderStage::Vertex => {
884 Some(self.write_ep_output_struct(module, fr, stage, ep_name, frag_ep)?)
885 }
886 _ => None,
887 },
888 mesh_vertices: if let Some(ref info) = ep.mesh_info {
889 Some(self.write_ep_mesh_output_struct(module, ep_name, false, info)?)
890 } else {
891 None
892 },
893 mesh_primitives: if let Some(ref info) = ep.mesh_info {
894 Some(self.write_ep_mesh_output_struct(module, ep_name, true, info)?)
895 } else {
896 None
897 },
898 mesh_indices: if let Some(ref info) = ep.mesh_info {
899 Some(self.write_ep_mesh_output_indices(info.topology)?)
900 } else {
901 None
902 },
903 })
904 }
905
906 fn write_ep_argument_initialization(
907 &mut self,
908 ep: &crate::EntryPoint,
909 ep_input: &EntryPointBinding,
910 fake_member: &EpStructMember,
911 ) -> BackendResult {
912 match fake_member.binding {
913 Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupSize)) => {
914 write!(self.out, "WaveGetLaneCount()")?
915 }
916 Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupInvocationId)) => {
917 write!(self.out, "WaveGetLaneIndex()")?
918 }
919 Some(crate::Binding::BuiltIn(crate::BuiltIn::NumSubgroups)) => write!(
920 self.out,
921 "({}u + WaveGetLaneCount() - 1u) / WaveGetLaneCount()",
922 ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2]
923 )?,
924 Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => {
925 write!(
926 self.out,
927 "{}.{} / WaveGetLaneCount()",
928 ep_input.arg_name,
929 ep_input.local_invocation_index_name.as_ref().unwrap()
931 )?;
932 }
933 Some(crate::Binding::Location {
934 interpolation: Some(crate::Interpolation::PerVertex),
935 ..
936 }) => {
937 if self.options.shader_model < ShaderModel::V6_1 {
938 return Err(Error::ShaderModelTooLow(
939 "per_vertex fragment inputs".to_string(),
940 ShaderModel::V6_1,
941 ));
942 }
943 write!(
944 self.out,
945 "{{ GetAttributeAtVertex({0}.{1}, 0), GetAttributeAtVertex({0}.{1}, 1), GetAttributeAtVertex({0}.{1}, 2) }}",
946 ep_input.arg_name,
947 fake_member.name,
948 )?;
949 }
950 _ => {
951 write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?;
952 }
953 }
954 Ok(())
955 }
956
957 fn write_ep_arguments_initialization(
959 &mut self,
960 module: &Module,
961 func: &crate::Function,
962 ep_index: u16,
963 ) -> BackendResult {
964 let ep = &module.entry_points[ep_index as usize];
965 let ep_input = match self
966 .entry_point_io
967 .get_mut(&(ep_index as usize))
968 .unwrap()
969 .input
970 .take()
971 {
972 Some(ep_input) => ep_input,
973 None => return Ok(()),
974 };
975 let mut fake_iter = ep_input.members.iter();
976 for (arg_index, arg) in func.arguments.iter().enumerate() {
977 write!(self.out, "{}", back::INDENT)?;
978 self.write_type(module, arg.ty)?;
979 let arg_name = &self.names[&NameKey::EntryPointArgument(ep_index, arg_index as u32)];
980 write!(self.out, " {arg_name}")?;
981 match module.types[arg.ty].inner {
982 TypeInner::Array { base, size, .. } => {
983 self.write_array_size(module, base, size)?;
984 write!(self.out, " = ")?;
985 self.write_ep_argument_initialization(
986 ep,
987 &ep_input,
988 fake_iter.next().unwrap(),
989 )?;
990 writeln!(self.out, ";")?;
991 }
992 TypeInner::Struct { ref members, .. } => {
993 write!(self.out, " = {{ ")?;
994 for index in 0..members.len() {
995 if index != 0 {
996 write!(self.out, ", ")?;
997 }
998 self.write_ep_argument_initialization(
999 ep,
1000 &ep_input,
1001 fake_iter.next().unwrap(),
1002 )?;
1003 }
1004 writeln!(self.out, " }};")?;
1005 }
1006 _ => {
1007 write!(self.out, " = ")?;
1008 self.write_ep_argument_initialization(
1009 ep,
1010 &ep_input,
1011 fake_iter.next().unwrap(),
1012 )?;
1013 writeln!(self.out, ";")?;
1014 }
1015 }
1016 }
1017 assert!(fake_iter.next().is_none());
1018 Ok(())
1019 }
1020
1021 fn write_global(
1025 &mut self,
1026 module: &Module,
1027 handle: Handle<crate::GlobalVariable>,
1028 ) -> BackendResult {
1029 let global = &module.global_variables[handle];
1030 let inner = &module.types[global.ty].inner;
1031
1032 let handle_ty = match *inner {
1033 TypeInner::BindingArray { ref base, .. } => &module.types[*base].inner,
1034 _ => inner,
1035 };
1036
1037 let is_external_texture = matches!(
1041 *handle_ty,
1042 TypeInner::Image {
1043 class: crate::ImageClass::External,
1044 ..
1045 }
1046 );
1047 if is_external_texture {
1048 return self.write_global_external_texture(module, handle, global);
1049 }
1050
1051 if let Some(ref binding) = global.binding {
1052 if let Err(err) = self.options.resolve_resource_binding(binding) {
1053 log::debug!(
1054 "Skipping global {:?} (name {:?}) for being inaccessible: {}",
1055 handle,
1056 global.name,
1057 err,
1058 );
1059 return Ok(());
1060 }
1061 }
1062
1063 let is_sampler = matches!(*handle_ty, TypeInner::Sampler { .. });
1065
1066 if is_sampler {
1067 return self.write_global_sampler(module, handle, global);
1068 }
1069
1070 let register_ty = match global.space {
1072 crate::AddressSpace::Function => unreachable!("Function address space"),
1073 crate::AddressSpace::Private => {
1074 write!(self.out, "static ")?;
1075 self.write_type(module, global.ty)?;
1076 ""
1077 }
1078 crate::AddressSpace::WorkGroup | crate::AddressSpace::TaskPayload => {
1079 write!(self.out, "groupshared ")?;
1080 self.write_type(module, global.ty)?;
1081 ""
1082 }
1083 crate::AddressSpace::Uniform => {
1084 write!(self.out, "cbuffer")?;
1087 "b"
1088 }
1089 crate::AddressSpace::Storage { access } => {
1090 if global
1091 .memory_decorations
1092 .contains(crate::MemoryDecorations::COHERENT)
1093 {
1094 write!(self.out, "globallycoherent ")?;
1095 }
1096 let (prefix, register) = if access.contains(crate::StorageAccess::STORE) {
1097 ("RW", "u")
1098 } else {
1099 ("", "t")
1100 };
1101 write!(self.out, "{prefix}ByteAddressBuffer")?;
1102 register
1103 }
1104 crate::AddressSpace::Handle => {
1105 let register = match *handle_ty {
1106 TypeInner::Image {
1108 class: crate::ImageClass::Storage { .. },
1109 ..
1110 } => "u",
1111 _ => "t",
1112 };
1113 self.write_type(module, global.ty)?;
1114 register
1115 }
1116 crate::AddressSpace::Immediate => {
1117 write!(self.out, "ConstantBuffer<")?;
1119 "b"
1120 }
1121 crate::AddressSpace::RayPayload | crate::AddressSpace::IncomingRayPayload => {
1122 unimplemented!()
1123 }
1124 };
1125
1126 if global.space == crate::AddressSpace::Immediate {
1129 self.write_global_type(module, global.ty)?;
1130
1131 if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1133 self.write_array_size(module, base, size)?;
1134 }
1135
1136 write!(self.out, ">")?;
1138 }
1139
1140 let name = &self.names[&NameKey::GlobalVariable(handle)];
1141 write!(self.out, " {name}")?;
1142
1143 if global.space == crate::AddressSpace::Immediate {
1146 match module.types[global.ty].inner {
1147 TypeInner::Struct { .. } => {}
1148 _ => {
1149 return Err(Error::Unimplemented(format!(
1150 "push-constant '{name}' has non-struct type; tracked by: https://github.com/gfx-rs/wgpu/issues/5683"
1151 )));
1152 }
1153 }
1154
1155 let target = self
1156 .options
1157 .immediates_target
1158 .as_ref()
1159 .expect("No bind target was defined for the immediates block");
1160 write!(self.out, ": register(b{}", target.register)?;
1161 if target.space != 0 {
1162 write!(self.out, ", space{}", target.space)?;
1163 }
1164 write!(self.out, ")")?;
1165 }
1166
1167 if let Some(ref binding) = global.binding {
1168 let bt = self.options.resolve_resource_binding(binding).unwrap();
1170
1171 if let TypeInner::BindingArray { base, size, .. } = module.types[global.ty].inner {
1173 if let Some(overridden_size) = bt.binding_array_size {
1174 write!(self.out, "[{overridden_size}]")?;
1175 } else {
1176 self.write_array_size(module, base, size)?;
1177 }
1178 }
1179
1180 write!(self.out, " : register({}{}", register_ty, bt.register)?;
1181 if bt.space != 0 {
1182 write!(self.out, ", space{}", bt.space)?;
1183 }
1184 write!(self.out, ")")?;
1185 } else {
1186 if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1188 self.write_array_size(module, base, size)?;
1189 }
1190 if global.space == crate::AddressSpace::Private {
1191 write!(self.out, " = ")?;
1192 if let Some(init) = global.init {
1193 self.write_const_expression(module, init, &module.global_expressions)?;
1194 } else {
1195 self.write_default_init(module, global.ty)?;
1196 }
1197 }
1198 }
1199
1200 if global.space == crate::AddressSpace::Uniform {
1201 write!(self.out, " {{ ")?;
1202
1203 self.write_global_type(module, global.ty)?;
1204
1205 write!(
1206 self.out,
1207 " {}",
1208 &self.names[&NameKey::GlobalVariable(handle)]
1209 )?;
1210
1211 if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1213 self.write_array_size(module, base, size)?;
1214 }
1215
1216 writeln!(self.out, "; }}")?;
1217 } else {
1218 writeln!(self.out, ";")?;
1219 }
1220
1221 Ok(())
1222 }
1223
1224 fn write_global_sampler(
1225 &mut self,
1226 module: &Module,
1227 handle: Handle<crate::GlobalVariable>,
1228 global: &crate::GlobalVariable,
1229 ) -> BackendResult {
1230 let binding = *global.binding.as_ref().unwrap();
1231
1232 let key = super::SamplerIndexBufferKey {
1233 group: binding.group,
1234 };
1235 self.write_wrapped_sampler_buffer(key)?;
1236
1237 let bt = self.options.resolve_resource_binding(&binding).unwrap();
1239
1240 match module.types[global.ty].inner {
1241 TypeInner::Sampler { comparison } => {
1242 write!(self.out, "static const ")?;
1249 self.write_type(module, global.ty)?;
1250
1251 let heap_var = if comparison {
1252 COMPARISON_SAMPLER_HEAP_VAR
1253 } else {
1254 SAMPLER_HEAP_VAR
1255 };
1256
1257 let index_buffer_name = &self.wrapped.sampler_index_buffers[&key];
1258 let name = &self.names[&NameKey::GlobalVariable(handle)];
1259 writeln!(
1260 self.out,
1261 " {name} = {heap_var}[{index_buffer_name}[{register}]];",
1262 register = bt.register
1263 )?;
1264 }
1265 TypeInner::BindingArray { .. } => {
1266 let name = &self.names[&NameKey::GlobalVariable(handle)];
1272 writeln!(
1273 self.out,
1274 "static const uint {name} = {register};",
1275 register = bt.register
1276 )?;
1277 }
1278 _ => unreachable!(),
1279 };
1280
1281 Ok(())
1282 }
1283
1284 fn write_global_external_texture(
1288 &mut self,
1289 module: &Module,
1290 handle: Handle<crate::GlobalVariable>,
1291 global: &crate::GlobalVariable,
1292 ) -> BackendResult {
1293 let res_binding = global
1294 .binding
1295 .as_ref()
1296 .expect("External texture global variables must have a resource binding");
1297 let ext_tex_bindings = match self
1298 .options
1299 .resolve_external_texture_resource_binding(res_binding)
1300 {
1301 Ok(bindings) => bindings,
1302 Err(err) => {
1303 log::debug!(
1304 "Skipping global {:?} (name {:?}) for being inaccessible: {}",
1305 handle,
1306 global.name,
1307 err,
1308 );
1309 return Ok(());
1310 }
1311 };
1312
1313 let mut write_plane = |bt: &super::BindTarget, name| -> BackendResult {
1314 write!(
1315 self.out,
1316 "Texture2D<float4> {}: register(t{}",
1317 name, bt.register
1318 )?;
1319 if bt.space != 0 {
1320 write!(self.out, ", space{}", bt.space)?;
1321 }
1322 writeln!(self.out, ");")?;
1323 Ok(())
1324 };
1325 for (i, bt) in ext_tex_bindings.planes.iter().enumerate() {
1326 let plane_name = &self.names
1327 [&NameKey::ExternalTextureGlobalVariable(handle, ExternalTextureNameKey::Plane(i))];
1328 write_plane(bt, plane_name)?;
1329 }
1330
1331 let params_name = &self.names
1332 [&NameKey::ExternalTextureGlobalVariable(handle, ExternalTextureNameKey::Params)];
1333 let params_ty_name =
1334 &self.names[&NameKey::Type(module.special_types.external_texture_params.unwrap())];
1335 write!(
1336 self.out,
1337 "cbuffer {}: register(b{}",
1338 params_name, ext_tex_bindings.params.register
1339 )?;
1340 if ext_tex_bindings.params.space != 0 {
1341 write!(self.out, ", space{}", ext_tex_bindings.params.space)?;
1342 }
1343 writeln!(self.out, ") {{ {params_ty_name} {params_name}; }};")?;
1344
1345 Ok(())
1346 }
1347
1348 fn write_global_constant(
1353 &mut self,
1354 module: &Module,
1355 handle: Handle<crate::Constant>,
1356 ) -> BackendResult {
1357 write!(self.out, "static const ")?;
1358 let constant = &module.constants[handle];
1359 self.write_type(module, constant.ty)?;
1360 let name = &self.names[&NameKey::Constant(handle)];
1361 write!(self.out, " {name}")?;
1362 if let TypeInner::Array { base, size, .. } = module.types[constant.ty].inner {
1364 self.write_array_size(module, base, size)?;
1365 }
1366 write!(self.out, " = ")?;
1367 self.write_const_expression(module, constant.init, &module.global_expressions)?;
1368 writeln!(self.out, ";")?;
1369 Ok(())
1370 }
1371
1372 pub(super) fn write_array_size(
1373 &mut self,
1374 module: &Module,
1375 base: Handle<crate::Type>,
1376 size: crate::ArraySize,
1377 ) -> BackendResult {
1378 write!(self.out, "[")?;
1379
1380 match size.resolve(module.to_ctx())? {
1381 proc::IndexableLength::Known(size) => {
1382 write!(self.out, "{size}")?;
1383 }
1384 proc::IndexableLength::Dynamic => unreachable!(),
1385 }
1386
1387 write!(self.out, "]")?;
1388
1389 if let TypeInner::Array {
1390 base: next_base,
1391 size: next_size,
1392 ..
1393 } = module.types[base].inner
1394 {
1395 self.write_array_size(module, next_base, next_size)?;
1396 }
1397
1398 Ok(())
1399 }
1400
1401 fn write_struct(
1406 &mut self,
1407 module: &Module,
1408 handle: Handle<crate::Type>,
1409 members: &[crate::StructMember],
1410 span: u32,
1411 shader_stage: Option<(ShaderStage, Io)>,
1412 ) -> BackendResult {
1413 let struct_name = &self.names[&NameKey::Type(handle)];
1415 writeln!(self.out, "struct {struct_name} {{")?;
1416
1417 let mut last_offset = 0;
1418 for (index, member) in members.iter().enumerate() {
1419 if member.binding.is_none() && member.offset > last_offset {
1420 let padding = (member.offset - last_offset) / 4;
1424 for i in 0..padding {
1425 writeln!(self.out, "{}int _pad{}_{};", back::INDENT, index, i)?;
1426 }
1427 }
1428 let ty_inner = &module.types[member.ty].inner;
1429 last_offset = member.offset + ty_inner.size_hlsl(module.to_ctx())?;
1430
1431 write!(self.out, "{}", back::INDENT)?;
1433
1434 match module.types[member.ty].inner {
1435 TypeInner::Array { base, size, .. } => {
1436 self.write_global_type(module, member.ty)?;
1439
1440 write!(
1442 self.out,
1443 " {}",
1444 &self.names[&NameKey::StructMember(handle, index as u32)]
1445 )?;
1446 self.write_array_size(module, base, size)?;
1448 }
1449 TypeInner::Matrix {
1452 rows,
1453 columns,
1454 scalar,
1455 } if member.binding.is_none() && rows == crate::VectorSize::Bi => {
1456 let vec_ty = TypeInner::Vector { size: rows, scalar };
1457 let field_name_key = NameKey::StructMember(handle, index as u32);
1458
1459 for i in 0..columns as u8 {
1460 if i != 0 {
1461 write!(self.out, "; ")?;
1462 }
1463 self.write_value_type(module, &vec_ty)?;
1464 write!(self.out, " {}_{}", &self.names[&field_name_key], i)?;
1465 }
1466 }
1467 _ => {
1468 if let Some(ref binding) = member.binding {
1470 self.write_modifier(binding)?;
1471 }
1472
1473 if let TypeInner::Matrix { .. } = module.types[member.ty].inner {
1477 write!(self.out, "row_major ")?;
1478 }
1479
1480 self.write_type(module, member.ty)?;
1482 write!(
1483 self.out,
1484 " {}",
1485 &self.names[&NameKey::StructMember(handle, index as u32)]
1486 )?;
1487 }
1488 }
1489
1490 self.write_semantic(&member.binding, shader_stage)?;
1491 writeln!(self.out, ";")?;
1492 }
1493
1494 if members.last().unwrap().binding.is_none() && span > last_offset {
1496 let padding = (span - last_offset) / 4;
1497 for i in 0..padding {
1498 writeln!(self.out, "{}int _end_pad_{};", back::INDENT, i)?;
1499 }
1500 }
1501
1502 writeln!(self.out, "}};")?;
1503 Ok(())
1504 }
1505
1506 pub(super) fn write_global_type(
1511 &mut self,
1512 module: &Module,
1513 ty: Handle<crate::Type>,
1514 ) -> BackendResult {
1515 let matrix_data = get_inner_matrix_data(module, ty);
1516
1517 if let Some(MatrixType {
1520 columns,
1521 rows: crate::VectorSize::Bi,
1522 width: 4,
1523 }) = matrix_data
1524 {
1525 write!(self.out, "__mat{}x2", columns as u8)?;
1526 } else {
1527 if matrix_data.is_some() {
1531 write!(self.out, "row_major ")?;
1532 }
1533
1534 self.write_type(module, ty)?;
1535 }
1536
1537 Ok(())
1538 }
1539
1540 pub(super) fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
1545 let inner = &module.types[ty].inner;
1546 match *inner {
1547 TypeInner::Struct { .. } => write!(self.out, "{}", self.names[&NameKey::Type(ty)])?,
1548 TypeInner::Array { base, .. } | TypeInner::BindingArray { base, .. } => {
1550 self.write_type(module, base)?
1551 }
1552 ref other => self.write_value_type(module, other)?,
1553 }
1554
1555 Ok(())
1556 }
1557
1558 pub(super) fn write_value_type(&mut self, module: &Module, inner: &TypeInner) -> BackendResult {
1563 match *inner {
1564 TypeInner::Scalar(scalar) | TypeInner::Atomic(scalar) => {
1565 write!(self.out, "{}", scalar.to_hlsl_str()?)?;
1566 }
1567 TypeInner::Vector { size, scalar } => {
1568 write!(
1569 self.out,
1570 "{}{}",
1571 scalar.to_hlsl_str()?,
1572 common::vector_size_str(size)
1573 )?;
1574 }
1575 TypeInner::Matrix {
1576 columns,
1577 rows,
1578 scalar,
1579 } => {
1580 write!(
1585 self.out,
1586 "{}{}x{}",
1587 scalar.to_hlsl_str()?,
1588 common::vector_size_str(columns),
1589 common::vector_size_str(rows),
1590 )?;
1591 }
1592 TypeInner::Image {
1593 dim,
1594 arrayed,
1595 class,
1596 } => {
1597 self.write_image_type(dim, arrayed, class)?;
1598 }
1599 TypeInner::Sampler { comparison } => {
1600 let sampler = if comparison {
1601 "SamplerComparisonState"
1602 } else {
1603 "SamplerState"
1604 };
1605 write!(self.out, "{sampler}")?;
1606 }
1607 TypeInner::Array { base, size, .. } | TypeInner::BindingArray { base, size } => {
1611 self.write_array_size(module, base, size)?;
1612 }
1613 TypeInner::AccelerationStructure { .. } => {
1614 write!(self.out, "RaytracingAccelerationStructure")?;
1615 }
1616 TypeInner::RayQuery { .. } => {
1617 write!(self.out, "RayQuery<RAY_FLAG_NONE>")?;
1619 }
1620 _ => return Err(Error::Unimplemented(format!("write_value_type {inner:?}"))),
1621 }
1622
1623 Ok(())
1624 }
1625
1626 fn write_function(
1630 &mut self,
1631 module: &Module,
1632 name: &str,
1633 func: &crate::Function,
1634 func_ctx: &back::FunctionCtx<'_>,
1635 info: &valid::FunctionInfo,
1636 header: String,
1637 ) -> BackendResult {
1638 self.update_expressions_to_bake(module, func, info);
1641 let ep = match func_ctx.ty {
1642 back::FunctionType::EntryPoint(idx) => Some(&module.entry_points[idx as usize]),
1643 back::FunctionType::Function(_) => None,
1644 };
1645
1646 let nested = matches!(
1647 ep,
1648 Some(crate::EntryPoint {
1649 stage: ShaderStage::Task | ShaderStage::Mesh,
1650 ..
1651 })
1652 );
1653 if !nested {
1654 write!(self.out, "{header}")?;
1655 }
1656
1657 if let Some(ref result) = func.result {
1658 let array_return_type = match module.types[result.ty].inner {
1660 TypeInner::Array { base, size, .. } => {
1661 let array_return_type = self.namer.call(&format!("ret_{name}"));
1662 write!(self.out, "typedef ")?;
1663 self.write_type(module, result.ty)?;
1664 write!(self.out, " {array_return_type}")?;
1665 self.write_array_size(module, base, size)?;
1666 writeln!(self.out, ";")?;
1667 Some(array_return_type)
1668 }
1669 _ => None,
1670 };
1671
1672 if let Some(
1674 ref binding @ crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }),
1675 ) = result.binding
1676 {
1677 self.write_modifier(binding)?;
1678 }
1679
1680 match func_ctx.ty {
1682 back::FunctionType::Function(_) => {
1683 if let Some(array_return_type) = array_return_type {
1684 write!(self.out, "{array_return_type}")?;
1685 } else {
1686 self.write_type(module, result.ty)?;
1687 }
1688 }
1689 back::FunctionType::EntryPoint(index) => {
1690 if let Some(ref ep_output) =
1691 self.entry_point_io.get(&(index as usize)).unwrap().output
1692 {
1693 write!(self.out, "{}", ep_output.ty_name)?;
1694 } else {
1695 self.write_type(module, result.ty)?;
1696 }
1697 }
1698 }
1699 } else {
1700 write!(self.out, "void")?;
1701 }
1702
1703 let nested_name = if nested {
1704 self.namer.call(&format!("_{name}"))
1705 } else {
1706 name.to_string()
1707 };
1708
1709 write!(self.out, " {nested_name}(")?;
1711
1712 let need_workgroup_variables_initialization =
1713 self.need_workgroup_variables_initialization(func_ctx, module);
1714
1715 let mut any_args_written = false;
1716 let mut separator = || {
1717 if any_args_written {
1718 ", "
1719 } else {
1720 any_args_written = true;
1721 ""
1722 }
1723 };
1724
1725 let needs_local_invocation_index_name = need_workgroup_variables_initialization || nested;
1726 let mut local_invocation_index_name = None;
1727 let mut nested_wgsl_args: Vec<String> = Vec::new();
1730 let mut nested_task_payload_name: Option<String> = None;
1731 match func_ctx.ty {
1733 back::FunctionType::Function(handle) => {
1734 for (index, arg) in func.arguments.iter().enumerate() {
1735 write!(self.out, "{}", separator())?;
1736 self.write_function_argument(module, handle, arg, index)?;
1737 }
1738 for (var_handle, var) in module.global_variables.iter() {
1740 let uses = info[var_handle];
1741 if uses.contains(valid::GlobalUse::READ)
1742 && !uses.contains(valid::GlobalUse::WRITE)
1743 && var.space == crate::AddressSpace::TaskPayload
1744 {
1745 self.function_task_payload_var.insert(handle, var_handle);
1746 write!(self.out, "{}in ", separator())?;
1747
1748 self.write_type(module, var.ty)?;
1749 let name = &self.names[&NameKey::GlobalVariable(var_handle)];
1750 write!(self.out, " {name}")?;
1751 break;
1752 }
1753 }
1754 }
1755 back::FunctionType::EntryPoint(ep_index) => {
1756 let ep = &module.entry_points[ep_index as usize];
1757 if let Some(ref ep_input) =
1758 self.entry_point_io.get(&(ep_index as usize)).unwrap().input
1759 {
1760 write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name)?;
1761 separator();
1762 nested_wgsl_args.push(ep_input.arg_name.clone());
1763 } else {
1764 let stage = ep.stage;
1765 for (index, arg) in func.arguments.iter().enumerate() {
1766 write!(self.out, "{}", separator())?;
1767 self.write_type(module, arg.ty)?;
1768
1769 let argument_name =
1770 &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)];
1771
1772 if arg.binding
1773 == Some(crate::Binding::BuiltIn(
1774 crate::BuiltIn::LocalInvocationIndex,
1775 ))
1776 {
1777 local_invocation_index_name = Some(argument_name.clone());
1778 }
1779
1780 nested_wgsl_args.push(argument_name.clone());
1781 write!(self.out, " {argument_name}")?;
1782 if let TypeInner::Array { base, size, .. } = module.types[arg.ty].inner {
1783 self.write_array_size(module, base, size)?;
1784 }
1785
1786 self.write_semantic(&arg.binding, Some((stage, Io::Input)))?;
1787 }
1788 }
1789 if ep.stage == ShaderStage::Mesh {
1790 if let Some(var_handle) = ep.task_payload {
1791 let var = &module.global_variables[var_handle];
1792 write!(self.out, "{}in ", separator())?;
1793 self.write_type(module, var.ty)?;
1794 let arg_name = &self.names[&NameKey::GlobalVariable(var_handle)];
1795 write!(self.out, " {arg_name}")?;
1796 nested_task_payload_name = Some(arg_name.clone());
1797 if let TypeInner::Array { base, size, .. } = module.types[var.ty].inner {
1798 self.write_array_size(module, base, size)?;
1799 }
1800 }
1801 }
1802 if needs_local_invocation_index_name && local_invocation_index_name.is_none() {
1803 let name = self.namer.call("local_invocation_index");
1804 write!(self.out, "{}uint {name}", separator())?;
1805 write!(self.out, " : SV_GroupIndex")?;
1806 local_invocation_index_name = Some(name);
1807 }
1808 }
1809 }
1810 write!(self.out, ")")?;
1812
1813 if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1815 let stage = module.entry_points[index as usize].stage;
1816 if let Some(crate::FunctionResult { ref binding, .. }) = func.result {
1817 self.write_semantic(binding, Some((stage, Io::Output)))?;
1818 }
1819 }
1820
1821 writeln!(self.out)?;
1823 writeln!(self.out, "{{")?;
1824
1825 if need_workgroup_variables_initialization && !nested {
1826 let back::FunctionType::EntryPoint(index) = func_ctx.ty else {
1827 unreachable!();
1828 };
1829 writeln!(
1830 self.out,
1831 "{}if ({} == 0) {{",
1832 back::INDENT,
1833 local_invocation_index_name.as_ref().unwrap(),
1836 )?;
1837 self.write_workgroup_variables_initialization(
1838 func_ctx,
1839 module,
1840 module.entry_points[index as usize].stage,
1841 )?;
1842
1843 writeln!(self.out, "{}}}", back::INDENT)?;
1844 self.write_control_barrier(crate::Barrier::WORK_GROUP, back::Level(1))?;
1845 }
1846
1847 if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1848 self.write_ep_arguments_initialization(module, func, index)?;
1849 }
1850
1851 for (handle, local) in func.local_variables.iter() {
1853 write!(self.out, "{}", back::INDENT)?;
1855
1856 self.write_type(module, local.ty)?;
1859 write!(self.out, " {}", self.names[&func_ctx.name_key(handle)])?;
1860 if let TypeInner::Array { base, size, .. } = module.types[local.ty].inner {
1862 self.write_array_size(module, base, size)?;
1863 }
1864
1865 let is_ray_query = match module.types[local.ty].inner {
1866 TypeInner::RayQuery { .. } => true,
1868 _ => {
1869 write!(self.out, " = ")?;
1870 if let Some(init) = local.init {
1872 self.write_expr(module, init, func_ctx)?;
1873 } else {
1874 self.write_default_init(module, local.ty)?;
1876 }
1877 false
1878 }
1879 };
1880 writeln!(self.out, ";")?;
1882 if is_ray_query {
1884 write!(self.out, "{}", back::INDENT)?;
1885 self.write_value_type(module, &TypeInner::Scalar(Scalar::U32))?;
1886 writeln!(
1887 self.out,
1888 " {RAY_QUERY_TRACKER_VARIABLE_PREFIX}{} = 0;",
1889 self.names[&func_ctx.name_key(handle)]
1890 )?;
1891 }
1892 }
1893
1894 if !func.local_variables.is_empty() {
1895 writeln!(self.out)?;
1896 }
1897
1898 for sta in func.body.iter() {
1900 self.write_stmt(module, sta, func_ctx, back::Level(1))?;
1902 }
1903
1904 writeln!(self.out, "}}")?;
1905
1906 if nested {
1907 self.write_nested_function_outer(
1908 module,
1909 func_ctx,
1910 &header,
1911 name,
1912 need_workgroup_variables_initialization,
1913 &nested_name,
1914 ep.unwrap(),
1915 NestedEntryPointArgs {
1916 user_args: nested_wgsl_args,
1917 task_payload: nested_task_payload_name,
1918 local_invocation_index: local_invocation_index_name.unwrap(),
1920 },
1921 )?;
1922 }
1923
1924 self.named_expressions.clear();
1925
1926 Ok(())
1927 }
1928
1929 fn write_function_argument(
1930 &mut self,
1931 module: &Module,
1932 handle: Handle<crate::Function>,
1933 arg: &crate::FunctionArgument,
1934 index: usize,
1935 ) -> BackendResult {
1936 if let TypeInner::Image {
1939 class: crate::ImageClass::External,
1940 ..
1941 } = module.types[arg.ty].inner
1942 {
1943 return self.write_function_external_texture_argument(module, handle, index);
1944 }
1945
1946 let arg_ty = match module.types[arg.ty].inner {
1948 TypeInner::Pointer { base, .. } => {
1950 write!(self.out, "inout ")?;
1952 base
1953 }
1954 _ => arg.ty,
1955 };
1956 self.write_type(module, arg_ty)?;
1957
1958 let argument_name = &self.names[&NameKey::FunctionArgument(handle, index as u32)];
1959
1960 write!(self.out, " {argument_name}")?;
1962 if let TypeInner::Array { base, size, .. } = module.types[arg_ty].inner {
1963 self.write_array_size(module, base, size)?;
1964 }
1965
1966 Ok(())
1967 }
1968
1969 fn write_function_external_texture_argument(
1970 &mut self,
1971 module: &Module,
1972 handle: Handle<crate::Function>,
1973 index: usize,
1974 ) -> BackendResult {
1975 let plane_names = [0, 1, 2].map(|i| {
1976 &self.names[&NameKey::ExternalTextureFunctionArgument(
1977 handle,
1978 index as u32,
1979 ExternalTextureNameKey::Plane(i),
1980 )]
1981 });
1982 let params_name = &self.names[&NameKey::ExternalTextureFunctionArgument(
1983 handle,
1984 index as u32,
1985 ExternalTextureNameKey::Params,
1986 )];
1987 let params_ty_name =
1988 &self.names[&NameKey::Type(module.special_types.external_texture_params.unwrap())];
1989 write!(
1990 self.out,
1991 "Texture2D<float4> {}, Texture2D<float4> {}, Texture2D<float4> {}, {params_ty_name} {params_name}",
1992 plane_names[0], plane_names[1], plane_names[2],
1993 )?;
1994 Ok(())
1995 }
1996
1997 fn need_workgroup_variables_initialization(
1998 &mut self,
1999 func_ctx: &back::FunctionCtx,
2000 module: &Module,
2001 ) -> bool {
2002 self.options.zero_initialize_workgroup_memory
2003 && func_ctx.ty.is_compute_like_entry_point(module)
2004 && module.global_variables.iter().any(|(handle, var)| {
2005 !func_ctx.info[handle].is_empty() && var.space.is_workgroup_like()
2006 })
2007 }
2008
2009 pub(super) fn write_workgroup_variables_initialization(
2010 &mut self,
2011 func_ctx: &back::FunctionCtx,
2012 module: &Module,
2013 stage: ShaderStage,
2014 ) -> BackendResult {
2015 let vars = module.global_variables.iter().filter(|&(handle, var)| {
2016 let task_needs_zero =
2018 (var.space == crate::AddressSpace::TaskPayload) && stage == ShaderStage::Task;
2019 !func_ctx.info[handle].is_empty()
2020 && (var.space == crate::AddressSpace::WorkGroup || task_needs_zero)
2021 });
2022
2023 for (handle, var) in vars {
2024 let name = &self.names[&NameKey::GlobalVariable(handle)];
2025 write!(self.out, "{}{} = ", back::Level(2), name)?;
2026 self.write_default_init(module, var.ty)?;
2027 writeln!(self.out, ";")?;
2028 }
2029 Ok(())
2030 }
2031
2032 fn write_switch(
2034 &mut self,
2035 module: &Module,
2036 func_ctx: &back::FunctionCtx<'_>,
2037 level: back::Level,
2038 selector: Handle<crate::Expression>,
2039 cases: &[crate::SwitchCase],
2040 ) -> BackendResult {
2041 let indent_level_1 = level.next();
2043 let indent_level_2 = indent_level_1.next();
2044
2045 if let Some(variable) = self.continue_ctx.enter_switch(&mut self.namer) {
2047 writeln!(self.out, "{level}bool {variable} = false;",)?;
2048 };
2049
2050 let one_body = cases
2055 .iter()
2056 .rev()
2057 .skip(1)
2058 .all(|case| case.fall_through && case.body.is_empty());
2059 if one_body {
2060 writeln!(self.out, "{level}do {{")?;
2062 if let Some(case) = cases.last() {
2066 for sta in case.body.iter() {
2067 self.write_stmt(module, sta, func_ctx, indent_level_1)?;
2068 }
2069 }
2070 writeln!(self.out, "{level}}} while(false);")?;
2072 } else {
2073 write!(self.out, "{level}")?;
2075 write!(self.out, "switch(")?;
2076 self.write_expr(module, selector, func_ctx)?;
2077 writeln!(self.out, ") {{")?;
2078
2079 for (i, case) in cases.iter().enumerate() {
2080 match case.value {
2081 crate::SwitchValue::I32(value) => {
2082 write!(self.out, "{indent_level_1}case {value}:")?
2083 }
2084 crate::SwitchValue::U32(value) => {
2085 write!(self.out, "{indent_level_1}case {value}u:")?
2086 }
2087 crate::SwitchValue::Default => write!(self.out, "{indent_level_1}default:")?,
2088 }
2089
2090 let write_block_braces = !(case.fall_through && case.body.is_empty());
2097 if write_block_braces {
2098 writeln!(self.out, " {{")?;
2099 } else {
2100 writeln!(self.out)?;
2101 }
2102
2103 if case.fall_through && !case.body.is_empty() {
2121 let curr_len = i + 1;
2122 let end_case_idx = curr_len
2123 + cases
2124 .iter()
2125 .skip(curr_len)
2126 .position(|case| !case.fall_through)
2127 .unwrap();
2128 let indent_level_3 = indent_level_2.next();
2129 for case in &cases[i..=end_case_idx] {
2130 writeln!(self.out, "{indent_level_2}{{")?;
2131 let prev_len = self.named_expressions.len();
2132 for sta in case.body.iter() {
2133 self.write_stmt(module, sta, func_ctx, indent_level_3)?;
2134 }
2135 self.named_expressions.truncate(prev_len);
2137 writeln!(self.out, "{indent_level_2}}}")?;
2138 }
2139
2140 let last_case = &cases[end_case_idx];
2141 if last_case.body.last().is_none_or(|s| !s.is_terminator()) {
2142 writeln!(self.out, "{indent_level_2}break;")?;
2143 }
2144 } else {
2145 for sta in case.body.iter() {
2146 self.write_stmt(module, sta, func_ctx, indent_level_2)?;
2147 }
2148 if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator()) {
2149 writeln!(self.out, "{indent_level_2}break;")?;
2150 }
2151 }
2152
2153 if write_block_braces {
2154 writeln!(self.out, "{indent_level_1}}}")?;
2155 }
2156 }
2157
2158 writeln!(self.out, "{level}}}")?;
2159 }
2160
2161 use back::continue_forward::ExitControlFlow;
2163 let op = match self.continue_ctx.exit_switch() {
2164 ExitControlFlow::None => None,
2165 ExitControlFlow::Continue { variable } => Some(("continue", variable)),
2166 ExitControlFlow::Break { variable } => Some(("break", variable)),
2167 };
2168 if let Some((control_flow, variable)) = op {
2169 writeln!(self.out, "{level}if ({variable}) {{")?;
2170 writeln!(self.out, "{indent_level_1}{control_flow};")?;
2171 writeln!(self.out, "{level}}}")?;
2172 }
2173
2174 Ok(())
2175 }
2176
2177 fn write_index(
2178 &mut self,
2179 module: &Module,
2180 index: Index,
2181 func_ctx: &back::FunctionCtx<'_>,
2182 ) -> BackendResult {
2183 match index {
2184 Index::Static(index) => {
2185 write!(self.out, "{index}")?;
2186 }
2187 Index::Expression(index) => {
2188 self.write_expr(module, index, func_ctx)?;
2189 }
2190 }
2191 Ok(())
2192 }
2193
2194 fn write_stmt(
2199 &mut self,
2200 module: &Module,
2201 stmt: &crate::Statement,
2202 func_ctx: &back::FunctionCtx<'_>,
2203 level: back::Level,
2204 ) -> BackendResult {
2205 use crate::Statement;
2206
2207 match *stmt {
2208 Statement::Emit(ref range) => {
2209 for handle in range.clone() {
2210 let ptr_class = func_ctx.resolve_type(handle, &module.types).pointer_space();
2211 let expr_name = if ptr_class.is_some() {
2212 None
2216 } else if let Some(name) = func_ctx.named_expressions.get(&handle) {
2217 Some(self.namer.call(name))
2222 } else if self.need_bake_expressions.contains(&handle) {
2223 Some(Baked(handle).to_string())
2224 } else {
2225 None
2226 };
2227
2228 if let Some(name) = expr_name {
2229 write!(self.out, "{level}")?;
2230 self.write_named_expr(module, handle, name, handle, func_ctx)?;
2231 }
2232 }
2233 }
2234 Statement::Block(ref block) => {
2236 write!(self.out, "{level}")?;
2237 writeln!(self.out, "{{")?;
2238 for sta in block.iter() {
2239 self.write_stmt(module, sta, func_ctx, level.next())?
2241 }
2242 writeln!(self.out, "{level}}}")?
2243 }
2244 Statement::If {
2246 condition,
2247 ref accept,
2248 ref reject,
2249 } => {
2250 write!(self.out, "{level}")?;
2251 write!(self.out, "if (")?;
2252 self.write_expr(module, condition, func_ctx)?;
2253 writeln!(self.out, ") {{")?;
2254
2255 let l2 = level.next();
2256 for sta in accept {
2257 self.write_stmt(module, sta, func_ctx, l2)?;
2259 }
2260
2261 if !reject.is_empty() {
2264 writeln!(self.out, "{level}}} else {{")?;
2265
2266 for sta in reject {
2267 self.write_stmt(module, sta, func_ctx, l2)?;
2269 }
2270 }
2271
2272 writeln!(self.out, "{level}}}")?
2273 }
2274 Statement::Kill => writeln!(self.out, "{level}discard;")?,
2276 Statement::Return { value: None } => {
2277 writeln!(self.out, "{level}return;")?;
2278 }
2279 Statement::Return { value: Some(expr) } => {
2280 let base_ty_res = &func_ctx.info[expr].ty;
2281 let mut resolved = base_ty_res.inner_with(&module.types);
2282 if let TypeInner::Pointer { base, space: _ } = *resolved {
2283 resolved = &module.types[base].inner;
2284 }
2285
2286 if let TypeInner::Struct { .. } = *resolved {
2287 let ty = base_ty_res.handle().unwrap();
2289 let struct_name = &self.names[&NameKey::Type(ty)];
2290 let variable_name = self.namer.call(&struct_name.to_lowercase());
2291 write!(self.out, "{level}const {struct_name} {variable_name} = ",)?;
2292 self.write_expr(module, expr, func_ctx)?;
2293 writeln!(self.out, ";")?;
2294
2295 let ep_output = match func_ctx.ty {
2297 back::FunctionType::Function(_) => None,
2298 back::FunctionType::EntryPoint(index) => self
2299 .entry_point_io
2300 .get(&(index as usize))
2301 .unwrap()
2302 .output
2303 .as_ref(),
2304 };
2305 let final_name = match ep_output {
2306 Some(ep_output) => {
2307 let final_name = self.namer.call(&variable_name);
2308 write!(
2309 self.out,
2310 "{}const {} {} = {{ ",
2311 level, ep_output.ty_name, final_name,
2312 )?;
2313 for (index, m) in ep_output.members.iter().enumerate() {
2314 if index != 0 {
2315 write!(self.out, ", ")?;
2316 }
2317 let member_name = &self.names[&NameKey::StructMember(ty, m.index)];
2318 write!(self.out, "{variable_name}.{member_name}")?;
2319 }
2320 writeln!(self.out, " }};")?;
2321 final_name
2322 }
2323 None => variable_name,
2324 };
2325 writeln!(self.out, "{level}return {final_name};")?;
2326 } else {
2327 write!(self.out, "{level}return ")?;
2328 self.write_expr(module, expr, func_ctx)?;
2329 writeln!(self.out, ";")?
2330 }
2331 }
2332 Statement::Store { pointer, value } => {
2333 let ty_inner = func_ctx.resolve_type(pointer, &module.types);
2334 if ty_inner.is_atomic_pointer(&module.types) {
2335 let pointer_space = ty_inner.pointer_space().unwrap();
2336 let dummy = self.namer.call("dummy");
2337 write!(self.out, "{level}{{ ")?;
2338 if let TypeInner::Pointer { base, .. } = *ty_inner {
2339 self.write_value_type(module, &module.types[base].inner)?;
2340 }
2341 write!(self.out, " {dummy} = 0; ")?;
2342 match pointer_space {
2343 crate::AddressSpace::WorkGroup => {
2344 write!(self.out, "InterlockedExchange(")?;
2345 self.write_expr(module, pointer, func_ctx)?;
2346 }
2347 crate::AddressSpace::Storage { .. } => {
2348 let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2349 let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
2350 write!(self.out, "{var_name}.InterlockedExchange(")?;
2351 let chain = mem::take(&mut self.temp_access_chain);
2352 self.write_storage_address(module, &chain, func_ctx)?;
2353 self.temp_access_chain = chain;
2354 }
2355 _ => unreachable!(),
2356 }
2357 write!(self.out, ", ")?;
2358 self.write_expr(module, value, func_ctx)?;
2359 writeln!(self.out, ", {dummy}); }}")?;
2360 } else if let Some(crate::AddressSpace::Storage { .. }) = ty_inner.pointer_space() {
2361 let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2362 self.write_storage_store(
2363 module,
2364 var_handle,
2365 StoreValue::Expression(value),
2366 func_ctx,
2367 level,
2368 None,
2369 )?;
2370 } else {
2371 enum MatrixAccess {
2377 Direct {
2378 base: Handle<crate::Expression>,
2379 index: u32,
2380 },
2381 Struct {
2382 columns: crate::VectorSize,
2383 base: Handle<crate::Expression>,
2384 },
2385 }
2386
2387 let get_members = |expr: Handle<crate::Expression>| {
2388 let resolved = func_ctx.resolve_type(expr, &module.types);
2389 match *resolved {
2390 TypeInner::Pointer { base, .. } => match module.types[base].inner {
2391 TypeInner::Struct { ref members, .. } => Some(members),
2392 _ => None,
2393 },
2394 _ => None,
2395 }
2396 };
2397
2398 write!(self.out, "{level}")?;
2399
2400 let matrix_access_on_lhs =
2401 find_matrix_in_access_chain(module, pointer, func_ctx).and_then(
2402 |(matrix_expr, vector, scalar)| match (
2403 func_ctx.resolve_type(matrix_expr, &module.types),
2404 &func_ctx.expressions[matrix_expr],
2405 ) {
2406 (
2407 &TypeInner::Pointer { base: ty, .. },
2408 &crate::Expression::AccessIndex { base, index },
2409 ) if matches!(
2410 module.types[ty].inner,
2411 TypeInner::Matrix {
2412 rows: crate::VectorSize::Bi,
2413 ..
2414 }
2415 ) && get_members(base)
2416 .map(|members| members[index as usize].binding.is_none())
2417 == Some(true) =>
2418 {
2419 Some((MatrixAccess::Direct { base, index }, vector, scalar))
2420 }
2421 _ => {
2422 if let Some(MatrixType {
2423 columns,
2424 rows: crate::VectorSize::Bi,
2425 width: 4,
2426 }) = get_inner_matrix_of_struct_array_member(
2427 module,
2428 matrix_expr,
2429 func_ctx,
2430 true,
2431 ) {
2432 Some((
2433 MatrixAccess::Struct {
2434 columns,
2435 base: matrix_expr,
2436 },
2437 vector,
2438 scalar,
2439 ))
2440 } else {
2441 None
2442 }
2443 }
2444 },
2445 );
2446
2447 match matrix_access_on_lhs {
2448 Some((MatrixAccess::Direct { index, base }, vector, scalar)) => {
2449 let base_ty_res = &func_ctx.info[base].ty;
2450 let resolved = base_ty_res.inner_with(&module.types);
2451 let ty = match *resolved {
2452 TypeInner::Pointer { base, .. } => base,
2453 _ => base_ty_res.handle().unwrap(),
2454 };
2455
2456 if let Some(Index::Static(vec_index)) = vector {
2457 self.write_expr(module, base, func_ctx)?;
2458 write!(
2459 self.out,
2460 ".{}_{}",
2461 &self.names[&NameKey::StructMember(ty, index)],
2462 vec_index
2463 )?;
2464
2465 if let Some(scalar_index) = scalar {
2466 write!(self.out, "[")?;
2467 self.write_index(module, scalar_index, func_ctx)?;
2468 write!(self.out, "]")?;
2469 }
2470
2471 write!(self.out, " = ")?;
2472 self.write_expr(module, value, func_ctx)?;
2473 writeln!(self.out, ";")?;
2474 } else {
2475 let access = WrappedStructMatrixAccess { ty, index };
2476 match (&vector, &scalar) {
2477 (&Some(_), &Some(_)) => {
2478 self.write_wrapped_struct_matrix_set_scalar_function_name(
2479 access,
2480 )?;
2481 }
2482 (&Some(_), &None) => {
2483 self.write_wrapped_struct_matrix_set_vec_function_name(
2484 access,
2485 )?;
2486 }
2487 (&None, _) => {
2488 self.write_wrapped_struct_matrix_set_function_name(access)?;
2489 }
2490 }
2491
2492 write!(self.out, "(")?;
2493 self.write_expr(module, base, func_ctx)?;
2494 write!(self.out, ", ")?;
2495 self.write_expr(module, value, func_ctx)?;
2496
2497 if let Some(Index::Expression(vec_index)) = vector {
2498 write!(self.out, ", ")?;
2499 self.write_expr(module, vec_index, func_ctx)?;
2500
2501 if let Some(scalar_index) = scalar {
2502 write!(self.out, ", ")?;
2503 self.write_index(module, scalar_index, func_ctx)?;
2504 }
2505 }
2506 writeln!(self.out, ");")?;
2507 }
2508 }
2509 Some((
2510 MatrixAccess::Struct { columns, base },
2511 Some(Index::Expression(vec_index)),
2512 scalar,
2513 )) => {
2514 if scalar.is_some() {
2518 write!(self.out, "__set_el_of_mat{}x2", columns as u8)?;
2519 } else {
2520 write!(self.out, "__set_col_of_mat{}x2", columns as u8)?;
2521 }
2522 write!(self.out, "(")?;
2523 self.write_expr(module, base, func_ctx)?;
2524 write!(self.out, ", ")?;
2525 self.write_expr(module, vec_index, func_ctx)?;
2526
2527 if let Some(scalar_index) = scalar {
2528 write!(self.out, ", ")?;
2529 self.write_index(module, scalar_index, func_ctx)?;
2530 }
2531
2532 write!(self.out, ", ")?;
2533 self.write_expr(module, value, func_ctx)?;
2534
2535 writeln!(self.out, ");")?;
2536 }
2537 Some((MatrixAccess::Struct { .. }, Some(Index::Static(_)), _))
2538 | Some((MatrixAccess::Struct { .. }, None, _))
2539 | None => {
2540 self.write_expr(module, pointer, func_ctx)?;
2541 write!(self.out, " = ")?;
2542
2543 if let Some(MatrixType {
2548 columns,
2549 rows: crate::VectorSize::Bi,
2550 width: 4,
2551 }) = get_inner_matrix_of_struct_array_member(
2552 module, pointer, func_ctx, false,
2553 ) {
2554 let mut resolved = func_ctx.resolve_type(pointer, &module.types);
2555 if let TypeInner::Pointer { base, .. } = *resolved {
2556 resolved = &module.types[base].inner;
2557 }
2558
2559 write!(self.out, "(__mat{}x2", columns as u8)?;
2560 if let TypeInner::Array { base, size, .. } = *resolved {
2561 self.write_array_size(module, base, size)?;
2562 }
2563 write!(self.out, ")")?;
2564 }
2565
2566 self.write_expr(module, value, func_ctx)?;
2567 writeln!(self.out, ";")?
2568 }
2569 }
2570 }
2571 }
2572 Statement::Loop {
2573 ref body,
2574 ref continuing,
2575 break_if,
2576 } => {
2577 let force_loop_bound_statements = self.gen_force_bounded_loop_statements(level);
2578 let gate_name = (!continuing.is_empty() || break_if.is_some())
2579 .then(|| self.namer.call("loop_init"));
2580
2581 if let Some((ref decl, _)) = force_loop_bound_statements {
2582 writeln!(self.out, "{decl}")?;
2583 }
2584 if let Some(ref gate_name) = gate_name {
2585 writeln!(self.out, "{level}bool {gate_name} = true;")?;
2586 }
2587
2588 self.continue_ctx.enter_loop();
2589 writeln!(self.out, "{level}while(true) {{")?;
2590 if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
2591 writeln!(self.out, "{break_and_inc}")?;
2592 }
2593 let l2 = level.next();
2594 if let Some(gate_name) = gate_name {
2595 writeln!(self.out, "{l2}if (!{gate_name}) {{")?;
2596 let l3 = l2.next();
2597 for sta in continuing.iter() {
2598 self.write_stmt(module, sta, func_ctx, l3)?;
2599 }
2600 if let Some(condition) = break_if {
2601 write!(self.out, "{l3}if (")?;
2602 self.write_expr(module, condition, func_ctx)?;
2603 writeln!(self.out, ") {{")?;
2604 writeln!(self.out, "{}break;", l3.next())?;
2605 writeln!(self.out, "{l3}}}")?;
2606 }
2607 writeln!(self.out, "{l2}}}")?;
2608 writeln!(self.out, "{l2}{gate_name} = false;")?;
2609 }
2610
2611 for sta in body.iter() {
2612 self.write_stmt(module, sta, func_ctx, l2)?;
2613 }
2614
2615 writeln!(self.out, "{level}}}")?;
2616 self.continue_ctx.exit_loop();
2617 }
2618 Statement::Break => writeln!(self.out, "{level}break;")?,
2619 Statement::Continue => {
2620 if let Some(variable) = self.continue_ctx.continue_encountered() {
2621 writeln!(self.out, "{level}{variable} = true;")?;
2622 writeln!(self.out, "{level}break;")?
2623 } else {
2624 writeln!(self.out, "{level}continue;")?
2625 }
2626 }
2627 Statement::ControlBarrier(barrier) => {
2628 self.write_control_barrier(barrier, level)?;
2629 }
2630 Statement::MemoryBarrier(barrier) => {
2631 self.write_memory_barrier(barrier, level)?;
2632 }
2633 Statement::ImageStore {
2634 image,
2635 coordinate,
2636 array_index,
2637 value,
2638 } => {
2639 write!(self.out, "{level}")?;
2640 self.write_expr(module, image, func_ctx)?;
2641
2642 write!(self.out, "[")?;
2643 if let Some(index) = array_index {
2644 write!(self.out, "int3(")?;
2646 self.write_expr(module, coordinate, func_ctx)?;
2647 write!(self.out, ", ")?;
2648 self.write_expr(module, index, func_ctx)?;
2649 write!(self.out, ")")?;
2650 } else {
2651 self.write_expr(module, coordinate, func_ctx)?;
2652 }
2653 write!(self.out, "]")?;
2654
2655 write!(self.out, " = ")?;
2656 self.write_expr(module, value, func_ctx)?;
2657 writeln!(self.out, ";")?;
2658 }
2659 Statement::Call {
2660 function,
2661 ref arguments,
2662 result,
2663 } => {
2664 write!(self.out, "{level}")?;
2665
2666 if let Some(expr) = result {
2667 write!(self.out, "const ")?;
2668 let name = Baked(expr).to_string();
2669 let expr_ty = &func_ctx.info[expr].ty;
2670 let ty_inner = match *expr_ty {
2671 proc::TypeResolution::Handle(handle) => {
2672 self.write_type(module, handle)?;
2673 &module.types[handle].inner
2674 }
2675 proc::TypeResolution::Value(ref value) => {
2676 self.write_value_type(module, value)?;
2677 value
2678 }
2679 };
2680 write!(self.out, " {name}")?;
2681 if let TypeInner::Array { base, size, .. } = *ty_inner {
2682 self.write_array_size(module, base, size)?;
2683 }
2684 write!(self.out, " = ")?;
2685 self.named_expressions.insert(expr, name);
2686 }
2687 let func_name = &self.names[&NameKey::Function(function)];
2688 write!(self.out, "{func_name}(")?;
2689 let mut any_args_written = false;
2690 let mut separator = || {
2691 if any_args_written {
2692 ", "
2693 } else {
2694 any_args_written = true;
2695 ""
2696 }
2697 };
2698 for argument in arguments {
2699 write!(self.out, "{}", separator())?;
2700 self.write_expr(module, *argument, func_ctx)?;
2701 }
2702 if let Some(&var) = self.function_task_payload_var.get(&function) {
2703 let name = &self.names[&NameKey::GlobalVariable(var)];
2704 write!(self.out, "{}{name}", separator())?;
2706 }
2707 writeln!(self.out, ");")?;
2708 }
2709 Statement::Atomic {
2710 pointer,
2711 ref fun,
2712 value,
2713 result,
2714 } => {
2715 write!(self.out, "{level}")?;
2716 let res_var_info = if let Some(res_handle) = result {
2717 let name = Baked(res_handle).to_string();
2718 match func_ctx.info[res_handle].ty {
2719 proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2720 proc::TypeResolution::Value(ref value) => {
2721 self.write_value_type(module, value)?
2722 }
2723 };
2724 write!(self.out, " {name}; ")?;
2725 self.named_expressions.insert(res_handle, name.clone());
2726 Some((res_handle, name))
2727 } else {
2728 None
2729 };
2730 let pointer_space = func_ctx
2731 .resolve_type(pointer, &module.types)
2732 .pointer_space()
2733 .unwrap();
2734 let fun_str = fun.to_hlsl_suffix();
2735 let compare_expr = match *fun {
2736 crate::AtomicFunction::Exchange { compare: Some(cmp) } => Some(cmp),
2737 _ => None,
2738 };
2739 match pointer_space {
2740 crate::AddressSpace::WorkGroup => {
2741 write!(self.out, "Interlocked{fun_str}(")?;
2742 self.write_expr(module, pointer, func_ctx)?;
2743 self.emit_hlsl_atomic_tail(
2744 module,
2745 func_ctx,
2746 fun,
2747 compare_expr,
2748 value,
2749 &res_var_info,
2750 )?;
2751 }
2752 crate::AddressSpace::Storage { .. } => {
2753 let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2754 let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
2755 let width = match func_ctx.resolve_type(value, &module.types) {
2756 &TypeInner::Scalar(Scalar { width: 8, .. }) => "64",
2757 _ => "",
2758 };
2759 write!(self.out, "{var_name}.Interlocked{fun_str}{width}(")?;
2760 let chain = mem::take(&mut self.temp_access_chain);
2761 self.write_storage_address(module, &chain, func_ctx)?;
2762 self.temp_access_chain = chain;
2763 self.emit_hlsl_atomic_tail(
2764 module,
2765 func_ctx,
2766 fun,
2767 compare_expr,
2768 value,
2769 &res_var_info,
2770 )?;
2771 }
2772 ref other => {
2773 return Err(Error::Custom(format!(
2774 "invalid address space {other:?} for atomic statement"
2775 )))
2776 }
2777 }
2778 if let Some(cmp) = compare_expr {
2779 if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
2780 write!(
2781 self.out,
2782 "{level}{res_name}.exchanged = ({res_name}.old_value == "
2783 )?;
2784 self.write_expr(module, cmp, func_ctx)?;
2785 writeln!(self.out, ");")?;
2786 }
2787 }
2788 }
2789 Statement::ImageAtomic {
2790 image,
2791 coordinate,
2792 array_index,
2793 fun,
2794 value,
2795 } => {
2796 write!(self.out, "{level}")?;
2797
2798 let fun_str = fun.to_hlsl_suffix();
2799 write!(self.out, "Interlocked{fun_str}(")?;
2800 self.write_expr(module, image, func_ctx)?;
2801 write!(self.out, "[")?;
2802 self.write_texture_coordinates(
2803 "int",
2804 coordinate,
2805 array_index,
2806 None,
2807 module,
2808 func_ctx,
2809 )?;
2810 write!(self.out, "],")?;
2811
2812 self.write_expr(module, value, func_ctx)?;
2813 writeln!(self.out, ");")?;
2814 }
2815 Statement::WorkGroupUniformLoad { pointer, result } => {
2816 self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
2817 write!(self.out, "{level}")?;
2818 let name = Baked(result).to_string();
2819 self.write_named_expr(module, pointer, name, result, func_ctx)?;
2820
2821 self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
2822 }
2823 Statement::Switch {
2824 selector,
2825 ref cases,
2826 } => {
2827 self.write_switch(module, func_ctx, level, selector, cases)?;
2828 }
2829 Statement::RayQuery { query, ref fun } => {
2830 let crate::Expression::LocalVariable(query_var) = func_ctx.expressions[query]
2842 else {
2843 unreachable!()
2844 };
2845
2846 let tracker_expr_name = format!(
2847 "{RAY_QUERY_TRACKER_VARIABLE_PREFIX}{}",
2848 self.names[&func_ctx.name_key(query_var)]
2849 );
2850
2851 match *fun {
2852 RayQueryFunction::Initialize {
2853 acceleration_structure,
2854 descriptor,
2855 } => {
2856 self.write_initialize_function(
2857 module,
2858 level,
2859 query,
2860 acceleration_structure,
2861 descriptor,
2862 &tracker_expr_name,
2863 func_ctx,
2864 )?;
2865 }
2866 RayQueryFunction::Proceed { result } => {
2867 self.write_proceed(
2868 module,
2869 level,
2870 query,
2871 result,
2872 &tracker_expr_name,
2873 func_ctx,
2874 )?;
2875 }
2876 RayQueryFunction::GenerateIntersection { hit_t } => {
2877 self.write_generate_intersection(
2878 module,
2879 level,
2880 query,
2881 hit_t,
2882 &tracker_expr_name,
2883 func_ctx,
2884 )?;
2885 }
2886 RayQueryFunction::ConfirmIntersection => {
2887 self.write_confirm_intersection(
2888 module,
2889 level,
2890 query,
2891 &tracker_expr_name,
2892 func_ctx,
2893 )?;
2894 }
2895 RayQueryFunction::Terminate => {
2896 self.write_terminate(module, level, query, &tracker_expr_name, func_ctx)?;
2897 }
2898 }
2899 }
2900 Statement::SubgroupBallot { result, predicate } => {
2901 write!(self.out, "{level}")?;
2902 let name = Baked(result).to_string();
2903 write!(self.out, "const uint4 {name} = ")?;
2904 self.named_expressions.insert(result, name);
2905
2906 write!(self.out, "WaveActiveBallot(")?;
2907 match predicate {
2908 Some(predicate) => self.write_expr(module, predicate, func_ctx)?,
2909 None => write!(self.out, "true")?,
2910 }
2911 writeln!(self.out, ");")?;
2912 }
2913 Statement::SubgroupCollectiveOperation {
2914 op,
2915 collective_op,
2916 argument,
2917 result,
2918 } => {
2919 write!(self.out, "{level}")?;
2920 write!(self.out, "const ")?;
2921 let name = Baked(result).to_string();
2922 match func_ctx.info[result].ty {
2923 proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2924 proc::TypeResolution::Value(ref value) => {
2925 self.write_value_type(module, value)?
2926 }
2927 };
2928 write!(self.out, " {name} = ")?;
2929 self.named_expressions.insert(result, name);
2930
2931 match (collective_op, op) {
2932 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
2933 write!(self.out, "WaveActiveAllTrue(")?
2934 }
2935 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
2936 write!(self.out, "WaveActiveAnyTrue(")?
2937 }
2938 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
2939 write!(self.out, "WaveActiveSum(")?
2940 }
2941 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
2942 write!(self.out, "WaveActiveProduct(")?
2943 }
2944 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
2945 write!(self.out, "WaveActiveMax(")?
2946 }
2947 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
2948 write!(self.out, "WaveActiveMin(")?
2949 }
2950 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
2951 write!(self.out, "WaveActiveBitAnd(")?
2952 }
2953 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
2954 write!(self.out, "WaveActiveBitOr(")?
2955 }
2956 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
2957 write!(self.out, "WaveActiveBitXor(")?
2958 }
2959 (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
2960 write!(self.out, "WavePrefixSum(")?
2961 }
2962 (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
2963 write!(self.out, "WavePrefixProduct(")?
2964 }
2965 (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
2966 self.write_expr(module, argument, func_ctx)?;
2967 write!(self.out, " + WavePrefixSum(")?;
2968 }
2969 (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
2970 self.write_expr(module, argument, func_ctx)?;
2971 write!(self.out, " * WavePrefixProduct(")?;
2972 }
2973 _ => unimplemented!(),
2974 }
2975 self.write_expr(module, argument, func_ctx)?;
2976 writeln!(self.out, ");")?;
2977 }
2978 Statement::SubgroupGather {
2979 mode,
2980 argument,
2981 result,
2982 } => {
2983 write!(self.out, "{level}")?;
2984 write!(self.out, "const ")?;
2985 let name = Baked(result).to_string();
2986 match func_ctx.info[result].ty {
2987 proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2988 proc::TypeResolution::Value(ref value) => {
2989 self.write_value_type(module, value)?
2990 }
2991 };
2992 write!(self.out, " {name} = ")?;
2993 self.named_expressions.insert(result, name);
2994 match mode {
2995 crate::GatherMode::BroadcastFirst => {
2996 write!(self.out, "WaveReadLaneFirst(")?;
2997 self.write_expr(module, argument, func_ctx)?;
2998 }
2999 crate::GatherMode::QuadBroadcast(index) => {
3000 write!(self.out, "QuadReadLaneAt(")?;
3001 self.write_expr(module, argument, func_ctx)?;
3002 write!(self.out, ", ")?;
3003 self.write_expr(module, index, func_ctx)?;
3004 }
3005 crate::GatherMode::QuadSwap(direction) => {
3006 match direction {
3007 crate::Direction::X => {
3008 write!(self.out, "QuadReadAcrossX(")?;
3009 }
3010 crate::Direction::Y => {
3011 write!(self.out, "QuadReadAcrossY(")?;
3012 }
3013 crate::Direction::Diagonal => {
3014 write!(self.out, "QuadReadAcrossDiagonal(")?;
3015 }
3016 }
3017 self.write_expr(module, argument, func_ctx)?;
3018 }
3019 _ => {
3020 write!(self.out, "WaveReadLaneAt(")?;
3021 self.write_expr(module, argument, func_ctx)?;
3022 write!(self.out, ", ")?;
3023 match mode {
3024 crate::GatherMode::BroadcastFirst => unreachable!(),
3025 crate::GatherMode::Broadcast(index)
3026 | crate::GatherMode::Shuffle(index) => {
3027 self.write_expr(module, index, func_ctx)?;
3028 }
3029 crate::GatherMode::ShuffleDown(index) => {
3030 write!(self.out, "WaveGetLaneIndex() + ")?;
3031 self.write_expr(module, index, func_ctx)?;
3032 }
3033 crate::GatherMode::ShuffleUp(index) => {
3034 write!(self.out, "WaveGetLaneIndex() - ")?;
3035 self.write_expr(module, index, func_ctx)?;
3036 }
3037 crate::GatherMode::ShuffleXor(index) => {
3038 write!(self.out, "WaveGetLaneIndex() ^ ")?;
3039 self.write_expr(module, index, func_ctx)?;
3040 }
3041 crate::GatherMode::QuadBroadcast(_) => unreachable!(),
3042 crate::GatherMode::QuadSwap(_) => unreachable!(),
3043 }
3044 }
3045 }
3046 writeln!(self.out, ");")?;
3047 }
3048 Statement::CooperativeStore { .. } => unimplemented!(),
3049 Statement::RayPipelineFunction(_) => unreachable!(),
3050 }
3051
3052 Ok(())
3053 }
3054
3055 fn write_const_expression(
3056 &mut self,
3057 module: &Module,
3058 expr: Handle<crate::Expression>,
3059 arena: &crate::Arena<crate::Expression>,
3060 ) -> BackendResult {
3061 self.write_possibly_const_expression(module, expr, arena, |writer, expr| {
3062 writer.write_const_expression(module, expr, arena)
3063 })
3064 }
3065
3066 pub(super) fn write_literal(&mut self, literal: crate::Literal) -> BackendResult {
3067 match literal {
3068 crate::Literal::F64(value) => write!(self.out, "{value:?}L")?,
3069 crate::Literal::F32(value) => write!(self.out, "{value:?}")?,
3070 crate::Literal::F16(value) => write!(self.out, "{value:?}h")?,
3071 crate::Literal::U16(value) => write!(self.out, "uint16_t({value})")?,
3072 crate::Literal::I16(value) => write!(self.out, "int16_t({value})")?,
3073 crate::Literal::U32(value) => write!(self.out, "{value}u")?,
3074 crate::Literal::I32(value) if value == i32::MIN => {
3080 write!(self.out, "int({} - 1)", value + 1)?
3081 }
3082 crate::Literal::I32(value) => write!(self.out, "int({value})")?,
3086 crate::Literal::U64(value) => write!(self.out, "{value}uL")?,
3087 crate::Literal::I64(value) if value == i64::MIN => {
3089 write!(self.out, "({}L - 1L)", value + 1)?;
3090 }
3091 crate::Literal::I64(value) => write!(self.out, "{value}L")?,
3092 crate::Literal::Bool(value) => write!(self.out, "{value}")?,
3093 crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
3094 return Err(Error::Custom(
3095 "Abstract types should not appear in IR presented to backends".into(),
3096 ));
3097 }
3098 }
3099 Ok(())
3100 }
3101
3102 fn write_possibly_const_expression<E>(
3103 &mut self,
3104 module: &Module,
3105 expr: Handle<crate::Expression>,
3106 expressions: &crate::Arena<crate::Expression>,
3107 write_expression: E,
3108 ) -> BackendResult
3109 where
3110 E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult,
3111 {
3112 use crate::Expression;
3113
3114 match expressions[expr] {
3115 Expression::Literal(literal) => {
3116 self.write_literal(literal)?;
3117 }
3118 Expression::Constant(handle) => {
3119 let constant = &module.constants[handle];
3120 if constant.name.is_some() {
3121 write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
3122 } else {
3123 self.write_const_expression(module, constant.init, &module.global_expressions)?;
3124 }
3125 }
3126 Expression::ZeroValue(ty) => {
3127 self.write_wrapped_zero_value_function_name(module, WrappedZeroValue { ty })?;
3128 write!(self.out, "()")?;
3129 }
3130 Expression::Compose { ty, ref components } => {
3131 match module.types[ty].inner {
3132 TypeInner::Struct { .. } | TypeInner::Array { .. } => {
3133 self.write_wrapped_constructor_function_name(
3134 module,
3135 WrappedConstructor { ty },
3136 )?;
3137 }
3138 _ => {
3139 self.write_type(module, ty)?;
3140 }
3141 };
3142 write!(self.out, "(")?;
3143 for (index, component) in components.iter().enumerate() {
3144 if index != 0 {
3145 write!(self.out, ", ")?;
3146 }
3147 write_expression(self, *component)?;
3148 }
3149 write!(self.out, ")")?;
3150 }
3151 Expression::Splat { size, value } => {
3152 let number_of_components = match size {
3156 crate::VectorSize::Bi => "xx",
3157 crate::VectorSize::Tri => "xxx",
3158 crate::VectorSize::Quad => "xxxx",
3159 };
3160 write!(self.out, "(")?;
3161 write_expression(self, value)?;
3162 write!(self.out, ").{number_of_components}")?
3163 }
3164 _ => {
3165 return Err(Error::Override);
3166 }
3167 }
3168
3169 Ok(())
3170 }
3171
3172 pub(super) fn write_expr(
3177 &mut self,
3178 module: &Module,
3179 expr: Handle<crate::Expression>,
3180 func_ctx: &back::FunctionCtx<'_>,
3181 ) -> BackendResult {
3182 use crate::Expression;
3183
3184 let ff_input = if self.options.special_constants_binding.is_some() {
3186 func_ctx.is_fixed_function_input(expr, module)
3187 } else {
3188 None
3189 };
3190 let closing_bracket = match ff_input {
3191 Some(crate::BuiltIn::VertexIndex) => {
3192 write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX} + ")?;
3193 ")"
3194 }
3195 Some(crate::BuiltIn::InstanceIndex) => {
3196 write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE} + ",)?;
3197 ")"
3198 }
3199 Some(crate::BuiltIn::NumWorkGroups) => {
3200 write!(
3204 self.out,
3205 "uint3({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX}, {SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE}, {SPECIAL_CBUF_VAR}.{SPECIAL_OTHER})",
3206 )?;
3207 return Ok(());
3208 }
3209 _ => "",
3210 };
3211
3212 if let Some(name) = self.named_expressions.get(&expr) {
3213 write!(self.out, "{name}{closing_bracket}")?;
3214 return Ok(());
3215 }
3216
3217 let expression = &func_ctx.expressions[expr];
3218
3219 match *expression {
3220 Expression::Literal(_)
3221 | Expression::Constant(_)
3222 | Expression::ZeroValue(_)
3223 | Expression::Compose { .. }
3224 | Expression::Splat { .. } => {
3225 self.write_possibly_const_expression(
3226 module,
3227 expr,
3228 func_ctx.expressions,
3229 |writer, expr| writer.write_expr(module, expr, func_ctx),
3230 )?;
3231 }
3232 Expression::Override(_) => return Err(Error::Override),
3233 Expression::Binary {
3240 op:
3241 op @ crate::BinaryOperator::Add
3242 | op @ crate::BinaryOperator::Subtract
3243 | op @ crate::BinaryOperator::Multiply,
3244 left,
3245 right,
3246 } if matches!(
3247 func_ctx.resolve_type(expr, &module.types).scalar(),
3248 Some(Scalar::I32)
3249 ) =>
3250 {
3251 write!(self.out, "asint(asuint(",)?;
3252 self.write_expr(module, left, func_ctx)?;
3253 write!(self.out, ") {} asuint(", back::binary_operation_str(op))?;
3254 self.write_expr(module, right, func_ctx)?;
3255 write!(self.out, "))")?;
3256 }
3257 Expression::Binary {
3260 op: crate::BinaryOperator::Multiply,
3261 left,
3262 right,
3263 } if func_ctx.resolve_type(left, &module.types).is_matrix()
3264 || func_ctx.resolve_type(right, &module.types).is_matrix() =>
3265 {
3266 write!(self.out, "mul(")?;
3268 self.write_expr(module, right, func_ctx)?;
3269 write!(self.out, ", ")?;
3270 self.write_expr(module, left, func_ctx)?;
3271 write!(self.out, ")")?;
3272 }
3273
3274 Expression::Binary {
3286 op: crate::BinaryOperator::Divide,
3287 left,
3288 right,
3289 } if matches!(
3290 func_ctx.resolve_type(expr, &module.types).scalar_kind(),
3291 Some(ScalarKind::Sint | ScalarKind::Uint)
3292 ) =>
3293 {
3294 write!(self.out, "{DIV_FUNCTION}(")?;
3295 self.write_expr(module, left, func_ctx)?;
3296 write!(self.out, ", ")?;
3297 self.write_expr(module, right, func_ctx)?;
3298 write!(self.out, ")")?;
3299 }
3300
3301 Expression::Binary {
3302 op: crate::BinaryOperator::Modulo,
3303 left,
3304 right,
3305 } if matches!(
3306 func_ctx.resolve_type(expr, &module.types).scalar_kind(),
3307 Some(ScalarKind::Sint | ScalarKind::Uint | ScalarKind::Float)
3308 ) =>
3309 {
3310 write!(self.out, "{MOD_FUNCTION}(")?;
3311 self.write_expr(module, left, func_ctx)?;
3312 write!(self.out, ", ")?;
3313 self.write_expr(module, right, func_ctx)?;
3314 write!(self.out, ")")?;
3315 }
3316
3317 Expression::Binary { op, left, right } => {
3318 write!(self.out, "(")?;
3319 self.write_expr(module, left, func_ctx)?;
3320 write!(self.out, " {} ", back::binary_operation_str(op))?;
3321 self.write_expr(module, right, func_ctx)?;
3322 write!(self.out, ")")?;
3323 }
3324 Expression::Access { base, index } => {
3325 if let Some(crate::AddressSpace::Storage { .. }) =
3326 func_ctx.resolve_type(expr, &module.types).pointer_space()
3327 {
3328 } else {
3330 if let Some(MatrixType {
3337 columns,
3338 rows: crate::VectorSize::Bi,
3339 width: 4,
3340 }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
3341 .or_else(|| get_global_uniform_matrix(module, base, func_ctx))
3342 {
3343 write!(self.out, "__get_col_of_mat{}x2(", columns as u8)?;
3344 self.write_expr(module, base, func_ctx)?;
3345 write!(self.out, ", ")?;
3346 self.write_expr(module, index, func_ctx)?;
3347 write!(self.out, ")")?;
3348 return Ok(());
3349 }
3350
3351 let resolved = func_ctx.resolve_type(base, &module.types);
3352
3353 let (indexing_binding_array, non_uniform_qualifier) = match *resolved {
3354 TypeInner::BindingArray { .. } => {
3355 let uniformity = &func_ctx.info[index].uniformity;
3356
3357 (true, uniformity.non_uniform_result.is_some())
3358 }
3359 _ => (false, false),
3360 };
3361
3362 self.write_expr(module, base, func_ctx)?;
3363
3364 let array_sampler_info = self.sampler_binding_array_info_from_expression(
3365 module, func_ctx, base, resolved,
3366 );
3367
3368 if let Some(ref info) = array_sampler_info {
3369 write!(self.out, "{}[", info.sampler_heap_name)?;
3370 } else {
3371 write!(self.out, "[")?;
3372 }
3373
3374 let needs_bound_check = self.options.restrict_indexing
3375 && !indexing_binding_array
3376 && match resolved.pointer_space() {
3377 Some(
3378 crate::AddressSpace::Function
3379 | crate::AddressSpace::Private
3380 | crate::AddressSpace::WorkGroup
3381 | crate::AddressSpace::Immediate
3382 | crate::AddressSpace::TaskPayload
3383 | crate::AddressSpace::RayPayload
3384 | crate::AddressSpace::IncomingRayPayload,
3385 )
3386 | None => true,
3387 Some(crate::AddressSpace::Uniform) => {
3388 let var_handle = self.fill_access_chain(module, base, func_ctx)?;
3390 let bind_target = self
3391 .options
3392 .resolve_resource_binding(
3393 module.global_variables[var_handle]
3394 .binding
3395 .as_ref()
3396 .unwrap(),
3397 )
3398 .unwrap();
3399 bind_target.restrict_indexing
3400 }
3401 Some(
3402 crate::AddressSpace::Handle | crate::AddressSpace::Storage { .. },
3403 ) => unreachable!(),
3404 };
3405 let restriction_needed = if needs_bound_check {
3407 index::access_needs_check(
3408 base,
3409 index::GuardedIndex::Expression(index),
3410 module,
3411 func_ctx.expressions,
3412 func_ctx.info,
3413 )
3414 } else {
3415 None
3416 };
3417 if let Some(limit) = restriction_needed {
3418 write!(self.out, "min(uint(")?;
3419 self.write_expr(module, index, func_ctx)?;
3420 write!(self.out, "), ")?;
3421 match limit {
3422 index::IndexableLength::Known(limit) => {
3423 write!(self.out, "{}u", limit - 1)?;
3424 }
3425 index::IndexableLength::Dynamic => unreachable!(),
3426 }
3427 write!(self.out, ")")?;
3428 } else {
3429 if non_uniform_qualifier {
3430 write!(self.out, "NonUniformResourceIndex(")?;
3431 }
3432 if let Some(ref info) = array_sampler_info {
3433 write!(
3434 self.out,
3435 "{}[{} + ",
3436 info.sampler_index_buffer_name, info.binding_array_base_index_name,
3437 )?;
3438 }
3439 self.write_expr(module, index, func_ctx)?;
3440 if array_sampler_info.is_some() {
3441 write!(self.out, "]")?;
3442 }
3443 if non_uniform_qualifier {
3444 write!(self.out, ")")?;
3445 }
3446 }
3447
3448 write!(self.out, "]")?;
3449 }
3450 }
3451 Expression::AccessIndex { base, index } => {
3452 if let Some(crate::AddressSpace::Storage { .. }) =
3453 func_ctx.resolve_type(expr, &module.types).pointer_space()
3454 {
3455 } else {
3457 if let Some(MatrixType {
3461 rows: crate::VectorSize::Bi,
3462 width: 4,
3463 ..
3464 }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
3465 .or_else(|| get_global_uniform_matrix(module, base, func_ctx))
3466 {
3467 self.write_expr(module, base, func_ctx)?;
3468 write!(self.out, "._{index}")?;
3469 return Ok(());
3470 }
3471
3472 let base_ty_res = &func_ctx.info[base].ty;
3473 let mut resolved = base_ty_res.inner_with(&module.types);
3474 let base_ty_handle = match *resolved {
3475 TypeInner::Pointer { base, .. } => {
3476 resolved = &module.types[base].inner;
3477 Some(base)
3478 }
3479 _ => base_ty_res.handle(),
3480 };
3481
3482 if let TypeInner::Struct { ref members, .. } = *resolved {
3488 let member = &members[index as usize];
3489
3490 match module.types[member.ty].inner {
3491 TypeInner::Matrix {
3492 rows: crate::VectorSize::Bi,
3493 ..
3494 } if member.binding.is_none() => {
3495 let ty = base_ty_handle.unwrap();
3496 self.write_wrapped_struct_matrix_get_function_name(
3497 WrappedStructMatrixAccess { ty, index },
3498 )?;
3499 write!(self.out, "(")?;
3500 self.write_expr(module, base, func_ctx)?;
3501 write!(self.out, ")")?;
3502 return Ok(());
3503 }
3504 _ => {}
3505 }
3506 }
3507
3508 let array_sampler_info = self.sampler_binding_array_info_from_expression(
3509 module, func_ctx, base, resolved,
3510 );
3511
3512 if let Some(ref info) = array_sampler_info {
3513 write!(
3514 self.out,
3515 "{}[{}",
3516 info.sampler_heap_name, info.sampler_index_buffer_name
3517 )?;
3518 }
3519
3520 self.write_expr(module, base, func_ctx)?;
3521
3522 match *resolved {
3523 TypeInner::Vector { .. } | TypeInner::ValuePointer { .. } => {
3529 write!(self.out, ".{}", back::COMPONENTS[index as usize])?
3531 }
3532 TypeInner::Matrix { .. }
3533 | TypeInner::Array { .. }
3534 | TypeInner::BindingArray { .. } => {
3535 if let Some(ref info) = array_sampler_info {
3536 write!(
3537 self.out,
3538 "[{} + {index}]",
3539 info.binding_array_base_index_name
3540 )?;
3541 } else {
3542 write!(self.out, "[{index}]")?;
3543 }
3544 }
3545 TypeInner::Struct { .. } => {
3546 let ty = base_ty_handle.unwrap();
3549
3550 write!(
3551 self.out,
3552 ".{}",
3553 &self.names[&NameKey::StructMember(ty, index)]
3554 )?
3555 }
3556 ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))),
3557 }
3558
3559 if array_sampler_info.is_some() {
3560 write!(self.out, "]")?;
3561 }
3562 }
3563 }
3564 Expression::FunctionArgument(pos) => {
3565 let ty = func_ctx.resolve_type(expr, &module.types);
3566
3567 if let TypeInner::Image {
3573 class: crate::ImageClass::External,
3574 ..
3575 } = *ty
3576 {
3577 let plane_names = [0, 1, 2].map(|i| {
3578 &self.names[&func_ctx
3579 .external_texture_argument_key(pos, ExternalTextureNameKey::Plane(i))]
3580 });
3581 let params_name = &self.names[&func_ctx
3582 .external_texture_argument_key(pos, ExternalTextureNameKey::Params)];
3583 write!(
3584 self.out,
3585 "{}, {}, {}, {}",
3586 plane_names[0], plane_names[1], plane_names[2], params_name
3587 )?;
3588 } else {
3589 let key = func_ctx.argument_key(pos);
3590 let name = &self.names[&key];
3591 write!(self.out, "{name}")?;
3592 }
3593 }
3594 Expression::ImageSample {
3595 coordinate,
3596 image,
3597 sampler,
3598 clamp_to_edge: true,
3599 gather: None,
3600 array_index: None,
3601 offset: None,
3602 level: crate::SampleLevel::Zero,
3603 depth_ref: None,
3604 } => {
3605 write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?;
3606 self.write_expr(module, image, func_ctx)?;
3607 write!(self.out, ", ")?;
3608 self.write_expr(module, sampler, func_ctx)?;
3609 write!(self.out, ", ")?;
3610 self.write_expr(module, coordinate, func_ctx)?;
3611 write!(self.out, ")")?;
3612 }
3613 Expression::ImageSample {
3614 image,
3615 sampler,
3616 gather,
3617 coordinate,
3618 array_index,
3619 offset,
3620 level,
3621 depth_ref,
3622 clamp_to_edge,
3623 } => {
3624 if clamp_to_edge {
3625 return Err(Error::Custom(
3626 "ImageSample::clamp_to_edge should have been validated out".to_string(),
3627 ));
3628 }
3629
3630 use crate::SampleLevel as Sl;
3631 const COMPONENTS: [&str; 4] = ["", "Green", "Blue", "Alpha"];
3632
3633 let (base_str, component_str) = match gather {
3634 Some(component) => ("Gather", COMPONENTS[component as usize]),
3635 None => ("Sample", ""),
3636 };
3637 let cmp_str = match depth_ref {
3638 Some(_) => "Cmp",
3639 None => "",
3640 };
3641 let level_str = match level {
3642 Sl::Zero if gather.is_none() => "LevelZero",
3643 Sl::Auto | Sl::Zero => "",
3644 Sl::Exact(_) => "Level",
3645 Sl::Bias(_) => "Bias",
3646 Sl::Gradient { .. } => "Grad",
3647 };
3648
3649 self.write_expr(module, image, func_ctx)?;
3650 write!(self.out, ".{base_str}{cmp_str}{component_str}{level_str}(")?;
3651 self.write_expr(module, sampler, func_ctx)?;
3652 write!(self.out, ", ")?;
3653 self.write_texture_coordinates(
3654 "float",
3655 coordinate,
3656 array_index,
3657 None,
3658 module,
3659 func_ctx,
3660 )?;
3661
3662 if let Some(depth_ref) = depth_ref {
3663 write!(self.out, ", ")?;
3664 self.write_expr(module, depth_ref, func_ctx)?;
3665 }
3666
3667 match level {
3668 Sl::Auto | Sl::Zero => {}
3669 Sl::Exact(expr) => {
3670 write!(self.out, ", ")?;
3671 self.write_expr(module, expr, func_ctx)?;
3672 }
3673 Sl::Bias(expr) => {
3674 write!(self.out, ", ")?;
3675 self.write_expr(module, expr, func_ctx)?;
3676 }
3677 Sl::Gradient { x, y } => {
3678 write!(self.out, ", ")?;
3679 self.write_expr(module, x, func_ctx)?;
3680 write!(self.out, ", ")?;
3681 self.write_expr(module, y, func_ctx)?;
3682 }
3683 }
3684
3685 if let Some(offset) = offset {
3686 write!(self.out, ", ")?;
3687 write!(self.out, "int2(")?; self.write_const_expression(module, offset, func_ctx.expressions)?;
3689 write!(self.out, ")")?;
3690 }
3691
3692 write!(self.out, ")")?;
3693 }
3694 Expression::ImageQuery { image, query } => {
3695 if let TypeInner::Image {
3697 dim,
3698 arrayed,
3699 class,
3700 } = *func_ctx.resolve_type(image, &module.types)
3701 {
3702 let wrapped_image_query = WrappedImageQuery {
3703 dim,
3704 arrayed,
3705 class,
3706 query: query.into(),
3707 };
3708
3709 self.write_wrapped_image_query_function_name(wrapped_image_query)?;
3710 write!(self.out, "(")?;
3711 self.write_expr(module, image, func_ctx)?;
3713 if let crate::ImageQuery::Size { level: Some(level) } = query {
3714 write!(self.out, ", ")?;
3715 self.write_expr(module, level, func_ctx)?;
3716 }
3717 write!(self.out, ")")?;
3718 }
3719 }
3720 Expression::ImageLoad {
3721 image,
3722 coordinate,
3723 array_index,
3724 sample,
3725 level,
3726 } => self.write_image_load(
3727 &module,
3728 expr,
3729 func_ctx,
3730 image,
3731 coordinate,
3732 array_index,
3733 sample,
3734 level,
3735 )?,
3736 Expression::GlobalVariable(handle) => {
3737 let global_variable = &module.global_variables[handle];
3738 let ty = &module.types[global_variable.ty].inner;
3739
3740 let is_binding_array_of_samplers = match *ty {
3745 TypeInner::BindingArray { base, .. } => {
3746 let base_ty = &module.types[base].inner;
3747 matches!(*base_ty, TypeInner::Sampler { .. })
3748 }
3749 _ => false,
3750 };
3751
3752 let is_storage_space =
3753 matches!(global_variable.space, crate::AddressSpace::Storage { .. });
3754
3755 if let TypeInner::Image {
3763 class: crate::ImageClass::External,
3764 ..
3765 } = *ty
3766 {
3767 let plane_names = [0, 1, 2].map(|i| {
3768 &self.names[&NameKey::ExternalTextureGlobalVariable(
3769 handle,
3770 ExternalTextureNameKey::Plane(i),
3771 )]
3772 });
3773 let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
3774 handle,
3775 ExternalTextureNameKey::Params,
3776 )];
3777 write!(
3778 self.out,
3779 "{}, {}, {}, {}",
3780 plane_names[0], plane_names[1], plane_names[2], params_name
3781 )?;
3782 } else if !is_binding_array_of_samplers && !is_storage_space {
3783 let name = &self.names[&NameKey::GlobalVariable(handle)];
3784 write!(self.out, "{name}")?;
3785 }
3786 }
3787 Expression::LocalVariable(handle) => {
3788 write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])?
3789 }
3790 Expression::Load { pointer } => {
3791 match func_ctx
3792 .resolve_type(pointer, &module.types)
3793 .pointer_space()
3794 {
3795 Some(crate::AddressSpace::Storage { .. }) => {
3796 let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
3797 let result_ty = func_ctx.info[expr].ty.clone();
3798 self.write_storage_load(module, var_handle, result_ty, func_ctx)?;
3799 }
3800 _ => {
3801 let mut close_paren = false;
3802
3803 if let Some(MatrixType {
3808 rows: crate::VectorSize::Bi,
3809 width: 4,
3810 ..
3811 }) = get_inner_matrix_of_struct_array_member(
3812 module, pointer, func_ctx, false,
3813 )
3814 .or_else(|| get_inner_matrix_of_global_uniform(module, pointer, func_ctx))
3815 {
3816 let mut resolved = func_ctx.resolve_type(pointer, &module.types);
3817 let ptr_tr = resolved.pointer_base_type();
3818 if let Some(ptr_ty) =
3819 ptr_tr.as_ref().map(|tr| tr.inner_with(&module.types))
3820 {
3821 resolved = ptr_ty;
3822 }
3823
3824 write!(self.out, "((")?;
3825 if let TypeInner::Array { base, size, .. } = *resolved {
3826 self.write_type(module, base)?;
3827 self.write_array_size(module, base, size)?;
3828 } else {
3829 self.write_value_type(module, resolved)?;
3830 }
3831 write!(self.out, ")")?;
3832 close_paren = true;
3833 }
3834
3835 self.write_expr(module, pointer, func_ctx)?;
3836
3837 if close_paren {
3838 write!(self.out, ")")?;
3839 }
3840 }
3841 }
3842 }
3843 Expression::Unary { op, expr } => {
3844 let op_str = match op {
3846 crate::UnaryOperator::Negate => {
3847 match func_ctx.resolve_type(expr, &module.types).scalar() {
3848 Some(Scalar::I32) => NEG_FUNCTION,
3849 _ => "-",
3850 }
3851 }
3852 crate::UnaryOperator::LogicalNot => "!",
3853 crate::UnaryOperator::BitwiseNot => "~",
3854 };
3855 write!(self.out, "{op_str}(")?;
3856 self.write_expr(module, expr, func_ctx)?;
3857 write!(self.out, ")")?;
3858 }
3859 Expression::As {
3860 expr,
3861 kind,
3862 convert,
3863 } => {
3864 let inner = func_ctx.resolve_type(expr, &module.types);
3865 if inner.scalar_kind() == Some(ScalarKind::Float)
3866 && (kind == ScalarKind::Sint || kind == ScalarKind::Uint)
3867 && convert.is_some()
3868 && matches!(convert, Some(4) | Some(8))
3869 {
3870 let fun_name = match (kind, convert) {
3874 (ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
3875 (ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
3876 (ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
3877 (ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
3878 _ => unreachable!(),
3879 };
3880 write!(self.out, "{fun_name}(")?;
3881 self.write_expr(module, expr, func_ctx)?;
3882 write!(self.out, ")")?;
3883 } else {
3884 let close_paren = match convert {
3885 Some(dst_width) => {
3886 let scalar = Scalar {
3887 kind,
3888 width: dst_width,
3889 };
3890 match *inner {
3891 TypeInner::Vector { size, .. } => {
3892 write!(
3893 self.out,
3894 "{}{}(",
3895 scalar.to_hlsl_str()?,
3896 common::vector_size_str(size)
3897 )?;
3898 }
3899 TypeInner::Scalar(_) => {
3900 write!(self.out, "{}(", scalar.to_hlsl_str()?,)?;
3901 }
3902 TypeInner::Matrix { columns, rows, .. } => {
3903 write!(
3904 self.out,
3905 "{}{}x{}(",
3906 scalar.to_hlsl_str()?,
3907 common::vector_size_str(columns),
3908 common::vector_size_str(rows)
3909 )?;
3910 }
3911 _ => {
3912 return Err(Error::Unimplemented(format!(
3913 "write_expr expression::as {inner:?}"
3914 )));
3915 }
3916 };
3917 true
3918 }
3919 None => {
3920 if inner.scalar_width() == Some(8) {
3921 false
3922 } else if inner.scalar_width() == Some(2) {
3923 let dst_scalar = Scalar { kind, width: 2 };
3926 match *inner {
3927 TypeInner::Vector { size, .. } => {
3928 write!(
3929 self.out,
3930 "{}{}(",
3931 dst_scalar.to_hlsl_str()?,
3932 common::vector_size_str(size)
3933 )?;
3934 }
3935 _ => {
3936 write!(self.out, "{}(", dst_scalar.to_hlsl_str()?)?;
3937 }
3938 };
3939 true
3940 } else {
3941 write!(self.out, "{}(", kind.to_hlsl_cast(),)?;
3942 true
3943 }
3944 }
3945 };
3946 self.write_expr(module, expr, func_ctx)?;
3947 if close_paren {
3948 write!(self.out, ")")?;
3949 }
3950 }
3951 }
3952 Expression::Math {
3953 fun,
3954 arg,
3955 arg1,
3956 arg2,
3957 arg3,
3958 } => {
3959 use crate::MathFunction as Mf;
3960
3961 enum Function {
3962 Asincosh { is_sin: bool },
3963 Atanh,
3964 Pack2x16float,
3965 Pack2x16snorm,
3966 Pack2x16unorm,
3967 Pack4x8snorm,
3968 Pack4x8unorm,
3969 Pack4xI8,
3970 Pack4xU8,
3971 Pack4xI8Clamp,
3972 Pack4xU8Clamp,
3973 Unpack2x16float,
3974 Unpack2x16snorm,
3975 Unpack2x16unorm,
3976 Unpack4x8snorm,
3977 Unpack4x8unorm,
3978 Unpack4xI8,
3979 Unpack4xU8,
3980 Dot4I8Packed,
3981 Dot4U8Packed,
3982 QuantizeToF16,
3983 Regular(&'static str),
3984 MissingIntOverload(&'static str),
3985 MissingIntReturnType(&'static str),
3986 CountTrailingZeros,
3987 CountLeadingZeros,
3988 }
3989
3990 let fun = match fun {
3991 Mf::Abs => match func_ctx.resolve_type(arg, &module.types).scalar() {
3993 Some(Scalar::I32) => Function::Regular(ABS_FUNCTION),
3994 _ => Function::Regular("abs"),
3995 },
3996 Mf::Min => Function::Regular("min"),
3997 Mf::Max => Function::Regular("max"),
3998 Mf::Clamp => Function::Regular("clamp"),
3999 Mf::Saturate => Function::Regular("saturate"),
4000 Mf::Cos => Function::Regular("cos"),
4002 Mf::Cosh => Function::Regular("cosh"),
4003 Mf::Sin => Function::Regular("sin"),
4004 Mf::Sinh => Function::Regular("sinh"),
4005 Mf::Tan => Function::Regular("tan"),
4006 Mf::Tanh => Function::Regular("tanh"),
4007 Mf::Acos => Function::Regular("acos"),
4008 Mf::Asin => Function::Regular("asin"),
4009 Mf::Atan => Function::Regular("atan"),
4010 Mf::Atan2 => Function::Regular("atan2"),
4011 Mf::Asinh => Function::Asincosh { is_sin: true },
4012 Mf::Acosh => Function::Asincosh { is_sin: false },
4013 Mf::Atanh => Function::Atanh,
4014 Mf::Radians => Function::Regular("radians"),
4015 Mf::Degrees => Function::Regular("degrees"),
4016 Mf::Ceil => Function::Regular("ceil"),
4018 Mf::Floor => Function::Regular("floor"),
4019 Mf::Round => Function::Regular("round"),
4020 Mf::Fract => Function::Regular("frac"),
4021 Mf::Trunc => Function::Regular("trunc"),
4022 Mf::Modf => Function::Regular(MODF_FUNCTION),
4023 Mf::Frexp => Function::Regular(FREXP_FUNCTION),
4024 Mf::Ldexp => Function::Regular("ldexp"),
4025 Mf::Exp => Function::Regular("exp"),
4027 Mf::Exp2 => Function::Regular("exp2"),
4028 Mf::Log => Function::Regular("log"),
4029 Mf::Log2 => Function::Regular("log2"),
4030 Mf::Pow => Function::Regular("pow"),
4031 Mf::Dot => Function::Regular("dot"),
4033 Mf::Dot4I8Packed => Function::Dot4I8Packed,
4034 Mf::Dot4U8Packed => Function::Dot4U8Packed,
4035 Mf::Cross => Function::Regular("cross"),
4037 Mf::Distance => Function::Regular("distance"),
4038 Mf::Length => Function::Regular("length"),
4039 Mf::Normalize => Function::Regular("normalize"),
4040 Mf::FaceForward => Function::Regular("faceforward"),
4041 Mf::Reflect => Function::Regular("reflect"),
4042 Mf::Refract => Function::Regular("refract"),
4043 Mf::Sign => Function::Regular("sign"),
4045 Mf::Fma => Function::Regular("mad"),
4046 Mf::Mix => Function::Regular("lerp"),
4047 Mf::Step => Function::Regular("step"),
4048 Mf::SmoothStep => Function::Regular("smoothstep"),
4049 Mf::Sqrt => Function::Regular("sqrt"),
4050 Mf::InverseSqrt => Function::Regular("rsqrt"),
4051 Mf::Transpose => Function::Regular("transpose"),
4053 Mf::Determinant => Function::Regular("determinant"),
4054 Mf::QuantizeToF16 => Function::QuantizeToF16,
4055 Mf::CountTrailingZeros => Function::CountTrailingZeros,
4057 Mf::CountLeadingZeros => Function::CountLeadingZeros,
4058 Mf::CountOneBits => Function::MissingIntOverload("countbits"),
4059 Mf::ReverseBits => Function::MissingIntOverload("reversebits"),
4060 Mf::FirstTrailingBit => Function::MissingIntReturnType("firstbitlow"),
4061 Mf::FirstLeadingBit => Function::MissingIntReturnType("firstbithigh"),
4062 Mf::ExtractBits => Function::Regular(EXTRACT_BITS_FUNCTION),
4063 Mf::InsertBits => Function::Regular(INSERT_BITS_FUNCTION),
4064 Mf::Pack2x16float => Function::Pack2x16float,
4066 Mf::Pack2x16snorm => Function::Pack2x16snorm,
4067 Mf::Pack2x16unorm => Function::Pack2x16unorm,
4068 Mf::Pack4x8snorm => Function::Pack4x8snorm,
4069 Mf::Pack4x8unorm => Function::Pack4x8unorm,
4070 Mf::Pack4xI8 => Function::Pack4xI8,
4071 Mf::Pack4xU8 => Function::Pack4xU8,
4072 Mf::Pack4xI8Clamp => Function::Pack4xI8Clamp,
4073 Mf::Pack4xU8Clamp => Function::Pack4xU8Clamp,
4074 Mf::Unpack2x16float => Function::Unpack2x16float,
4076 Mf::Unpack2x16snorm => Function::Unpack2x16snorm,
4077 Mf::Unpack2x16unorm => Function::Unpack2x16unorm,
4078 Mf::Unpack4x8snorm => Function::Unpack4x8snorm,
4079 Mf::Unpack4x8unorm => Function::Unpack4x8unorm,
4080 Mf::Unpack4xI8 => Function::Unpack4xI8,
4081 Mf::Unpack4xU8 => Function::Unpack4xU8,
4082 _ => return Err(Error::Unimplemented(format!("write_expr_math {fun:?}"))),
4083 };
4084
4085 match fun {
4086 Function::Asincosh { is_sin } => {
4087 write!(self.out, "log(")?;
4088 self.write_expr(module, arg, func_ctx)?;
4089 write!(self.out, " + sqrt(")?;
4090 self.write_expr(module, arg, func_ctx)?;
4091 write!(self.out, " * ")?;
4092 self.write_expr(module, arg, func_ctx)?;
4093 match is_sin {
4094 true => write!(self.out, " + 1.0))")?,
4095 false => write!(self.out, " - 1.0))")?,
4096 }
4097 }
4098 Function::Atanh => {
4099 write!(self.out, "0.5 * log((1.0 + ")?;
4100 self.write_expr(module, arg, func_ctx)?;
4101 write!(self.out, ") / (1.0 - ")?;
4102 self.write_expr(module, arg, func_ctx)?;
4103 write!(self.out, "))")?;
4104 }
4105 Function::Pack2x16float => {
4106 write!(self.out, "(f32tof16(")?;
4107 self.write_expr(module, arg, func_ctx)?;
4108 write!(self.out, "[0]) | f32tof16(")?;
4109 self.write_expr(module, arg, func_ctx)?;
4110 write!(self.out, "[1]) << 16)")?;
4111 }
4112 Function::Pack2x16snorm => {
4113 let scale = 32767;
4114
4115 write!(self.out, "uint((int(round(clamp(")?;
4116 self.write_expr(module, arg, func_ctx)?;
4117 write!(
4118 self.out,
4119 "[0], -1.0, 1.0) * {scale}.0)) & 0xFFFF) | ((int(round(clamp("
4120 )?;
4121 self.write_expr(module, arg, func_ctx)?;
4122 write!(self.out, "[1], -1.0, 1.0) * {scale}.0)) & 0xFFFF) << 16))",)?;
4123 }
4124 Function::Pack2x16unorm => {
4125 let scale = 65535;
4126
4127 write!(self.out, "(uint(round(clamp(")?;
4128 self.write_expr(module, arg, func_ctx)?;
4129 write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
4130 self.write_expr(module, arg, func_ctx)?;
4131 write!(self.out, "[1], 0.0, 1.0) * {scale}.0)) << 16)")?;
4132 }
4133 Function::Pack4x8snorm => {
4134 let scale = 127;
4135
4136 write!(self.out, "uint((int(round(clamp(")?;
4137 self.write_expr(module, arg, func_ctx)?;
4138 write!(
4139 self.out,
4140 "[0], -1.0, 1.0) * {scale}.0)) & 0xFF) | ((int(round(clamp("
4141 )?;
4142 self.write_expr(module, arg, func_ctx)?;
4143 write!(
4144 self.out,
4145 "[1], -1.0, 1.0) * {scale}.0)) & 0xFF) << 8) | ((int(round(clamp("
4146 )?;
4147 self.write_expr(module, arg, func_ctx)?;
4148 write!(
4149 self.out,
4150 "[2], -1.0, 1.0) * {scale}.0)) & 0xFF) << 16) | ((int(round(clamp("
4151 )?;
4152 self.write_expr(module, arg, func_ctx)?;
4153 write!(self.out, "[3], -1.0, 1.0) * {scale}.0)) & 0xFF) << 24))",)?;
4154 }
4155 Function::Pack4x8unorm => {
4156 let scale = 255;
4157
4158 write!(self.out, "(uint(round(clamp(")?;
4159 self.write_expr(module, arg, func_ctx)?;
4160 write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
4161 self.write_expr(module, arg, func_ctx)?;
4162 write!(
4163 self.out,
4164 "[1], 0.0, 1.0) * {scale}.0)) << 8 | uint(round(clamp("
4165 )?;
4166 self.write_expr(module, arg, func_ctx)?;
4167 write!(
4168 self.out,
4169 "[2], 0.0, 1.0) * {scale}.0)) << 16 | uint(round(clamp("
4170 )?;
4171 self.write_expr(module, arg, func_ctx)?;
4172 write!(self.out, "[3], 0.0, 1.0) * {scale}.0)) << 24)")?;
4173 }
4174 fun @ (Function::Pack4xI8
4175 | Function::Pack4xU8
4176 | Function::Pack4xI8Clamp
4177 | Function::Pack4xU8Clamp) => {
4178 let was_signed =
4179 matches!(fun, Function::Pack4xI8 | Function::Pack4xI8Clamp);
4180 let clamp_bounds = match fun {
4181 Function::Pack4xI8Clamp => Some(("-128", "127")),
4182 Function::Pack4xU8Clamp => Some(("0", "255")),
4183 _ => None,
4184 };
4185 if was_signed {
4186 write!(self.out, "uint(")?;
4187 }
4188 let write_arg = |this: &mut Self| -> BackendResult {
4189 if let Some((min, max)) = clamp_bounds {
4190 write!(this.out, "clamp(")?;
4191 this.write_expr(module, arg, func_ctx)?;
4192 write!(this.out, ", {min}, {max})")?;
4193 } else {
4194 this.write_expr(module, arg, func_ctx)?;
4195 }
4196 Ok(())
4197 };
4198 write!(self.out, "(")?;
4199 write_arg(self)?;
4200 write!(self.out, "[0] & 0xFF) | ((")?;
4201 write_arg(self)?;
4202 write!(self.out, "[1] & 0xFF) << 8) | ((")?;
4203 write_arg(self)?;
4204 write!(self.out, "[2] & 0xFF) << 16) | ((")?;
4205 write_arg(self)?;
4206 write!(self.out, "[3] & 0xFF) << 24)")?;
4207 if was_signed {
4208 write!(self.out, ")")?;
4209 }
4210 }
4211
4212 Function::Unpack2x16float => {
4213 write!(self.out, "float2(f16tof32(")?;
4214 self.write_expr(module, arg, func_ctx)?;
4215 write!(self.out, "), f16tof32((")?;
4216 self.write_expr(module, arg, func_ctx)?;
4217 write!(self.out, ") >> 16))")?;
4218 }
4219 Function::Unpack2x16snorm => {
4220 let scale = 32767;
4221
4222 write!(self.out, "(float2(int2(")?;
4223 self.write_expr(module, arg, func_ctx)?;
4224 write!(self.out, " << 16, ")?;
4225 self.write_expr(module, arg, func_ctx)?;
4226 write!(self.out, ") >> 16) / {scale}.0)")?;
4227 }
4228 Function::Unpack2x16unorm => {
4229 let scale = 65535;
4230
4231 write!(self.out, "(float2(")?;
4232 self.write_expr(module, arg, func_ctx)?;
4233 write!(self.out, " & 0xFFFF, ")?;
4234 self.write_expr(module, arg, func_ctx)?;
4235 write!(self.out, " >> 16) / {scale}.0)")?;
4236 }
4237 Function::Unpack4x8snorm => {
4238 let scale = 127;
4239
4240 write!(self.out, "(float4(int4(")?;
4241 self.write_expr(module, arg, func_ctx)?;
4242 write!(self.out, " << 24, ")?;
4243 self.write_expr(module, arg, func_ctx)?;
4244 write!(self.out, " << 16, ")?;
4245 self.write_expr(module, arg, func_ctx)?;
4246 write!(self.out, " << 8, ")?;
4247 self.write_expr(module, arg, func_ctx)?;
4248 write!(self.out, ") >> 24) / {scale}.0)")?;
4249 }
4250 Function::Unpack4x8unorm => {
4251 let scale = 255;
4252
4253 write!(self.out, "(float4(")?;
4254 self.write_expr(module, arg, func_ctx)?;
4255 write!(self.out, " & 0xFF, ")?;
4256 self.write_expr(module, arg, func_ctx)?;
4257 write!(self.out, " >> 8 & 0xFF, ")?;
4258 self.write_expr(module, arg, func_ctx)?;
4259 write!(self.out, " >> 16 & 0xFF, ")?;
4260 self.write_expr(module, arg, func_ctx)?;
4261 write!(self.out, " >> 24) / {scale}.0)")?;
4262 }
4263 fun @ (Function::Unpack4xI8 | Function::Unpack4xU8) => {
4264 write!(self.out, "(")?;
4265 if matches!(fun, Function::Unpack4xU8) {
4266 write!(self.out, "u")?;
4267 }
4268 write!(self.out, "int4(")?;
4269 self.write_expr(module, arg, func_ctx)?;
4270 write!(self.out, ", ")?;
4271 self.write_expr(module, arg, func_ctx)?;
4272 write!(self.out, " >> 8, ")?;
4273 self.write_expr(module, arg, func_ctx)?;
4274 write!(self.out, " >> 16, ")?;
4275 self.write_expr(module, arg, func_ctx)?;
4276 write!(self.out, " >> 24) << 24 >> 24)")?;
4277 }
4278 fun @ (Function::Dot4I8Packed | Function::Dot4U8Packed) => {
4279 let arg1 = arg1.unwrap();
4280
4281 if self.options.shader_model >= ShaderModel::V6_4 {
4282 let function_name = match fun {
4284 Function::Dot4I8Packed => "dot4add_i8packed",
4285 Function::Dot4U8Packed => "dot4add_u8packed",
4286 _ => unreachable!(),
4287 };
4288 write!(self.out, "{function_name}(")?;
4289 self.write_expr(module, arg, func_ctx)?;
4290 write!(self.out, ", ")?;
4291 self.write_expr(module, arg1, func_ctx)?;
4292 write!(self.out, ", 0)")?;
4293 } else {
4294 write!(self.out, "dot(")?;
4296
4297 if matches!(fun, Function::Dot4U8Packed) {
4298 write!(self.out, "u")?;
4299 }
4300 write!(self.out, "int4(")?;
4301 self.write_expr(module, arg, func_ctx)?;
4302 write!(self.out, ", ")?;
4303 self.write_expr(module, arg, func_ctx)?;
4304 write!(self.out, " >> 8, ")?;
4305 self.write_expr(module, arg, func_ctx)?;
4306 write!(self.out, " >> 16, ")?;
4307 self.write_expr(module, arg, func_ctx)?;
4308 write!(self.out, " >> 24) << 24 >> 24, ")?;
4309
4310 if matches!(fun, Function::Dot4U8Packed) {
4311 write!(self.out, "u")?;
4312 }
4313 write!(self.out, "int4(")?;
4314 self.write_expr(module, arg1, func_ctx)?;
4315 write!(self.out, ", ")?;
4316 self.write_expr(module, arg1, func_ctx)?;
4317 write!(self.out, " >> 8, ")?;
4318 self.write_expr(module, arg1, func_ctx)?;
4319 write!(self.out, " >> 16, ")?;
4320 self.write_expr(module, arg1, func_ctx)?;
4321 write!(self.out, " >> 24) << 24 >> 24)")?;
4322 }
4323 }
4324 Function::QuantizeToF16 => {
4325 write!(self.out, "f16tof32(f32tof16(")?;
4326 self.write_expr(module, arg, func_ctx)?;
4327 write!(self.out, "))")?;
4328 }
4329 Function::Regular(fun_name) => {
4330 write!(self.out, "{fun_name}(")?;
4331 self.write_expr(module, arg, func_ctx)?;
4332 if let Some(arg) = arg1 {
4333 write!(self.out, ", ")?;
4334 self.write_expr(module, arg, func_ctx)?;
4335 }
4336 if let Some(arg) = arg2 {
4337 write!(self.out, ", ")?;
4338 self.write_expr(module, arg, func_ctx)?;
4339 }
4340 if let Some(arg) = arg3 {
4341 write!(self.out, ", ")?;
4342 self.write_expr(module, arg, func_ctx)?;
4343 }
4344 write!(self.out, ")")?
4345 }
4346 Function::MissingIntOverload(fun_name) => {
4349 let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
4350 if let Some(Scalar::I32) = scalar_kind {
4351 write!(self.out, "asint({fun_name}(asuint(")?;
4352 self.write_expr(module, arg, func_ctx)?;
4353 write!(self.out, ")))")?;
4354 } else {
4355 write!(self.out, "{fun_name}(")?;
4356 self.write_expr(module, arg, func_ctx)?;
4357 write!(self.out, ")")?;
4358 }
4359 }
4360 Function::MissingIntReturnType(fun_name) => {
4363 let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
4364 if let Some(Scalar::I32) = scalar_kind {
4365 write!(self.out, "asint({fun_name}(")?;
4366 self.write_expr(module, arg, func_ctx)?;
4367 write!(self.out, "))")?;
4368 } else {
4369 write!(self.out, "{fun_name}(")?;
4370 self.write_expr(module, arg, func_ctx)?;
4371 write!(self.out, ")")?;
4372 }
4373 }
4374 Function::CountTrailingZeros => {
4375 match *func_ctx.resolve_type(arg, &module.types) {
4376 TypeInner::Vector { size, scalar } => {
4377 let s = match size {
4378 crate::VectorSize::Bi => ".xx",
4379 crate::VectorSize::Tri => ".xxx",
4380 crate::VectorSize::Quad => ".xxxx",
4381 };
4382
4383 let scalar_width_bits = scalar.width * 8;
4384
4385 if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
4386 write!(
4387 self.out,
4388 "min(({scalar_width_bits}u){s}, firstbitlow("
4389 )?;
4390 self.write_expr(module, arg, func_ctx)?;
4391 write!(self.out, "))")?;
4392 } else {
4393 write!(
4395 self.out,
4396 "asint(min(({scalar_width_bits}u){s}, firstbitlow("
4397 )?;
4398 self.write_expr(module, arg, func_ctx)?;
4399 write!(self.out, ")))")?;
4400 }
4401 }
4402 TypeInner::Scalar(scalar) => {
4403 let scalar_width_bits = scalar.width * 8;
4404
4405 if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
4406 write!(self.out, "min({scalar_width_bits}u, firstbitlow(")?;
4407 self.write_expr(module, arg, func_ctx)?;
4408 write!(self.out, "))")?;
4409 } else {
4410 write!(
4412 self.out,
4413 "asint(min({scalar_width_bits}u, firstbitlow("
4414 )?;
4415 self.write_expr(module, arg, func_ctx)?;
4416 write!(self.out, ")))")?;
4417 }
4418 }
4419 _ => unreachable!(),
4420 }
4421
4422 return Ok(());
4423 }
4424 Function::CountLeadingZeros => {
4425 match *func_ctx.resolve_type(arg, &module.types) {
4426 TypeInner::Vector { size, scalar } => {
4427 let s = match size {
4428 crate::VectorSize::Bi => ".xx",
4429 crate::VectorSize::Tri => ".xxx",
4430 crate::VectorSize::Quad => ".xxxx",
4431 };
4432
4433 let constant = scalar.width * 8 - 1;
4435
4436 if scalar.kind == ScalarKind::Uint {
4437 write!(self.out, "(({constant}u){s} - firstbithigh(")?;
4438 self.write_expr(module, arg, func_ctx)?;
4439 write!(self.out, "))")?;
4440 } else {
4441 let conversion_func = match scalar.width {
4442 4 => "asint",
4443 _ => "",
4444 };
4445 write!(self.out, "(")?;
4446 self.write_expr(module, arg, func_ctx)?;
4447 write!(
4448 self.out,
4449 " < (0){s} ? (0){s} : ({constant}){s} - {conversion_func}(firstbithigh("
4450 )?;
4451 self.write_expr(module, arg, func_ctx)?;
4452 write!(self.out, ")))")?;
4453 }
4454 }
4455 TypeInner::Scalar(scalar) => {
4456 let constant = scalar.width * 8 - 1;
4458
4459 if let ScalarKind::Uint = scalar.kind {
4460 write!(self.out, "({constant}u - firstbithigh(")?;
4461 self.write_expr(module, arg, func_ctx)?;
4462 write!(self.out, "))")?;
4463 } else {
4464 let conversion_func = match scalar.width {
4465 4 => "asint",
4466 _ => "",
4467 };
4468 write!(self.out, "(")?;
4469 self.write_expr(module, arg, func_ctx)?;
4470 write!(
4471 self.out,
4472 " < 0 ? 0 : {constant} - {conversion_func}(firstbithigh("
4473 )?;
4474 self.write_expr(module, arg, func_ctx)?;
4475 write!(self.out, ")))")?;
4476 }
4477 }
4478 _ => unreachable!(),
4479 }
4480
4481 return Ok(());
4482 }
4483 }
4484 }
4485 Expression::Swizzle {
4486 size,
4487 vector,
4488 pattern,
4489 } => {
4490 self.write_expr(module, vector, func_ctx)?;
4491 write!(self.out, ".")?;
4492 for &sc in pattern[..size as usize].iter() {
4493 self.out.write_char(back::COMPONENTS[sc as usize])?;
4494 }
4495 }
4496 Expression::ArrayLength(expr) => {
4497 let var_handle = match func_ctx.expressions[expr] {
4498 Expression::AccessIndex { base, index: _ } => {
4499 match func_ctx.expressions[base] {
4500 Expression::GlobalVariable(handle) => handle,
4501 _ => unreachable!(),
4502 }
4503 }
4504 Expression::GlobalVariable(handle) => handle,
4505 _ => unreachable!(),
4506 };
4507
4508 let var = &module.global_variables[var_handle];
4509 let (offset, stride) = match module.types[var.ty].inner {
4510 TypeInner::Array { stride, .. } => (0, stride),
4511 TypeInner::Struct { ref members, .. } => {
4512 let last = members.last().unwrap();
4513 let stride = match module.types[last.ty].inner {
4514 TypeInner::Array { stride, .. } => stride,
4515 _ => unreachable!(),
4516 };
4517 (last.offset, stride)
4518 }
4519 _ => unreachable!(),
4520 };
4521
4522 let storage_access = match var.space {
4523 crate::AddressSpace::Storage { access } => access,
4524 _ => crate::StorageAccess::default(),
4525 };
4526 let wrapped_array_length = WrappedArrayLength {
4527 writable: storage_access.contains(crate::StorageAccess::STORE),
4528 };
4529
4530 write!(self.out, "((")?;
4531 self.write_wrapped_array_length_function_name(wrapped_array_length)?;
4532 let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
4533 write!(self.out, "({var_name}) - {offset}) / {stride})")?
4534 }
4535 Expression::Derivative { axis, ctrl, expr } => {
4536 use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
4537 if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
4538 let tail = match ctrl {
4539 Ctrl::Coarse => "coarse",
4540 Ctrl::Fine => "fine",
4541 Ctrl::None => unreachable!(),
4542 };
4543 write!(self.out, "abs(ddx_{tail}(")?;
4544 self.write_expr(module, expr, func_ctx)?;
4545 write!(self.out, ")) + abs(ddy_{tail}(")?;
4546 self.write_expr(module, expr, func_ctx)?;
4547 write!(self.out, "))")?
4548 } else {
4549 let fun_str = match (axis, ctrl) {
4550 (Axis::X, Ctrl::Coarse) => "ddx_coarse",
4551 (Axis::X, Ctrl::Fine) => "ddx_fine",
4552 (Axis::X, Ctrl::None) => "ddx",
4553 (Axis::Y, Ctrl::Coarse) => "ddy_coarse",
4554 (Axis::Y, Ctrl::Fine) => "ddy_fine",
4555 (Axis::Y, Ctrl::None) => "ddy",
4556 (Axis::Width, Ctrl::Coarse | Ctrl::Fine) => unreachable!(),
4557 (Axis::Width, Ctrl::None) => "fwidth",
4558 };
4559 write!(self.out, "{fun_str}(")?;
4560 self.write_expr(module, expr, func_ctx)?;
4561 write!(self.out, ")")?
4562 }
4563 }
4564 Expression::Relational { fun, argument } => {
4565 use crate::RelationalFunction as Rf;
4566
4567 let fun_str = match fun {
4568 Rf::All => "all",
4569 Rf::Any => "any",
4570 Rf::IsNan => "isnan",
4571 Rf::IsInf => "isinf",
4572 };
4573 write!(self.out, "{fun_str}(")?;
4574 self.write_expr(module, argument, func_ctx)?;
4575 write!(self.out, ")")?
4576 }
4577 Expression::Select {
4578 condition,
4579 accept,
4580 reject,
4581 } => {
4582 write!(self.out, "(")?;
4583 self.write_expr(module, condition, func_ctx)?;
4584 write!(self.out, " ? ")?;
4585 self.write_expr(module, accept, func_ctx)?;
4586 write!(self.out, " : ")?;
4587 self.write_expr(module, reject, func_ctx)?;
4588 write!(self.out, ")")?
4589 }
4590 Expression::RayQueryGetIntersection { query, committed } => {
4591 let Expression::LocalVariable(query_var) = func_ctx.expressions[query] else {
4593 unreachable!()
4594 };
4595
4596 let tracker_expr_name = format!(
4597 "{RAY_QUERY_TRACKER_VARIABLE_PREFIX}{}",
4598 self.names[&func_ctx.name_key(query_var)]
4599 );
4600
4601 if committed {
4602 write!(self.out, "GetCommittedIntersection(")?;
4603 self.write_expr(module, query, func_ctx)?;
4604 write!(self.out, ", {tracker_expr_name})")?;
4605 } else {
4606 write!(self.out, "GetCandidateIntersection(")?;
4607 self.write_expr(module, query, func_ctx)?;
4608 write!(self.out, ", {tracker_expr_name})")?;
4609 }
4610 }
4611 Expression::RayQueryVertexPositions { .. }
4613 | Expression::CooperativeLoad { .. }
4614 | Expression::CooperativeMultiplyAdd { .. } => {
4615 unreachable!()
4616 }
4617 Expression::CallResult(_)
4619 | Expression::AtomicResult { .. }
4620 | Expression::WorkGroupUniformLoadResult { .. }
4621 | Expression::RayQueryProceedResult
4622 | Expression::SubgroupBallotResult
4623 | Expression::SubgroupOperationResult { .. } => {}
4624 }
4625
4626 if !closing_bracket.is_empty() {
4627 write!(self.out, "{closing_bracket}")?;
4628 }
4629 Ok(())
4630 }
4631
4632 #[allow(clippy::too_many_arguments)]
4633 fn write_image_load(
4634 &mut self,
4635 module: &&Module,
4636 expr: Handle<crate::Expression>,
4637 func_ctx: &back::FunctionCtx,
4638 image: Handle<crate::Expression>,
4639 coordinate: Handle<crate::Expression>,
4640 array_index: Option<Handle<crate::Expression>>,
4641 sample: Option<Handle<crate::Expression>>,
4642 level: Option<Handle<crate::Expression>>,
4643 ) -> Result<(), Error> {
4644 let mut wrapping_type = None;
4645 match *func_ctx.resolve_type(image, &module.types) {
4646 TypeInner::Image {
4647 class: crate::ImageClass::External,
4648 ..
4649 } => {
4650 write!(self.out, "{IMAGE_LOAD_EXTERNAL_FUNCTION}(")?;
4651 self.write_expr(module, image, func_ctx)?;
4652 write!(self.out, ", ")?;
4653 self.write_expr(module, coordinate, func_ctx)?;
4654 write!(self.out, ")")?;
4655 return Ok(());
4656 }
4657 TypeInner::Image {
4658 class: crate::ImageClass::Storage { format, .. },
4659 ..
4660 } => {
4661 if format.single_component() {
4662 wrapping_type = Some(Scalar::from(format));
4663 }
4664 }
4665 _ => {}
4666 }
4667 if let Some(scalar) = wrapping_type {
4668 write!(
4669 self.out,
4670 "{}{}(",
4671 help::IMAGE_STORAGE_LOAD_SCALAR_WRAPPER,
4672 scalar.to_hlsl_str()?
4673 )?;
4674 }
4675 self.write_expr(module, image, func_ctx)?;
4677 write!(self.out, ".Load(")?;
4678
4679 self.write_texture_coordinates("int", coordinate, array_index, level, module, func_ctx)?;
4680
4681 if let Some(sample) = sample {
4682 write!(self.out, ", ")?;
4683 self.write_expr(module, sample, func_ctx)?;
4684 }
4685
4686 write!(self.out, ")")?;
4688
4689 if wrapping_type.is_some() {
4690 write!(self.out, ")")?;
4691 }
4692
4693 if let TypeInner::Scalar(_) = *func_ctx.resolve_type(expr, &module.types) {
4695 write!(self.out, ".x")?;
4696 }
4697 Ok(())
4698 }
4699
4700 fn sampler_binding_array_info_from_expression(
4703 &mut self,
4704 module: &Module,
4705 func_ctx: &back::FunctionCtx<'_>,
4706 base: Handle<crate::Expression>,
4707 resolved: &TypeInner,
4708 ) -> Option<BindingArraySamplerInfo> {
4709 if let TypeInner::BindingArray {
4710 base: base_ty_handle,
4711 ..
4712 } = *resolved
4713 {
4714 let base_ty = &module.types[base_ty_handle].inner;
4715 if let TypeInner::Sampler { comparison, .. } = *base_ty {
4716 let base = &func_ctx.expressions[base];
4717
4718 if let crate::Expression::GlobalVariable(handle) = *base {
4719 let variable = &module.global_variables[handle];
4720
4721 let sampler_heap_name = match comparison {
4722 true => COMPARISON_SAMPLER_HEAP_VAR,
4723 false => SAMPLER_HEAP_VAR,
4724 };
4725
4726 return Some(BindingArraySamplerInfo {
4727 sampler_heap_name,
4728 sampler_index_buffer_name: self
4729 .wrapped
4730 .sampler_index_buffers
4731 .get(&super::SamplerIndexBufferKey {
4732 group: variable.binding.unwrap().group,
4733 })
4734 .unwrap()
4735 .clone(),
4736 binding_array_base_index_name: self.names[&NameKey::GlobalVariable(handle)]
4737 .clone(),
4738 });
4739 }
4740 }
4741 }
4742
4743 None
4744 }
4745
4746 fn write_named_expr(
4747 &mut self,
4748 module: &Module,
4749 handle: Handle<crate::Expression>,
4750 name: String,
4751 expr: Handle<crate::Expression>,
4754 func_ctx: &back::FunctionCtx,
4755 ) -> BackendResult {
4756 if let crate::Expression::Load { pointer } = func_ctx.expressions[expr] {
4757 let ty_inner = func_ctx.resolve_type(pointer, &module.types);
4758 if ty_inner.is_atomic_pointer(&module.types) {
4759 let pointer_space = ty_inner.pointer_space().unwrap();
4760 self.write_value_type(module, func_ctx.info[handle].ty.inner_with(&module.types))?;
4761 write!(self.out, " {name}; ")?;
4762 match pointer_space {
4763 crate::AddressSpace::WorkGroup => {
4764 write!(self.out, "InterlockedOr(")?;
4765 self.write_expr(module, pointer, func_ctx)?;
4766 }
4767 crate::AddressSpace::Storage { .. } => {
4768 let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
4769 let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
4770 write!(self.out, "{var_name}.InterlockedOr(")?;
4771 let chain = mem::take(&mut self.temp_access_chain);
4772 self.write_storage_address(module, &chain, func_ctx)?;
4773 self.temp_access_chain = chain;
4774 }
4775 _ => unreachable!(),
4776 }
4777 writeln!(self.out, ", 0, {name});")?;
4778 self.named_expressions.insert(expr, name);
4779 return Ok(());
4780 }
4781 }
4782 match func_ctx.info[expr].ty {
4783 proc::TypeResolution::Handle(ty_handle) => match module.types[ty_handle].inner {
4784 TypeInner::Struct { .. } => {
4785 let ty_name = &self.names[&NameKey::Type(ty_handle)];
4786 write!(self.out, "{ty_name}")?;
4787 }
4788 _ => {
4789 self.write_type(module, ty_handle)?;
4790 }
4791 },
4792 proc::TypeResolution::Value(ref inner) => {
4793 self.write_value_type(module, inner)?;
4794 }
4795 }
4796
4797 let resolved = func_ctx.resolve_type(expr, &module.types);
4798
4799 write!(self.out, " {name}")?;
4800 if let TypeInner::Array { base, size, .. } = *resolved {
4802 self.write_array_size(module, base, size)?;
4803 }
4804 write!(self.out, " = ")?;
4805 self.write_expr(module, handle, func_ctx)?;
4806 writeln!(self.out, ";")?;
4807 self.named_expressions.insert(expr, name);
4808
4809 Ok(())
4810 }
4811
4812 pub(super) fn write_default_init(
4814 &mut self,
4815 module: &Module,
4816 ty: Handle<crate::Type>,
4817 ) -> BackendResult {
4818 write!(self.out, "(")?;
4819 self.write_type(module, ty)?;
4820 if let TypeInner::Array { base, size, .. } = module.types[ty].inner {
4821 self.write_array_size(module, base, size)?;
4822 }
4823 write!(self.out, ")0")?;
4824 Ok(())
4825 }
4826
4827 pub(super) fn write_control_barrier(
4828 &mut self,
4829 barrier: crate::Barrier,
4830 level: back::Level,
4831 ) -> BackendResult {
4832 if barrier.contains(crate::Barrier::STORAGE) {
4833 writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
4834 }
4835 if barrier.contains(crate::Barrier::WORK_GROUP) {
4836 writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?;
4837 }
4838 if barrier.contains(crate::Barrier::SUB_GROUP) {
4839 }
4841 if barrier.contains(crate::Barrier::TEXTURE) {
4842 writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
4843 }
4844 Ok(())
4845 }
4846
4847 fn write_memory_barrier(
4848 &mut self,
4849 barrier: crate::Barrier,
4850 level: back::Level,
4851 ) -> BackendResult {
4852 if barrier.contains(crate::Barrier::STORAGE) {
4853 writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
4854 }
4855 if barrier.contains(crate::Barrier::WORK_GROUP) {
4856 writeln!(self.out, "{level}GroupMemoryBarrier();")?;
4857 }
4858 if barrier.contains(crate::Barrier::SUB_GROUP) {
4859 }
4861 if barrier.contains(crate::Barrier::TEXTURE) {
4862 writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
4863 }
4864 Ok(())
4865 }
4866
4867 fn emit_hlsl_atomic_tail(
4869 &mut self,
4870 module: &Module,
4871 func_ctx: &back::FunctionCtx<'_>,
4872 fun: &crate::AtomicFunction,
4873 compare_expr: Option<Handle<crate::Expression>>,
4874 value: Handle<crate::Expression>,
4875 res_var_info: &Option<(Handle<crate::Expression>, String)>,
4876 ) -> BackendResult {
4877 if let Some(cmp) = compare_expr {
4878 write!(self.out, ", ")?;
4879 self.write_expr(module, cmp, func_ctx)?;
4880 }
4881 write!(self.out, ", ")?;
4882 if let crate::AtomicFunction::Subtract = *fun {
4883 write!(self.out, "-")?;
4885 }
4886 self.write_expr(module, value, func_ctx)?;
4887 if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
4888 write!(self.out, ", ")?;
4889 if compare_expr.is_some() {
4890 write!(self.out, "{res_name}.old_value")?;
4891 } else {
4892 write!(self.out, "{res_name}")?;
4893 }
4894 }
4895 writeln!(self.out, ");")?;
4896 Ok(())
4897 }
4898}
4899
4900pub(super) struct MatrixType {
4901 pub(super) columns: crate::VectorSize,
4902 pub(super) rows: crate::VectorSize,
4903 pub(super) width: crate::Bytes,
4904}
4905
4906pub(super) fn get_inner_matrix_data(
4907 module: &Module,
4908 handle: Handle<crate::Type>,
4909) -> Option<MatrixType> {
4910 match module.types[handle].inner {
4911 TypeInner::Matrix {
4912 columns,
4913 rows,
4914 scalar,
4915 } => Some(MatrixType {
4916 columns,
4917 rows,
4918 width: scalar.width,
4919 }),
4920 TypeInner::Array { base, .. } => get_inner_matrix_data(module, base),
4921 _ => None,
4922 }
4923}
4924
4925fn find_matrix_in_access_chain(
4929 module: &Module,
4930 base: Handle<crate::Expression>,
4931 func_ctx: &back::FunctionCtx<'_>,
4932) -> Option<(Handle<crate::Expression>, Option<Index>, Option<Index>)> {
4933 let mut current_base = base;
4934 let mut vector = None;
4935 let mut scalar = None;
4936 loop {
4937 let resolved_tr = func_ctx
4938 .resolve_type(current_base, &module.types)
4939 .pointer_base_type();
4940 let resolved = resolved_tr.as_ref()?.inner_with(&module.types);
4941
4942 match *resolved {
4943 TypeInner::Matrix { .. } => return Some((current_base, vector, scalar)),
4944 TypeInner::Scalar(_) | TypeInner::Vector { .. } => {}
4945 _ => return None,
4946 }
4947
4948 let index;
4949 (current_base, index) = match func_ctx.expressions[current_base] {
4950 crate::Expression::Access { base, index } => (base, Index::Expression(index)),
4951 crate::Expression::AccessIndex { base, index } => (base, Index::Static(index)),
4952 _ => return None,
4953 };
4954
4955 match *resolved {
4956 TypeInner::Scalar(_) => scalar = Some(index),
4957 TypeInner::Vector { .. } => vector = Some(index),
4958 _ => unreachable!(),
4959 }
4960 }
4961}
4962
4963pub(super) fn get_inner_matrix_of_struct_array_member(
4968 module: &Module,
4969 base: Handle<crate::Expression>,
4970 func_ctx: &back::FunctionCtx<'_>,
4971 direct: bool,
4972) -> Option<MatrixType> {
4973 let mut mat_data = None;
4974 let mut array_base = None;
4975
4976 let mut current_base = base;
4977 loop {
4978 let mut resolved = func_ctx.resolve_type(current_base, &module.types);
4979 if let TypeInner::Pointer { base, .. } = *resolved {
4980 resolved = &module.types[base].inner;
4981 };
4982
4983 match *resolved {
4984 TypeInner::Matrix {
4985 columns,
4986 rows,
4987 scalar,
4988 } => {
4989 mat_data = Some(MatrixType {
4990 columns,
4991 rows,
4992 width: scalar.width,
4993 })
4994 }
4995 TypeInner::Array { base, .. } => {
4996 array_base = Some(base);
4997 }
4998 TypeInner::Struct { .. } => {
4999 if let Some(array_base) = array_base {
5000 if direct {
5001 return mat_data;
5002 } else {
5003 return get_inner_matrix_data(module, array_base);
5004 }
5005 }
5006
5007 break;
5008 }
5009 _ => break,
5010 }
5011
5012 current_base = match func_ctx.expressions[current_base] {
5013 crate::Expression::Access { base, .. } => base,
5014 crate::Expression::AccessIndex { base, .. } => base,
5015 _ => break,
5016 };
5017 }
5018 None
5019}
5020
5021fn get_global_uniform_matrix(
5024 module: &Module,
5025 base: Handle<crate::Expression>,
5026 func_ctx: &back::FunctionCtx<'_>,
5027) -> Option<MatrixType> {
5028 let base_tr = func_ctx
5029 .resolve_type(base, &module.types)
5030 .pointer_base_type();
5031 let base_ty = base_tr.as_ref().map(|tr| tr.inner_with(&module.types));
5032 match (&func_ctx.expressions[base], base_ty) {
5033 (
5034 &crate::Expression::GlobalVariable(handle),
5035 Some(&TypeInner::Matrix {
5036 columns,
5037 rows,
5038 scalar,
5039 }),
5040 ) if module.global_variables[handle].space == crate::AddressSpace::Uniform => {
5041 Some(MatrixType {
5042 columns,
5043 rows,
5044 width: scalar.width,
5045 })
5046 }
5047 _ => None,
5048 }
5049}
5050
5051fn get_inner_matrix_of_global_uniform(
5056 module: &Module,
5057 base: Handle<crate::Expression>,
5058 func_ctx: &back::FunctionCtx<'_>,
5059) -> Option<MatrixType> {
5060 let mut mat_data = None;
5061 let mut array_base = None;
5062
5063 let mut current_base = base;
5064 loop {
5065 let mut resolved = func_ctx.resolve_type(current_base, &module.types);
5066 if let TypeInner::Pointer { base, .. } = *resolved {
5067 resolved = &module.types[base].inner;
5068 };
5069
5070 match *resolved {
5071 TypeInner::Matrix {
5072 columns,
5073 rows,
5074 scalar,
5075 } => {
5076 mat_data = Some(MatrixType {
5077 columns,
5078 rows,
5079 width: scalar.width,
5080 })
5081 }
5082 TypeInner::Array { base, .. } => {
5083 array_base = Some(base);
5084 }
5085 _ => break,
5086 }
5087
5088 current_base = match func_ctx.expressions[current_base] {
5089 crate::Expression::Access { base, .. } => base,
5090 crate::Expression::AccessIndex { base, .. } => base,
5091 crate::Expression::GlobalVariable(handle)
5092 if module.global_variables[handle].space == crate::AddressSpace::Uniform =>
5093 {
5094 return mat_data.or_else(|| {
5095 array_base.and_then(|array_base| get_inner_matrix_data(module, array_base))
5096 })
5097 }
5098 _ => break,
5099 };
5100 }
5101 None
5102}