1use alloc::{
2 format,
3 string::{String, ToString},
4 vec::Vec,
5};
6use core::{
7 fmt::{self, Write as _},
8 mem,
9};
10
11use super::{
12 help,
13 help::{
14 WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess,
15 WrappedZeroValue,
16 },
17 storage::StoreValue,
18 BackendResult, Error, FragmentEntryPoint, Options, PipelineOptions, ShaderModel,
19};
20use crate::{
21 back::{self, get_entry_points, Baked},
22 common,
23 proc::{self, index, ExternalTextureNameKey, NameKey},
24 valid, Handle, Module, RayQueryFunction, Scalar, ScalarKind, ShaderStage, TypeInner,
25};
26
27const LOCATION_SEMANTIC: &str = "LOC";
28const SPECIAL_CBUF_TYPE: &str = "NagaConstants";
29const SPECIAL_CBUF_VAR: &str = "_NagaConstants";
30const SPECIAL_FIRST_VERTEX: &str = "first_vertex";
31const SPECIAL_FIRST_INSTANCE: &str = "first_instance";
32const SPECIAL_OTHER: &str = "other";
33
34pub(crate) const MODF_FUNCTION: &str = "naga_modf";
35pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
36pub(crate) const EXTRACT_BITS_FUNCTION: &str = "naga_extractBits";
37pub(crate) const INSERT_BITS_FUNCTION: &str = "naga_insertBits";
38pub(crate) const SAMPLER_HEAP_VAR: &str = "nagaSamplerHeap";
39pub(crate) const COMPARISON_SAMPLER_HEAP_VAR: &str = "nagaComparisonSamplerHeap";
40pub(crate) const SAMPLE_EXTERNAL_TEXTURE_FUNCTION: &str = "nagaSampleExternalTexture";
41pub(crate) const ABS_FUNCTION: &str = "naga_abs";
42pub(crate) const DIV_FUNCTION: &str = "naga_div";
43pub(crate) const MOD_FUNCTION: &str = "naga_mod";
44pub(crate) const NEG_FUNCTION: &str = "naga_neg";
45pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
46pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
47pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
48pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
49pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
50 "nagaTextureSampleBaseClampToEdge";
51pub(crate) const IMAGE_LOAD_EXTERNAL_FUNCTION: &str = "nagaTextureLoadExternal";
52pub(crate) const RAY_QUERY_TRACKER_VARIABLE_PREFIX: &str = "naga_query_init_tracker_for_";
53pub(crate) const INTERNAL_PREFIX: &str = "naga_";
55
56enum Index {
57 Expression(Handle<crate::Expression>),
58 Static(u32),
59}
60
61pub(super) struct EpStructMember {
62 pub(super) name: String,
63 pub(super) ty: Handle<crate::Type>,
64 pub(super) binding: Option<crate::Binding>,
67 pub(super) index: u32,
68}
69
70pub(super) struct EntryPointBinding {
73 pub(super) arg_name: String,
76 pub(super) ty_name: String,
78 pub(super) members: Vec<EpStructMember>,
80 pub(super) local_invocation_index_name: Option<String>,
81}
82
83pub(super) struct EntryPointInterface {
84 pub(crate) input: Option<EntryPointBinding>,
89 pub(crate) output: Option<EntryPointBinding>,
93 pub(crate) mesh_vertices: Option<EntryPointBinding>,
94 pub(crate) mesh_primitives: Option<EntryPointBinding>,
95 pub(crate) mesh_indices: Option<EntryPointBinding>,
96}
97
98#[derive(Clone, Eq, PartialEq, PartialOrd, Ord)]
99enum InterfaceKey {
100 Location(u32),
101 BuiltIn(crate::BuiltIn),
102 Other,
103}
104
105impl InterfaceKey {
106 const fn new(binding: Option<&crate::Binding>) -> Self {
107 match binding {
108 Some(&crate::Binding::Location { location, .. }) => Self::Location(location),
109 Some(&crate::Binding::BuiltIn(built_in)) => Self::BuiltIn(built_in),
110 None => Self::Other,
111 }
112 }
113}
114
115#[derive(Copy, Clone, PartialEq)]
116pub(super) enum Io {
117 Input,
118 Output,
119 MeshVertices,
120 MeshPrimitives,
121}
122
123pub(super) struct NestedEntryPointArgs {
125 pub user_args: Vec<String>,
127 pub task_payload: Option<String>,
128 pub local_invocation_index: String,
129}
130
131const fn is_subgroup_builtin_binding(binding: &Option<crate::Binding>) -> bool {
132 let &Some(crate::Binding::BuiltIn(builtin)) = binding else {
133 return false;
134 };
135 matches!(
136 builtin,
137 crate::BuiltIn::SubgroupSize
138 | crate::BuiltIn::SubgroupInvocationId
139 | crate::BuiltIn::NumSubgroups
140 | crate::BuiltIn::SubgroupId
141 )
142}
143
144struct BindingArraySamplerInfo {
146 sampler_heap_name: &'static str,
148 sampler_index_buffer_name: String,
150 binding_array_base_index_name: String,
152}
153
154impl<'a, W: fmt::Write> super::Writer<'a, W> {
155 pub fn new(out: W, options: &'a Options, pipeline_options: &'a PipelineOptions) -> Self {
156 Self {
157 out,
158 names: crate::FastHashMap::default(),
159 namer: proc::Namer::default(),
160 options,
161 pipeline_options,
162 entry_point_io: crate::FastHashMap::default(),
163 named_expressions: crate::NamedExpressions::default(),
164 wrapped: super::Wrapped::default(),
165 written_committed_intersection: false,
166 written_candidate_intersection: false,
167 continue_ctx: back::continue_forward::ContinueCtx::default(),
168 temp_access_chain: Vec::new(),
169 need_bake_expressions: Default::default(),
170 function_task_payload_var: Default::default(),
171 }
172 }
173
174 fn reset(&mut self, module: &Module) {
175 self.names.clear();
176 self.namer.reset(
177 module,
178 &super::keywords::RESERVED_SET,
179 proc::KeywordSet::empty(),
180 &super::keywords::RESERVED_CASE_INSENSITIVE_SET,
181 super::keywords::RESERVED_PREFIXES,
182 &mut self.names,
183 );
184 self.entry_point_io.clear();
185 self.named_expressions.clear();
186 self.wrapped.clear();
187 self.written_committed_intersection = false;
188 self.written_candidate_intersection = false;
189 self.continue_ctx.clear();
190 self.need_bake_expressions.clear();
191 self.function_task_payload_var.clear();
192 }
193
194 fn gen_force_bounded_loop_statements(
202 &mut self,
203 level: back::Level,
204 ) -> Option<(String, String)> {
205 if !self.options.force_loop_bounding {
206 return None;
207 }
208
209 let loop_bound_name = self.namer.call("loop_bound");
210 let max = u32::MAX;
211 let decl = format!("{level}uint2 {loop_bound_name} = uint2({max}u, {max}u);");
214 let level = level.next();
215 let break_and_inc = format!(
216 "{level}if (all({loop_bound_name} == uint2(0u, 0u))) {{ break; }}
217{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
218 );
219
220 Some((decl, break_and_inc))
221 }
222
223 fn update_expressions_to_bake(
228 &mut self,
229 module: &Module,
230 func: &crate::Function,
231 info: &valid::FunctionInfo,
232 ) {
233 use crate::Expression;
234 self.need_bake_expressions.clear();
235 for (exp_handle, expr) in func.expressions.iter() {
236 let expr_info = &info[exp_handle];
237 let min_ref_count = func.expressions[exp_handle].bake_ref_count();
238 if min_ref_count <= expr_info.ref_count {
239 self.need_bake_expressions.insert(exp_handle);
240 }
241 if let Expression::Load { pointer } = *expr {
242 if info[pointer]
243 .ty
244 .inner_with(&module.types)
245 .is_atomic_pointer(&module.types)
246 {
247 self.need_bake_expressions.insert(exp_handle);
248 }
249 }
250
251 if let Expression::Math { fun, arg, arg1, .. } = *expr {
252 match fun {
253 crate::MathFunction::Asinh
254 | crate::MathFunction::Acosh
255 | crate::MathFunction::Atanh
256 | crate::MathFunction::Unpack2x16float
257 | crate::MathFunction::Unpack2x16snorm
258 | crate::MathFunction::Unpack2x16unorm
259 | crate::MathFunction::Unpack4x8snorm
260 | crate::MathFunction::Unpack4x8unorm
261 | crate::MathFunction::Unpack4xI8
262 | crate::MathFunction::Unpack4xU8
263 | crate::MathFunction::Pack2x16float
264 | crate::MathFunction::Pack2x16snorm
265 | crate::MathFunction::Pack2x16unorm
266 | crate::MathFunction::Pack4x8snorm
267 | crate::MathFunction::Pack4x8unorm
268 | crate::MathFunction::Pack4xI8
269 | crate::MathFunction::Pack4xU8
270 | crate::MathFunction::Pack4xI8Clamp
271 | crate::MathFunction::Pack4xU8Clamp => {
272 self.need_bake_expressions.insert(arg);
273 }
274 crate::MathFunction::CountLeadingZeros => {
275 let inner = info[exp_handle].ty.inner_with(&module.types);
276 if let Some(ScalarKind::Sint) = inner.scalar_kind() {
277 self.need_bake_expressions.insert(arg);
278 }
279 }
280 crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
281 self.need_bake_expressions.insert(arg);
282 self.need_bake_expressions.insert(arg1.unwrap());
283 }
284 _ => {}
285 }
286 }
287
288 if let Expression::Derivative { axis, ctrl, expr } = *expr {
289 use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
290 if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
291 self.need_bake_expressions.insert(expr);
292 }
293 }
294
295 if let Expression::GlobalVariable(_) = *expr {
296 let inner = info[exp_handle].ty.inner_with(&module.types);
297
298 if let TypeInner::Sampler { .. } = *inner {
299 self.need_bake_expressions.insert(exp_handle);
300 }
301 }
302 }
303 for statement in func.body.iter() {
304 match *statement {
305 crate::Statement::SubgroupCollectiveOperation {
306 op: _,
307 collective_op: crate::CollectiveOperation::InclusiveScan,
308 argument,
309 result: _,
310 } => {
311 self.need_bake_expressions.insert(argument);
312 }
313 crate::Statement::Atomic {
314 fun: crate::AtomicFunction::Exchange { compare: Some(cmp) },
315 ..
316 } => {
317 self.need_bake_expressions.insert(cmp);
318 }
319 _ => {}
320 }
321 }
322 }
323
324 pub fn write(
325 &mut self,
326 module: &Module,
327 module_info: &valid::ModuleInfo,
328 fragment_entry_point: Option<&FragmentEntryPoint<'_>>,
329 ) -> Result<super::ReflectionInfo, Error> {
330 self.reset(module);
331
332 if module.uses_mesh_shaders() && self.options.shader_model < ShaderModel::V6_5 {
333 return Err(Error::ShaderModelTooLow(
334 "mesh shaders".to_string(),
335 ShaderModel::V6_5,
336 ));
337 }
338
339 if let Some(ref bt) = self.options.special_constants_binding {
341 writeln!(self.out, "struct {SPECIAL_CBUF_TYPE} {{")?;
342 writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_VERTEX)?;
343 writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_INSTANCE)?;
344 writeln!(self.out, "{}uint {};", back::INDENT, SPECIAL_OTHER)?;
345 writeln!(self.out, "}};")?;
346 write!(
347 self.out,
348 "ConstantBuffer<{}> {}: register(b{}",
349 SPECIAL_CBUF_TYPE, SPECIAL_CBUF_VAR, bt.register
350 )?;
351 if bt.space != 0 {
352 write!(self.out, ", space{}", bt.space)?;
353 }
354 writeln!(self.out, ");")?;
355
356 writeln!(self.out)?;
358 }
359
360 for (group, bt) in self.options.dynamic_storage_buffer_offsets_targets.iter() {
361 writeln!(self.out, "struct __dynamic_buffer_offsetsTy{group} {{")?;
362 for i in 0..bt.size {
363 writeln!(self.out, "{}uint _{};", back::INDENT, i)?;
364 }
365 writeln!(self.out, "}};")?;
366 writeln!(
367 self.out,
368 "ConstantBuffer<__dynamic_buffer_offsetsTy{}> __dynamic_buffer_offsets{}: register(b{}, space{});",
369 group, group, bt.register, bt.space
370 )?;
371
372 writeln!(self.out)?;
374 }
375
376 let ep_results = module
378 .entry_points
379 .iter()
380 .map(|ep| (ep.stage, ep.function.result.clone()))
381 .collect::<Vec<(ShaderStage, Option<crate::FunctionResult>)>>();
382
383 self.write_all_mat_cx2_typedefs_and_functions(module)?;
384
385 for (handle, ty) in module.types.iter() {
387 if let TypeInner::Struct { ref members, span } = ty.inner {
388 if module.types[members.last().unwrap().ty]
389 .inner
390 .is_dynamically_sized(&module.types)
391 {
392 continue;
395 }
396
397 let ep_result = ep_results.iter().find(|e| {
398 if let Some(ref result) = e.1 {
399 result.ty == handle
400 } else {
401 false
402 }
403 });
404
405 self.write_struct(
406 module,
407 handle,
408 members,
409 span,
410 ep_result.map(|r| (r.0, Io::Output)),
411 )?;
412 writeln!(self.out)?;
413 }
414 }
415
416 self.write_special_functions(module)?;
417
418 self.write_wrapped_expression_functions(module, &module.global_expressions, None)?;
419 self.write_wrapped_zero_value_functions(module, &module.global_expressions)?;
420
421 let mut constants = module
423 .constants
424 .iter()
425 .filter(|&(_, c)| c.name.is_some())
426 .peekable();
427 while let Some((handle, _)) = constants.next() {
428 self.write_global_constant(module, handle)?;
429 if constants.peek().is_none() {
431 writeln!(self.out)?;
432 }
433 }
434
435 for (global, _) in module.global_variables.iter() {
437 self.write_global(module, global)?;
438 }
439
440 if !module.global_variables.is_empty() {
441 writeln!(self.out)?;
443 }
444
445 let ep_range = get_entry_points(module, self.pipeline_options.entry_point.as_ref())
446 .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
447
448 for index in ep_range.clone() {
450 let ep = &module.entry_points[index];
451 let ep_name = self.names[&NameKey::EntryPoint(index as u16)].clone();
452 let ep_io = self.write_ep_interface(module, ep, &ep_name, fragment_entry_point)?;
453 self.entry_point_io.insert(index, ep_io);
454 }
455
456 for (handle, function) in module.functions.iter() {
458 let info = &module_info[handle];
459
460 if !self.options.fake_missing_bindings {
462 if let Some((var_handle, _)) =
463 module
464 .global_variables
465 .iter()
466 .find(|&(var_handle, var)| match var.binding {
467 Some(ref binding) if !info[var_handle].is_empty() => {
468 self.options.resolve_resource_binding(binding).is_err()
469 && self
470 .options
471 .resolve_external_texture_resource_binding(binding)
472 .is_err()
473 }
474 _ => false,
475 })
476 {
477 log::debug!(
478 "Skipping function {:?} (name {:?}) because global {:?} is inaccessible",
479 handle,
480 function.name,
481 var_handle
482 );
483 continue;
484 }
485 }
486
487 let ctx = back::FunctionCtx {
488 ty: back::FunctionType::Function(handle),
489 info,
490 expressions: &function.expressions,
491 named_expressions: &function.named_expressions,
492 };
493 let name = self.names[&NameKey::Function(handle)].clone();
494
495 self.write_wrapped_functions(module, &ctx)?;
496
497 self.write_function(module, name.as_str(), function, &ctx, info, String::new())?;
498
499 writeln!(self.out)?;
500 }
501
502 let mut translated_ep_names = Vec::with_capacity(ep_range.len());
503
504 for index in ep_range {
506 let ep = &module.entry_points[index];
507 let info = module_info.get_entry_point(index);
508
509 if !self.options.fake_missing_bindings {
510 let mut ep_error = None;
511 for (var_handle, var) in module.global_variables.iter() {
512 match var.binding {
513 Some(ref binding) if !info[var_handle].is_empty() => {
514 if let Err(err) = self.options.resolve_resource_binding(binding) {
515 if self
516 .options
517 .resolve_external_texture_resource_binding(binding)
518 .is_err()
519 {
520 ep_error = Some(err);
521 break;
522 }
523 }
524 }
525 _ => {}
526 }
527 }
528 if let Some(err) = ep_error {
529 translated_ep_names.push(Err(err));
530 continue;
531 }
532 }
533
534 let ctx = back::FunctionCtx {
535 ty: back::FunctionType::EntryPoint(index as u16),
536 info,
537 expressions: &ep.function.expressions,
538 named_expressions: &ep.function.named_expressions,
539 };
540
541 self.write_wrapped_functions(module, &ctx)?;
542
543 let mut attribute_string = String::new();
546 if ep.stage.compute_like() {
547 let num_threads = ep.workgroup_size;
549 writeln!(
550 attribute_string,
551 "[numthreads({}, {}, {})]",
552 num_threads[0], num_threads[1], num_threads[2]
553 )?;
554 }
555 if let Some(ref info) = ep.mesh_info {
556 let topology_str = match info.topology {
557 crate::MeshOutputTopology::Points => unreachable!(),
558 crate::MeshOutputTopology::Lines => "line",
559 crate::MeshOutputTopology::Triangles => "triangle",
560 };
561 writeln!(attribute_string, "[outputtopology(\"{topology_str}\")]")?;
562 }
563
564 let name = self.names[&NameKey::EntryPoint(index as u16)].clone();
565 self.write_function(module, &name, &ep.function, &ctx, info, attribute_string)?;
566
567 if index < module.entry_points.len() - 1 {
568 writeln!(self.out)?;
569 }
570
571 translated_ep_names.push(Ok(name));
572 }
573
574 Ok(super::ReflectionInfo {
575 entry_point_names: translated_ep_names,
576 })
577 }
578
579 fn write_modifier(&mut self, binding: &crate::Binding) -> BackendResult {
580 match *binding {
581 crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }) => {
582 write!(self.out, "precise ")?;
583 }
584 crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { perspective: false }) => {
585 write!(self.out, "noperspective ")?;
586 }
587 crate::Binding::Location {
588 interpolation,
589 sampling,
590 ..
591 } => {
592 if let Some(interpolation) = interpolation {
593 if let Some(string) = interpolation.to_hlsl_str() {
594 write!(self.out, "{string} ")?
595 }
596 }
597
598 if let Some(sampling) = sampling {
599 if let Some(string) = sampling.to_hlsl_str() {
600 write!(self.out, "{string} ")?
601 }
602 }
603 }
604 crate::Binding::BuiltIn(_) => {}
605 }
606
607 Ok(())
608 }
609
610 pub(super) fn write_semantic(
613 &mut self,
614 binding: &Option<crate::Binding>,
615 stage: Option<(ShaderStage, Io)>,
616 ) -> BackendResult {
617 let is_per_primitive = match *binding {
618 Some(crate::Binding::BuiltIn(builtin)) if !is_subgroup_builtin_binding(binding) => {
619 if builtin == crate::BuiltIn::ViewIndex
620 && self.options.shader_model < ShaderModel::V6_1
621 {
622 return Err(Error::ShaderModelTooLow(
623 "used @builtin(view_index) or SV_ViewID".to_string(),
624 ShaderModel::V6_1,
625 ));
626 }
627 if let Some(builtin_str) = builtin.to_hlsl_str()? {
628 write!(self.out, " : {builtin_str}")?;
629 }
630 false
631 }
632 Some(crate::Binding::Location {
633 blend_src: Some(1),
634 per_primitive,
635 ..
636 }) => {
637 write!(self.out, " : SV_Target1")?;
638 per_primitive
639 }
640 Some(crate::Binding::Location {
641 location,
642 per_primitive,
643 ..
644 }) => {
645 if stage == Some((ShaderStage::Fragment, Io::Output)) {
646 write!(self.out, " : SV_Target{location}")?;
647 } else {
648 write!(self.out, " : {LOCATION_SEMANTIC}{location}")?;
649 }
650 per_primitive
651 }
652 _ => false,
653 };
654 if is_per_primitive {
655 write!(self.out, " : primitive")?;
656 }
657
658 Ok(())
659 }
660
661 pub(super) fn write_interface_struct(
662 &mut self,
663 module: &Module,
664 shader_stage: (ShaderStage, Io),
665 struct_name: String,
666 var_name: Option<&str>,
667 mut members: Vec<EpStructMember>,
668 ) -> Result<EntryPointBinding, Error> {
669 let struct_name = self.namer.call(&struct_name);
670 members.sort_by_key(|m| InterfaceKey::new(m.binding.as_ref()));
674
675 write!(self.out, "struct {struct_name}")?;
676 writeln!(self.out, " {{")?;
677 let mut local_invocation_index_name = None;
678 let mut subgroup_id_used = false;
679 for m in members.iter() {
680 debug_assert!(m.binding.is_some());
683
684 match m.binding {
685 Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => {
686 subgroup_id_used = true;
687 }
688 Some(crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationIndex)) => {
689 local_invocation_index_name = Some(m.name.clone());
690 }
691 _ => (),
692 }
693
694 if is_subgroup_builtin_binding(&m.binding) {
695 continue;
696 }
697 write!(self.out, "{}", back::INDENT)?;
698 if let Some(ref binding) = m.binding {
699 self.write_modifier(binding)?;
700 }
701 self.write_type(module, m.ty)?;
702 write!(self.out, " {}", &m.name)?;
703 self.write_semantic(&m.binding, Some(shader_stage))?;
704 writeln!(self.out, ";")?;
705 }
706 if subgroup_id_used && local_invocation_index_name.is_none() {
707 let name = self.namer.call("local_invocation_index");
708 writeln!(self.out, "{}uint {name} : SV_GroupIndex;", back::INDENT)?;
709 local_invocation_index_name = Some(name);
710 }
711 writeln!(self.out, "}};")?;
712 writeln!(self.out)?;
713
714 match shader_stage.1 {
716 Io::Input => {
717 members.sort_by_key(|m| m.index);
719 }
720 Io::Output | Io::MeshVertices | Io::MeshPrimitives => {
721 }
723 }
724
725 Ok(EntryPointBinding {
726 arg_name: self
727 .namer
728 .call(var_name.unwrap_or(struct_name.to_lowercase().as_str())),
729 ty_name: struct_name,
730 members,
731 local_invocation_index_name,
732 })
733 }
734
735 fn write_ep_input_struct(
739 &mut self,
740 module: &Module,
741 func: &crate::Function,
742 stage: ShaderStage,
743 entry_point_name: &str,
744 ) -> Result<EntryPointBinding, Error> {
745 let struct_name = format!("{stage:?}Input_{entry_point_name}");
746
747 let mut fake_members = Vec::new();
748 for arg in func.arguments.iter() {
749 match module.types[arg.ty].inner {
754 TypeInner::Struct { ref members, .. } => {
755 for member in members.iter() {
756 let name = self.namer.call_or(&member.name, "member");
757 let index = fake_members.len() as u32;
758 fake_members.push(EpStructMember {
759 name,
760 ty: member.ty,
761 binding: member.binding.clone(),
762 index,
763 });
764 }
765 }
766 _ => {
767 let member_name = self.namer.call_or(&arg.name, "member");
768 let index = fake_members.len() as u32;
769 fake_members.push(EpStructMember {
770 name: member_name,
771 ty: arg.ty,
772 binding: arg.binding.clone(),
773 index,
774 });
775 }
776 }
777 }
778
779 self.write_interface_struct(module, (stage, Io::Input), struct_name, None, fake_members)
780 }
781
782 fn write_ep_output_struct(
786 &mut self,
787 module: &Module,
788 result: &crate::FunctionResult,
789 stage: ShaderStage,
790 entry_point_name: &str,
791 frag_ep: Option<&FragmentEntryPoint<'_>>,
792 ) -> Result<EntryPointBinding, Error> {
793 let struct_name = format!("{stage:?}Output_{entry_point_name}");
794
795 let empty = [];
796 let members = match module.types[result.ty].inner {
797 TypeInner::Struct { ref members, .. } => members,
798 ref other => {
799 log::error!("Unexpected {other:?} output type without a binding");
800 &empty[..]
801 }
802 };
803
804 let fs_input_locs = if let (Some(frag_ep), ShaderStage::Vertex) = (frag_ep, stage) {
809 let mut fs_input_locs = Vec::new();
810 for arg in frag_ep.func.arguments.iter() {
811 let mut push_if_location = |binding: &Option<crate::Binding>| match *binding {
812 Some(crate::Binding::Location { location, .. }) => fs_input_locs.push(location),
813 Some(crate::Binding::BuiltIn(_)) | None => {}
814 };
815
816 match frag_ep.module.types[arg.ty].inner {
819 TypeInner::Struct { ref members, .. } => {
820 for member in members.iter() {
821 push_if_location(&member.binding);
822 }
823 }
824 _ => push_if_location(&arg.binding),
825 }
826 }
827 fs_input_locs.sort();
828 Some(fs_input_locs)
829 } else {
830 None
831 };
832
833 let mut fake_members = Vec::new();
834 for (index, member) in members.iter().enumerate() {
835 if let Some(ref fs_input_locs) = fs_input_locs {
836 match member.binding {
837 Some(crate::Binding::Location { location, .. }) => {
838 if fs_input_locs.binary_search(&location).is_err() {
839 continue;
840 }
841 }
842 Some(crate::Binding::BuiltIn(_)) | None => {}
843 }
844 }
845
846 let member_name = self.namer.call_or(&member.name, "member");
847 fake_members.push(EpStructMember {
848 name: member_name,
849 ty: member.ty,
850 binding: member.binding.clone(),
851 index: index as u32,
852 });
853 }
854
855 self.write_interface_struct(module, (stage, Io::Output), struct_name, None, fake_members)
856 }
857
858 fn write_ep_interface(
862 &mut self,
863 module: &Module,
864 ep: &crate::EntryPoint,
865 ep_name: &str,
866 frag_ep: Option<&FragmentEntryPoint<'_>>,
867 ) -> Result<EntryPointInterface, Error> {
868 let func = &ep.function;
869 let stage = ep.stage;
870 Ok(EntryPointInterface {
871 input: if !func.arguments.is_empty()
872 && (stage == ShaderStage::Fragment
873 || func
874 .arguments
875 .iter()
876 .any(|arg| is_subgroup_builtin_binding(&arg.binding)))
877 {
878 Some(self.write_ep_input_struct(module, func, stage, ep_name)?)
879 } else {
880 None
881 },
882 output: match func.result {
883 Some(ref fr) if fr.binding.is_none() && stage == ShaderStage::Vertex => {
884 Some(self.write_ep_output_struct(module, fr, stage, ep_name, frag_ep)?)
885 }
886 _ => None,
887 },
888 mesh_vertices: if let Some(ref info) = ep.mesh_info {
889 Some(self.write_ep_mesh_output_struct(module, ep_name, false, info)?)
890 } else {
891 None
892 },
893 mesh_primitives: if let Some(ref info) = ep.mesh_info {
894 Some(self.write_ep_mesh_output_struct(module, ep_name, true, info)?)
895 } else {
896 None
897 },
898 mesh_indices: if let Some(ref info) = ep.mesh_info {
899 Some(self.write_ep_mesh_output_indices(info.topology)?)
900 } else {
901 None
902 },
903 })
904 }
905
906 fn write_ep_argument_initialization(
907 &mut self,
908 ep: &crate::EntryPoint,
909 ep_input: &EntryPointBinding,
910 fake_member: &EpStructMember,
911 ) -> BackendResult {
912 match fake_member.binding {
913 Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupSize)) => {
914 write!(self.out, "WaveGetLaneCount()")?
915 }
916 Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupInvocationId)) => {
917 write!(self.out, "WaveGetLaneIndex()")?
918 }
919 Some(crate::Binding::BuiltIn(crate::BuiltIn::NumSubgroups)) => write!(
920 self.out,
921 "({}u + WaveGetLaneCount() - 1u) / WaveGetLaneCount()",
922 ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2]
923 )?,
924 Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => {
925 write!(
926 self.out,
927 "{}.{} / WaveGetLaneCount()",
928 ep_input.arg_name,
929 ep_input.local_invocation_index_name.as_ref().unwrap()
931 )?;
932 }
933 Some(crate::Binding::Location {
934 interpolation: Some(crate::Interpolation::PerVertex),
935 ..
936 }) => {
937 if self.options.shader_model < ShaderModel::V6_1 {
938 return Err(Error::ShaderModelTooLow(
939 "per_vertex fragment inputs".to_string(),
940 ShaderModel::V6_1,
941 ));
942 }
943 write!(
944 self.out,
945 "{{ GetAttributeAtVertex({0}.{1}, 0), GetAttributeAtVertex({0}.{1}, 1), GetAttributeAtVertex({0}.{1}, 2) }}",
946 ep_input.arg_name,
947 fake_member.name,
948 )?;
949 }
950 _ => {
951 write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?;
952 }
953 }
954 Ok(())
955 }
956
957 fn write_ep_arguments_initialization(
959 &mut self,
960 module: &Module,
961 func: &crate::Function,
962 ep_index: u16,
963 ) -> BackendResult {
964 let ep = &module.entry_points[ep_index as usize];
965 let ep_input = match self
966 .entry_point_io
967 .get_mut(&(ep_index as usize))
968 .unwrap()
969 .input
970 .take()
971 {
972 Some(ep_input) => ep_input,
973 None => return Ok(()),
974 };
975 let mut fake_iter = ep_input.members.iter();
976 for (arg_index, arg) in func.arguments.iter().enumerate() {
977 write!(self.out, "{}", back::INDENT)?;
978 self.write_type(module, arg.ty)?;
979 let arg_name = &self.names[&NameKey::EntryPointArgument(ep_index, arg_index as u32)];
980 write!(self.out, " {arg_name}")?;
981 match module.types[arg.ty].inner {
982 TypeInner::Array { base, size, .. } => {
983 self.write_array_size(module, base, size)?;
984 write!(self.out, " = ")?;
985 self.write_ep_argument_initialization(
986 ep,
987 &ep_input,
988 fake_iter.next().unwrap(),
989 )?;
990 writeln!(self.out, ";")?;
991 }
992 TypeInner::Struct { ref members, .. } => {
993 write!(self.out, " = {{ ")?;
994 for index in 0..members.len() {
995 if index != 0 {
996 write!(self.out, ", ")?;
997 }
998 self.write_ep_argument_initialization(
999 ep,
1000 &ep_input,
1001 fake_iter.next().unwrap(),
1002 )?;
1003 }
1004 writeln!(self.out, " }};")?;
1005 }
1006 _ => {
1007 write!(self.out, " = ")?;
1008 self.write_ep_argument_initialization(
1009 ep,
1010 &ep_input,
1011 fake_iter.next().unwrap(),
1012 )?;
1013 writeln!(self.out, ";")?;
1014 }
1015 }
1016 }
1017 assert!(fake_iter.next().is_none());
1018 Ok(())
1019 }
1020
1021 fn write_global(
1025 &mut self,
1026 module: &Module,
1027 handle: Handle<crate::GlobalVariable>,
1028 ) -> BackendResult {
1029 let global = &module.global_variables[handle];
1030 let inner = &module.types[global.ty].inner;
1031
1032 let handle_ty = match *inner {
1033 TypeInner::BindingArray { ref base, .. } => &module.types[*base].inner,
1034 _ => inner,
1035 };
1036
1037 let is_external_texture = matches!(
1041 *handle_ty,
1042 TypeInner::Image {
1043 class: crate::ImageClass::External,
1044 ..
1045 }
1046 );
1047 if is_external_texture {
1048 return self.write_global_external_texture(module, handle, global);
1049 }
1050
1051 if let Some(ref binding) = global.binding {
1052 if let Err(err) = self.options.resolve_resource_binding(binding) {
1053 log::debug!(
1054 "Skipping global {:?} (name {:?}) for being inaccessible: {}",
1055 handle,
1056 global.name,
1057 err,
1058 );
1059 return Ok(());
1060 }
1061 }
1062
1063 let is_sampler = matches!(*handle_ty, TypeInner::Sampler { .. });
1065
1066 if is_sampler {
1067 return self.write_global_sampler(module, handle, global);
1068 }
1069
1070 let register_ty = match global.space {
1072 crate::AddressSpace::Function => unreachable!("Function address space"),
1073 crate::AddressSpace::Private => {
1074 write!(self.out, "static ")?;
1075 self.write_type(module, global.ty)?;
1076 ""
1077 }
1078 crate::AddressSpace::WorkGroup | crate::AddressSpace::TaskPayload => {
1079 write!(self.out, "groupshared ")?;
1080 self.write_type(module, global.ty)?;
1081 ""
1082 }
1083 crate::AddressSpace::Uniform => {
1084 write!(self.out, "cbuffer")?;
1087 "b"
1088 }
1089 crate::AddressSpace::Storage { access } => {
1090 if global
1091 .memory_decorations
1092 .contains(crate::MemoryDecorations::COHERENT)
1093 {
1094 write!(self.out, "globallycoherent ")?;
1095 }
1096 let (prefix, register) = if access.contains(crate::StorageAccess::STORE) {
1097 ("RW", "u")
1098 } else {
1099 ("", "t")
1100 };
1101 write!(self.out, "{prefix}ByteAddressBuffer")?;
1102 register
1103 }
1104 crate::AddressSpace::Handle => {
1105 let register = match *handle_ty {
1106 TypeInner::Image {
1108 class: crate::ImageClass::Storage { .. },
1109 ..
1110 } => "u",
1111 _ => "t",
1112 };
1113 self.write_type(module, global.ty)?;
1114 register
1115 }
1116 crate::AddressSpace::Immediate => {
1117 write!(self.out, "ConstantBuffer<")?;
1119 "b"
1120 }
1121 crate::AddressSpace::RayPayload | crate::AddressSpace::IncomingRayPayload => {
1122 unimplemented!()
1123 }
1124 };
1125
1126 if global.space == crate::AddressSpace::Immediate {
1129 self.write_global_type(module, global.ty)?;
1130
1131 if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1133 self.write_array_size(module, base, size)?;
1134 }
1135
1136 write!(self.out, ">")?;
1138 }
1139
1140 let name = &self.names[&NameKey::GlobalVariable(handle)];
1141 write!(self.out, " {name}")?;
1142
1143 if global.space == crate::AddressSpace::Immediate {
1146 match module.types[global.ty].inner {
1147 TypeInner::Struct { .. } => {}
1148 _ => {
1149 return Err(Error::Unimplemented(format!(
1150 "push-constant '{name}' has non-struct type; tracked by: https://github.com/gfx-rs/wgpu/issues/5683"
1151 )));
1152 }
1153 }
1154
1155 let target = self
1156 .options
1157 .immediates_target
1158 .as_ref()
1159 .expect("No bind target was defined for the immediates block");
1160 write!(self.out, ": register(b{}", target.register)?;
1161 if target.space != 0 {
1162 write!(self.out, ", space{}", target.space)?;
1163 }
1164 write!(self.out, ")")?;
1165 }
1166
1167 if let Some(ref binding) = global.binding {
1168 let bt = self.options.resolve_resource_binding(binding).unwrap();
1170
1171 if let TypeInner::BindingArray { base, size, .. } = module.types[global.ty].inner {
1173 if let Some(overridden_size) = bt.binding_array_size {
1174 write!(self.out, "[{overridden_size}]")?;
1175 } else {
1176 self.write_array_size(module, base, size)?;
1177 }
1178 }
1179
1180 write!(self.out, " : register({}{}", register_ty, bt.register)?;
1181 if bt.space != 0 {
1182 write!(self.out, ", space{}", bt.space)?;
1183 }
1184 write!(self.out, ")")?;
1185 } else {
1186 if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1188 self.write_array_size(module, base, size)?;
1189 }
1190 if global.space == crate::AddressSpace::Private {
1191 write!(self.out, " = ")?;
1192 if let Some(init) = global.init {
1193 self.write_const_expression(module, init, &module.global_expressions)?;
1194 } else {
1195 self.write_default_init(module, global.ty)?;
1196 }
1197 }
1198 }
1199
1200 if global.space == crate::AddressSpace::Uniform {
1201 write!(self.out, " {{ ")?;
1202
1203 self.write_global_type(module, global.ty)?;
1204
1205 write!(
1206 self.out,
1207 " {}",
1208 &self.names[&NameKey::GlobalVariable(handle)]
1209 )?;
1210
1211 if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1213 self.write_array_size(module, base, size)?;
1214 }
1215
1216 writeln!(self.out, "; }}")?;
1217 } else {
1218 writeln!(self.out, ";")?;
1219 }
1220
1221 Ok(())
1222 }
1223
1224 fn write_global_sampler(
1225 &mut self,
1226 module: &Module,
1227 handle: Handle<crate::GlobalVariable>,
1228 global: &crate::GlobalVariable,
1229 ) -> BackendResult {
1230 let binding = *global.binding.as_ref().unwrap();
1231
1232 let key = super::SamplerIndexBufferKey {
1233 group: binding.group,
1234 };
1235 self.write_wrapped_sampler_buffer(key)?;
1236
1237 let bt = self.options.resolve_resource_binding(&binding).unwrap();
1239
1240 match module.types[global.ty].inner {
1241 TypeInner::Sampler { comparison } => {
1242 write!(self.out, "static const ")?;
1249 self.write_type(module, global.ty)?;
1250
1251 let heap_var = if comparison {
1252 COMPARISON_SAMPLER_HEAP_VAR
1253 } else {
1254 SAMPLER_HEAP_VAR
1255 };
1256
1257 let index_buffer_name = &self.wrapped.sampler_index_buffers[&key];
1258 let name = &self.names[&NameKey::GlobalVariable(handle)];
1259 writeln!(
1260 self.out,
1261 " {name} = {heap_var}[{index_buffer_name}[{register}]];",
1262 register = bt.register
1263 )?;
1264 }
1265 TypeInner::BindingArray { .. } => {
1266 let name = &self.names[&NameKey::GlobalVariable(handle)];
1272 writeln!(
1273 self.out,
1274 "static const uint {name} = {register};",
1275 register = bt.register
1276 )?;
1277 }
1278 _ => unreachable!(),
1279 };
1280
1281 Ok(())
1282 }
1283
1284 fn write_global_external_texture(
1288 &mut self,
1289 module: &Module,
1290 handle: Handle<crate::GlobalVariable>,
1291 global: &crate::GlobalVariable,
1292 ) -> BackendResult {
1293 let res_binding = global
1294 .binding
1295 .as_ref()
1296 .expect("External texture global variables must have a resource binding");
1297 let ext_tex_bindings = match self
1298 .options
1299 .resolve_external_texture_resource_binding(res_binding)
1300 {
1301 Ok(bindings) => bindings,
1302 Err(err) => {
1303 log::debug!(
1304 "Skipping global {:?} (name {:?}) for being inaccessible: {}",
1305 handle,
1306 global.name,
1307 err,
1308 );
1309 return Ok(());
1310 }
1311 };
1312
1313 let mut write_plane = |bt: &super::BindTarget, name| -> BackendResult {
1314 write!(
1315 self.out,
1316 "Texture2D<float4> {}: register(t{}",
1317 name, bt.register
1318 )?;
1319 if bt.space != 0 {
1320 write!(self.out, ", space{}", bt.space)?;
1321 }
1322 writeln!(self.out, ");")?;
1323 Ok(())
1324 };
1325 for (i, bt) in ext_tex_bindings.planes.iter().enumerate() {
1326 let plane_name = &self.names
1327 [&NameKey::ExternalTextureGlobalVariable(handle, ExternalTextureNameKey::Plane(i))];
1328 write_plane(bt, plane_name)?;
1329 }
1330
1331 let params_name = &self.names
1332 [&NameKey::ExternalTextureGlobalVariable(handle, ExternalTextureNameKey::Params)];
1333 let params_ty_name =
1334 &self.names[&NameKey::Type(module.special_types.external_texture_params.unwrap())];
1335 write!(
1336 self.out,
1337 "cbuffer {}: register(b{}",
1338 params_name, ext_tex_bindings.params.register
1339 )?;
1340 if ext_tex_bindings.params.space != 0 {
1341 write!(self.out, ", space{}", ext_tex_bindings.params.space)?;
1342 }
1343 writeln!(self.out, ") {{ {params_ty_name} {params_name}; }};")?;
1344
1345 Ok(())
1346 }
1347
1348 fn write_global_constant(
1353 &mut self,
1354 module: &Module,
1355 handle: Handle<crate::Constant>,
1356 ) -> BackendResult {
1357 write!(self.out, "static const ")?;
1358 let constant = &module.constants[handle];
1359 self.write_type(module, constant.ty)?;
1360 let name = &self.names[&NameKey::Constant(handle)];
1361 write!(self.out, " {name}")?;
1362 if let TypeInner::Array { base, size, .. } = module.types[constant.ty].inner {
1364 self.write_array_size(module, base, size)?;
1365 }
1366 write!(self.out, " = ")?;
1367 self.write_const_expression(module, constant.init, &module.global_expressions)?;
1368 writeln!(self.out, ";")?;
1369 Ok(())
1370 }
1371
1372 pub(super) fn write_array_size(
1373 &mut self,
1374 module: &Module,
1375 base: Handle<crate::Type>,
1376 size: crate::ArraySize,
1377 ) -> BackendResult {
1378 write!(self.out, "[")?;
1379
1380 match size.resolve(module.to_ctx())? {
1381 proc::IndexableLength::Known(size) => {
1382 write!(self.out, "{size}")?;
1383 }
1384 proc::IndexableLength::Dynamic => unreachable!(),
1385 }
1386
1387 write!(self.out, "]")?;
1388
1389 if let TypeInner::Array {
1390 base: next_base,
1391 size: next_size,
1392 ..
1393 } = module.types[base].inner
1394 {
1395 self.write_array_size(module, next_base, next_size)?;
1396 }
1397
1398 Ok(())
1399 }
1400
1401 fn write_struct(
1406 &mut self,
1407 module: &Module,
1408 handle: Handle<crate::Type>,
1409 members: &[crate::StructMember],
1410 span: u32,
1411 shader_stage: Option<(ShaderStage, Io)>,
1412 ) -> BackendResult {
1413 let struct_name = &self.names[&NameKey::Type(handle)];
1415 writeln!(self.out, "struct {struct_name} {{")?;
1416
1417 let mut last_offset = 0;
1418 for (index, member) in members.iter().enumerate() {
1419 if member.binding.is_none() && member.offset > last_offset {
1420 let padding = (member.offset - last_offset) / 4;
1424 for i in 0..padding {
1425 writeln!(self.out, "{}int _pad{}_{};", back::INDENT, index, i)?;
1426 }
1427 }
1428 let ty_inner = &module.types[member.ty].inner;
1429 last_offset = member.offset + ty_inner.size_hlsl(module.to_ctx())?;
1430
1431 write!(self.out, "{}", back::INDENT)?;
1433
1434 match module.types[member.ty].inner {
1435 TypeInner::Array { base, size, .. } => {
1436 self.write_global_type(module, member.ty)?;
1439
1440 write!(
1442 self.out,
1443 " {}",
1444 &self.names[&NameKey::StructMember(handle, index as u32)]
1445 )?;
1446 self.write_array_size(module, base, size)?;
1448 }
1449 TypeInner::Matrix {
1452 rows,
1453 columns,
1454 scalar,
1455 } if member.binding.is_none() && rows == crate::VectorSize::Bi => {
1456 let vec_ty = TypeInner::Vector { size: rows, scalar };
1457 let field_name_key = NameKey::StructMember(handle, index as u32);
1458
1459 for i in 0..columns as u8 {
1460 if i != 0 {
1461 write!(self.out, "; ")?;
1462 }
1463 self.write_value_type(module, &vec_ty)?;
1464 write!(self.out, " {}_{}", &self.names[&field_name_key], i)?;
1465 }
1466 }
1467 _ => {
1468 if let Some(ref binding) = member.binding {
1470 self.write_modifier(binding)?;
1471 }
1472
1473 if let TypeInner::Matrix { .. } = module.types[member.ty].inner {
1477 write!(self.out, "row_major ")?;
1478 }
1479
1480 self.write_type(module, member.ty)?;
1482 write!(
1483 self.out,
1484 " {}",
1485 &self.names[&NameKey::StructMember(handle, index as u32)]
1486 )?;
1487 }
1488 }
1489
1490 self.write_semantic(&member.binding, shader_stage)?;
1491 writeln!(self.out, ";")?;
1492 }
1493
1494 if members.last().unwrap().binding.is_none() && span > last_offset {
1496 let padding = (span - last_offset) / 4;
1497 for i in 0..padding {
1498 writeln!(self.out, "{}int _end_pad_{};", back::INDENT, i)?;
1499 }
1500 }
1501
1502 writeln!(self.out, "}};")?;
1503 Ok(())
1504 }
1505
1506 pub(super) fn write_global_type(
1511 &mut self,
1512 module: &Module,
1513 ty: Handle<crate::Type>,
1514 ) -> BackendResult {
1515 let matrix_data = get_inner_matrix_data(module, ty);
1516
1517 if let Some(MatrixType {
1520 columns,
1521 rows: crate::VectorSize::Bi,
1522 width,
1523 }) = matrix_data
1524 {
1525 write!(self.out, "__mat{}x2_f{}", columns as u8, width * 8)?;
1526 } else {
1527 if matrix_data.is_some() {
1531 write!(self.out, "row_major ")?;
1532 }
1533
1534 self.write_type(module, ty)?;
1535 }
1536
1537 Ok(())
1538 }
1539
1540 pub(super) fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
1545 let inner = &module.types[ty].inner;
1546 match *inner {
1547 TypeInner::Struct { .. } => write!(self.out, "{}", self.names[&NameKey::Type(ty)])?,
1548 TypeInner::Array { base, .. } | TypeInner::BindingArray { base, .. } => {
1550 self.write_type(module, base)?
1551 }
1552 ref other => self.write_value_type(module, other)?,
1553 }
1554
1555 Ok(())
1556 }
1557
1558 pub(super) fn write_value_type(&mut self, module: &Module, inner: &TypeInner) -> BackendResult {
1563 match *inner {
1564 TypeInner::Scalar(scalar) | TypeInner::Atomic(scalar) => {
1565 write!(self.out, "{}", scalar.to_hlsl_str()?)?;
1566 }
1567 TypeInner::Vector { size, scalar } => {
1568 write!(
1569 self.out,
1570 "{}{}",
1571 scalar.to_hlsl_str()?,
1572 common::vector_size_str(size)
1573 )?;
1574 }
1575 TypeInner::Matrix {
1576 columns,
1577 rows,
1578 scalar,
1579 } => {
1580 write!(
1585 self.out,
1586 "{}{}x{}",
1587 scalar.to_hlsl_str()?,
1588 common::vector_size_str(columns),
1589 common::vector_size_str(rows),
1590 )?;
1591 }
1592 TypeInner::Image {
1593 dim,
1594 arrayed,
1595 class,
1596 } => {
1597 self.write_image_type(dim, arrayed, class)?;
1598 }
1599 TypeInner::Sampler { comparison } => {
1600 let sampler = if comparison {
1601 "SamplerComparisonState"
1602 } else {
1603 "SamplerState"
1604 };
1605 write!(self.out, "{sampler}")?;
1606 }
1607 TypeInner::Array { base, size, .. } | TypeInner::BindingArray { base, size } => {
1611 self.write_array_size(module, base, size)?;
1612 }
1613 TypeInner::AccelerationStructure { .. } => {
1614 write!(self.out, "RaytracingAccelerationStructure")?;
1615 }
1616 TypeInner::RayQuery { .. } => {
1617 write!(self.out, "RayQuery<RAY_FLAG_NONE>")?;
1619 }
1620 _ => return Err(Error::Unimplemented(format!("write_value_type {inner:?}"))),
1621 }
1622
1623 Ok(())
1624 }
1625
1626 fn write_function(
1630 &mut self,
1631 module: &Module,
1632 name: &str,
1633 func: &crate::Function,
1634 func_ctx: &back::FunctionCtx<'_>,
1635 info: &valid::FunctionInfo,
1636 header: String,
1637 ) -> BackendResult {
1638 self.update_expressions_to_bake(module, func, info);
1641 let ep = match func_ctx.ty {
1642 back::FunctionType::EntryPoint(idx) => Some(&module.entry_points[idx as usize]),
1643 back::FunctionType::Function(_) => None,
1644 };
1645
1646 let nested = matches!(
1647 ep,
1648 Some(crate::EntryPoint {
1649 stage: ShaderStage::Task | ShaderStage::Mesh,
1650 ..
1651 })
1652 );
1653 if !nested {
1654 write!(self.out, "{header}")?;
1655 }
1656
1657 if let Some(ref result) = func.result {
1658 let array_return_type = match module.types[result.ty].inner {
1660 TypeInner::Array { base, size, .. } => {
1661 let array_return_type = self.namer.call(&format!("ret_{name}"));
1662 write!(self.out, "typedef ")?;
1663 self.write_type(module, result.ty)?;
1664 write!(self.out, " {array_return_type}")?;
1665 self.write_array_size(module, base, size)?;
1666 writeln!(self.out, ";")?;
1667 Some(array_return_type)
1668 }
1669 _ => None,
1670 };
1671
1672 if let Some(
1674 ref binding @ crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }),
1675 ) = result.binding
1676 {
1677 self.write_modifier(binding)?;
1678 }
1679
1680 match func_ctx.ty {
1682 back::FunctionType::Function(_) => {
1683 if let Some(array_return_type) = array_return_type {
1684 write!(self.out, "{array_return_type}")?;
1685 } else {
1686 self.write_type(module, result.ty)?;
1687 }
1688 }
1689 back::FunctionType::EntryPoint(index) => {
1690 if let Some(ref ep_output) =
1691 self.entry_point_io.get(&(index as usize)).unwrap().output
1692 {
1693 write!(self.out, "{}", ep_output.ty_name)?;
1694 } else {
1695 self.write_type(module, result.ty)?;
1696 }
1697 }
1698 }
1699 } else {
1700 write!(self.out, "void")?;
1701 }
1702
1703 let nested_name = if nested {
1704 self.namer.call(&format!("_{name}"))
1705 } else {
1706 name.to_string()
1707 };
1708
1709 write!(self.out, " {nested_name}(")?;
1711
1712 let need_workgroup_variables_initialization =
1713 self.need_workgroup_variables_initialization(func_ctx, module);
1714
1715 let mut any_args_written = false;
1716 let mut separator = || {
1717 if any_args_written {
1718 ", "
1719 } else {
1720 any_args_written = true;
1721 ""
1722 }
1723 };
1724
1725 let needs_local_invocation_index_name = need_workgroup_variables_initialization || nested;
1726 let mut local_invocation_index_name = None;
1727 let mut nested_wgsl_args: Vec<String> = Vec::new();
1730 let mut nested_task_payload_name: Option<String> = None;
1731 match func_ctx.ty {
1733 back::FunctionType::Function(handle) => {
1734 for (index, arg) in func.arguments.iter().enumerate() {
1735 write!(self.out, "{}", separator())?;
1736 self.write_function_argument(module, handle, arg, index)?;
1737 }
1738 for (var_handle, var) in module.global_variables.iter() {
1740 let uses = info[var_handle];
1741 if uses.contains(valid::GlobalUse::READ)
1742 && !uses.contains(valid::GlobalUse::WRITE)
1743 && var.space == crate::AddressSpace::TaskPayload
1744 {
1745 self.function_task_payload_var.insert(handle, var_handle);
1746 write!(self.out, "{}in ", separator())?;
1747
1748 self.write_type(module, var.ty)?;
1749 let name = &self.names[&NameKey::GlobalVariable(var_handle)];
1750 write!(self.out, " {name}")?;
1751 break;
1752 }
1753 }
1754 }
1755 back::FunctionType::EntryPoint(ep_index) => {
1756 let ep = &module.entry_points[ep_index as usize];
1757 if let Some(ref ep_input) =
1758 self.entry_point_io.get(&(ep_index as usize)).unwrap().input
1759 {
1760 write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name)?;
1761 separator();
1762 nested_wgsl_args.push(ep_input.arg_name.clone());
1763 } else {
1764 let stage = ep.stage;
1765 for (index, arg) in func.arguments.iter().enumerate() {
1766 write!(self.out, "{}", separator())?;
1767 self.write_type(module, arg.ty)?;
1768
1769 let argument_name =
1770 &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)];
1771
1772 if arg.binding
1773 == Some(crate::Binding::BuiltIn(
1774 crate::BuiltIn::LocalInvocationIndex,
1775 ))
1776 {
1777 local_invocation_index_name = Some(argument_name.clone());
1778 }
1779
1780 nested_wgsl_args.push(argument_name.clone());
1781 write!(self.out, " {argument_name}")?;
1782 if let TypeInner::Array { base, size, .. } = module.types[arg.ty].inner {
1783 self.write_array_size(module, base, size)?;
1784 }
1785
1786 self.write_semantic(&arg.binding, Some((stage, Io::Input)))?;
1787 }
1788 }
1789 if ep.stage == ShaderStage::Mesh {
1790 if let Some(var_handle) = ep.task_payload {
1791 let var = &module.global_variables[var_handle];
1792 write!(self.out, "{}in ", separator())?;
1793 self.write_type(module, var.ty)?;
1794 let arg_name = &self.names[&NameKey::GlobalVariable(var_handle)];
1795 write!(self.out, " {arg_name}")?;
1796 nested_task_payload_name = Some(arg_name.clone());
1797 if let TypeInner::Array { base, size, .. } = module.types[var.ty].inner {
1798 self.write_array_size(module, base, size)?;
1799 }
1800 }
1801 }
1802 if needs_local_invocation_index_name && local_invocation_index_name.is_none() {
1803 let name = self.namer.call("local_invocation_index");
1804 write!(self.out, "{}uint {name}", separator())?;
1805 write!(self.out, " : SV_GroupIndex")?;
1806 local_invocation_index_name = Some(name);
1807 }
1808 }
1809 }
1810 write!(self.out, ")")?;
1812
1813 if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1815 let stage = module.entry_points[index as usize].stage;
1816 if let Some(crate::FunctionResult { ref binding, .. }) = func.result {
1817 self.write_semantic(binding, Some((stage, Io::Output)))?;
1818 }
1819 }
1820
1821 writeln!(self.out)?;
1823 writeln!(self.out, "{{")?;
1824
1825 if need_workgroup_variables_initialization && !nested {
1826 let back::FunctionType::EntryPoint(index) = func_ctx.ty else {
1827 unreachable!();
1828 };
1829 writeln!(
1830 self.out,
1831 "{}if ({} == 0) {{",
1832 back::INDENT,
1833 local_invocation_index_name.as_ref().unwrap(),
1836 )?;
1837 self.write_workgroup_variables_initialization(
1838 func_ctx,
1839 module,
1840 module.entry_points[index as usize].stage,
1841 )?;
1842
1843 writeln!(self.out, "{}}}", back::INDENT)?;
1844 self.write_control_barrier(crate::Barrier::WORK_GROUP, back::Level(1))?;
1845 }
1846
1847 if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1848 self.write_ep_arguments_initialization(module, func, index)?;
1849 }
1850
1851 for (handle, local) in func.local_variables.iter() {
1853 write!(self.out, "{}", back::INDENT)?;
1855
1856 self.write_type(module, local.ty)?;
1859 write!(self.out, " {}", self.names[&func_ctx.name_key(handle)])?;
1860 if let TypeInner::Array { base, size, .. } = module.types[local.ty].inner {
1862 self.write_array_size(module, base, size)?;
1863 }
1864
1865 let is_ray_query = match module.types[local.ty].inner {
1866 TypeInner::RayQuery { .. } => true,
1868 _ => {
1869 write!(self.out, " = ")?;
1870 if let Some(init) = local.init {
1872 self.write_expr(module, init, func_ctx)?;
1873 } else {
1874 self.write_default_init(module, local.ty)?;
1876 }
1877 false
1878 }
1879 };
1880 writeln!(self.out, ";")?;
1882 if is_ray_query {
1884 write!(self.out, "{}", back::INDENT)?;
1885 self.write_value_type(module, &TypeInner::Scalar(Scalar::U32))?;
1886 writeln!(
1887 self.out,
1888 " {RAY_QUERY_TRACKER_VARIABLE_PREFIX}{} = 0;",
1889 self.names[&func_ctx.name_key(handle)]
1890 )?;
1891 }
1892 }
1893
1894 if !func.local_variables.is_empty() {
1895 writeln!(self.out)?;
1896 }
1897
1898 for sta in func.body.iter() {
1900 self.write_stmt(module, sta, func_ctx, back::Level(1))?;
1902 }
1903
1904 writeln!(self.out, "}}")?;
1905
1906 if nested {
1907 self.write_nested_function_outer(
1908 module,
1909 func_ctx,
1910 &header,
1911 name,
1912 need_workgroup_variables_initialization,
1913 &nested_name,
1914 ep.unwrap(),
1915 NestedEntryPointArgs {
1916 user_args: nested_wgsl_args,
1917 task_payload: nested_task_payload_name,
1918 local_invocation_index: local_invocation_index_name.unwrap(),
1920 },
1921 )?;
1922 }
1923
1924 self.named_expressions.clear();
1925
1926 Ok(())
1927 }
1928
1929 fn write_function_argument(
1930 &mut self,
1931 module: &Module,
1932 handle: Handle<crate::Function>,
1933 arg: &crate::FunctionArgument,
1934 index: usize,
1935 ) -> BackendResult {
1936 if let TypeInner::Image {
1939 class: crate::ImageClass::External,
1940 ..
1941 } = module.types[arg.ty].inner
1942 {
1943 return self.write_function_external_texture_argument(module, handle, index);
1944 }
1945
1946 let arg_ty = match module.types[arg.ty].inner {
1948 TypeInner::Pointer { base, .. } => {
1950 write!(self.out, "inout ")?;
1952 base
1953 }
1954 _ => arg.ty,
1955 };
1956 self.write_type(module, arg_ty)?;
1957
1958 let argument_name = &self.names[&NameKey::FunctionArgument(handle, index as u32)];
1959
1960 write!(self.out, " {argument_name}")?;
1962 if let TypeInner::Array { base, size, .. } = module.types[arg_ty].inner {
1963 self.write_array_size(module, base, size)?;
1964 }
1965
1966 Ok(())
1967 }
1968
1969 fn write_function_external_texture_argument(
1970 &mut self,
1971 module: &Module,
1972 handle: Handle<crate::Function>,
1973 index: usize,
1974 ) -> BackendResult {
1975 let plane_names = [0, 1, 2].map(|i| {
1976 &self.names[&NameKey::ExternalTextureFunctionArgument(
1977 handle,
1978 index as u32,
1979 ExternalTextureNameKey::Plane(i),
1980 )]
1981 });
1982 let params_name = &self.names[&NameKey::ExternalTextureFunctionArgument(
1983 handle,
1984 index as u32,
1985 ExternalTextureNameKey::Params,
1986 )];
1987 let params_ty_name =
1988 &self.names[&NameKey::Type(module.special_types.external_texture_params.unwrap())];
1989 write!(
1990 self.out,
1991 "Texture2D<float4> {}, Texture2D<float4> {}, Texture2D<float4> {}, {params_ty_name} {params_name}",
1992 plane_names[0], plane_names[1], plane_names[2],
1993 )?;
1994 Ok(())
1995 }
1996
1997 fn need_workgroup_variables_initialization(
1998 &mut self,
1999 func_ctx: &back::FunctionCtx,
2000 module: &Module,
2001 ) -> bool {
2002 self.options.zero_initialize_workgroup_memory
2003 && func_ctx.ty.is_compute_like_entry_point(module)
2004 && module.global_variables.iter().any(|(handle, var)| {
2005 !func_ctx.info[handle].is_empty() && var.space.is_workgroup_like()
2006 })
2007 }
2008
2009 pub(super) fn write_workgroup_variables_initialization(
2010 &mut self,
2011 func_ctx: &back::FunctionCtx,
2012 module: &Module,
2013 stage: ShaderStage,
2014 ) -> BackendResult {
2015 let vars = module.global_variables.iter().filter(|&(handle, var)| {
2016 let task_needs_zero =
2018 (var.space == crate::AddressSpace::TaskPayload) && stage == ShaderStage::Task;
2019 !func_ctx.info[handle].is_empty()
2020 && (var.space == crate::AddressSpace::WorkGroup || task_needs_zero)
2021 });
2022
2023 for (handle, var) in vars {
2024 let name = &self.names[&NameKey::GlobalVariable(handle)];
2025 write!(self.out, "{}{} = ", back::Level(2), name)?;
2026 self.write_default_init(module, var.ty)?;
2027 writeln!(self.out, ";")?;
2028 }
2029 Ok(())
2030 }
2031
2032 fn write_switch(
2034 &mut self,
2035 module: &Module,
2036 func_ctx: &back::FunctionCtx<'_>,
2037 level: back::Level,
2038 selector: Handle<crate::Expression>,
2039 cases: &[crate::SwitchCase],
2040 ) -> BackendResult {
2041 let indent_level_1 = level.next();
2043 let indent_level_2 = indent_level_1.next();
2044
2045 if let Some(variable) = self.continue_ctx.enter_switch(&mut self.namer) {
2047 writeln!(self.out, "{level}bool {variable} = false;",)?;
2048 };
2049
2050 let one_body = cases
2055 .iter()
2056 .rev()
2057 .skip(1)
2058 .all(|case| case.fall_through && case.body.is_empty());
2059 if one_body {
2060 writeln!(self.out, "{level}do {{")?;
2062 if let Some(case) = cases.last() {
2066 for sta in case.body.iter() {
2067 self.write_stmt(module, sta, func_ctx, indent_level_1)?;
2068 }
2069 }
2070 writeln!(self.out, "{level}}} while(false);")?;
2072 } else {
2073 write!(self.out, "{level}")?;
2075 write!(self.out, "switch(")?;
2076 self.write_expr(module, selector, func_ctx)?;
2077 writeln!(self.out, ") {{")?;
2078
2079 for (i, case) in cases.iter().enumerate() {
2080 match case.value {
2081 crate::SwitchValue::I32(value) => {
2082 write!(self.out, "{indent_level_1}case {value}:")?
2083 }
2084 crate::SwitchValue::U32(value) => {
2085 write!(self.out, "{indent_level_1}case {value}u:")?
2086 }
2087 crate::SwitchValue::Default => write!(self.out, "{indent_level_1}default:")?,
2088 }
2089
2090 let write_block_braces = !(case.fall_through && case.body.is_empty());
2097 if write_block_braces {
2098 writeln!(self.out, " {{")?;
2099 } else {
2100 writeln!(self.out)?;
2101 }
2102
2103 if case.fall_through && !case.body.is_empty() {
2121 let curr_len = i + 1;
2122 let end_case_idx = curr_len
2123 + cases
2124 .iter()
2125 .skip(curr_len)
2126 .position(|case| !case.fall_through)
2127 .unwrap();
2128 let indent_level_3 = indent_level_2.next();
2129 for case in &cases[i..=end_case_idx] {
2130 writeln!(self.out, "{indent_level_2}{{")?;
2131 let prev_len = self.named_expressions.len();
2132 for sta in case.body.iter() {
2133 self.write_stmt(module, sta, func_ctx, indent_level_3)?;
2134 }
2135 self.named_expressions.truncate(prev_len);
2137 writeln!(self.out, "{indent_level_2}}}")?;
2138 }
2139
2140 let last_case = &cases[end_case_idx];
2141 if last_case.body.last().is_none_or(|s| !s.is_terminator()) {
2142 writeln!(self.out, "{indent_level_2}break;")?;
2143 }
2144 } else {
2145 for sta in case.body.iter() {
2146 self.write_stmt(module, sta, func_ctx, indent_level_2)?;
2147 }
2148 if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator()) {
2149 writeln!(self.out, "{indent_level_2}break;")?;
2150 }
2151 }
2152
2153 if write_block_braces {
2154 writeln!(self.out, "{indent_level_1}}}")?;
2155 }
2156 }
2157
2158 writeln!(self.out, "{level}}}")?;
2159 }
2160
2161 use back::continue_forward::ExitControlFlow;
2163 let op = match self.continue_ctx.exit_switch() {
2164 ExitControlFlow::None => None,
2165 ExitControlFlow::Continue { variable } => Some(("continue", variable)),
2166 ExitControlFlow::Break { variable } => Some(("break", variable)),
2167 };
2168 if let Some((control_flow, variable)) = op {
2169 writeln!(self.out, "{level}if ({variable}) {{")?;
2170 writeln!(self.out, "{indent_level_1}{control_flow};")?;
2171 writeln!(self.out, "{level}}}")?;
2172 }
2173
2174 Ok(())
2175 }
2176
2177 fn write_index(
2178 &mut self,
2179 module: &Module,
2180 index: Index,
2181 func_ctx: &back::FunctionCtx<'_>,
2182 ) -> BackendResult {
2183 match index {
2184 Index::Static(index) => {
2185 write!(self.out, "{index}")?;
2186 }
2187 Index::Expression(index) => {
2188 self.write_expr(module, index, func_ctx)?;
2189 }
2190 }
2191 Ok(())
2192 }
2193
2194 fn write_stmt(
2199 &mut self,
2200 module: &Module,
2201 stmt: &crate::Statement,
2202 func_ctx: &back::FunctionCtx<'_>,
2203 level: back::Level,
2204 ) -> BackendResult {
2205 use crate::Statement;
2206
2207 match *stmt {
2208 Statement::Emit(ref range) => {
2209 for handle in range.clone() {
2210 let ptr_class = func_ctx.resolve_type(handle, &module.types).pointer_space();
2211 let expr_name = if ptr_class.is_some() {
2212 None
2216 } else if let Some(name) = func_ctx.named_expressions.get(&handle) {
2217 Some(self.namer.call(name))
2222 } else if self.need_bake_expressions.contains(&handle) {
2223 Some(Baked(handle).to_string())
2224 } else {
2225 None
2226 };
2227
2228 if let Some(name) = expr_name {
2229 write!(self.out, "{level}")?;
2230 self.write_named_expr(module, handle, name, handle, func_ctx)?;
2231 }
2232 }
2233 }
2234 Statement::Block(ref block) => {
2236 write!(self.out, "{level}")?;
2237 writeln!(self.out, "{{")?;
2238 for sta in block.iter() {
2239 self.write_stmt(module, sta, func_ctx, level.next())?
2241 }
2242 writeln!(self.out, "{level}}}")?
2243 }
2244 Statement::If {
2246 condition,
2247 ref accept,
2248 ref reject,
2249 } => {
2250 write!(self.out, "{level}")?;
2251 write!(self.out, "if (")?;
2252 self.write_expr(module, condition, func_ctx)?;
2253 writeln!(self.out, ") {{")?;
2254
2255 let l2 = level.next();
2256 for sta in accept {
2257 self.write_stmt(module, sta, func_ctx, l2)?;
2259 }
2260
2261 if !reject.is_empty() {
2264 writeln!(self.out, "{level}}} else {{")?;
2265
2266 for sta in reject {
2267 self.write_stmt(module, sta, func_ctx, l2)?;
2269 }
2270 }
2271
2272 writeln!(self.out, "{level}}}")?
2273 }
2274 Statement::Kill => writeln!(self.out, "{level}discard;")?,
2276 Statement::Return { value: None } => {
2277 writeln!(self.out, "{level}return;")?;
2278 }
2279 Statement::Return { value: Some(expr) } => {
2280 let base_ty_res = &func_ctx.info[expr].ty;
2281 let mut resolved = base_ty_res.inner_with(&module.types);
2282 if let TypeInner::Pointer { base, space: _ } = *resolved {
2283 resolved = &module.types[base].inner;
2284 }
2285
2286 if let TypeInner::Struct { .. } = *resolved {
2287 let ty = base_ty_res.handle().unwrap();
2289 let struct_name = &self.names[&NameKey::Type(ty)];
2290 let variable_name = self.namer.call(&struct_name.to_lowercase());
2291 write!(self.out, "{level}const {struct_name} {variable_name} = ",)?;
2292 self.write_expr(module, expr, func_ctx)?;
2293 writeln!(self.out, ";")?;
2294
2295 let ep_output = match func_ctx.ty {
2297 back::FunctionType::Function(_) => None,
2298 back::FunctionType::EntryPoint(index) => self
2299 .entry_point_io
2300 .get(&(index as usize))
2301 .unwrap()
2302 .output
2303 .as_ref(),
2304 };
2305 let final_name = match ep_output {
2306 Some(ep_output) => {
2307 let final_name = self.namer.call(&variable_name);
2308 write!(
2309 self.out,
2310 "{}const {} {} = {{ ",
2311 level, ep_output.ty_name, final_name,
2312 )?;
2313 for (index, m) in ep_output.members.iter().enumerate() {
2314 if index != 0 {
2315 write!(self.out, ", ")?;
2316 }
2317 let member_name = &self.names[&NameKey::StructMember(ty, m.index)];
2318 write!(self.out, "{variable_name}.{member_name}")?;
2319 }
2320 writeln!(self.out, " }};")?;
2321 final_name
2322 }
2323 None => variable_name,
2324 };
2325 writeln!(self.out, "{level}return {final_name};")?;
2326 } else {
2327 write!(self.out, "{level}return ")?;
2328 self.write_expr(module, expr, func_ctx)?;
2329 writeln!(self.out, ";")?
2330 }
2331 }
2332 Statement::Store { pointer, value } => {
2333 let ty_inner = func_ctx.resolve_type(pointer, &module.types);
2334 if ty_inner.is_atomic_pointer(&module.types) {
2335 let pointer_space = ty_inner.pointer_space().unwrap();
2336 let dummy = self.namer.call("dummy");
2337 write!(self.out, "{level}{{ ")?;
2338 if let TypeInner::Pointer { base, .. } = *ty_inner {
2339 self.write_value_type(module, &module.types[base].inner)?;
2340 }
2341 write!(self.out, " {dummy} = 0; ")?;
2342 match pointer_space {
2343 crate::AddressSpace::WorkGroup => {
2344 write!(self.out, "InterlockedExchange(")?;
2345 self.write_expr(module, pointer, func_ctx)?;
2346 }
2347 crate::AddressSpace::Storage { .. } => {
2348 let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2349 let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
2350 write!(self.out, "{var_name}.InterlockedExchange(")?;
2351 let chain = mem::take(&mut self.temp_access_chain);
2352 self.write_storage_address(module, &chain, func_ctx)?;
2353 self.temp_access_chain = chain;
2354 }
2355 _ => unreachable!(),
2356 }
2357 write!(self.out, ", ")?;
2358 self.write_expr(module, value, func_ctx)?;
2359 writeln!(self.out, ", {dummy}); }}")?;
2360 } else if let Some(crate::AddressSpace::Storage { .. }) = ty_inner.pointer_space() {
2361 let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2362 self.write_storage_store(
2363 module,
2364 var_handle,
2365 StoreValue::Expression(value),
2366 func_ctx,
2367 level,
2368 None,
2369 )?;
2370 } else {
2371 enum MatrixAccess {
2377 Direct {
2378 base: Handle<crate::Expression>,
2379 index: u32,
2380 },
2381 Struct {
2382 columns: crate::VectorSize,
2383 width: u8,
2384 base: Handle<crate::Expression>,
2385 },
2386 }
2387
2388 let get_members = |expr: Handle<crate::Expression>| {
2389 let resolved = func_ctx.resolve_type(expr, &module.types);
2390 match *resolved {
2391 TypeInner::Pointer { base, .. } => match module.types[base].inner {
2392 TypeInner::Struct { ref members, .. } => Some(members),
2393 _ => None,
2394 },
2395 _ => None,
2396 }
2397 };
2398
2399 write!(self.out, "{level}")?;
2400
2401 let matrix_access_on_lhs =
2402 find_matrix_in_access_chain(module, pointer, func_ctx).and_then(
2403 |(matrix_expr, vector, scalar)| match (
2404 func_ctx.resolve_type(matrix_expr, &module.types),
2405 &func_ctx.expressions[matrix_expr],
2406 ) {
2407 (
2408 &TypeInner::Pointer { base: ty, .. },
2409 &crate::Expression::AccessIndex { base, index },
2410 ) if matches!(
2411 module.types[ty].inner,
2412 TypeInner::Matrix {
2413 rows: crate::VectorSize::Bi,
2414 ..
2415 }
2416 ) && get_members(base)
2417 .map(|members| members[index as usize].binding.is_none())
2418 == Some(true) =>
2419 {
2420 Some((MatrixAccess::Direct { base, index }, vector, scalar))
2421 }
2422 _ => {
2423 if let Some(MatrixType {
2424 columns,
2425 rows: crate::VectorSize::Bi,
2426 width,
2427 }) = get_inner_matrix_of_struct_array_member(
2428 module,
2429 matrix_expr,
2430 func_ctx,
2431 true,
2432 ) {
2433 Some((
2434 MatrixAccess::Struct {
2435 columns,
2436 width,
2437 base: matrix_expr,
2438 },
2439 vector,
2440 scalar,
2441 ))
2442 } else {
2443 None
2444 }
2445 }
2446 },
2447 );
2448
2449 match matrix_access_on_lhs {
2450 Some((MatrixAccess::Direct { index, base }, vector, scalar)) => {
2451 let base_ty_res = &func_ctx.info[base].ty;
2452 let resolved = base_ty_res.inner_with(&module.types);
2453 let ty = match *resolved {
2454 TypeInner::Pointer { base, .. } => base,
2455 _ => base_ty_res.handle().unwrap(),
2456 };
2457
2458 if let Some(Index::Static(vec_index)) = vector {
2459 self.write_expr(module, base, func_ctx)?;
2460 write!(
2461 self.out,
2462 ".{}_{}",
2463 &self.names[&NameKey::StructMember(ty, index)],
2464 vec_index
2465 )?;
2466
2467 if let Some(scalar_index) = scalar {
2468 write!(self.out, "[")?;
2469 self.write_index(module, scalar_index, func_ctx)?;
2470 write!(self.out, "]")?;
2471 }
2472
2473 write!(self.out, " = ")?;
2474 self.write_expr(module, value, func_ctx)?;
2475 writeln!(self.out, ";")?;
2476 } else {
2477 let access = WrappedStructMatrixAccess { ty, index };
2478 match (&vector, &scalar) {
2479 (&Some(_), &Some(_)) => {
2480 self.write_wrapped_struct_matrix_set_scalar_function_name(
2481 access,
2482 )?;
2483 }
2484 (&Some(_), &None) => {
2485 self.write_wrapped_struct_matrix_set_vec_function_name(
2486 access,
2487 )?;
2488 }
2489 (&None, _) => {
2490 self.write_wrapped_struct_matrix_set_function_name(access)?;
2491 }
2492 }
2493
2494 write!(self.out, "(")?;
2495 self.write_expr(module, base, func_ctx)?;
2496 write!(self.out, ", ")?;
2497 self.write_expr(module, value, func_ctx)?;
2498
2499 if let Some(Index::Expression(vec_index)) = vector {
2500 write!(self.out, ", ")?;
2501 self.write_expr(module, vec_index, func_ctx)?;
2502
2503 if let Some(scalar_index) = scalar {
2504 write!(self.out, ", ")?;
2505 self.write_index(module, scalar_index, func_ctx)?;
2506 }
2507 }
2508 writeln!(self.out, ");")?;
2509 }
2510 }
2511 Some((
2512 MatrixAccess::Struct {
2513 columns,
2514 width,
2515 base,
2516 },
2517 Some(Index::Expression(vec_index)),
2518 scalar,
2519 )) => {
2520 if scalar.is_some() {
2524 write!(
2525 self.out,
2526 "__set_el_of_mat{}x2_f{}",
2527 columns as u8,
2528 width * 8
2529 )?;
2530 } else {
2531 write!(
2532 self.out,
2533 "__set_col_of_mat{}x2_f{}",
2534 columns as u8,
2535 width * 8
2536 )?;
2537 }
2538 write!(self.out, "(")?;
2539 self.write_expr(module, base, func_ctx)?;
2540 write!(self.out, ", ")?;
2541 self.write_expr(module, vec_index, func_ctx)?;
2542
2543 if let Some(scalar_index) = scalar {
2544 write!(self.out, ", ")?;
2545 self.write_index(module, scalar_index, func_ctx)?;
2546 }
2547
2548 write!(self.out, ", ")?;
2549 self.write_expr(module, value, func_ctx)?;
2550
2551 writeln!(self.out, ");")?;
2552 }
2553 Some((MatrixAccess::Struct { .. }, Some(Index::Static(_)), _))
2554 | Some((MatrixAccess::Struct { .. }, None, _))
2555 | None => {
2556 self.write_expr(module, pointer, func_ctx)?;
2557 write!(self.out, " = ")?;
2558
2559 if let Some(MatrixType {
2564 columns,
2565 rows: crate::VectorSize::Bi,
2566 width,
2567 }) = get_inner_matrix_of_struct_array_member(
2568 module, pointer, func_ctx, false,
2569 ) {
2570 let mut resolved = func_ctx.resolve_type(pointer, &module.types);
2571 if let TypeInner::Pointer { base, .. } = *resolved {
2572 resolved = &module.types[base].inner;
2573 }
2574
2575 write!(self.out, "(__mat{}x2_f{}", columns as u8, width * 8)?;
2576 if let TypeInner::Array { base, size, .. } = *resolved {
2577 self.write_array_size(module, base, size)?;
2578 }
2579 write!(self.out, ")")?;
2580 }
2581
2582 self.write_expr(module, value, func_ctx)?;
2583 writeln!(self.out, ";")?
2584 }
2585 }
2586 }
2587 }
2588 Statement::Loop {
2589 ref body,
2590 ref continuing,
2591 break_if,
2592 } => {
2593 let force_loop_bound_statements = self.gen_force_bounded_loop_statements(level);
2594 let gate_name = (!continuing.is_empty() || break_if.is_some())
2595 .then(|| self.namer.call("loop_init"));
2596
2597 if let Some((ref decl, _)) = force_loop_bound_statements {
2598 writeln!(self.out, "{decl}")?;
2599 }
2600 if let Some(ref gate_name) = gate_name {
2601 writeln!(self.out, "{level}bool {gate_name} = true;")?;
2602 }
2603
2604 self.continue_ctx.enter_loop();
2605 writeln!(self.out, "{level}while(true) {{")?;
2606 if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
2607 writeln!(self.out, "{break_and_inc}")?;
2608 }
2609 let l2 = level.next();
2610 if let Some(gate_name) = gate_name {
2611 writeln!(self.out, "{l2}if (!{gate_name}) {{")?;
2612 let l3 = l2.next();
2613 for sta in continuing.iter() {
2614 self.write_stmt(module, sta, func_ctx, l3)?;
2615 }
2616 if let Some(condition) = break_if {
2617 write!(self.out, "{l3}if (")?;
2618 self.write_expr(module, condition, func_ctx)?;
2619 writeln!(self.out, ") {{")?;
2620 writeln!(self.out, "{}break;", l3.next())?;
2621 writeln!(self.out, "{l3}}}")?;
2622 }
2623 writeln!(self.out, "{l2}}}")?;
2624 writeln!(self.out, "{l2}{gate_name} = false;")?;
2625 }
2626
2627 for sta in body.iter() {
2628 self.write_stmt(module, sta, func_ctx, l2)?;
2629 }
2630
2631 writeln!(self.out, "{level}}}")?;
2632 self.continue_ctx.exit_loop();
2633 }
2634 Statement::Break => writeln!(self.out, "{level}break;")?,
2635 Statement::Continue => {
2636 if let Some(variable) = self.continue_ctx.continue_encountered() {
2637 writeln!(self.out, "{level}{variable} = true;")?;
2638 writeln!(self.out, "{level}break;")?
2639 } else {
2640 writeln!(self.out, "{level}continue;")?
2641 }
2642 }
2643 Statement::ControlBarrier(barrier) => {
2644 self.write_control_barrier(barrier, level)?;
2645 }
2646 Statement::MemoryBarrier(barrier) => {
2647 self.write_memory_barrier(barrier, level)?;
2648 }
2649 Statement::ImageStore {
2650 image,
2651 coordinate,
2652 array_index,
2653 value,
2654 } => {
2655 write!(self.out, "{level}")?;
2656 self.write_expr(module, image, func_ctx)?;
2657
2658 write!(self.out, "[")?;
2659 if let Some(index) = array_index {
2660 write!(self.out, "int3(")?;
2662 self.write_expr(module, coordinate, func_ctx)?;
2663 write!(self.out, ", ")?;
2664 self.write_expr(module, index, func_ctx)?;
2665 write!(self.out, ")")?;
2666 } else {
2667 self.write_expr(module, coordinate, func_ctx)?;
2668 }
2669 write!(self.out, "]")?;
2670
2671 write!(self.out, " = ")?;
2672 self.write_expr(module, value, func_ctx)?;
2673 writeln!(self.out, ";")?;
2674 }
2675 Statement::Call {
2676 function,
2677 ref arguments,
2678 result,
2679 } => {
2680 write!(self.out, "{level}")?;
2681
2682 if let Some(expr) = result {
2683 write!(self.out, "const ")?;
2684 let name = Baked(expr).to_string();
2685 let expr_ty = &func_ctx.info[expr].ty;
2686 let ty_inner = match *expr_ty {
2687 proc::TypeResolution::Handle(handle) => {
2688 self.write_type(module, handle)?;
2689 &module.types[handle].inner
2690 }
2691 proc::TypeResolution::Value(ref value) => {
2692 self.write_value_type(module, value)?;
2693 value
2694 }
2695 };
2696 write!(self.out, " {name}")?;
2697 if let TypeInner::Array { base, size, .. } = *ty_inner {
2698 self.write_array_size(module, base, size)?;
2699 }
2700 write!(self.out, " = ")?;
2701 self.named_expressions.insert(expr, name);
2702 }
2703 let func_name = &self.names[&NameKey::Function(function)];
2704 write!(self.out, "{func_name}(")?;
2705 let mut any_args_written = false;
2706 let mut separator = || {
2707 if any_args_written {
2708 ", "
2709 } else {
2710 any_args_written = true;
2711 ""
2712 }
2713 };
2714 for argument in arguments {
2715 write!(self.out, "{}", separator())?;
2716 self.write_expr(module, *argument, func_ctx)?;
2717 }
2718 if let Some(&var) = self.function_task_payload_var.get(&function) {
2719 let name = &self.names[&NameKey::GlobalVariable(var)];
2720 write!(self.out, "{}{name}", separator())?;
2722 }
2723 writeln!(self.out, ");")?;
2724 }
2725 Statement::Atomic {
2726 pointer,
2727 ref fun,
2728 value,
2729 result,
2730 } => {
2731 write!(self.out, "{level}")?;
2732 let res_var_info = if let Some(res_handle) = result {
2733 let name = Baked(res_handle).to_string();
2734 match func_ctx.info[res_handle].ty {
2735 proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2736 proc::TypeResolution::Value(ref value) => {
2737 self.write_value_type(module, value)?
2738 }
2739 };
2740 write!(self.out, " {name}; ")?;
2741 self.named_expressions.insert(res_handle, name.clone());
2742 Some((res_handle, name))
2743 } else {
2744 None
2745 };
2746 let pointer_space = func_ctx
2747 .resolve_type(pointer, &module.types)
2748 .pointer_space()
2749 .unwrap();
2750 let fun_str = fun.to_hlsl_suffix();
2751 let compare_expr = match *fun {
2752 crate::AtomicFunction::Exchange { compare: Some(cmp) } => Some(cmp),
2753 _ => None,
2754 };
2755 match pointer_space {
2756 crate::AddressSpace::WorkGroup => {
2757 write!(self.out, "Interlocked{fun_str}(")?;
2758 self.write_expr(module, pointer, func_ctx)?;
2759 self.emit_hlsl_atomic_tail(
2760 module,
2761 func_ctx,
2762 fun,
2763 compare_expr,
2764 value,
2765 &res_var_info,
2766 )?;
2767 }
2768 crate::AddressSpace::Storage { .. } => {
2769 let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2770 let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
2771 let width = match func_ctx.resolve_type(value, &module.types) {
2772 &TypeInner::Scalar(Scalar { width: 8, .. }) => "64",
2773 _ => "",
2774 };
2775 write!(self.out, "{var_name}.Interlocked{fun_str}{width}(")?;
2776 let chain = mem::take(&mut self.temp_access_chain);
2777 self.write_storage_address(module, &chain, func_ctx)?;
2778 self.temp_access_chain = chain;
2779 self.emit_hlsl_atomic_tail(
2780 module,
2781 func_ctx,
2782 fun,
2783 compare_expr,
2784 value,
2785 &res_var_info,
2786 )?;
2787 }
2788 ref other => {
2789 return Err(Error::Custom(format!(
2790 "invalid address space {other:?} for atomic statement"
2791 )))
2792 }
2793 }
2794 if let Some(cmp) = compare_expr {
2795 if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
2796 write!(
2797 self.out,
2798 "{level}{res_name}.exchanged = ({res_name}.old_value == "
2799 )?;
2800 self.write_expr(module, cmp, func_ctx)?;
2801 writeln!(self.out, ");")?;
2802 }
2803 }
2804 }
2805 Statement::ImageAtomic {
2806 image,
2807 coordinate,
2808 array_index,
2809 fun,
2810 value,
2811 } => {
2812 write!(self.out, "{level}")?;
2813
2814 let fun_str = fun.to_hlsl_suffix();
2815 write!(self.out, "Interlocked{fun_str}(")?;
2816 self.write_expr(module, image, func_ctx)?;
2817 write!(self.out, "[")?;
2818 self.write_texture_coordinates(
2819 "int",
2820 coordinate,
2821 array_index,
2822 None,
2823 module,
2824 func_ctx,
2825 )?;
2826 write!(self.out, "],")?;
2827
2828 self.write_expr(module, value, func_ctx)?;
2829 writeln!(self.out, ");")?;
2830 }
2831 Statement::WorkGroupUniformLoad { pointer, result } => {
2832 self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
2833 write!(self.out, "{level}")?;
2834 let name = Baked(result).to_string();
2835 self.write_named_expr(module, pointer, name, result, func_ctx)?;
2836
2837 self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
2838 }
2839 Statement::Switch {
2840 selector,
2841 ref cases,
2842 } => {
2843 self.write_switch(module, func_ctx, level, selector, cases)?;
2844 }
2845 Statement::RayQuery { query, ref fun } => {
2846 let crate::Expression::LocalVariable(query_var) = func_ctx.expressions[query]
2858 else {
2859 unreachable!()
2860 };
2861
2862 let tracker_expr_name = format!(
2863 "{RAY_QUERY_TRACKER_VARIABLE_PREFIX}{}",
2864 self.names[&func_ctx.name_key(query_var)]
2865 );
2866
2867 match *fun {
2868 RayQueryFunction::Initialize {
2869 acceleration_structure,
2870 descriptor,
2871 } => {
2872 self.write_initialize_function(
2873 module,
2874 level,
2875 query,
2876 acceleration_structure,
2877 descriptor,
2878 &tracker_expr_name,
2879 func_ctx,
2880 )?;
2881 }
2882 RayQueryFunction::Proceed { result } => {
2883 self.write_proceed(
2884 module,
2885 level,
2886 query,
2887 result,
2888 &tracker_expr_name,
2889 func_ctx,
2890 )?;
2891 }
2892 RayQueryFunction::GenerateIntersection { hit_t } => {
2893 self.write_generate_intersection(
2894 module,
2895 level,
2896 query,
2897 hit_t,
2898 &tracker_expr_name,
2899 func_ctx,
2900 )?;
2901 }
2902 RayQueryFunction::ConfirmIntersection => {
2903 self.write_confirm_intersection(
2904 module,
2905 level,
2906 query,
2907 &tracker_expr_name,
2908 func_ctx,
2909 )?;
2910 }
2911 RayQueryFunction::Terminate => {
2912 self.write_terminate(module, level, query, &tracker_expr_name, func_ctx)?;
2913 }
2914 }
2915 }
2916 Statement::SubgroupBallot { result, predicate } => {
2917 write!(self.out, "{level}")?;
2918 let name = Baked(result).to_string();
2919 write!(self.out, "const uint4 {name} = ")?;
2920 self.named_expressions.insert(result, name);
2921
2922 write!(self.out, "WaveActiveBallot(")?;
2923 match predicate {
2924 Some(predicate) => self.write_expr(module, predicate, func_ctx)?,
2925 None => write!(self.out, "true")?,
2926 }
2927 writeln!(self.out, ");")?;
2928 }
2929 Statement::SubgroupCollectiveOperation {
2930 op,
2931 collective_op,
2932 argument,
2933 result,
2934 } => {
2935 write!(self.out, "{level}")?;
2936 write!(self.out, "const ")?;
2937 let name = Baked(result).to_string();
2938 match func_ctx.info[result].ty {
2939 proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2940 proc::TypeResolution::Value(ref value) => {
2941 self.write_value_type(module, value)?
2942 }
2943 };
2944 write!(self.out, " {name} = ")?;
2945 self.named_expressions.insert(result, name);
2946
2947 match (collective_op, op) {
2948 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
2949 write!(self.out, "WaveActiveAllTrue(")?
2950 }
2951 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
2952 write!(self.out, "WaveActiveAnyTrue(")?
2953 }
2954 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
2955 write!(self.out, "WaveActiveSum(")?
2956 }
2957 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
2958 write!(self.out, "WaveActiveProduct(")?
2959 }
2960 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
2961 write!(self.out, "WaveActiveMax(")?
2962 }
2963 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
2964 write!(self.out, "WaveActiveMin(")?
2965 }
2966 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
2967 write!(self.out, "WaveActiveBitAnd(")?
2968 }
2969 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
2970 write!(self.out, "WaveActiveBitOr(")?
2971 }
2972 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
2973 write!(self.out, "WaveActiveBitXor(")?
2974 }
2975 (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
2976 write!(self.out, "WavePrefixSum(")?
2977 }
2978 (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
2979 write!(self.out, "WavePrefixProduct(")?
2980 }
2981 (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
2982 self.write_expr(module, argument, func_ctx)?;
2983 write!(self.out, " + WavePrefixSum(")?;
2984 }
2985 (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
2986 self.write_expr(module, argument, func_ctx)?;
2987 write!(self.out, " * WavePrefixProduct(")?;
2988 }
2989 _ => unimplemented!(),
2990 }
2991 self.write_expr(module, argument, func_ctx)?;
2992 writeln!(self.out, ");")?;
2993 }
2994 Statement::SubgroupGather {
2995 mode,
2996 argument,
2997 result,
2998 } => {
2999 write!(self.out, "{level}")?;
3000 write!(self.out, "const ")?;
3001 let name = Baked(result).to_string();
3002 match func_ctx.info[result].ty {
3003 proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
3004 proc::TypeResolution::Value(ref value) => {
3005 self.write_value_type(module, value)?
3006 }
3007 };
3008 write!(self.out, " {name} = ")?;
3009 self.named_expressions.insert(result, name);
3010 match mode {
3011 crate::GatherMode::BroadcastFirst => {
3012 write!(self.out, "WaveReadLaneFirst(")?;
3013 self.write_expr(module, argument, func_ctx)?;
3014 }
3015 crate::GatherMode::QuadBroadcast(index) => {
3016 write!(self.out, "QuadReadLaneAt(")?;
3017 self.write_expr(module, argument, func_ctx)?;
3018 write!(self.out, ", ")?;
3019 self.write_expr(module, index, func_ctx)?;
3020 }
3021 crate::GatherMode::QuadSwap(direction) => {
3022 match direction {
3023 crate::Direction::X => {
3024 write!(self.out, "QuadReadAcrossX(")?;
3025 }
3026 crate::Direction::Y => {
3027 write!(self.out, "QuadReadAcrossY(")?;
3028 }
3029 crate::Direction::Diagonal => {
3030 write!(self.out, "QuadReadAcrossDiagonal(")?;
3031 }
3032 }
3033 self.write_expr(module, argument, func_ctx)?;
3034 }
3035 _ => {
3036 write!(self.out, "WaveReadLaneAt(")?;
3037 self.write_expr(module, argument, func_ctx)?;
3038 write!(self.out, ", ")?;
3039 match mode {
3040 crate::GatherMode::BroadcastFirst => unreachable!(),
3041 crate::GatherMode::Broadcast(index)
3042 | crate::GatherMode::Shuffle(index) => {
3043 self.write_expr(module, index, func_ctx)?;
3044 }
3045 crate::GatherMode::ShuffleDown(index) => {
3046 write!(self.out, "WaveGetLaneIndex() + ")?;
3047 self.write_expr(module, index, func_ctx)?;
3048 }
3049 crate::GatherMode::ShuffleUp(index) => {
3050 write!(self.out, "WaveGetLaneIndex() - ")?;
3051 self.write_expr(module, index, func_ctx)?;
3052 }
3053 crate::GatherMode::ShuffleXor(index) => {
3054 write!(self.out, "WaveGetLaneIndex() ^ ")?;
3055 self.write_expr(module, index, func_ctx)?;
3056 }
3057 crate::GatherMode::QuadBroadcast(_) => unreachable!(),
3058 crate::GatherMode::QuadSwap(_) => unreachable!(),
3059 }
3060 }
3061 }
3062 writeln!(self.out, ");")?;
3063 }
3064 Statement::CooperativeStore { .. } => unimplemented!(),
3065 Statement::RayPipelineFunction(_) => unreachable!(),
3066 }
3067
3068 Ok(())
3069 }
3070
3071 fn write_const_expression(
3072 &mut self,
3073 module: &Module,
3074 expr: Handle<crate::Expression>,
3075 arena: &crate::Arena<crate::Expression>,
3076 ) -> BackendResult {
3077 self.write_possibly_const_expression(module, expr, arena, |writer, expr| {
3078 writer.write_const_expression(module, expr, arena)
3079 })
3080 }
3081
3082 pub(super) fn write_literal(&mut self, literal: crate::Literal) -> BackendResult {
3083 match literal {
3084 crate::Literal::F64(value) => write!(self.out, "{value:?}L")?,
3085 crate::Literal::F32(value) => write!(self.out, "{value:?}")?,
3086 crate::Literal::F16(value) => write!(self.out, "{value:?}h")?,
3087 crate::Literal::U16(value) => write!(self.out, "uint16_t({value})")?,
3088 crate::Literal::I16(value) => write!(self.out, "int16_t({value})")?,
3089 crate::Literal::U32(value) => write!(self.out, "{value}u")?,
3090 crate::Literal::I32(value) if value == i32::MIN => {
3096 write!(self.out, "int({} - 1)", value + 1)?
3097 }
3098 crate::Literal::I32(value) => write!(self.out, "int({value})")?,
3102 crate::Literal::U64(value) => write!(self.out, "{value}uL")?,
3103 crate::Literal::I64(value) if value == i64::MIN => {
3105 write!(self.out, "({}L - 1L)", value + 1)?;
3106 }
3107 crate::Literal::I64(value) => write!(self.out, "{value}L")?,
3108 crate::Literal::Bool(value) => write!(self.out, "{value}")?,
3109 crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
3110 return Err(Error::Custom(
3111 "Abstract types should not appear in IR presented to backends".into(),
3112 ));
3113 }
3114 }
3115 Ok(())
3116 }
3117
3118 fn write_possibly_const_expression<E>(
3119 &mut self,
3120 module: &Module,
3121 expr: Handle<crate::Expression>,
3122 expressions: &crate::Arena<crate::Expression>,
3123 write_expression: E,
3124 ) -> BackendResult
3125 where
3126 E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult,
3127 {
3128 use crate::Expression;
3129
3130 match expressions[expr] {
3131 Expression::Literal(literal) => {
3132 self.write_literal(literal)?;
3133 }
3134 Expression::Constant(handle) => {
3135 let constant = &module.constants[handle];
3136 if constant.name.is_some() {
3137 write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
3138 } else {
3139 self.write_const_expression(module, constant.init, &module.global_expressions)?;
3140 }
3141 }
3142 Expression::ZeroValue(ty) => {
3143 self.write_wrapped_zero_value_function_name(module, WrappedZeroValue { ty })?;
3144 write!(self.out, "()")?;
3145 }
3146 Expression::Compose { ty, ref components } => {
3147 match module.types[ty].inner {
3148 TypeInner::Struct { .. } | TypeInner::Array { .. } => {
3149 self.write_wrapped_constructor_function_name(
3150 module,
3151 WrappedConstructor { ty },
3152 )?;
3153 }
3154 _ => {
3155 self.write_type(module, ty)?;
3156 }
3157 };
3158 write!(self.out, "(")?;
3159 for (index, component) in components.iter().enumerate() {
3160 if index != 0 {
3161 write!(self.out, ", ")?;
3162 }
3163 write_expression(self, *component)?;
3164 }
3165 write!(self.out, ")")?;
3166 }
3167 Expression::Splat { size, value } => {
3168 let number_of_components = match size {
3172 crate::VectorSize::Bi => "xx",
3173 crate::VectorSize::Tri => "xxx",
3174 crate::VectorSize::Quad => "xxxx",
3175 };
3176 write!(self.out, "(")?;
3177 write_expression(self, value)?;
3178 write!(self.out, ").{number_of_components}")?
3179 }
3180 _ => {
3181 return Err(Error::Override);
3182 }
3183 }
3184
3185 Ok(())
3186 }
3187
3188 pub(super) fn write_expr(
3193 &mut self,
3194 module: &Module,
3195 expr: Handle<crate::Expression>,
3196 func_ctx: &back::FunctionCtx<'_>,
3197 ) -> BackendResult {
3198 use crate::Expression;
3199
3200 let ff_input = if self.options.special_constants_binding.is_some() {
3202 func_ctx.is_fixed_function_input(expr, module)
3203 } else {
3204 None
3205 };
3206 let closing_bracket = match ff_input {
3207 Some(crate::BuiltIn::VertexIndex) => {
3208 write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX} + ")?;
3209 ")"
3210 }
3211 Some(crate::BuiltIn::InstanceIndex) => {
3212 write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE} + ",)?;
3213 ")"
3214 }
3215 Some(crate::BuiltIn::NumWorkGroups) => {
3216 write!(
3220 self.out,
3221 "uint3({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX}, {SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE}, {SPECIAL_CBUF_VAR}.{SPECIAL_OTHER})",
3222 )?;
3223 return Ok(());
3224 }
3225 _ => "",
3226 };
3227
3228 if let Some(name) = self.named_expressions.get(&expr) {
3229 write!(self.out, "{name}{closing_bracket}")?;
3230 return Ok(());
3231 }
3232
3233 let expression = &func_ctx.expressions[expr];
3234
3235 match *expression {
3236 Expression::Literal(_)
3237 | Expression::Constant(_)
3238 | Expression::ZeroValue(_)
3239 | Expression::Compose { .. }
3240 | Expression::Splat { .. } => {
3241 self.write_possibly_const_expression(
3242 module,
3243 expr,
3244 func_ctx.expressions,
3245 |writer, expr| writer.write_expr(module, expr, func_ctx),
3246 )?;
3247 }
3248 Expression::Override(_) => return Err(Error::Override),
3249 Expression::Binary {
3256 op:
3257 op @ crate::BinaryOperator::Add
3258 | op @ crate::BinaryOperator::Subtract
3259 | op @ crate::BinaryOperator::Multiply,
3260 left,
3261 right,
3262 } if matches!(
3263 func_ctx.resolve_type(expr, &module.types).scalar(),
3264 Some(Scalar::I32)
3265 ) =>
3266 {
3267 write!(self.out, "asint(asuint(",)?;
3268 self.write_expr(module, left, func_ctx)?;
3269 write!(self.out, ") {} asuint(", back::binary_operation_str(op))?;
3270 self.write_expr(module, right, func_ctx)?;
3271 write!(self.out, "))")?;
3272 }
3273 Expression::Binary {
3276 op: crate::BinaryOperator::Multiply,
3277 left,
3278 right,
3279 } if func_ctx.resolve_type(left, &module.types).is_matrix()
3280 || func_ctx.resolve_type(right, &module.types).is_matrix() =>
3281 {
3282 write!(self.out, "mul(")?;
3284 self.write_expr(module, right, func_ctx)?;
3285 write!(self.out, ", ")?;
3286 self.write_expr(module, left, func_ctx)?;
3287 write!(self.out, ")")?;
3288 }
3289
3290 Expression::Binary {
3302 op: crate::BinaryOperator::Divide,
3303 left,
3304 right,
3305 } if matches!(
3306 func_ctx.resolve_type(expr, &module.types).scalar_kind(),
3307 Some(ScalarKind::Sint | ScalarKind::Uint)
3308 ) =>
3309 {
3310 write!(self.out, "{DIV_FUNCTION}(")?;
3311 self.write_expr(module, left, func_ctx)?;
3312 write!(self.out, ", ")?;
3313 self.write_expr(module, right, func_ctx)?;
3314 write!(self.out, ")")?;
3315 }
3316
3317 Expression::Binary {
3318 op: crate::BinaryOperator::Modulo,
3319 left,
3320 right,
3321 } if matches!(
3322 func_ctx.resolve_type(expr, &module.types).scalar_kind(),
3323 Some(ScalarKind::Sint | ScalarKind::Uint | ScalarKind::Float)
3324 ) =>
3325 {
3326 write!(self.out, "{MOD_FUNCTION}(")?;
3327 self.write_expr(module, left, func_ctx)?;
3328 write!(self.out, ", ")?;
3329 self.write_expr(module, right, func_ctx)?;
3330 write!(self.out, ")")?;
3331 }
3332
3333 Expression::Binary { op, left, right } => {
3334 write!(self.out, "(")?;
3335 self.write_expr(module, left, func_ctx)?;
3336 write!(self.out, " {} ", back::binary_operation_str(op))?;
3337 self.write_expr(module, right, func_ctx)?;
3338 write!(self.out, ")")?;
3339 }
3340 Expression::Access { base, index } => {
3341 if let Some(crate::AddressSpace::Storage { .. }) =
3342 func_ctx.resolve_type(expr, &module.types).pointer_space()
3343 {
3344 } else {
3346 if let Some(MatrixType {
3353 columns,
3354 rows: crate::VectorSize::Bi,
3355 width,
3356 }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
3357 .or_else(|| {
3358 get_inner_matrix_of_global_uniform(module, base, func_ctx, true)
3359 })
3360 {
3361 write!(
3362 self.out,
3363 "__get_col_of_mat{}x2_f{}(",
3364 columns as u8,
3365 width * 8
3366 )?;
3367 self.write_expr(module, base, func_ctx)?;
3368 write!(self.out, ", ")?;
3369 self.write_expr(module, index, func_ctx)?;
3370 write!(self.out, ")")?;
3371 return Ok(());
3372 }
3373
3374 let resolved = func_ctx.resolve_type(base, &module.types);
3375
3376 let (indexing_binding_array, non_uniform_qualifier) = match *resolved {
3377 TypeInner::BindingArray { .. } => {
3378 let uniformity = &func_ctx.info[index].uniformity;
3379
3380 (true, uniformity.non_uniform_result.is_some())
3381 }
3382 _ => (false, false),
3383 };
3384
3385 self.write_expr(module, base, func_ctx)?;
3386
3387 let array_sampler_info = self.sampler_binding_array_info_from_expression(
3388 module, func_ctx, base, resolved,
3389 );
3390
3391 if let Some(ref info) = array_sampler_info {
3392 write!(self.out, "{}[", info.sampler_heap_name)?;
3393 } else {
3394 write!(self.out, "[")?;
3395 }
3396
3397 let needs_bound_check = self.options.restrict_indexing
3398 && !indexing_binding_array
3399 && match resolved.pointer_space() {
3400 Some(
3401 crate::AddressSpace::Function
3402 | crate::AddressSpace::Private
3403 | crate::AddressSpace::WorkGroup
3404 | crate::AddressSpace::Immediate
3405 | crate::AddressSpace::TaskPayload
3406 | crate::AddressSpace::RayPayload
3407 | crate::AddressSpace::IncomingRayPayload,
3408 )
3409 | None => true,
3410 Some(crate::AddressSpace::Uniform) => {
3411 let var_handle = self.fill_access_chain(module, base, func_ctx)?;
3413 let bind_target = self
3414 .options
3415 .resolve_resource_binding(
3416 module.global_variables[var_handle]
3417 .binding
3418 .as_ref()
3419 .unwrap(),
3420 )
3421 .unwrap();
3422 bind_target.restrict_indexing
3423 }
3424 Some(
3425 crate::AddressSpace::Handle | crate::AddressSpace::Storage { .. },
3426 ) => unreachable!(),
3427 };
3428 let restriction_needed = if needs_bound_check {
3430 index::access_needs_check(
3431 base,
3432 index::GuardedIndex::Expression(index),
3433 module,
3434 func_ctx.expressions,
3435 func_ctx.info,
3436 )
3437 } else {
3438 None
3439 };
3440 if let Some(limit) = restriction_needed {
3441 write!(self.out, "min(uint(")?;
3442 self.write_expr(module, index, func_ctx)?;
3443 write!(self.out, "), ")?;
3444 match limit {
3445 index::IndexableLength::Known(limit) => {
3446 write!(self.out, "{}u", limit - 1)?;
3447 }
3448 index::IndexableLength::Dynamic => unreachable!(),
3449 }
3450 write!(self.out, ")")?;
3451 } else {
3452 if non_uniform_qualifier {
3453 write!(self.out, "NonUniformResourceIndex(")?;
3454 }
3455 if let Some(ref info) = array_sampler_info {
3456 write!(
3457 self.out,
3458 "{}[{} + ",
3459 info.sampler_index_buffer_name, info.binding_array_base_index_name,
3460 )?;
3461 }
3462 self.write_expr(module, index, func_ctx)?;
3463 if array_sampler_info.is_some() {
3464 write!(self.out, "]")?;
3465 }
3466 if non_uniform_qualifier {
3467 write!(self.out, ")")?;
3468 }
3469 }
3470
3471 write!(self.out, "]")?;
3472 }
3473 }
3474 Expression::AccessIndex { base, index } => {
3475 if let Some(crate::AddressSpace::Storage { .. }) =
3476 func_ctx.resolve_type(expr, &module.types).pointer_space()
3477 {
3478 } else {
3480 if let Some(MatrixType {
3484 rows: crate::VectorSize::Bi,
3485 ..
3486 }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
3487 .or_else(|| {
3488 get_inner_matrix_of_global_uniform(module, base, func_ctx, true)
3489 })
3490 {
3491 self.write_expr(module, base, func_ctx)?;
3492 write!(self.out, "._{index}")?;
3493 return Ok(());
3494 }
3495
3496 let base_ty_res = &func_ctx.info[base].ty;
3497 let mut resolved = base_ty_res.inner_with(&module.types);
3498 let base_ty_handle = match *resolved {
3499 TypeInner::Pointer { base, .. } => {
3500 resolved = &module.types[base].inner;
3501 Some(base)
3502 }
3503 _ => base_ty_res.handle(),
3504 };
3505
3506 if let TypeInner::Struct { ref members, .. } = *resolved {
3512 let member = &members[index as usize];
3513
3514 match module.types[member.ty].inner {
3515 TypeInner::Matrix {
3516 rows: crate::VectorSize::Bi,
3517 ..
3518 } if member.binding.is_none() => {
3519 let ty = base_ty_handle.unwrap();
3520 self.write_wrapped_struct_matrix_get_function_name(
3521 WrappedStructMatrixAccess { ty, index },
3522 )?;
3523 write!(self.out, "(")?;
3524 self.write_expr(module, base, func_ctx)?;
3525 write!(self.out, ")")?;
3526 return Ok(());
3527 }
3528 _ => {}
3529 }
3530 }
3531
3532 let array_sampler_info = self.sampler_binding_array_info_from_expression(
3533 module, func_ctx, base, resolved,
3534 );
3535
3536 if let Some(ref info) = array_sampler_info {
3537 write!(
3538 self.out,
3539 "{}[{}",
3540 info.sampler_heap_name, info.sampler_index_buffer_name
3541 )?;
3542 }
3543
3544 self.write_expr(module, base, func_ctx)?;
3545
3546 match *resolved {
3547 TypeInner::Vector { .. } | TypeInner::ValuePointer { .. } => {
3553 write!(self.out, ".{}", back::COMPONENTS[index as usize])?
3555 }
3556 TypeInner::Matrix { .. }
3557 | TypeInner::Array { .. }
3558 | TypeInner::BindingArray { .. } => {
3559 if let Some(ref info) = array_sampler_info {
3560 write!(
3561 self.out,
3562 "[{} + {index}]",
3563 info.binding_array_base_index_name
3564 )?;
3565 } else {
3566 write!(self.out, "[{index}]")?;
3567 }
3568 }
3569 TypeInner::Struct { .. } => {
3570 let ty = base_ty_handle.unwrap();
3573
3574 write!(
3575 self.out,
3576 ".{}",
3577 &self.names[&NameKey::StructMember(ty, index)]
3578 )?
3579 }
3580 ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))),
3581 }
3582
3583 if array_sampler_info.is_some() {
3584 write!(self.out, "]")?;
3585 }
3586 }
3587 }
3588 Expression::FunctionArgument(pos) => {
3589 let ty = func_ctx.resolve_type(expr, &module.types);
3590
3591 if let TypeInner::Image {
3597 class: crate::ImageClass::External,
3598 ..
3599 } = *ty
3600 {
3601 let plane_names = [0, 1, 2].map(|i| {
3602 &self.names[&func_ctx
3603 .external_texture_argument_key(pos, ExternalTextureNameKey::Plane(i))]
3604 });
3605 let params_name = &self.names[&func_ctx
3606 .external_texture_argument_key(pos, ExternalTextureNameKey::Params)];
3607 write!(
3608 self.out,
3609 "{}, {}, {}, {}",
3610 plane_names[0], plane_names[1], plane_names[2], params_name
3611 )?;
3612 } else {
3613 let key = func_ctx.argument_key(pos);
3614 let name = &self.names[&key];
3615 write!(self.out, "{name}")?;
3616 }
3617 }
3618 Expression::ImageSample {
3619 coordinate,
3620 image,
3621 sampler,
3622 clamp_to_edge: true,
3623 gather: None,
3624 array_index: None,
3625 offset: None,
3626 level: crate::SampleLevel::Zero,
3627 depth_ref: None,
3628 } => {
3629 write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?;
3630 self.write_expr(module, image, func_ctx)?;
3631 write!(self.out, ", ")?;
3632 self.write_expr(module, sampler, func_ctx)?;
3633 write!(self.out, ", ")?;
3634 self.write_expr(module, coordinate, func_ctx)?;
3635 write!(self.out, ")")?;
3636 }
3637 Expression::ImageSample {
3638 image,
3639 sampler,
3640 gather,
3641 coordinate,
3642 array_index,
3643 offset,
3644 level,
3645 depth_ref,
3646 clamp_to_edge,
3647 } => {
3648 if clamp_to_edge {
3649 return Err(Error::Custom(
3650 "ImageSample::clamp_to_edge should have been validated out".to_string(),
3651 ));
3652 }
3653
3654 use crate::SampleLevel as Sl;
3655 const COMPONENTS: [&str; 4] = ["", "Green", "Blue", "Alpha"];
3656
3657 let (base_str, component_str) = match gather {
3658 Some(component) => ("Gather", COMPONENTS[component as usize]),
3659 None => ("Sample", ""),
3660 };
3661 let cmp_str = match depth_ref {
3662 Some(_) => "Cmp",
3663 None => "",
3664 };
3665 let level_str = match level {
3666 Sl::Zero if gather.is_none() => "LevelZero",
3667 Sl::Auto | Sl::Zero => "",
3668 Sl::Exact(_) => "Level",
3669 Sl::Bias(_) => "Bias",
3670 Sl::Gradient { .. } => "Grad",
3671 };
3672
3673 self.write_expr(module, image, func_ctx)?;
3674 write!(self.out, ".{base_str}{cmp_str}{component_str}{level_str}(")?;
3675 self.write_expr(module, sampler, func_ctx)?;
3676 write!(self.out, ", ")?;
3677 self.write_texture_coordinates(
3678 "float",
3679 coordinate,
3680 array_index,
3681 None,
3682 module,
3683 func_ctx,
3684 )?;
3685
3686 if let Some(depth_ref) = depth_ref {
3687 write!(self.out, ", ")?;
3688 self.write_expr(module, depth_ref, func_ctx)?;
3689 }
3690
3691 match level {
3692 Sl::Auto | Sl::Zero => {}
3693 Sl::Exact(expr) => {
3694 write!(self.out, ", ")?;
3695 self.write_expr(module, expr, func_ctx)?;
3696 }
3697 Sl::Bias(expr) => {
3698 write!(self.out, ", ")?;
3699 self.write_expr(module, expr, func_ctx)?;
3700 }
3701 Sl::Gradient { x, y } => {
3702 write!(self.out, ", ")?;
3703 self.write_expr(module, x, func_ctx)?;
3704 write!(self.out, ", ")?;
3705 self.write_expr(module, y, func_ctx)?;
3706 }
3707 }
3708
3709 if let Some(offset) = offset {
3710 write!(self.out, ", ")?;
3711 write!(self.out, "int2(")?; self.write_const_expression(module, offset, func_ctx.expressions)?;
3713 write!(self.out, ")")?;
3714 }
3715
3716 write!(self.out, ")")?;
3717 }
3718 Expression::ImageQuery { image, query } => {
3719 if let TypeInner::Image {
3721 dim,
3722 arrayed,
3723 class,
3724 } = *func_ctx.resolve_type(image, &module.types)
3725 {
3726 let wrapped_image_query = WrappedImageQuery {
3727 dim,
3728 arrayed,
3729 class,
3730 query: query.into(),
3731 };
3732
3733 self.write_wrapped_image_query_function_name(wrapped_image_query)?;
3734 write!(self.out, "(")?;
3735 self.write_expr(module, image, func_ctx)?;
3737 if let crate::ImageQuery::Size { level: Some(level) } = query {
3738 write!(self.out, ", ")?;
3739 self.write_expr(module, level, func_ctx)?;
3740 }
3741 write!(self.out, ")")?;
3742 }
3743 }
3744 Expression::ImageLoad {
3745 image,
3746 coordinate,
3747 array_index,
3748 sample,
3749 level,
3750 } => self.write_image_load(
3751 &module,
3752 expr,
3753 func_ctx,
3754 image,
3755 coordinate,
3756 array_index,
3757 sample,
3758 level,
3759 )?,
3760 Expression::GlobalVariable(handle) => {
3761 let global_variable = &module.global_variables[handle];
3762 let ty = &module.types[global_variable.ty].inner;
3763
3764 let is_binding_array_of_samplers = match *ty {
3769 TypeInner::BindingArray { base, .. } => {
3770 let base_ty = &module.types[base].inner;
3771 matches!(*base_ty, TypeInner::Sampler { .. })
3772 }
3773 _ => false,
3774 };
3775
3776 let is_storage_space =
3777 matches!(global_variable.space, crate::AddressSpace::Storage { .. });
3778
3779 if let TypeInner::Image {
3787 class: crate::ImageClass::External,
3788 ..
3789 } = *ty
3790 {
3791 let plane_names = [0, 1, 2].map(|i| {
3792 &self.names[&NameKey::ExternalTextureGlobalVariable(
3793 handle,
3794 ExternalTextureNameKey::Plane(i),
3795 )]
3796 });
3797 let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
3798 handle,
3799 ExternalTextureNameKey::Params,
3800 )];
3801 write!(
3802 self.out,
3803 "{}, {}, {}, {}",
3804 plane_names[0], plane_names[1], plane_names[2], params_name
3805 )?;
3806 } else if !is_binding_array_of_samplers && !is_storage_space {
3807 let name = &self.names[&NameKey::GlobalVariable(handle)];
3808 write!(self.out, "{name}")?;
3809 }
3810 }
3811 Expression::LocalVariable(handle) => {
3812 write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])?
3813 }
3814 Expression::Load { pointer } => {
3815 match func_ctx
3816 .resolve_type(pointer, &module.types)
3817 .pointer_space()
3818 {
3819 Some(crate::AddressSpace::Storage { .. }) => {
3820 let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
3821 let result_ty = func_ctx.info[expr].ty.clone();
3822 self.write_storage_load(module, var_handle, result_ty, func_ctx)?;
3823 }
3824 _ => {
3825 let mut close_paren = false;
3826
3827 if let Some(MatrixType {
3832 rows: crate::VectorSize::Bi,
3833 ..
3834 }) = get_inner_matrix_of_struct_array_member(
3835 module, pointer, func_ctx, false,
3836 )
3837 .or_else(|| {
3838 get_inner_matrix_of_global_uniform(module, pointer, func_ctx, false)
3839 }) {
3840 let mut resolved = func_ctx.resolve_type(pointer, &module.types);
3841 let ptr_tr = resolved.pointer_base_type();
3842 if let Some(ptr_ty) =
3843 ptr_tr.as_ref().map(|tr| tr.inner_with(&module.types))
3844 {
3845 resolved = ptr_ty;
3846 }
3847
3848 write!(self.out, "((")?;
3849 if let TypeInner::Array { base, size, .. } = *resolved {
3850 self.write_type(module, base)?;
3851 self.write_array_size(module, base, size)?;
3852 } else {
3853 self.write_value_type(module, resolved)?;
3854 }
3855 write!(self.out, ")")?;
3856 close_paren = true;
3857 }
3858
3859 self.write_expr(module, pointer, func_ctx)?;
3860
3861 if close_paren {
3862 write!(self.out, ")")?;
3863 }
3864 }
3865 }
3866 }
3867 Expression::Unary { op, expr } => {
3868 let op_str = match op {
3870 crate::UnaryOperator::Negate => {
3871 match func_ctx.resolve_type(expr, &module.types).scalar() {
3872 Some(Scalar::I32) => NEG_FUNCTION,
3873 _ => "-",
3874 }
3875 }
3876 crate::UnaryOperator::LogicalNot => "!",
3877 crate::UnaryOperator::BitwiseNot => "~",
3878 };
3879 write!(self.out, "{op_str}(")?;
3880 self.write_expr(module, expr, func_ctx)?;
3881 write!(self.out, ")")?;
3882 }
3883 Expression::As {
3884 expr,
3885 kind,
3886 convert,
3887 } => {
3888 let inner = func_ctx.resolve_type(expr, &module.types);
3889 if inner.scalar_kind() == Some(ScalarKind::Float)
3890 && (kind == ScalarKind::Sint || kind == ScalarKind::Uint)
3891 && convert.is_some()
3892 && matches!(convert, Some(4) | Some(8))
3893 {
3894 let fun_name = match (kind, convert) {
3898 (ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
3899 (ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
3900 (ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
3901 (ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
3902 _ => unreachable!(),
3903 };
3904 write!(self.out, "{fun_name}(")?;
3905 self.write_expr(module, expr, func_ctx)?;
3906 write!(self.out, ")")?;
3907 } else {
3908 let close_paren = match convert {
3909 Some(dst_width) => {
3910 let scalar = Scalar {
3911 kind,
3912 width: dst_width,
3913 };
3914 match *inner {
3915 TypeInner::Vector { size, .. } => {
3916 write!(
3917 self.out,
3918 "{}{}(",
3919 scalar.to_hlsl_str()?,
3920 common::vector_size_str(size)
3921 )?;
3922 }
3923 TypeInner::Scalar(_) => {
3924 write!(self.out, "{}(", scalar.to_hlsl_str()?,)?;
3925 }
3926 TypeInner::Matrix { columns, rows, .. } => {
3927 write!(
3928 self.out,
3929 "{}{}x{}(",
3930 scalar.to_hlsl_str()?,
3931 common::vector_size_str(columns),
3932 common::vector_size_str(rows)
3933 )?;
3934 }
3935 _ => {
3936 return Err(Error::Unimplemented(format!(
3937 "write_expr expression::as {inner:?}"
3938 )));
3939 }
3940 };
3941 true
3942 }
3943 None => {
3944 if inner.scalar_width() == Some(8) {
3945 false
3946 } else if inner.scalar_width() == Some(2) {
3947 let dst_scalar = Scalar { kind, width: 2 };
3950 match *inner {
3951 TypeInner::Vector { size, .. } => {
3952 write!(
3953 self.out,
3954 "{}{}(",
3955 dst_scalar.to_hlsl_str()?,
3956 common::vector_size_str(size)
3957 )?;
3958 }
3959 _ => {
3960 write!(self.out, "{}(", dst_scalar.to_hlsl_str()?)?;
3961 }
3962 };
3963 true
3964 } else {
3965 write!(self.out, "{}(", kind.to_hlsl_cast(),)?;
3966 true
3967 }
3968 }
3969 };
3970 self.write_expr(module, expr, func_ctx)?;
3971 if close_paren {
3972 write!(self.out, ")")?;
3973 }
3974 }
3975 }
3976 Expression::Math {
3977 fun,
3978 arg,
3979 arg1,
3980 arg2,
3981 arg3,
3982 } => {
3983 use crate::MathFunction as Mf;
3984
3985 enum Function {
3986 Asincosh { is_sin: bool },
3987 Atanh,
3988 Pack2x16float,
3989 Pack2x16snorm,
3990 Pack2x16unorm,
3991 Pack4x8snorm,
3992 Pack4x8unorm,
3993 Pack4xI8,
3994 Pack4xU8,
3995 Pack4xI8Clamp,
3996 Pack4xU8Clamp,
3997 Unpack2x16float,
3998 Unpack2x16snorm,
3999 Unpack2x16unorm,
4000 Unpack4x8snorm,
4001 Unpack4x8unorm,
4002 Unpack4xI8,
4003 Unpack4xU8,
4004 Dot4I8Packed,
4005 Dot4U8Packed,
4006 QuantizeToF16,
4007 Regular(&'static str),
4008 MissingIntOverload(&'static str),
4009 MissingIntReturnType(&'static str),
4010 CountTrailingZeros,
4011 CountLeadingZeros,
4012 }
4013
4014 let fun = match fun {
4015 Mf::Abs => match func_ctx.resolve_type(arg, &module.types).scalar() {
4017 Some(Scalar::I32) => Function::Regular(ABS_FUNCTION),
4018 _ => Function::Regular("abs"),
4019 },
4020 Mf::Min => Function::Regular("min"),
4021 Mf::Max => Function::Regular("max"),
4022 Mf::Clamp => Function::Regular("clamp"),
4023 Mf::Saturate => Function::Regular("saturate"),
4024 Mf::Cos => Function::Regular("cos"),
4026 Mf::Cosh => Function::Regular("cosh"),
4027 Mf::Sin => Function::Regular("sin"),
4028 Mf::Sinh => Function::Regular("sinh"),
4029 Mf::Tan => Function::Regular("tan"),
4030 Mf::Tanh => Function::Regular("tanh"),
4031 Mf::Acos => Function::Regular("acos"),
4032 Mf::Asin => Function::Regular("asin"),
4033 Mf::Atan => Function::Regular("atan"),
4034 Mf::Atan2 => Function::Regular("atan2"),
4035 Mf::Asinh => Function::Asincosh { is_sin: true },
4036 Mf::Acosh => Function::Asincosh { is_sin: false },
4037 Mf::Atanh => Function::Atanh,
4038 Mf::Radians => Function::Regular("radians"),
4039 Mf::Degrees => Function::Regular("degrees"),
4040 Mf::Ceil => Function::Regular("ceil"),
4042 Mf::Floor => Function::Regular("floor"),
4043 Mf::Round => Function::Regular("round"),
4044 Mf::Fract => Function::Regular("frac"),
4045 Mf::Trunc => Function::Regular("trunc"),
4046 Mf::Modf => Function::Regular(MODF_FUNCTION),
4047 Mf::Frexp => Function::Regular(FREXP_FUNCTION),
4048 Mf::Ldexp => Function::Regular("ldexp"),
4049 Mf::Exp => Function::Regular("exp"),
4051 Mf::Exp2 => Function::Regular("exp2"),
4052 Mf::Log => Function::Regular("log"),
4053 Mf::Log2 => Function::Regular("log2"),
4054 Mf::Pow => Function::Regular("pow"),
4055 Mf::Dot => Function::Regular("dot"),
4057 Mf::Dot4I8Packed => Function::Dot4I8Packed,
4058 Mf::Dot4U8Packed => Function::Dot4U8Packed,
4059 Mf::Cross => Function::Regular("cross"),
4061 Mf::Distance => Function::Regular("distance"),
4062 Mf::Length => Function::Regular("length"),
4063 Mf::Normalize => Function::Regular("normalize"),
4064 Mf::FaceForward => Function::Regular("faceforward"),
4065 Mf::Reflect => Function::Regular("reflect"),
4066 Mf::Refract => Function::Regular("refract"),
4067 Mf::Sign => Function::Regular("sign"),
4069 Mf::Fma => Function::Regular("mad"),
4070 Mf::Mix => Function::Regular("lerp"),
4071 Mf::Step => Function::Regular("step"),
4072 Mf::SmoothStep => Function::Regular("smoothstep"),
4073 Mf::Sqrt => Function::Regular("sqrt"),
4074 Mf::InverseSqrt => Function::Regular("rsqrt"),
4075 Mf::Transpose => Function::Regular("transpose"),
4077 Mf::Determinant => Function::Regular("determinant"),
4078 Mf::QuantizeToF16 => Function::QuantizeToF16,
4079 Mf::CountTrailingZeros => Function::CountTrailingZeros,
4081 Mf::CountLeadingZeros => Function::CountLeadingZeros,
4082 Mf::CountOneBits => Function::MissingIntOverload("countbits"),
4083 Mf::ReverseBits => Function::MissingIntOverload("reversebits"),
4084 Mf::FirstTrailingBit => Function::MissingIntReturnType("firstbitlow"),
4085 Mf::FirstLeadingBit => Function::MissingIntReturnType("firstbithigh"),
4086 Mf::ExtractBits => Function::Regular(EXTRACT_BITS_FUNCTION),
4087 Mf::InsertBits => Function::Regular(INSERT_BITS_FUNCTION),
4088 Mf::Pack2x16float => Function::Pack2x16float,
4090 Mf::Pack2x16snorm => Function::Pack2x16snorm,
4091 Mf::Pack2x16unorm => Function::Pack2x16unorm,
4092 Mf::Pack4x8snorm => Function::Pack4x8snorm,
4093 Mf::Pack4x8unorm => Function::Pack4x8unorm,
4094 Mf::Pack4xI8 => Function::Pack4xI8,
4095 Mf::Pack4xU8 => Function::Pack4xU8,
4096 Mf::Pack4xI8Clamp => Function::Pack4xI8Clamp,
4097 Mf::Pack4xU8Clamp => Function::Pack4xU8Clamp,
4098 Mf::Unpack2x16float => Function::Unpack2x16float,
4100 Mf::Unpack2x16snorm => Function::Unpack2x16snorm,
4101 Mf::Unpack2x16unorm => Function::Unpack2x16unorm,
4102 Mf::Unpack4x8snorm => Function::Unpack4x8snorm,
4103 Mf::Unpack4x8unorm => Function::Unpack4x8unorm,
4104 Mf::Unpack4xI8 => Function::Unpack4xI8,
4105 Mf::Unpack4xU8 => Function::Unpack4xU8,
4106 _ => return Err(Error::Unimplemented(format!("write_expr_math {fun:?}"))),
4107 };
4108
4109 match fun {
4110 Function::Asincosh { is_sin } => {
4111 write!(self.out, "log(")?;
4112 self.write_expr(module, arg, func_ctx)?;
4113 write!(self.out, " + sqrt(")?;
4114 self.write_expr(module, arg, func_ctx)?;
4115 write!(self.out, " * ")?;
4116 self.write_expr(module, arg, func_ctx)?;
4117 match is_sin {
4118 true => write!(self.out, " + 1.0))")?,
4119 false => write!(self.out, " - 1.0))")?,
4120 }
4121 }
4122 Function::Atanh => {
4123 write!(self.out, "0.5 * log((1.0 + ")?;
4124 self.write_expr(module, arg, func_ctx)?;
4125 write!(self.out, ") / (1.0 - ")?;
4126 self.write_expr(module, arg, func_ctx)?;
4127 write!(self.out, "))")?;
4128 }
4129 Function::Pack2x16float => {
4130 write!(self.out, "(f32tof16(")?;
4131 self.write_expr(module, arg, func_ctx)?;
4132 write!(self.out, "[0]) | f32tof16(")?;
4133 self.write_expr(module, arg, func_ctx)?;
4134 write!(self.out, "[1]) << 16)")?;
4135 }
4136 Function::Pack2x16snorm => {
4137 let scale = 32767;
4138
4139 write!(self.out, "uint((int(round(clamp(")?;
4140 self.write_expr(module, arg, func_ctx)?;
4141 write!(
4142 self.out,
4143 "[0], -1.0, 1.0) * {scale}.0)) & 0xFFFF) | ((int(round(clamp("
4144 )?;
4145 self.write_expr(module, arg, func_ctx)?;
4146 write!(self.out, "[1], -1.0, 1.0) * {scale}.0)) & 0xFFFF) << 16))",)?;
4147 }
4148 Function::Pack2x16unorm => {
4149 let scale = 65535;
4150
4151 write!(self.out, "(uint(round(clamp(")?;
4152 self.write_expr(module, arg, func_ctx)?;
4153 write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
4154 self.write_expr(module, arg, func_ctx)?;
4155 write!(self.out, "[1], 0.0, 1.0) * {scale}.0)) << 16)")?;
4156 }
4157 Function::Pack4x8snorm => {
4158 let scale = 127;
4159
4160 write!(self.out, "uint((int(round(clamp(")?;
4161 self.write_expr(module, arg, func_ctx)?;
4162 write!(
4163 self.out,
4164 "[0], -1.0, 1.0) * {scale}.0)) & 0xFF) | ((int(round(clamp("
4165 )?;
4166 self.write_expr(module, arg, func_ctx)?;
4167 write!(
4168 self.out,
4169 "[1], -1.0, 1.0) * {scale}.0)) & 0xFF) << 8) | ((int(round(clamp("
4170 )?;
4171 self.write_expr(module, arg, func_ctx)?;
4172 write!(
4173 self.out,
4174 "[2], -1.0, 1.0) * {scale}.0)) & 0xFF) << 16) | ((int(round(clamp("
4175 )?;
4176 self.write_expr(module, arg, func_ctx)?;
4177 write!(self.out, "[3], -1.0, 1.0) * {scale}.0)) & 0xFF) << 24))",)?;
4178 }
4179 Function::Pack4x8unorm => {
4180 let scale = 255;
4181
4182 write!(self.out, "(uint(round(clamp(")?;
4183 self.write_expr(module, arg, func_ctx)?;
4184 write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
4185 self.write_expr(module, arg, func_ctx)?;
4186 write!(
4187 self.out,
4188 "[1], 0.0, 1.0) * {scale}.0)) << 8 | uint(round(clamp("
4189 )?;
4190 self.write_expr(module, arg, func_ctx)?;
4191 write!(
4192 self.out,
4193 "[2], 0.0, 1.0) * {scale}.0)) << 16 | uint(round(clamp("
4194 )?;
4195 self.write_expr(module, arg, func_ctx)?;
4196 write!(self.out, "[3], 0.0, 1.0) * {scale}.0)) << 24)")?;
4197 }
4198 fun @ (Function::Pack4xI8
4199 | Function::Pack4xU8
4200 | Function::Pack4xI8Clamp
4201 | Function::Pack4xU8Clamp) => {
4202 let was_signed =
4203 matches!(fun, Function::Pack4xI8 | Function::Pack4xI8Clamp);
4204 let clamp_bounds = match fun {
4205 Function::Pack4xI8Clamp => Some(("-128", "127")),
4206 Function::Pack4xU8Clamp => Some(("0", "255")),
4207 _ => None,
4208 };
4209 if was_signed {
4210 write!(self.out, "uint(")?;
4211 }
4212 let write_arg = |this: &mut Self| -> BackendResult {
4213 if let Some((min, max)) = clamp_bounds {
4214 write!(this.out, "clamp(")?;
4215 this.write_expr(module, arg, func_ctx)?;
4216 write!(this.out, ", {min}, {max})")?;
4217 } else {
4218 this.write_expr(module, arg, func_ctx)?;
4219 }
4220 Ok(())
4221 };
4222 write!(self.out, "(")?;
4223 write_arg(self)?;
4224 write!(self.out, "[0] & 0xFF) | ((")?;
4225 write_arg(self)?;
4226 write!(self.out, "[1] & 0xFF) << 8) | ((")?;
4227 write_arg(self)?;
4228 write!(self.out, "[2] & 0xFF) << 16) | ((")?;
4229 write_arg(self)?;
4230 write!(self.out, "[3] & 0xFF) << 24)")?;
4231 if was_signed {
4232 write!(self.out, ")")?;
4233 }
4234 }
4235
4236 Function::Unpack2x16float => {
4237 write!(self.out, "float2(f16tof32(")?;
4238 self.write_expr(module, arg, func_ctx)?;
4239 write!(self.out, "), f16tof32((")?;
4240 self.write_expr(module, arg, func_ctx)?;
4241 write!(self.out, ") >> 16))")?;
4242 }
4243 Function::Unpack2x16snorm => {
4244 let scale = 32767;
4245
4246 write!(self.out, "(float2(int2(")?;
4247 self.write_expr(module, arg, func_ctx)?;
4248 write!(self.out, " << 16, ")?;
4249 self.write_expr(module, arg, func_ctx)?;
4250 write!(self.out, ") >> 16) / {scale}.0)")?;
4251 }
4252 Function::Unpack2x16unorm => {
4253 let scale = 65535;
4254
4255 write!(self.out, "(float2(")?;
4256 self.write_expr(module, arg, func_ctx)?;
4257 write!(self.out, " & 0xFFFF, ")?;
4258 self.write_expr(module, arg, func_ctx)?;
4259 write!(self.out, " >> 16) / {scale}.0)")?;
4260 }
4261 Function::Unpack4x8snorm => {
4262 let scale = 127;
4263
4264 write!(self.out, "(float4(int4(")?;
4265 self.write_expr(module, arg, func_ctx)?;
4266 write!(self.out, " << 24, ")?;
4267 self.write_expr(module, arg, func_ctx)?;
4268 write!(self.out, " << 16, ")?;
4269 self.write_expr(module, arg, func_ctx)?;
4270 write!(self.out, " << 8, ")?;
4271 self.write_expr(module, arg, func_ctx)?;
4272 write!(self.out, ") >> 24) / {scale}.0)")?;
4273 }
4274 Function::Unpack4x8unorm => {
4275 let scale = 255;
4276
4277 write!(self.out, "(float4(")?;
4278 self.write_expr(module, arg, func_ctx)?;
4279 write!(self.out, " & 0xFF, ")?;
4280 self.write_expr(module, arg, func_ctx)?;
4281 write!(self.out, " >> 8 & 0xFF, ")?;
4282 self.write_expr(module, arg, func_ctx)?;
4283 write!(self.out, " >> 16 & 0xFF, ")?;
4284 self.write_expr(module, arg, func_ctx)?;
4285 write!(self.out, " >> 24) / {scale}.0)")?;
4286 }
4287 fun @ (Function::Unpack4xI8 | Function::Unpack4xU8) => {
4288 write!(self.out, "(")?;
4289 if matches!(fun, Function::Unpack4xU8) {
4290 write!(self.out, "u")?;
4291 }
4292 write!(self.out, "int4(")?;
4293 self.write_expr(module, arg, func_ctx)?;
4294 write!(self.out, ", ")?;
4295 self.write_expr(module, arg, func_ctx)?;
4296 write!(self.out, " >> 8, ")?;
4297 self.write_expr(module, arg, func_ctx)?;
4298 write!(self.out, " >> 16, ")?;
4299 self.write_expr(module, arg, func_ctx)?;
4300 write!(self.out, " >> 24) << 24 >> 24)")?;
4301 }
4302 fun @ (Function::Dot4I8Packed | Function::Dot4U8Packed) => {
4303 let arg1 = arg1.unwrap();
4304
4305 if self.options.shader_model >= ShaderModel::V6_4 {
4306 let function_name = match fun {
4308 Function::Dot4I8Packed => "dot4add_i8packed",
4309 Function::Dot4U8Packed => "dot4add_u8packed",
4310 _ => unreachable!(),
4311 };
4312 write!(self.out, "{function_name}(")?;
4313 self.write_expr(module, arg, func_ctx)?;
4314 write!(self.out, ", ")?;
4315 self.write_expr(module, arg1, func_ctx)?;
4316 write!(self.out, ", 0)")?;
4317 } else {
4318 write!(self.out, "dot(")?;
4320
4321 if matches!(fun, Function::Dot4U8Packed) {
4322 write!(self.out, "u")?;
4323 }
4324 write!(self.out, "int4(")?;
4325 self.write_expr(module, arg, func_ctx)?;
4326 write!(self.out, ", ")?;
4327 self.write_expr(module, arg, func_ctx)?;
4328 write!(self.out, " >> 8, ")?;
4329 self.write_expr(module, arg, func_ctx)?;
4330 write!(self.out, " >> 16, ")?;
4331 self.write_expr(module, arg, func_ctx)?;
4332 write!(self.out, " >> 24) << 24 >> 24, ")?;
4333
4334 if matches!(fun, Function::Dot4U8Packed) {
4335 write!(self.out, "u")?;
4336 }
4337 write!(self.out, "int4(")?;
4338 self.write_expr(module, arg1, func_ctx)?;
4339 write!(self.out, ", ")?;
4340 self.write_expr(module, arg1, func_ctx)?;
4341 write!(self.out, " >> 8, ")?;
4342 self.write_expr(module, arg1, func_ctx)?;
4343 write!(self.out, " >> 16, ")?;
4344 self.write_expr(module, arg1, func_ctx)?;
4345 write!(self.out, " >> 24) << 24 >> 24)")?;
4346 }
4347 }
4348 Function::QuantizeToF16 => {
4349 write!(self.out, "f16tof32(f32tof16(")?;
4350 self.write_expr(module, arg, func_ctx)?;
4351 write!(self.out, "))")?;
4352 }
4353 Function::Regular(fun_name) => {
4354 write!(self.out, "{fun_name}(")?;
4355 self.write_expr(module, arg, func_ctx)?;
4356 if let Some(arg) = arg1 {
4357 write!(self.out, ", ")?;
4358 self.write_expr(module, arg, func_ctx)?;
4359 }
4360 if let Some(arg) = arg2 {
4361 write!(self.out, ", ")?;
4362 self.write_expr(module, arg, func_ctx)?;
4363 }
4364 if let Some(arg) = arg3 {
4365 write!(self.out, ", ")?;
4366 self.write_expr(module, arg, func_ctx)?;
4367 }
4368 write!(self.out, ")")?
4369 }
4370 Function::MissingIntOverload(fun_name) => {
4373 let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
4374 if let Some(Scalar::I32) = scalar_kind {
4375 write!(self.out, "asint({fun_name}(asuint(")?;
4376 self.write_expr(module, arg, func_ctx)?;
4377 write!(self.out, ")))")?;
4378 } else {
4379 write!(self.out, "{fun_name}(")?;
4380 self.write_expr(module, arg, func_ctx)?;
4381 write!(self.out, ")")?;
4382 }
4383 }
4384 Function::MissingIntReturnType(fun_name) => {
4387 let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
4388 if let Some(Scalar::I32) = scalar_kind {
4389 write!(self.out, "asint({fun_name}(")?;
4390 self.write_expr(module, arg, func_ctx)?;
4391 write!(self.out, "))")?;
4392 } else {
4393 write!(self.out, "{fun_name}(")?;
4394 self.write_expr(module, arg, func_ctx)?;
4395 write!(self.out, ")")?;
4396 }
4397 }
4398 Function::CountTrailingZeros => {
4399 match *func_ctx.resolve_type(arg, &module.types) {
4400 TypeInner::Vector { size, scalar } => {
4401 let s = match size {
4402 crate::VectorSize::Bi => ".xx",
4403 crate::VectorSize::Tri => ".xxx",
4404 crate::VectorSize::Quad => ".xxxx",
4405 };
4406
4407 let scalar_width_bits = scalar.width * 8;
4408
4409 if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
4410 write!(
4411 self.out,
4412 "min(({scalar_width_bits}u){s}, firstbitlow("
4413 )?;
4414 self.write_expr(module, arg, func_ctx)?;
4415 write!(self.out, "))")?;
4416 } else {
4417 write!(
4419 self.out,
4420 "asint(min(({scalar_width_bits}u){s}, firstbitlow("
4421 )?;
4422 self.write_expr(module, arg, func_ctx)?;
4423 write!(self.out, ")))")?;
4424 }
4425 }
4426 TypeInner::Scalar(scalar) => {
4427 let scalar_width_bits = scalar.width * 8;
4428
4429 if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
4430 write!(self.out, "min({scalar_width_bits}u, firstbitlow(")?;
4431 self.write_expr(module, arg, func_ctx)?;
4432 write!(self.out, "))")?;
4433 } else {
4434 write!(
4436 self.out,
4437 "asint(min({scalar_width_bits}u, firstbitlow("
4438 )?;
4439 self.write_expr(module, arg, func_ctx)?;
4440 write!(self.out, ")))")?;
4441 }
4442 }
4443 _ => unreachable!(),
4444 }
4445
4446 return Ok(());
4447 }
4448 Function::CountLeadingZeros => {
4449 match *func_ctx.resolve_type(arg, &module.types) {
4450 TypeInner::Vector { size, scalar } => {
4451 let s = match size {
4452 crate::VectorSize::Bi => ".xx",
4453 crate::VectorSize::Tri => ".xxx",
4454 crate::VectorSize::Quad => ".xxxx",
4455 };
4456
4457 let constant = scalar.width * 8 - 1;
4459
4460 if scalar.kind == ScalarKind::Uint {
4461 write!(self.out, "(({constant}u){s} - firstbithigh(")?;
4462 self.write_expr(module, arg, func_ctx)?;
4463 write!(self.out, "))")?;
4464 } else {
4465 let conversion_func = match scalar.width {
4466 4 => "asint",
4467 _ => "",
4468 };
4469 write!(self.out, "(")?;
4470 self.write_expr(module, arg, func_ctx)?;
4471 write!(
4472 self.out,
4473 " < (0){s} ? (0){s} : ({constant}){s} - {conversion_func}(firstbithigh("
4474 )?;
4475 self.write_expr(module, arg, func_ctx)?;
4476 write!(self.out, ")))")?;
4477 }
4478 }
4479 TypeInner::Scalar(scalar) => {
4480 let constant = scalar.width * 8 - 1;
4482
4483 if let ScalarKind::Uint = scalar.kind {
4484 write!(self.out, "({constant}u - firstbithigh(")?;
4485 self.write_expr(module, arg, func_ctx)?;
4486 write!(self.out, "))")?;
4487 } else {
4488 let conversion_func = match scalar.width {
4489 4 => "asint",
4490 _ => "",
4491 };
4492 write!(self.out, "(")?;
4493 self.write_expr(module, arg, func_ctx)?;
4494 write!(
4495 self.out,
4496 " < 0 ? 0 : {constant} - {conversion_func}(firstbithigh("
4497 )?;
4498 self.write_expr(module, arg, func_ctx)?;
4499 write!(self.out, ")))")?;
4500 }
4501 }
4502 _ => unreachable!(),
4503 }
4504
4505 return Ok(());
4506 }
4507 }
4508 }
4509 Expression::Swizzle {
4510 size,
4511 vector,
4512 pattern,
4513 } => {
4514 self.write_expr(module, vector, func_ctx)?;
4515 write!(self.out, ".")?;
4516 for &sc in pattern[..size as usize].iter() {
4517 self.out.write_char(back::COMPONENTS[sc as usize])?;
4518 }
4519 }
4520 Expression::ArrayLength(expr) => {
4521 let var_handle = match func_ctx.expressions[expr] {
4522 Expression::AccessIndex { base, index: _ } => {
4523 match func_ctx.expressions[base] {
4524 Expression::GlobalVariable(handle) => handle,
4525 _ => unreachable!(),
4526 }
4527 }
4528 Expression::GlobalVariable(handle) => handle,
4529 _ => unreachable!(),
4530 };
4531
4532 let var = &module.global_variables[var_handle];
4533 let (offset, stride) = match module.types[var.ty].inner {
4534 TypeInner::Array { stride, .. } => (0, stride),
4535 TypeInner::Struct { ref members, .. } => {
4536 let last = members.last().unwrap();
4537 let stride = match module.types[last.ty].inner {
4538 TypeInner::Array { stride, .. } => stride,
4539 _ => unreachable!(),
4540 };
4541 (last.offset, stride)
4542 }
4543 _ => unreachable!(),
4544 };
4545
4546 let storage_access = match var.space {
4547 crate::AddressSpace::Storage { access } => access,
4548 _ => crate::StorageAccess::default(),
4549 };
4550 let wrapped_array_length = WrappedArrayLength {
4551 writable: storage_access.contains(crate::StorageAccess::STORE),
4552 };
4553
4554 write!(self.out, "((")?;
4555 self.write_wrapped_array_length_function_name(wrapped_array_length)?;
4556 let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
4557 write!(self.out, "({var_name}) - {offset}) / {stride})")?
4558 }
4559 Expression::Derivative { axis, ctrl, expr } => {
4560 use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
4561 if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
4562 let tail = match ctrl {
4563 Ctrl::Coarse => "coarse",
4564 Ctrl::Fine => "fine",
4565 Ctrl::None => unreachable!(),
4566 };
4567 write!(self.out, "abs(ddx_{tail}(")?;
4568 self.write_expr(module, expr, func_ctx)?;
4569 write!(self.out, ")) + abs(ddy_{tail}(")?;
4570 self.write_expr(module, expr, func_ctx)?;
4571 write!(self.out, "))")?
4572 } else {
4573 let fun_str = match (axis, ctrl) {
4574 (Axis::X, Ctrl::Coarse) => "ddx_coarse",
4575 (Axis::X, Ctrl::Fine) => "ddx_fine",
4576 (Axis::X, Ctrl::None) => "ddx",
4577 (Axis::Y, Ctrl::Coarse) => "ddy_coarse",
4578 (Axis::Y, Ctrl::Fine) => "ddy_fine",
4579 (Axis::Y, Ctrl::None) => "ddy",
4580 (Axis::Width, Ctrl::Coarse | Ctrl::Fine) => unreachable!(),
4581 (Axis::Width, Ctrl::None) => "fwidth",
4582 };
4583 write!(self.out, "{fun_str}(")?;
4584 self.write_expr(module, expr, func_ctx)?;
4585 write!(self.out, ")")?
4586 }
4587 }
4588 Expression::Relational { fun, argument } => {
4589 use crate::RelationalFunction as Rf;
4590
4591 let fun_str = match fun {
4592 Rf::All => "all",
4593 Rf::Any => "any",
4594 Rf::IsNan => "isnan",
4595 Rf::IsInf => "isinf",
4596 };
4597 write!(self.out, "{fun_str}(")?;
4598 self.write_expr(module, argument, func_ctx)?;
4599 write!(self.out, ")")?
4600 }
4601 Expression::Select {
4602 condition,
4603 accept,
4604 reject,
4605 } => {
4606 write!(self.out, "(")?;
4607 self.write_expr(module, condition, func_ctx)?;
4608 write!(self.out, " ? ")?;
4609 self.write_expr(module, accept, func_ctx)?;
4610 write!(self.out, " : ")?;
4611 self.write_expr(module, reject, func_ctx)?;
4612 write!(self.out, ")")?
4613 }
4614 Expression::RayQueryGetIntersection { query, committed } => {
4615 let Expression::LocalVariable(query_var) = func_ctx.expressions[query] else {
4617 unreachable!()
4618 };
4619
4620 let tracker_expr_name = format!(
4621 "{RAY_QUERY_TRACKER_VARIABLE_PREFIX}{}",
4622 self.names[&func_ctx.name_key(query_var)]
4623 );
4624
4625 if committed {
4626 write!(self.out, "GetCommittedIntersection(")?;
4627 self.write_expr(module, query, func_ctx)?;
4628 write!(self.out, ", {tracker_expr_name})")?;
4629 } else {
4630 write!(self.out, "GetCandidateIntersection(")?;
4631 self.write_expr(module, query, func_ctx)?;
4632 write!(self.out, ", {tracker_expr_name})")?;
4633 }
4634 }
4635 Expression::RayQueryVertexPositions { .. }
4637 | Expression::CooperativeLoad { .. }
4638 | Expression::CooperativeMultiplyAdd { .. } => {
4639 unreachable!()
4640 }
4641 Expression::CallResult(_)
4643 | Expression::AtomicResult { .. }
4644 | Expression::WorkGroupUniformLoadResult { .. }
4645 | Expression::RayQueryProceedResult
4646 | Expression::SubgroupBallotResult
4647 | Expression::SubgroupOperationResult { .. } => {}
4648 }
4649
4650 if !closing_bracket.is_empty() {
4651 write!(self.out, "{closing_bracket}")?;
4652 }
4653 Ok(())
4654 }
4655
4656 #[allow(clippy::too_many_arguments)]
4657 fn write_image_load(
4658 &mut self,
4659 module: &&Module,
4660 expr: Handle<crate::Expression>,
4661 func_ctx: &back::FunctionCtx,
4662 image: Handle<crate::Expression>,
4663 coordinate: Handle<crate::Expression>,
4664 array_index: Option<Handle<crate::Expression>>,
4665 sample: Option<Handle<crate::Expression>>,
4666 level: Option<Handle<crate::Expression>>,
4667 ) -> Result<(), Error> {
4668 let mut wrapping_type = None;
4669 match *func_ctx.resolve_type(image, &module.types) {
4670 TypeInner::Image {
4671 class: crate::ImageClass::External,
4672 ..
4673 } => {
4674 write!(self.out, "{IMAGE_LOAD_EXTERNAL_FUNCTION}(")?;
4675 self.write_expr(module, image, func_ctx)?;
4676 write!(self.out, ", ")?;
4677 self.write_expr(module, coordinate, func_ctx)?;
4678 write!(self.out, ")")?;
4679 return Ok(());
4680 }
4681 TypeInner::Image {
4682 class: crate::ImageClass::Storage { format, .. },
4683 ..
4684 } => {
4685 if format.single_component() {
4686 wrapping_type = Some(Scalar::from(format));
4687 }
4688 }
4689 _ => {}
4690 }
4691 if let Some(scalar) = wrapping_type {
4692 write!(
4693 self.out,
4694 "{}{}(",
4695 help::IMAGE_STORAGE_LOAD_SCALAR_WRAPPER,
4696 scalar.to_hlsl_str()?
4697 )?;
4698 }
4699 self.write_expr(module, image, func_ctx)?;
4701 write!(self.out, ".Load(")?;
4702
4703 self.write_texture_coordinates("int", coordinate, array_index, level, module, func_ctx)?;
4704
4705 if let Some(sample) = sample {
4706 write!(self.out, ", ")?;
4707 self.write_expr(module, sample, func_ctx)?;
4708 }
4709
4710 write!(self.out, ")")?;
4712
4713 if wrapping_type.is_some() {
4714 write!(self.out, ")")?;
4715 }
4716
4717 if let TypeInner::Scalar(_) = *func_ctx.resolve_type(expr, &module.types) {
4719 write!(self.out, ".x")?;
4720 }
4721 Ok(())
4722 }
4723
4724 fn sampler_binding_array_info_from_expression(
4727 &mut self,
4728 module: &Module,
4729 func_ctx: &back::FunctionCtx<'_>,
4730 base: Handle<crate::Expression>,
4731 resolved: &TypeInner,
4732 ) -> Option<BindingArraySamplerInfo> {
4733 if let TypeInner::BindingArray {
4734 base: base_ty_handle,
4735 ..
4736 } = *resolved
4737 {
4738 let base_ty = &module.types[base_ty_handle].inner;
4739 if let TypeInner::Sampler { comparison, .. } = *base_ty {
4740 let base = &func_ctx.expressions[base];
4741
4742 if let crate::Expression::GlobalVariable(handle) = *base {
4743 let variable = &module.global_variables[handle];
4744
4745 let sampler_heap_name = match comparison {
4746 true => COMPARISON_SAMPLER_HEAP_VAR,
4747 false => SAMPLER_HEAP_VAR,
4748 };
4749
4750 return Some(BindingArraySamplerInfo {
4751 sampler_heap_name,
4752 sampler_index_buffer_name: self
4753 .wrapped
4754 .sampler_index_buffers
4755 .get(&super::SamplerIndexBufferKey {
4756 group: variable.binding.unwrap().group,
4757 })
4758 .unwrap()
4759 .clone(),
4760 binding_array_base_index_name: self.names[&NameKey::GlobalVariable(handle)]
4761 .clone(),
4762 });
4763 }
4764 }
4765 }
4766
4767 None
4768 }
4769
4770 fn write_named_expr(
4771 &mut self,
4772 module: &Module,
4773 handle: Handle<crate::Expression>,
4774 name: String,
4775 expr: Handle<crate::Expression>,
4778 func_ctx: &back::FunctionCtx,
4779 ) -> BackendResult {
4780 if let crate::Expression::Load { pointer } = func_ctx.expressions[expr] {
4781 let ty_inner = func_ctx.resolve_type(pointer, &module.types);
4782 if ty_inner.is_atomic_pointer(&module.types) {
4783 let pointer_space = ty_inner.pointer_space().unwrap();
4784 self.write_value_type(module, func_ctx.info[handle].ty.inner_with(&module.types))?;
4785 write!(self.out, " {name}; ")?;
4786 match pointer_space {
4787 crate::AddressSpace::WorkGroup => {
4788 write!(self.out, "InterlockedOr(")?;
4789 self.write_expr(module, pointer, func_ctx)?;
4790 }
4791 crate::AddressSpace::Storage { .. } => {
4792 let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
4793 let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
4794 write!(self.out, "{var_name}.InterlockedOr(")?;
4795 let chain = mem::take(&mut self.temp_access_chain);
4796 self.write_storage_address(module, &chain, func_ctx)?;
4797 self.temp_access_chain = chain;
4798 }
4799 _ => unreachable!(),
4800 }
4801 writeln!(self.out, ", 0, {name});")?;
4802 self.named_expressions.insert(expr, name);
4803 return Ok(());
4804 }
4805 }
4806 match func_ctx.info[expr].ty {
4807 proc::TypeResolution::Handle(ty_handle) => match module.types[ty_handle].inner {
4808 TypeInner::Struct { .. } => {
4809 let ty_name = &self.names[&NameKey::Type(ty_handle)];
4810 write!(self.out, "{ty_name}")?;
4811 }
4812 _ => {
4813 self.write_type(module, ty_handle)?;
4814 }
4815 },
4816 proc::TypeResolution::Value(ref inner) => {
4817 self.write_value_type(module, inner)?;
4818 }
4819 }
4820
4821 let resolved = func_ctx.resolve_type(expr, &module.types);
4822
4823 write!(self.out, " {name}")?;
4824 if let TypeInner::Array { base, size, .. } = *resolved {
4826 self.write_array_size(module, base, size)?;
4827 }
4828 write!(self.out, " = ")?;
4829 self.write_expr(module, handle, func_ctx)?;
4830 writeln!(self.out, ";")?;
4831 self.named_expressions.insert(expr, name);
4832
4833 Ok(())
4834 }
4835
4836 pub(super) fn write_default_init(
4838 &mut self,
4839 module: &Module,
4840 ty: Handle<crate::Type>,
4841 ) -> BackendResult {
4842 write!(self.out, "(")?;
4843 self.write_type(module, ty)?;
4844 if let TypeInner::Array { base, size, .. } = module.types[ty].inner {
4845 self.write_array_size(module, base, size)?;
4846 }
4847 write!(self.out, ")0")?;
4848 Ok(())
4849 }
4850
4851 pub(super) fn write_control_barrier(
4852 &mut self,
4853 barrier: crate::Barrier,
4854 level: back::Level,
4855 ) -> BackendResult {
4856 if barrier.contains(crate::Barrier::STORAGE) {
4857 writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
4858 }
4859 if barrier.contains(crate::Barrier::WORK_GROUP) {
4860 writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?;
4861 }
4862 if barrier.contains(crate::Barrier::SUB_GROUP) {
4863 }
4865 if barrier.contains(crate::Barrier::TEXTURE) {
4866 writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
4867 }
4868 Ok(())
4869 }
4870
4871 fn write_memory_barrier(
4872 &mut self,
4873 barrier: crate::Barrier,
4874 level: back::Level,
4875 ) -> BackendResult {
4876 if barrier.contains(crate::Barrier::STORAGE) {
4877 writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
4878 }
4879 if barrier.contains(crate::Barrier::WORK_GROUP) {
4880 writeln!(self.out, "{level}GroupMemoryBarrier();")?;
4881 }
4882 if barrier.contains(crate::Barrier::SUB_GROUP) {
4883 }
4885 if barrier.contains(crate::Barrier::TEXTURE) {
4886 writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
4887 }
4888 Ok(())
4889 }
4890
4891 fn emit_hlsl_atomic_tail(
4893 &mut self,
4894 module: &Module,
4895 func_ctx: &back::FunctionCtx<'_>,
4896 fun: &crate::AtomicFunction,
4897 compare_expr: Option<Handle<crate::Expression>>,
4898 value: Handle<crate::Expression>,
4899 res_var_info: &Option<(Handle<crate::Expression>, String)>,
4900 ) -> BackendResult {
4901 if let Some(cmp) = compare_expr {
4902 write!(self.out, ", ")?;
4903 self.write_expr(module, cmp, func_ctx)?;
4904 }
4905 write!(self.out, ", ")?;
4906 if let crate::AtomicFunction::Subtract = *fun {
4907 write!(self.out, "-")?;
4909 }
4910 self.write_expr(module, value, func_ctx)?;
4911 if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
4912 write!(self.out, ", ")?;
4913 if compare_expr.is_some() {
4914 write!(self.out, "{res_name}.old_value")?;
4915 } else {
4916 write!(self.out, "{res_name}")?;
4917 }
4918 }
4919 writeln!(self.out, ");")?;
4920 Ok(())
4921 }
4922}
4923
4924pub(super) struct MatrixType {
4925 pub(super) columns: crate::VectorSize,
4926 pub(super) rows: crate::VectorSize,
4927 pub(super) width: crate::Bytes,
4928}
4929
4930pub(super) fn get_inner_matrix_data(
4931 module: &Module,
4932 handle: Handle<crate::Type>,
4933) -> Option<MatrixType> {
4934 match module.types[handle].inner {
4935 TypeInner::Matrix {
4936 columns,
4937 rows,
4938 scalar,
4939 } => Some(MatrixType {
4940 columns,
4941 rows,
4942 width: scalar.width,
4943 }),
4944 TypeInner::Array { base, .. } => get_inner_matrix_data(module, base),
4945 _ => None,
4946 }
4947}
4948
4949fn find_matrix_in_access_chain(
4953 module: &Module,
4954 base: Handle<crate::Expression>,
4955 func_ctx: &back::FunctionCtx<'_>,
4956) -> Option<(Handle<crate::Expression>, Option<Index>, Option<Index>)> {
4957 let mut current_base = base;
4958 let mut vector = None;
4959 let mut scalar = None;
4960 loop {
4961 let resolved_tr = func_ctx
4962 .resolve_type(current_base, &module.types)
4963 .pointer_base_type();
4964 let resolved = resolved_tr.as_ref()?.inner_with(&module.types);
4965
4966 match *resolved {
4967 TypeInner::Matrix { .. } => return Some((current_base, vector, scalar)),
4968 TypeInner::Scalar(_) | TypeInner::Vector { .. } => {}
4969 _ => return None,
4970 }
4971
4972 let index;
4973 (current_base, index) = match func_ctx.expressions[current_base] {
4974 crate::Expression::Access { base, index } => (base, Index::Expression(index)),
4975 crate::Expression::AccessIndex { base, index } => (base, Index::Static(index)),
4976 _ => return None,
4977 };
4978
4979 match *resolved {
4980 TypeInner::Scalar(_) => scalar = Some(index),
4981 TypeInner::Vector { .. } => vector = Some(index),
4982 _ => unreachable!(),
4983 }
4984 }
4985}
4986
4987pub(super) fn get_inner_matrix_of_struct_array_member(
4992 module: &Module,
4993 base: Handle<crate::Expression>,
4994 func_ctx: &back::FunctionCtx<'_>,
4995 direct: bool,
4996) -> Option<MatrixType> {
4997 let mut mat_data = None;
4998 let mut array_base = None;
4999
5000 let mut current_base = base;
5001 loop {
5002 let mut resolved = func_ctx.resolve_type(current_base, &module.types);
5003 if let TypeInner::Pointer { base, .. } = *resolved {
5004 resolved = &module.types[base].inner;
5005 };
5006
5007 match *resolved {
5008 TypeInner::Matrix {
5009 columns,
5010 rows,
5011 scalar,
5012 } => {
5013 mat_data = Some(MatrixType {
5014 columns,
5015 rows,
5016 width: scalar.width,
5017 })
5018 }
5019 TypeInner::Array { base, .. } => {
5020 array_base = Some(base);
5021 }
5022 TypeInner::Struct { .. } => {
5023 if let Some(array_base) = array_base {
5024 if direct {
5025 return mat_data;
5026 } else {
5027 return get_inner_matrix_data(module, array_base);
5028 }
5029 }
5030
5031 break;
5032 }
5033 _ => break,
5034 }
5035
5036 current_base = match func_ctx.expressions[current_base] {
5037 crate::Expression::Access { base, .. } => base,
5038 crate::Expression::AccessIndex { base, .. } => base,
5039 _ => break,
5040 };
5041 }
5042 None
5043}
5044
5045fn get_inner_matrix_of_global_uniform(
5050 module: &Module,
5051 base: Handle<crate::Expression>,
5052 func_ctx: &back::FunctionCtx<'_>,
5053 direct: bool,
5054) -> Option<MatrixType> {
5055 let mut mat_data = None;
5056 let mut array_base = None;
5057
5058 let mut current_base = base;
5059 loop {
5060 let mut resolved = func_ctx.resolve_type(current_base, &module.types);
5061 if let TypeInner::Pointer { base, .. } = *resolved {
5062 resolved = &module.types[base].inner;
5063 };
5064
5065 match *resolved {
5066 TypeInner::Matrix {
5067 columns,
5068 rows,
5069 scalar,
5070 } => {
5071 mat_data = Some(MatrixType {
5072 columns,
5073 rows,
5074 width: scalar.width,
5075 })
5076 }
5077 TypeInner::Array { base, .. } => {
5078 if !direct {
5079 array_base = Some(base);
5080 }
5081 }
5082 _ => break,
5083 }
5084
5085 current_base = match func_ctx.expressions[current_base] {
5086 crate::Expression::Access { base, .. } => base,
5087 crate::Expression::AccessIndex { base, .. } => base,
5088 crate::Expression::GlobalVariable(handle)
5089 if module.global_variables[handle].space == crate::AddressSpace::Uniform =>
5090 {
5091 return mat_data.or_else(|| {
5092 array_base.and_then(|array_base| get_inner_matrix_data(module, array_base))
5093 })
5094 }
5095 _ => break,
5096 };
5097 }
5098 None
5099}