1use alloc::{
2 format,
3 string::{String, ToString},
4 vec,
5 vec::Vec,
6};
7use core::fmt::Write;
8
9use super::Error;
10use super::ToWgslIfImplemented as _;
11use crate::{back::wgsl::polyfill::InversePolyfill, common::wgsl::TypeContext};
12use crate::{
13 back::{self, Baked},
14 common::{
15 self,
16 wgsl::{address_space_str, ToWgsl, TryToWgsl},
17 },
18 proc::{self, NameKey},
19 valid, Handle, Module, ShaderStage, TypeInner,
20};
21
22type BackendResult = Result<(), Error>;
24
25enum Attribute {
27 Binding(u32),
28 BuiltIn(crate::BuiltIn),
29 Group(u32),
30 Invariant,
31 Interpolate(Option<crate::Interpolation>, Option<crate::Sampling>),
32 Location(u32),
33 BlendSrc(u32),
34 Stage(ShaderStage),
35 WorkGroupSize([u32; 3]),
36 MeshStage(String),
37 TaskPayload(String),
38 PerPrimitive,
39}
40
41#[derive(Clone, Copy, Debug)]
54enum Indirection {
55 Ordinary,
61
62 Reference,
68}
69
70bitflags::bitflags! {
71 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
72 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
73 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
74 pub struct WriterFlags: u32 {
75 const EXPLICIT_TYPES = 0x1;
77 }
78}
79
80pub struct Writer<W> {
81 out: W,
82 flags: WriterFlags,
83 names: crate::FastHashMap<NameKey, String>,
84 namer: proc::Namer,
85 named_expressions: crate::NamedExpressions,
86 required_polyfills: crate::FastIndexSet<InversePolyfill>,
87}
88
89impl<W: Write> Writer<W> {
90 pub fn new(out: W, flags: WriterFlags) -> Self {
91 Writer {
92 out,
93 flags,
94 names: crate::FastHashMap::default(),
95 namer: proc::Namer::default(),
96 named_expressions: crate::NamedExpressions::default(),
97 required_polyfills: crate::FastIndexSet::default(),
98 }
99 }
100
101 fn reset(&mut self, module: &Module) {
102 self.names.clear();
103 self.namer.reset(
104 module,
105 &crate::keywords::wgsl::RESERVED_SET,
106 proc::CaseInsensitiveKeywordSet::empty(),
108 &["__", "_naga"],
109 &mut self.names,
110 );
111 self.named_expressions.clear();
112 self.required_polyfills.clear();
113 }
114
115 fn is_builtin_wgsl_struct(&self, module: &Module, ty: Handle<crate::Type>) -> bool {
131 module
132 .special_types
133 .predeclared_types
134 .values()
135 .any(|t| *t == ty)
136 || Some(ty) == module.special_types.external_texture_params
137 || Some(ty) == module.special_types.external_texture_transfer_function
138 }
139
140 pub fn write(&mut self, module: &Module, info: &valid::ModuleInfo) -> BackendResult {
141 self.reset(module);
142
143 self.write_enable_declarations(module)?;
145
146 for (handle, ty) in module.types.iter() {
148 if let TypeInner::Struct { ref members, .. } = ty.inner {
149 {
150 if !self.is_builtin_wgsl_struct(module, handle) {
151 self.write_struct(module, handle, members)?;
152 writeln!(self.out)?;
153 }
154 }
155 }
156 }
157
158 let mut constants = module
160 .constants
161 .iter()
162 .filter(|&(_, c)| c.name.is_some())
163 .peekable();
164 while let Some((handle, _)) = constants.next() {
165 self.write_global_constant(module, handle)?;
166 if constants.peek().is_none() {
168 writeln!(self.out)?;
169 }
170 }
171
172 let mut overrides = module.overrides.iter().peekable();
174 while let Some((handle, _)) = overrides.next() {
175 self.write_override(module, handle)?;
176 if overrides.peek().is_none() {
178 writeln!(self.out)?;
179 }
180 }
181
182 for (ty, global) in module.global_variables.iter() {
184 self.write_global(module, global, ty)?;
185 }
186
187 if !module.global_variables.is_empty() {
188 writeln!(self.out)?;
190 }
191
192 for (handle, function) in module.functions.iter() {
194 let fun_info = &info[handle];
195
196 let func_ctx = back::FunctionCtx {
197 ty: back::FunctionType::Function(handle),
198 info: fun_info,
199 expressions: &function.expressions,
200 named_expressions: &function.named_expressions,
201 };
202
203 self.write_function(module, function, &func_ctx)?;
205
206 writeln!(self.out)?;
207 }
208
209 for (index, ep) in module.entry_points.iter().enumerate() {
211 let attributes = match ep.stage {
212 ShaderStage::Vertex | ShaderStage::Fragment => vec![Attribute::Stage(ep.stage)],
213 ShaderStage::Compute => vec![
214 Attribute::Stage(ShaderStage::Compute),
215 Attribute::WorkGroupSize(ep.workgroup_size),
216 ],
217 ShaderStage::Mesh => {
218 let mesh_output_name = module.global_variables
219 [ep.mesh_info.as_ref().unwrap().output_variable]
220 .name
221 .clone()
222 .unwrap();
223 let mut mesh_attrs = vec![
224 Attribute::MeshStage(mesh_output_name),
225 Attribute::WorkGroupSize(ep.workgroup_size),
226 ];
227 if ep.task_payload.is_some() {
228 let payload_name = module.global_variables[ep.task_payload.unwrap()]
229 .name
230 .clone()
231 .unwrap();
232 mesh_attrs.push(Attribute::TaskPayload(payload_name));
233 }
234 mesh_attrs
235 }
236 ShaderStage::Task => {
237 let payload_name = module.global_variables[ep.task_payload.unwrap()]
238 .name
239 .clone()
240 .unwrap();
241 vec![
242 Attribute::Stage(ShaderStage::Task),
243 Attribute::TaskPayload(payload_name),
244 Attribute::WorkGroupSize(ep.workgroup_size),
245 ]
246 }
247 };
248 self.write_attributes(&attributes)?;
249 writeln!(self.out)?;
251
252 let func_ctx = back::FunctionCtx {
253 ty: back::FunctionType::EntryPoint(index as u16),
254 info: info.get_entry_point(index),
255 expressions: &ep.function.expressions,
256 named_expressions: &ep.function.named_expressions,
257 };
258 self.write_function(module, &ep.function, &func_ctx)?;
259
260 if index < module.entry_points.len() - 1 {
261 writeln!(self.out)?;
262 }
263 }
264
265 for polyfill in &self.required_polyfills {
267 writeln!(self.out)?;
268 write!(self.out, "{}", polyfill.source)?;
269 writeln!(self.out)?;
270 }
271
272 Ok(())
273 }
274
275 fn write_enable_declarations(&mut self, module: &Module) -> BackendResult {
278 let mut needs_f16 = false;
279 let mut needs_dual_source_blending = false;
280 let mut needs_clip_distances = false;
281 let mut needs_mesh_shaders = false;
282 let mut needs_cooperative_matrix = false;
283
284 for (_, ty) in module.types.iter() {
286 match ty.inner {
287 TypeInner::Scalar(scalar)
288 | TypeInner::Vector { scalar, .. }
289 | TypeInner::Matrix { scalar, .. } => {
290 needs_f16 |= scalar == crate::Scalar::F16;
291 }
292 TypeInner::Struct { ref members, .. } => {
293 for binding in members.iter().filter_map(|m| m.binding.as_ref()) {
294 match *binding {
295 crate::Binding::Location {
296 blend_src: Some(_), ..
297 } => {
298 needs_dual_source_blending = true;
299 }
300 crate::Binding::BuiltIn(crate::BuiltIn::ClipDistance) => {
301 needs_clip_distances = true;
302 }
303 crate::Binding::Location {
304 per_primitive: true,
305 ..
306 } => {
307 needs_mesh_shaders = true;
308 }
309 crate::Binding::BuiltIn(
310 crate::BuiltIn::MeshTaskSize
311 | crate::BuiltIn::CullPrimitive
312 | crate::BuiltIn::PointIndex
313 | crate::BuiltIn::LineIndices
314 | crate::BuiltIn::TriangleIndices
315 | crate::BuiltIn::VertexCount
316 | crate::BuiltIn::Vertices
317 | crate::BuiltIn::PrimitiveCount
318 | crate::BuiltIn::Primitives,
319 ) => {
320 needs_mesh_shaders = true;
321 }
322 _ => {}
323 }
324 }
325 }
326 TypeInner::CooperativeMatrix { .. } => {
327 needs_cooperative_matrix = true;
328 }
329 _ => {}
330 }
331 }
332
333 if module
334 .entry_points
335 .iter()
336 .any(|ep| matches!(ep.stage, ShaderStage::Mesh | ShaderStage::Task))
337 {
338 needs_mesh_shaders = true;
339 }
340
341 if module
342 .global_variables
343 .iter()
344 .any(|gv| gv.1.space == crate::AddressSpace::TaskPayload)
345 {
346 needs_mesh_shaders = true;
347 }
348
349 let mut any_written = false;
351 if needs_f16 {
352 writeln!(self.out, "enable f16;")?;
353 any_written = true;
354 }
355 if needs_dual_source_blending {
356 writeln!(self.out, "enable dual_source_blending;")?;
357 any_written = true;
358 }
359 if needs_clip_distances {
360 writeln!(self.out, "enable clip_distances;")?;
361 any_written = true;
362 }
363 if needs_mesh_shaders {
364 writeln!(self.out, "enable wgpu_mesh_shader;")?;
365 any_written = true;
366 }
367 if needs_cooperative_matrix {
368 writeln!(self.out, "enable wgpu_cooperative_matrix;")?;
369 any_written = true;
370 }
371 if any_written {
372 writeln!(self.out)?;
374 }
375
376 Ok(())
377 }
378
379 fn write_function(
385 &mut self,
386 module: &Module,
387 func: &crate::Function,
388 func_ctx: &back::FunctionCtx<'_>,
389 ) -> BackendResult {
390 let func_name = match func_ctx.ty {
391 back::FunctionType::EntryPoint(index) => &self.names[&NameKey::EntryPoint(index)],
392 back::FunctionType::Function(handle) => &self.names[&NameKey::Function(handle)],
393 };
394
395 write!(self.out, "fn {func_name}(")?;
397
398 for (index, arg) in func.arguments.iter().enumerate() {
400 if let Some(ref binding) = arg.binding {
402 self.write_attributes(&map_binding_to_attribute(binding))?;
403 }
404 let argument_name = &self.names[&func_ctx.argument_key(index as u32)];
406
407 write!(self.out, "{argument_name}: ")?;
408 self.write_type(module, arg.ty)?;
410 if index < func.arguments.len() - 1 {
411 write!(self.out, ", ")?;
413 }
414 }
415
416 write!(self.out, ")")?;
417
418 if let Some(ref result) = func.result {
420 write!(self.out, " -> ")?;
421 if let Some(ref binding) = result.binding {
422 self.write_attributes(&map_binding_to_attribute(binding))?;
423 }
424 self.write_type(module, result.ty)?;
425 }
426
427 write!(self.out, " {{")?;
428 writeln!(self.out)?;
429
430 for (handle, local) in func.local_variables.iter() {
432 write!(self.out, "{}", back::INDENT)?;
434
435 write!(self.out, "var {}: ", self.names[&func_ctx.name_key(handle)])?;
438
439 self.write_type(module, local.ty)?;
441
442 if let Some(init) = local.init {
444 write!(self.out, " = ")?;
447
448 self.write_expr(module, init, func_ctx)?;
451 }
452
453 writeln!(self.out, ";")?
455 }
456
457 if !func.local_variables.is_empty() {
458 writeln!(self.out)?;
459 }
460
461 for sta in func.body.iter() {
463 self.write_stmt(module, sta, func_ctx, back::Level(1))?;
465 }
466
467 writeln!(self.out, "}}")?;
468
469 self.named_expressions.clear();
470
471 Ok(())
472 }
473
474 fn write_attributes(&mut self, attributes: &[Attribute]) -> BackendResult {
476 for attribute in attributes {
477 match *attribute {
478 Attribute::Location(id) => write!(self.out, "@location({id}) ")?,
479 Attribute::BlendSrc(blend_src) => write!(self.out, "@blend_src({blend_src}) ")?,
480 Attribute::BuiltIn(builtin_attrib) => {
481 let builtin = builtin_attrib.to_wgsl_if_implemented()?;
482 write!(self.out, "@builtin({builtin}) ")?;
483 }
484 Attribute::Stage(shader_stage) => {
485 let stage_str = match shader_stage {
486 ShaderStage::Vertex => "vertex",
487 ShaderStage::Fragment => "fragment",
488 ShaderStage::Compute => "compute",
489 ShaderStage::Task => "task",
490 ShaderStage::Mesh => unreachable!(),
492 };
493
494 write!(self.out, "@{stage_str} ")?;
495 }
496 Attribute::WorkGroupSize(size) => {
497 write!(
498 self.out,
499 "@workgroup_size({}, {}, {}) ",
500 size[0], size[1], size[2]
501 )?;
502 }
503 Attribute::Binding(id) => write!(self.out, "@binding({id}) ")?,
504 Attribute::Group(id) => write!(self.out, "@group({id}) ")?,
505 Attribute::Invariant => write!(self.out, "@invariant ")?,
506 Attribute::Interpolate(interpolation, sampling) => {
507 if sampling.is_some() && sampling != Some(crate::Sampling::Center) {
508 let interpolation = interpolation
509 .unwrap_or(crate::Interpolation::Perspective)
510 .to_wgsl();
511 let sampling = sampling.unwrap_or(crate::Sampling::Center).to_wgsl();
512 write!(self.out, "@interpolate({interpolation}, {sampling}) ")?;
513 } else if interpolation.is_some()
514 && interpolation != Some(crate::Interpolation::Perspective)
515 {
516 let interpolation = interpolation
517 .unwrap_or(crate::Interpolation::Perspective)
518 .to_wgsl();
519 write!(self.out, "@interpolate({interpolation}) ")?;
520 }
521 }
522 Attribute::MeshStage(ref name) => {
523 write!(self.out, "@mesh({name}) ")?;
524 }
525 Attribute::TaskPayload(ref payload_name) => {
526 write!(self.out, "@payload({payload_name}) ")?;
527 }
528 Attribute::PerPrimitive => write!(self.out, "@per_primitive ")?,
529 };
530 }
531 Ok(())
532 }
533
534 fn write_struct(
545 &mut self,
546 module: &Module,
547 handle: Handle<crate::Type>,
548 members: &[crate::StructMember],
549 ) -> BackendResult {
550 write!(self.out, "struct {}", self.names[&NameKey::Type(handle)])?;
551 write!(self.out, " {{")?;
552 writeln!(self.out)?;
553 for (index, member) in members.iter().enumerate() {
554 write!(self.out, "{}", back::INDENT)?;
556 if let Some(ref binding) = member.binding {
557 self.write_attributes(&map_binding_to_attribute(binding))?;
558 }
559 let member_name = &self.names[&NameKey::StructMember(handle, index as u32)];
561 write!(self.out, "{member_name}: ")?;
562 self.write_type(module, member.ty)?;
563 write!(self.out, ",")?;
564 writeln!(self.out)?;
565 }
566
567 writeln!(self.out, "}}")?;
568
569 Ok(())
570 }
571
572 fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
573 let type_context = WriterTypeContext {
577 module,
578 names: &self.names,
579 };
580 type_context.write_type(ty, &mut self.out)?;
581
582 Ok(())
583 }
584
585 fn write_type_resolution(
586 &mut self,
587 module: &Module,
588 resolution: &proc::TypeResolution,
589 ) -> BackendResult {
590 let type_context = WriterTypeContext {
594 module,
595 names: &self.names,
596 };
597 type_context.write_type_resolution(resolution, &mut self.out)?;
598
599 Ok(())
600 }
601
602 fn write_stmt(
607 &mut self,
608 module: &Module,
609 stmt: &crate::Statement,
610 func_ctx: &back::FunctionCtx<'_>,
611 level: back::Level,
612 ) -> BackendResult {
613 use crate::{Expression, Statement};
614
615 match *stmt {
616 Statement::Emit(ref range) => {
617 for handle in range.clone() {
618 let info = &func_ctx.info[handle];
619 let expr_name = if let Some(name) = func_ctx.named_expressions.get(&handle) {
620 Some(self.namer.call(name))
625 } else {
626 let expr = &func_ctx.expressions[handle];
627 let min_ref_count = expr.bake_ref_count();
628 let required_baking_expr = match *expr {
630 Expression::ImageLoad { .. }
631 | Expression::ImageQuery { .. }
632 | Expression::ImageSample { .. } => true,
633 _ => false,
634 };
635 if min_ref_count <= info.ref_count || required_baking_expr {
636 Some(Baked(handle).to_string())
637 } else {
638 None
639 }
640 };
641
642 if let Some(name) = expr_name {
643 write!(self.out, "{level}")?;
644 self.start_named_expr(module, handle, func_ctx, &name)?;
645 self.write_expr(module, handle, func_ctx)?;
646 self.named_expressions.insert(handle, name);
647 writeln!(self.out, ";")?;
648 }
649 }
650 }
651 Statement::If {
653 condition,
654 ref accept,
655 ref reject,
656 } => {
657 write!(self.out, "{level}")?;
658 write!(self.out, "if ")?;
659 self.write_expr(module, condition, func_ctx)?;
660 writeln!(self.out, " {{")?;
661
662 let l2 = level.next();
663 for sta in accept {
664 self.write_stmt(module, sta, func_ctx, l2)?;
666 }
667
668 if !reject.is_empty() {
671 writeln!(self.out, "{level}}} else {{")?;
672
673 for sta in reject {
674 self.write_stmt(module, sta, func_ctx, l2)?;
676 }
677 }
678
679 writeln!(self.out, "{level}}}")?
680 }
681 Statement::Return { value } => {
682 write!(self.out, "{level}")?;
683 write!(self.out, "return")?;
684 if let Some(return_value) = value {
685 write!(self.out, " ")?;
687 self.write_expr(module, return_value, func_ctx)?;
688 }
689 writeln!(self.out, ";")?;
690 }
691 Statement::Kill => {
693 write!(self.out, "{level}")?;
694 writeln!(self.out, "discard;")?
695 }
696 Statement::Store { pointer, value } => {
697 write!(self.out, "{level}")?;
698
699 let is_atomic_pointer = func_ctx
700 .resolve_type(pointer, &module.types)
701 .is_atomic_pointer(&module.types);
702
703 if is_atomic_pointer {
704 write!(self.out, "atomicStore(")?;
705 self.write_expr(module, pointer, func_ctx)?;
706 write!(self.out, ", ")?;
707 self.write_expr(module, value, func_ctx)?;
708 write!(self.out, ")")?;
709 } else {
710 self.write_expr_with_indirection(
711 module,
712 pointer,
713 func_ctx,
714 Indirection::Reference,
715 )?;
716 write!(self.out, " = ")?;
717 self.write_expr(module, value, func_ctx)?;
718 }
719 writeln!(self.out, ";")?
720 }
721 Statement::Call {
722 function,
723 ref arguments,
724 result,
725 } => {
726 write!(self.out, "{level}")?;
727 if let Some(expr) = result {
728 let name = Baked(expr).to_string();
729 self.start_named_expr(module, expr, func_ctx, &name)?;
730 self.named_expressions.insert(expr, name);
731 }
732 let func_name = &self.names[&NameKey::Function(function)];
733 write!(self.out, "{func_name}(")?;
734 for (index, &argument) in arguments.iter().enumerate() {
735 if index != 0 {
736 write!(self.out, ", ")?;
737 }
738 self.write_expr(module, argument, func_ctx)?;
739 }
740 writeln!(self.out, ");")?
741 }
742 Statement::Atomic {
743 pointer,
744 ref fun,
745 value,
746 result,
747 } => {
748 write!(self.out, "{level}")?;
749 if let Some(result) = result {
750 let res_name = Baked(result).to_string();
751 self.start_named_expr(module, result, func_ctx, &res_name)?;
752 self.named_expressions.insert(result, res_name);
753 }
754
755 let fun_str = fun.to_wgsl();
756 write!(self.out, "atomic{fun_str}(")?;
757 self.write_expr(module, pointer, func_ctx)?;
758 if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
759 write!(self.out, ", ")?;
760 self.write_expr(module, cmp, func_ctx)?;
761 }
762 write!(self.out, ", ")?;
763 self.write_expr(module, value, func_ctx)?;
764 writeln!(self.out, ");")?
765 }
766 Statement::ImageAtomic {
767 image,
768 coordinate,
769 array_index,
770 ref fun,
771 value,
772 } => {
773 write!(self.out, "{level}")?;
774 let fun_str = fun.to_wgsl();
775 write!(self.out, "textureAtomic{fun_str}(")?;
776 self.write_expr(module, image, func_ctx)?;
777 write!(self.out, ", ")?;
778 self.write_expr(module, coordinate, func_ctx)?;
779 if let Some(array_index_expr) = array_index {
780 write!(self.out, ", ")?;
781 self.write_expr(module, array_index_expr, func_ctx)?;
782 }
783 write!(self.out, ", ")?;
784 self.write_expr(module, value, func_ctx)?;
785 writeln!(self.out, ");")?;
786 }
787 Statement::WorkGroupUniformLoad { pointer, result } => {
788 write!(self.out, "{level}")?;
789 let res_name = Baked(result).to_string();
791 self.start_named_expr(module, result, func_ctx, &res_name)?;
792 self.named_expressions.insert(result, res_name);
793 write!(self.out, "workgroupUniformLoad(")?;
794 self.write_expr(module, pointer, func_ctx)?;
795 writeln!(self.out, ");")?;
796 }
797 Statement::ImageStore {
798 image,
799 coordinate,
800 array_index,
801 value,
802 } => {
803 write!(self.out, "{level}")?;
804 write!(self.out, "textureStore(")?;
805 self.write_expr(module, image, func_ctx)?;
806 write!(self.out, ", ")?;
807 self.write_expr(module, coordinate, func_ctx)?;
808 if let Some(array_index_expr) = array_index {
809 write!(self.out, ", ")?;
810 self.write_expr(module, array_index_expr, func_ctx)?;
811 }
812 write!(self.out, ", ")?;
813 self.write_expr(module, value, func_ctx)?;
814 writeln!(self.out, ");")?;
815 }
816 Statement::Block(ref block) => {
818 write!(self.out, "{level}")?;
819 writeln!(self.out, "{{")?;
820 for sta in block.iter() {
821 self.write_stmt(module, sta, func_ctx, level.next())?
823 }
824 writeln!(self.out, "{level}}}")?
825 }
826 Statement::Switch {
827 selector,
828 ref cases,
829 } => {
830 write!(self.out, "{level}")?;
832 write!(self.out, "switch ")?;
833 self.write_expr(module, selector, func_ctx)?;
834 writeln!(self.out, " {{")?;
835
836 let l2 = level.next();
837 let mut new_case = true;
838 for case in cases {
839 if case.fall_through && !case.body.is_empty() {
840 return Err(Error::Unimplemented(
842 "fall-through switch case block".into(),
843 ));
844 }
845
846 match case.value {
847 crate::SwitchValue::I32(value) => {
848 if new_case {
849 write!(self.out, "{l2}case ")?;
850 }
851 write!(self.out, "{value}")?;
852 }
853 crate::SwitchValue::U32(value) => {
854 if new_case {
855 write!(self.out, "{l2}case ")?;
856 }
857 write!(self.out, "{value}u")?;
858 }
859 crate::SwitchValue::Default => {
860 if new_case {
861 if case.fall_through {
862 write!(self.out, "{l2}case ")?;
863 } else {
864 write!(self.out, "{l2}")?;
865 }
866 }
867 write!(self.out, "default")?;
868 }
869 }
870
871 new_case = !case.fall_through;
872
873 if case.fall_through {
874 write!(self.out, ", ")?;
875 } else {
876 writeln!(self.out, ": {{")?;
877 }
878
879 for sta in case.body.iter() {
880 self.write_stmt(module, sta, func_ctx, l2.next())?;
881 }
882
883 if !case.fall_through {
884 writeln!(self.out, "{l2}}}")?;
885 }
886 }
887
888 writeln!(self.out, "{level}}}")?
889 }
890 Statement::Loop {
891 ref body,
892 ref continuing,
893 break_if,
894 } => {
895 write!(self.out, "{level}")?;
896 writeln!(self.out, "loop {{")?;
897
898 let l2 = level.next();
899 for sta in body.iter() {
900 self.write_stmt(module, sta, func_ctx, l2)?;
901 }
902
903 if !continuing.is_empty() || break_if.is_some() {
908 writeln!(self.out, "{l2}continuing {{")?;
909 for sta in continuing.iter() {
910 self.write_stmt(module, sta, func_ctx, l2.next())?;
911 }
912
913 if let Some(condition) = break_if {
916 write!(self.out, "{}break if ", l2.next())?;
918 self.write_expr(module, condition, func_ctx)?;
919 writeln!(self.out, ";")?;
921 }
922
923 writeln!(self.out, "{l2}}}")?;
924 }
925
926 writeln!(self.out, "{level}}}")?
927 }
928 Statement::Break => {
929 writeln!(self.out, "{level}break;")?;
930 }
931 Statement::Continue => {
932 writeln!(self.out, "{level}continue;")?;
933 }
934 Statement::ControlBarrier(barrier) | Statement::MemoryBarrier(barrier) => {
935 if barrier.contains(crate::Barrier::STORAGE) {
936 writeln!(self.out, "{level}storageBarrier();")?;
937 }
938
939 if barrier.contains(crate::Barrier::WORK_GROUP) {
940 writeln!(self.out, "{level}workgroupBarrier();")?;
941 }
942
943 if barrier.contains(crate::Barrier::SUB_GROUP) {
944 writeln!(self.out, "{level}subgroupBarrier();")?;
945 }
946
947 if barrier.contains(crate::Barrier::TEXTURE) {
948 writeln!(self.out, "{level}textureBarrier();")?;
949 }
950 }
951 Statement::RayQuery { .. } => unreachable!(),
952 Statement::SubgroupBallot { result, predicate } => {
953 write!(self.out, "{level}")?;
954 let res_name = Baked(result).to_string();
955 self.start_named_expr(module, result, func_ctx, &res_name)?;
956 self.named_expressions.insert(result, res_name);
957
958 write!(self.out, "subgroupBallot(")?;
959 if let Some(predicate) = predicate {
960 self.write_expr(module, predicate, func_ctx)?;
961 }
962 writeln!(self.out, ");")?;
963 }
964 Statement::SubgroupCollectiveOperation {
965 op,
966 collective_op,
967 argument,
968 result,
969 } => {
970 write!(self.out, "{level}")?;
971 let res_name = Baked(result).to_string();
972 self.start_named_expr(module, result, func_ctx, &res_name)?;
973 self.named_expressions.insert(result, res_name);
974
975 match (collective_op, op) {
976 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
977 write!(self.out, "subgroupAll(")?
978 }
979 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
980 write!(self.out, "subgroupAny(")?
981 }
982 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
983 write!(self.out, "subgroupAdd(")?
984 }
985 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
986 write!(self.out, "subgroupMul(")?
987 }
988 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
989 write!(self.out, "subgroupMax(")?
990 }
991 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
992 write!(self.out, "subgroupMin(")?
993 }
994 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
995 write!(self.out, "subgroupAnd(")?
996 }
997 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
998 write!(self.out, "subgroupOr(")?
999 }
1000 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
1001 write!(self.out, "subgroupXor(")?
1002 }
1003 (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
1004 write!(self.out, "subgroupExclusiveAdd(")?
1005 }
1006 (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
1007 write!(self.out, "subgroupExclusiveMul(")?
1008 }
1009 (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
1010 write!(self.out, "subgroupInclusiveAdd(")?
1011 }
1012 (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
1013 write!(self.out, "subgroupInclusiveMul(")?
1014 }
1015 _ => unimplemented!(),
1016 }
1017 self.write_expr(module, argument, func_ctx)?;
1018 writeln!(self.out, ");")?;
1019 }
1020 Statement::SubgroupGather {
1021 mode,
1022 argument,
1023 result,
1024 } => {
1025 write!(self.out, "{level}")?;
1026 let res_name = Baked(result).to_string();
1027 self.start_named_expr(module, result, func_ctx, &res_name)?;
1028 self.named_expressions.insert(result, res_name);
1029
1030 match mode {
1031 crate::GatherMode::BroadcastFirst => {
1032 write!(self.out, "subgroupBroadcastFirst(")?;
1033 }
1034 crate::GatherMode::Broadcast(_) => {
1035 write!(self.out, "subgroupBroadcast(")?;
1036 }
1037 crate::GatherMode::Shuffle(_) => {
1038 write!(self.out, "subgroupShuffle(")?;
1039 }
1040 crate::GatherMode::ShuffleDown(_) => {
1041 write!(self.out, "subgroupShuffleDown(")?;
1042 }
1043 crate::GatherMode::ShuffleUp(_) => {
1044 write!(self.out, "subgroupShuffleUp(")?;
1045 }
1046 crate::GatherMode::ShuffleXor(_) => {
1047 write!(self.out, "subgroupShuffleXor(")?;
1048 }
1049 crate::GatherMode::QuadBroadcast(_) => {
1050 write!(self.out, "quadBroadcast(")?;
1051 }
1052 crate::GatherMode::QuadSwap(direction) => match direction {
1053 crate::Direction::X => {
1054 write!(self.out, "quadSwapX(")?;
1055 }
1056 crate::Direction::Y => {
1057 write!(self.out, "quadSwapY(")?;
1058 }
1059 crate::Direction::Diagonal => {
1060 write!(self.out, "quadSwapDiagonal(")?;
1061 }
1062 },
1063 }
1064 self.write_expr(module, argument, func_ctx)?;
1065 match mode {
1066 crate::GatherMode::BroadcastFirst => {}
1067 crate::GatherMode::Broadcast(index)
1068 | crate::GatherMode::Shuffle(index)
1069 | crate::GatherMode::ShuffleDown(index)
1070 | crate::GatherMode::ShuffleUp(index)
1071 | crate::GatherMode::ShuffleXor(index)
1072 | crate::GatherMode::QuadBroadcast(index) => {
1073 write!(self.out, ", ")?;
1074 self.write_expr(module, index, func_ctx)?;
1075 }
1076 crate::GatherMode::QuadSwap(_) => {}
1077 }
1078 writeln!(self.out, ");")?;
1079 }
1080 Statement::CooperativeStore { target, ref data } => {
1081 let suffix = if data.row_major { "T" } else { "" };
1082 write!(self.out, "{level}coopStore{suffix}(")?;
1083 self.write_expr(module, target, func_ctx)?;
1084 write!(self.out, ", ")?;
1085 self.write_expr(module, data.pointer, func_ctx)?;
1086 write!(self.out, ", ")?;
1087 self.write_expr(module, data.stride, func_ctx)?;
1088 writeln!(self.out, ");")?
1089 }
1090 }
1091
1092 Ok(())
1093 }
1094
1095 fn plain_form_indirection(
1118 &self,
1119 expr: Handle<crate::Expression>,
1120 module: &Module,
1121 func_ctx: &back::FunctionCtx<'_>,
1122 ) -> Indirection {
1123 use crate::Expression as Ex;
1124
1125 if self.named_expressions.contains_key(&expr) {
1129 return Indirection::Ordinary;
1130 }
1131
1132 match func_ctx.expressions[expr] {
1133 Ex::LocalVariable(_) => Indirection::Reference,
1134 Ex::GlobalVariable(handle) => {
1135 let global = &module.global_variables[handle];
1136 match global.space {
1137 crate::AddressSpace::Handle => Indirection::Ordinary,
1138 _ => Indirection::Reference,
1139 }
1140 }
1141 Ex::Access { base, .. } | Ex::AccessIndex { base, .. } => {
1142 let base_ty = func_ctx.resolve_type(base, &module.types);
1143 match *base_ty {
1144 TypeInner::Pointer { .. } | TypeInner::ValuePointer { .. } => {
1145 Indirection::Reference
1146 }
1147 _ => Indirection::Ordinary,
1148 }
1149 }
1150 _ => Indirection::Ordinary,
1151 }
1152 }
1153
1154 fn start_named_expr(
1155 &mut self,
1156 module: &Module,
1157 handle: Handle<crate::Expression>,
1158 func_ctx: &back::FunctionCtx,
1159 name: &str,
1160 ) -> BackendResult {
1161 write!(self.out, "let {name}")?;
1163 if self.flags.contains(WriterFlags::EXPLICIT_TYPES) {
1164 write!(self.out, ": ")?;
1165 self.write_type_resolution(module, &func_ctx.info[handle].ty)?;
1167 }
1168
1169 write!(self.out, " = ")?;
1170 Ok(())
1171 }
1172
1173 fn write_expr(
1177 &mut self,
1178 module: &Module,
1179 expr: Handle<crate::Expression>,
1180 func_ctx: &back::FunctionCtx<'_>,
1181 ) -> BackendResult {
1182 self.write_expr_with_indirection(module, expr, func_ctx, Indirection::Ordinary)
1183 }
1184
1185 fn write_expr_with_indirection(
1198 &mut self,
1199 module: &Module,
1200 expr: Handle<crate::Expression>,
1201 func_ctx: &back::FunctionCtx<'_>,
1202 requested: Indirection,
1203 ) -> BackendResult {
1204 let plain = self.plain_form_indirection(expr, module, func_ctx);
1207 log::trace!(
1208 "expression {:?}={:?} is {:?}, expected {:?}",
1209 expr,
1210 func_ctx.expressions[expr],
1211 plain,
1212 requested,
1213 );
1214 match (requested, plain) {
1215 (Indirection::Ordinary, Indirection::Reference) => {
1216 write!(self.out, "(&")?;
1217 self.write_expr_plain_form(module, expr, func_ctx, plain)?;
1218 write!(self.out, ")")?;
1219 }
1220 (Indirection::Reference, Indirection::Ordinary) => {
1221 write!(self.out, "(*")?;
1222 self.write_expr_plain_form(module, expr, func_ctx, plain)?;
1223 write!(self.out, ")")?;
1224 }
1225 (_, _) => self.write_expr_plain_form(module, expr, func_ctx, plain)?,
1226 }
1227
1228 Ok(())
1229 }
1230
1231 fn write_const_expression(
1232 &mut self,
1233 module: &Module,
1234 expr: Handle<crate::Expression>,
1235 arena: &crate::Arena<crate::Expression>,
1236 ) -> BackendResult {
1237 self.write_possibly_const_expression(module, expr, arena, |writer, expr| {
1238 writer.write_const_expression(module, expr, arena)
1239 })
1240 }
1241
1242 fn write_possibly_const_expression<E>(
1243 &mut self,
1244 module: &Module,
1245 expr: Handle<crate::Expression>,
1246 expressions: &crate::Arena<crate::Expression>,
1247 write_expression: E,
1248 ) -> BackendResult
1249 where
1250 E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult,
1251 {
1252 use crate::Expression;
1253
1254 match expressions[expr] {
1255 Expression::Literal(literal) => match literal {
1256 crate::Literal::F16(value) => write!(self.out, "{value}h")?,
1257 crate::Literal::F32(value) => write!(self.out, "{value}f")?,
1258 crate::Literal::U32(value) => write!(self.out, "{value}u")?,
1259 crate::Literal::I32(value) => {
1260 if value == i32::MIN {
1264 write!(self.out, "i32({value})")?;
1265 } else {
1266 write!(self.out, "{value}i")?;
1267 }
1268 }
1269 crate::Literal::Bool(value) => write!(self.out, "{value}")?,
1270 crate::Literal::F64(value) => write!(self.out, "{value:?}lf")?,
1271 crate::Literal::I64(value) => {
1272 if value == i64::MIN {
1277 write!(self.out, "i64({} - 1)", value + 1)?;
1278 } else {
1279 write!(self.out, "{value}li")?;
1280 }
1281 }
1282 crate::Literal::U64(value) => write!(self.out, "{value:?}lu")?,
1283 crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
1284 return Err(Error::Custom(
1285 "Abstract types should not appear in IR presented to backends".into(),
1286 ));
1287 }
1288 },
1289 Expression::Constant(handle) => {
1290 let constant = &module.constants[handle];
1291 if constant.name.is_some() {
1292 write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
1293 } else {
1294 self.write_const_expression(module, constant.init, &module.global_expressions)?;
1295 }
1296 }
1297 Expression::ZeroValue(ty) => {
1298 self.write_type(module, ty)?;
1299 write!(self.out, "()")?;
1300 }
1301 Expression::Compose { ty, ref components } => {
1302 self.write_type(module, ty)?;
1303 write!(self.out, "(")?;
1304 for (index, component) in components.iter().enumerate() {
1305 if index != 0 {
1306 write!(self.out, ", ")?;
1307 }
1308 write_expression(self, *component)?;
1309 }
1310 write!(self.out, ")")?
1311 }
1312 Expression::Splat { size, value } => {
1313 let size = common::vector_size_str(size);
1314 write!(self.out, "vec{size}(")?;
1315 write_expression(self, value)?;
1316 write!(self.out, ")")?;
1317 }
1318 Expression::Override(handle) => {
1319 write!(self.out, "{}", self.names[&NameKey::Override(handle)])?;
1320 }
1321 _ => unreachable!(),
1322 }
1323
1324 Ok(())
1325 }
1326
1327 fn write_expr_plain_form(
1335 &mut self,
1336 module: &Module,
1337 expr: Handle<crate::Expression>,
1338 func_ctx: &back::FunctionCtx<'_>,
1339 indirection: Indirection,
1340 ) -> BackendResult {
1341 use crate::Expression;
1342
1343 if let Some(name) = self.named_expressions.get(&expr) {
1344 write!(self.out, "{name}")?;
1345 return Ok(());
1346 }
1347
1348 let expression = &func_ctx.expressions[expr];
1349
1350 match *expression {
1359 Expression::Literal(_)
1360 | Expression::Constant(_)
1361 | Expression::ZeroValue(_)
1362 | Expression::Compose { .. }
1363 | Expression::Splat { .. } => {
1364 self.write_possibly_const_expression(
1365 module,
1366 expr,
1367 func_ctx.expressions,
1368 |writer, expr| writer.write_expr(module, expr, func_ctx),
1369 )?;
1370 }
1371 Expression::Override(handle) => {
1372 write!(self.out, "{}", self.names[&NameKey::Override(handle)])?;
1373 }
1374 Expression::FunctionArgument(pos) => {
1375 let name_key = func_ctx.argument_key(pos);
1376 let name = &self.names[&name_key];
1377 write!(self.out, "{name}")?;
1378 }
1379 Expression::Binary { op, left, right } => {
1380 write!(self.out, "(")?;
1381 self.write_expr(module, left, func_ctx)?;
1382 write!(self.out, " {} ", back::binary_operation_str(op))?;
1383 self.write_expr(module, right, func_ctx)?;
1384 write!(self.out, ")")?;
1385 }
1386 Expression::Access { base, index } => {
1387 self.write_expr_with_indirection(module, base, func_ctx, indirection)?;
1388 write!(self.out, "[")?;
1389 self.write_expr(module, index, func_ctx)?;
1390 write!(self.out, "]")?
1391 }
1392 Expression::AccessIndex { base, index } => {
1393 let base_ty_res = &func_ctx.info[base].ty;
1394 let mut resolved = base_ty_res.inner_with(&module.types);
1395
1396 self.write_expr_with_indirection(module, base, func_ctx, indirection)?;
1397
1398 let base_ty_handle = match *resolved {
1399 TypeInner::Pointer { base, space: _ } => {
1400 resolved = &module.types[base].inner;
1401 Some(base)
1402 }
1403 _ => base_ty_res.handle(),
1404 };
1405
1406 match *resolved {
1407 TypeInner::Vector { .. } => {
1408 write!(self.out, ".{}", back::COMPONENTS[index as usize])?
1410 }
1411 TypeInner::Matrix { .. }
1412 | TypeInner::Array { .. }
1413 | TypeInner::BindingArray { .. }
1414 | TypeInner::ValuePointer { .. } => write!(self.out, "[{index}]")?,
1415 TypeInner::Struct { .. } => {
1416 let ty = base_ty_handle.unwrap();
1419
1420 write!(
1421 self.out,
1422 ".{}",
1423 &self.names[&NameKey::StructMember(ty, index)]
1424 )?
1425 }
1426 ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))),
1427 }
1428 }
1429 Expression::ImageSample {
1430 image,
1431 sampler,
1432 gather: None,
1433 coordinate,
1434 array_index,
1435 offset,
1436 level,
1437 depth_ref,
1438 clamp_to_edge,
1439 } => {
1440 use crate::SampleLevel as Sl;
1441
1442 let suffix_cmp = match depth_ref {
1443 Some(_) => "Compare",
1444 None => "",
1445 };
1446 let suffix_level = match level {
1447 Sl::Auto => "",
1448 Sl::Zero if clamp_to_edge => "BaseClampToEdge",
1449 Sl::Zero | Sl::Exact(_) => "Level",
1450 Sl::Bias(_) => "Bias",
1451 Sl::Gradient { .. } => "Grad",
1452 };
1453
1454 write!(self.out, "textureSample{suffix_cmp}{suffix_level}(")?;
1455 self.write_expr(module, image, func_ctx)?;
1456 write!(self.out, ", ")?;
1457 self.write_expr(module, sampler, func_ctx)?;
1458 write!(self.out, ", ")?;
1459 self.write_expr(module, coordinate, func_ctx)?;
1460
1461 if let Some(array_index) = array_index {
1462 write!(self.out, ", ")?;
1463 self.write_expr(module, array_index, func_ctx)?;
1464 }
1465
1466 if let Some(depth_ref) = depth_ref {
1467 write!(self.out, ", ")?;
1468 self.write_expr(module, depth_ref, func_ctx)?;
1469 }
1470
1471 match level {
1472 Sl::Auto => {}
1473 Sl::Zero => {
1474 if depth_ref.is_none() && !clamp_to_edge {
1476 write!(self.out, ", 0.0")?;
1477 }
1478 }
1479 Sl::Exact(expr) => {
1480 write!(self.out, ", ")?;
1481 self.write_expr(module, expr, func_ctx)?;
1482 }
1483 Sl::Bias(expr) => {
1484 write!(self.out, ", ")?;
1485 self.write_expr(module, expr, func_ctx)?;
1486 }
1487 Sl::Gradient { x, y } => {
1488 write!(self.out, ", ")?;
1489 self.write_expr(module, x, func_ctx)?;
1490 write!(self.out, ", ")?;
1491 self.write_expr(module, y, func_ctx)?;
1492 }
1493 }
1494
1495 if let Some(offset) = offset {
1496 write!(self.out, ", ")?;
1497 self.write_const_expression(module, offset, func_ctx.expressions)?;
1498 }
1499
1500 write!(self.out, ")")?;
1501 }
1502
1503 Expression::ImageSample {
1504 image,
1505 sampler,
1506 gather: Some(component),
1507 coordinate,
1508 array_index,
1509 offset,
1510 level: _,
1511 depth_ref,
1512 clamp_to_edge: _,
1513 } => {
1514 let suffix_cmp = match depth_ref {
1515 Some(_) => "Compare",
1516 None => "",
1517 };
1518
1519 write!(self.out, "textureGather{suffix_cmp}(")?;
1520 match *func_ctx.resolve_type(image, &module.types) {
1521 TypeInner::Image {
1522 class: crate::ImageClass::Depth { multi: _ },
1523 ..
1524 } => {}
1525 _ => {
1526 write!(self.out, "{}, ", component as u8)?;
1527 }
1528 }
1529 self.write_expr(module, image, func_ctx)?;
1530 write!(self.out, ", ")?;
1531 self.write_expr(module, sampler, func_ctx)?;
1532 write!(self.out, ", ")?;
1533 self.write_expr(module, coordinate, func_ctx)?;
1534
1535 if let Some(array_index) = array_index {
1536 write!(self.out, ", ")?;
1537 self.write_expr(module, array_index, func_ctx)?;
1538 }
1539
1540 if let Some(depth_ref) = depth_ref {
1541 write!(self.out, ", ")?;
1542 self.write_expr(module, depth_ref, func_ctx)?;
1543 }
1544
1545 if let Some(offset) = offset {
1546 write!(self.out, ", ")?;
1547 self.write_const_expression(module, offset, func_ctx.expressions)?;
1548 }
1549
1550 write!(self.out, ")")?;
1551 }
1552 Expression::ImageQuery { image, query } => {
1553 use crate::ImageQuery as Iq;
1554
1555 let texture_function = match query {
1556 Iq::Size { .. } => "textureDimensions",
1557 Iq::NumLevels => "textureNumLevels",
1558 Iq::NumLayers => "textureNumLayers",
1559 Iq::NumSamples => "textureNumSamples",
1560 };
1561
1562 write!(self.out, "{texture_function}(")?;
1563 self.write_expr(module, image, func_ctx)?;
1564 if let Iq::Size { level: Some(level) } = query {
1565 write!(self.out, ", ")?;
1566 self.write_expr(module, level, func_ctx)?;
1567 };
1568 write!(self.out, ")")?;
1569 }
1570
1571 Expression::ImageLoad {
1572 image,
1573 coordinate,
1574 array_index,
1575 sample,
1576 level,
1577 } => {
1578 write!(self.out, "textureLoad(")?;
1579 self.write_expr(module, image, func_ctx)?;
1580 write!(self.out, ", ")?;
1581 self.write_expr(module, coordinate, func_ctx)?;
1582 if let Some(array_index) = array_index {
1583 write!(self.out, ", ")?;
1584 self.write_expr(module, array_index, func_ctx)?;
1585 }
1586 if let Some(index) = sample.or(level) {
1587 write!(self.out, ", ")?;
1588 self.write_expr(module, index, func_ctx)?;
1589 }
1590 write!(self.out, ")")?;
1591 }
1592 Expression::GlobalVariable(handle) => {
1593 let name = &self.names[&NameKey::GlobalVariable(handle)];
1594 write!(self.out, "{name}")?;
1595 }
1596
1597 Expression::As {
1598 expr,
1599 kind,
1600 convert,
1601 } => {
1602 let inner = func_ctx.resolve_type(expr, &module.types);
1603 match *inner {
1604 TypeInner::Matrix {
1605 columns,
1606 rows,
1607 scalar,
1608 } => {
1609 let scalar = crate::Scalar {
1610 kind,
1611 width: convert.unwrap_or(scalar.width),
1612 };
1613 let scalar_kind_str = scalar.to_wgsl_if_implemented()?;
1614 write!(
1615 self.out,
1616 "mat{}x{}<{}>",
1617 common::vector_size_str(columns),
1618 common::vector_size_str(rows),
1619 scalar_kind_str
1620 )?;
1621 }
1622 TypeInner::Vector {
1623 size,
1624 scalar: crate::Scalar { width, .. },
1625 } => {
1626 let scalar = crate::Scalar {
1627 kind,
1628 width: convert.unwrap_or(width),
1629 };
1630 let vector_size_str = common::vector_size_str(size);
1631 let scalar_kind_str = scalar.to_wgsl_if_implemented()?;
1632 if convert.is_some() {
1633 write!(self.out, "vec{vector_size_str}<{scalar_kind_str}>")?;
1634 } else {
1635 write!(self.out, "bitcast<vec{vector_size_str}<{scalar_kind_str}>>")?;
1636 }
1637 }
1638 TypeInner::Scalar(crate::Scalar { width, .. }) => {
1639 let scalar = crate::Scalar {
1640 kind,
1641 width: convert.unwrap_or(width),
1642 };
1643 let scalar_kind_str = scalar.to_wgsl_if_implemented()?;
1644 if convert.is_some() {
1645 write!(self.out, "{scalar_kind_str}")?
1646 } else {
1647 write!(self.out, "bitcast<{scalar_kind_str}>")?
1648 }
1649 }
1650 _ => {
1651 return Err(Error::Unimplemented(format!(
1652 "write_expr expression::as {inner:?}"
1653 )));
1654 }
1655 };
1656 write!(self.out, "(")?;
1657 self.write_expr(module, expr, func_ctx)?;
1658 write!(self.out, ")")?;
1659 }
1660 Expression::Load { pointer } => {
1661 let is_atomic_pointer = func_ctx
1662 .resolve_type(pointer, &module.types)
1663 .is_atomic_pointer(&module.types);
1664
1665 if is_atomic_pointer {
1666 write!(self.out, "atomicLoad(")?;
1667 self.write_expr(module, pointer, func_ctx)?;
1668 write!(self.out, ")")?;
1669 } else {
1670 self.write_expr_with_indirection(
1671 module,
1672 pointer,
1673 func_ctx,
1674 Indirection::Reference,
1675 )?;
1676 }
1677 }
1678 Expression::LocalVariable(handle) => {
1679 write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])?
1680 }
1681 Expression::ArrayLength(expr) => {
1682 write!(self.out, "arrayLength(")?;
1683 self.write_expr(module, expr, func_ctx)?;
1684 write!(self.out, ")")?;
1685 }
1686
1687 Expression::Math {
1688 fun,
1689 arg,
1690 arg1,
1691 arg2,
1692 arg3,
1693 } => {
1694 use crate::MathFunction as Mf;
1695
1696 enum Function {
1697 Regular(&'static str),
1698 InversePolyfill(InversePolyfill),
1699 }
1700
1701 let function = match fun.try_to_wgsl() {
1702 Some(name) => Function::Regular(name),
1703 None => match fun {
1704 Mf::Inverse => {
1705 let ty = func_ctx.resolve_type(arg, &module.types);
1706 let Some(overload) = InversePolyfill::find_overload(ty) else {
1707 return Err(Error::unsupported("math function", fun));
1708 };
1709
1710 Function::InversePolyfill(overload)
1711 }
1712 _ => return Err(Error::unsupported("math function", fun)),
1713 },
1714 };
1715
1716 match function {
1717 Function::Regular(fun_name) => {
1718 write!(self.out, "{fun_name}(")?;
1719 self.write_expr(module, arg, func_ctx)?;
1720 for arg in IntoIterator::into_iter([arg1, arg2, arg3]).flatten() {
1721 write!(self.out, ", ")?;
1722 self.write_expr(module, arg, func_ctx)?;
1723 }
1724 write!(self.out, ")")?
1725 }
1726 Function::InversePolyfill(inverse) => {
1727 write!(self.out, "{}(", inverse.fun_name)?;
1728 self.write_expr(module, arg, func_ctx)?;
1729 write!(self.out, ")")?;
1730 self.required_polyfills.insert(inverse);
1731 }
1732 }
1733 }
1734
1735 Expression::Swizzle {
1736 size,
1737 vector,
1738 pattern,
1739 } => {
1740 self.write_expr(module, vector, func_ctx)?;
1741 write!(self.out, ".")?;
1742 for &sc in pattern[..size as usize].iter() {
1743 self.out.write_char(back::COMPONENTS[sc as usize])?;
1744 }
1745 }
1746 Expression::Unary { op, expr } => {
1747 let unary = match op {
1748 crate::UnaryOperator::Negate => "-",
1749 crate::UnaryOperator::LogicalNot => "!",
1750 crate::UnaryOperator::BitwiseNot => "~",
1751 };
1752
1753 write!(self.out, "{unary}(")?;
1754 self.write_expr(module, expr, func_ctx)?;
1755
1756 write!(self.out, ")")?
1757 }
1758
1759 Expression::Select {
1760 condition,
1761 accept,
1762 reject,
1763 } => {
1764 write!(self.out, "select(")?;
1765 self.write_expr(module, reject, func_ctx)?;
1766 write!(self.out, ", ")?;
1767 self.write_expr(module, accept, func_ctx)?;
1768 write!(self.out, ", ")?;
1769 self.write_expr(module, condition, func_ctx)?;
1770 write!(self.out, ")")?
1771 }
1772 Expression::Derivative { axis, ctrl, expr } => {
1773 use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
1774 let op = match (axis, ctrl) {
1775 (Axis::X, Ctrl::Coarse) => "dpdxCoarse",
1776 (Axis::X, Ctrl::Fine) => "dpdxFine",
1777 (Axis::X, Ctrl::None) => "dpdx",
1778 (Axis::Y, Ctrl::Coarse) => "dpdyCoarse",
1779 (Axis::Y, Ctrl::Fine) => "dpdyFine",
1780 (Axis::Y, Ctrl::None) => "dpdy",
1781 (Axis::Width, Ctrl::Coarse) => "fwidthCoarse",
1782 (Axis::Width, Ctrl::Fine) => "fwidthFine",
1783 (Axis::Width, Ctrl::None) => "fwidth",
1784 };
1785 write!(self.out, "{op}(")?;
1786 self.write_expr(module, expr, func_ctx)?;
1787 write!(self.out, ")")?
1788 }
1789 Expression::Relational { fun, argument } => {
1790 use crate::RelationalFunction as Rf;
1791
1792 let fun_name = match fun {
1793 Rf::All => "all",
1794 Rf::Any => "any",
1795 _ => return Err(Error::UnsupportedRelationalFunction(fun)),
1796 };
1797 write!(self.out, "{fun_name}(")?;
1798
1799 self.write_expr(module, argument, func_ctx)?;
1800
1801 write!(self.out, ")")?
1802 }
1803 Expression::RayQueryGetIntersection { .. }
1805 | Expression::RayQueryVertexPositions { .. } => unreachable!(),
1806 Expression::CallResult(_)
1808 | Expression::AtomicResult { .. }
1809 | Expression::RayQueryProceedResult
1810 | Expression::SubgroupBallotResult
1811 | Expression::SubgroupOperationResult { .. }
1812 | Expression::WorkGroupUniformLoadResult { .. } => {}
1813 Expression::CooperativeLoad {
1814 columns,
1815 rows,
1816 role,
1817 ref data,
1818 } => {
1819 let suffix = if data.row_major { "T" } else { "" };
1820 let scalar = func_ctx.info[data.pointer]
1821 .ty
1822 .inner_with(&module.types)
1823 .pointer_base_type()
1824 .unwrap()
1825 .inner_with(&module.types)
1826 .scalar()
1827 .unwrap();
1828 write!(
1829 self.out,
1830 "coopLoad{suffix}<coop_mat{}x{}<{},{:?}>>(",
1831 columns as u32,
1832 rows as u32,
1833 scalar.try_to_wgsl().unwrap(),
1834 role,
1835 )?;
1836 self.write_expr(module, data.pointer, func_ctx)?;
1837 write!(self.out, ", ")?;
1838 self.write_expr(module, data.stride, func_ctx)?;
1839 write!(self.out, ")")?;
1840 }
1841 Expression::CooperativeMultiplyAdd { a, b, c } => {
1842 write!(self.out, "coopMultiplyAdd(")?;
1843 self.write_expr(module, a, func_ctx)?;
1844 write!(self.out, ", ")?;
1845 self.write_expr(module, b, func_ctx)?;
1846 write!(self.out, ", ")?;
1847 self.write_expr(module, c, func_ctx)?;
1848 write!(self.out, ")")?;
1849 }
1850 }
1851
1852 Ok(())
1853 }
1854
1855 fn write_global(
1859 &mut self,
1860 module: &Module,
1861 global: &crate::GlobalVariable,
1862 handle: Handle<crate::GlobalVariable>,
1863 ) -> BackendResult {
1864 if let Some(ref binding) = global.binding {
1866 self.write_attributes(&[
1867 Attribute::Group(binding.group),
1868 Attribute::Binding(binding.binding),
1869 ])?;
1870 writeln!(self.out)?;
1871 }
1872
1873 write!(self.out, "var")?;
1875 let (address, maybe_access) = address_space_str(global.space);
1876 if let Some(space) = address {
1877 write!(self.out, "<{space}")?;
1878 if let Some(access) = maybe_access {
1879 write!(self.out, ", {access}")?;
1880 }
1881 write!(self.out, ">")?;
1882 }
1883 write!(
1884 self.out,
1885 " {}: ",
1886 &self.names[&NameKey::GlobalVariable(handle)]
1887 )?;
1888
1889 self.write_type(module, global.ty)?;
1891
1892 if let Some(init) = global.init {
1894 write!(self.out, " = ")?;
1895 self.write_const_expression(module, init, &module.global_expressions)?;
1896 }
1897
1898 writeln!(self.out, ";")?;
1900
1901 Ok(())
1902 }
1903
1904 fn write_global_constant(
1909 &mut self,
1910 module: &Module,
1911 handle: Handle<crate::Constant>,
1912 ) -> BackendResult {
1913 let name = &self.names[&NameKey::Constant(handle)];
1914 write!(self.out, "const {name}: ")?;
1916 self.write_type(module, module.constants[handle].ty)?;
1917 write!(self.out, " = ")?;
1918 let init = module.constants[handle].init;
1919 self.write_const_expression(module, init, &module.global_expressions)?;
1920 writeln!(self.out, ";")?;
1921
1922 Ok(())
1923 }
1924
1925 fn write_override(
1930 &mut self,
1931 module: &Module,
1932 handle: Handle<crate::Override>,
1933 ) -> BackendResult {
1934 let override_ = &module.overrides[handle];
1935 let name = &self.names[&NameKey::Override(handle)];
1936
1937 if let Some(id) = override_.id {
1939 write!(self.out, "@id({id}) ")?;
1940 }
1941
1942 write!(self.out, "override {name}: ")?;
1944 self.write_type(module, override_.ty)?;
1945
1946 if let Some(init) = override_.init {
1948 write!(self.out, " = ")?;
1949 self.write_const_expression(module, init, &module.global_expressions)?;
1950 }
1951
1952 writeln!(self.out, ";")?;
1953
1954 Ok(())
1955 }
1956
1957 #[allow(clippy::missing_const_for_fn)]
1959 pub fn finish(self) -> W {
1960 self.out
1961 }
1962}
1963
1964struct WriterTypeContext<'m> {
1965 module: &'m Module,
1966 names: &'m crate::FastHashMap<NameKey, String>,
1967}
1968
1969impl TypeContext for WriterTypeContext<'_> {
1970 fn lookup_type(&self, handle: Handle<crate::Type>) -> &crate::Type {
1971 &self.module.types[handle]
1972 }
1973
1974 fn type_name(&self, handle: Handle<crate::Type>) -> &str {
1975 self.names[&NameKey::Type(handle)].as_str()
1976 }
1977
1978 fn write_unnamed_struct<W: Write>(&self, _: &TypeInner, _: &mut W) -> core::fmt::Result {
1979 unreachable!("the WGSL back end should always provide type handles");
1980 }
1981
1982 fn write_override<W: Write>(
1983 &self,
1984 handle: Handle<crate::Override>,
1985 out: &mut W,
1986 ) -> core::fmt::Result {
1987 write!(out, "{}", self.names[&NameKey::Override(handle)])
1988 }
1989
1990 fn write_non_wgsl_inner<W: Write>(&self, _: &TypeInner, _: &mut W) -> core::fmt::Result {
1991 unreachable!("backends should only be passed validated modules");
1992 }
1993
1994 fn write_non_wgsl_scalar<W: Write>(&self, _: crate::Scalar, _: &mut W) -> core::fmt::Result {
1995 unreachable!("backends should only be passed validated modules");
1996 }
1997}
1998
1999fn map_binding_to_attribute(binding: &crate::Binding) -> Vec<Attribute> {
2000 match *binding {
2001 crate::Binding::BuiltIn(built_in) => {
2002 if let crate::BuiltIn::Position { invariant: true } = built_in {
2003 vec![Attribute::BuiltIn(built_in), Attribute::Invariant]
2004 } else {
2005 vec![Attribute::BuiltIn(built_in)]
2006 }
2007 }
2008 crate::Binding::Location {
2009 location,
2010 interpolation,
2011 sampling,
2012 blend_src: None,
2013 per_primitive,
2014 } => {
2015 let mut attrs = vec![
2016 Attribute::Location(location),
2017 Attribute::Interpolate(interpolation, sampling),
2018 ];
2019 if per_primitive {
2020 attrs.push(Attribute::PerPrimitive);
2021 }
2022 attrs
2023 }
2024 crate::Binding::Location {
2025 location,
2026 interpolation,
2027 sampling,
2028 blend_src: Some(blend_src),
2029 per_primitive,
2030 } => {
2031 let mut attrs = vec![
2032 Attribute::Location(location),
2033 Attribute::BlendSrc(blend_src),
2034 Attribute::Interpolate(interpolation, sampling),
2035 ];
2036 if per_primitive {
2037 attrs.push(Attribute::PerPrimitive);
2038 }
2039 attrs
2040 }
2041 }
2042}