1use alloc::{
2 format,
3 string::{String, ToString},
4 vec,
5 vec::Vec,
6};
7use core::fmt::Write;
8
9use super::Error;
10use super::ToWgslIfImplemented as _;
11use crate::{back::wgsl::polyfill::InversePolyfill, common::wgsl::TypeContext};
12use crate::{
13 back::{self, Baked},
14 common::{
15 self,
16 wgsl::{address_space_str, ToWgsl, TryToWgsl},
17 },
18 proc::{self, NameKey},
19 valid, Handle, Module, ShaderStage, TypeInner,
20};
21
22type BackendResult = Result<(), Error>;
24
25enum Attribute {
27 Binding(u32),
28 BuiltIn(crate::BuiltIn),
29 Group(u32),
30 Invariant,
31 Interpolate(Option<crate::Interpolation>, Option<crate::Sampling>),
32 Location(u32),
33 BlendSrc(u32),
34 Stage(ShaderStage),
35 WorkGroupSize([u32; 3]),
36 MeshStage(String),
37 TaskPayload(String),
38 PerPrimitive,
39 IncomingRayPayload(String),
40}
41
42#[derive(Clone, Copy, Debug)]
55enum Indirection {
56 Ordinary,
62
63 Reference,
69}
70
71bitflags::bitflags! {
72 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
73 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
74 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
75 pub struct WriterFlags: u32 {
76 const EXPLICIT_TYPES = 0x1;
78 }
79}
80
81#[expect(missing_debug_implementations, reason = "would be way too verbose?")]
82pub struct Writer<W> {
83 out: W,
84 flags: WriterFlags,
85 names: crate::FastHashMap<NameKey, String>,
86 namer: proc::Namer,
87 named_expressions: crate::NamedExpressions,
88 required_polyfills: crate::FastIndexSet<InversePolyfill>,
89}
90
91impl<W: Write> Writer<W> {
92 pub fn new(out: W, flags: WriterFlags) -> Self {
93 Writer {
94 out,
95 flags,
96 names: crate::FastHashMap::default(),
97 namer: proc::Namer::default(),
98 named_expressions: crate::NamedExpressions::default(),
99 required_polyfills: crate::FastIndexSet::default(),
100 }
101 }
102
103 fn reset(&mut self, module: &Module) {
104 self.names.clear();
105 self.namer.reset(
106 module,
107 &crate::keywords::wgsl::RESERVED_SET,
108 &crate::keywords::wgsl::BUILTIN_IDENTIFIER_SET,
109 proc::CaseInsensitiveKeywordSet::empty(),
111 &["__", "_naga"],
112 &mut self.names,
113 );
114 self.named_expressions.clear();
115 self.required_polyfills.clear();
116 }
117
118 fn is_builtin_wgsl_struct(&self, module: &Module, ty: Handle<crate::Type>) -> bool {
134 module
135 .special_types
136 .predeclared_types
137 .values()
138 .any(|t| *t == ty)
139 || Some(ty) == module.special_types.external_texture_params
140 || Some(ty) == module.special_types.external_texture_transfer_function
141 }
142
143 pub fn write(&mut self, module: &Module, info: &valid::ModuleInfo) -> BackendResult {
144 self.reset(module);
145
146 self.write_enable_declarations(module)?;
148
149 for (handle, ty) in module.types.iter() {
151 if let TypeInner::Struct { ref members, .. } = ty.inner {
152 {
153 if !self.is_builtin_wgsl_struct(module, handle) {
154 self.write_struct(module, handle, members)?;
155 writeln!(self.out)?;
156 }
157 }
158 }
159 }
160
161 let mut constants = module
163 .constants
164 .iter()
165 .filter(|&(_, c)| c.name.is_some())
166 .peekable();
167 while let Some((handle, _)) = constants.next() {
168 self.write_global_constant(module, handle)?;
169 if constants.peek().is_none() {
171 writeln!(self.out)?;
172 }
173 }
174
175 let mut overrides = module.overrides.iter().peekable();
177 while let Some((handle, _)) = overrides.next() {
178 self.write_override(module, handle)?;
179 if overrides.peek().is_none() {
181 writeln!(self.out)?;
182 }
183 }
184
185 for (ty, global) in module.global_variables.iter() {
187 self.write_global(module, global, ty)?;
188 }
189
190 if !module.global_variables.is_empty() {
191 writeln!(self.out)?;
193 }
194
195 for (handle, function) in module.functions.iter() {
197 let fun_info = &info[handle];
198
199 let func_ctx = back::FunctionCtx {
200 ty: back::FunctionType::Function(handle),
201 info: fun_info,
202 expressions: &function.expressions,
203 named_expressions: &function.named_expressions,
204 };
205
206 self.write_function(module, function, &func_ctx)?;
208
209 writeln!(self.out)?;
210 }
211
212 for (index, ep) in module.entry_points.iter().enumerate() {
214 let attributes = match ep.stage {
215 ShaderStage::Vertex | ShaderStage::Fragment => vec![Attribute::Stage(ep.stage)],
216 ShaderStage::Compute => vec![
217 Attribute::Stage(ShaderStage::Compute),
218 Attribute::WorkGroupSize(ep.workgroup_size),
219 ],
220 ShaderStage::Mesh => {
221 let mesh_output_name = module.global_variables
222 [ep.mesh_info.as_ref().unwrap().output_variable]
223 .name
224 .clone()
225 .unwrap();
226 let mut mesh_attrs = vec![
227 Attribute::MeshStage(mesh_output_name),
228 Attribute::WorkGroupSize(ep.workgroup_size),
229 ];
230 if let Some(task_payload) = ep.task_payload {
231 let payload_name =
232 module.global_variables[task_payload].name.clone().unwrap();
233 mesh_attrs.push(Attribute::TaskPayload(payload_name));
234 }
235 mesh_attrs
236 }
237 ShaderStage::Task => {
238 let payload_name = module.global_variables[ep.task_payload.unwrap()]
239 .name
240 .clone()
241 .unwrap();
242 vec![
243 Attribute::Stage(ShaderStage::Task),
244 Attribute::TaskPayload(payload_name),
245 Attribute::WorkGroupSize(ep.workgroup_size),
246 ]
247 }
248 ShaderStage::RayGeneration => vec![Attribute::Stage(ShaderStage::RayGeneration)],
249 ShaderStage::AnyHit | ShaderStage::ClosestHit | ShaderStage::Miss => {
250 let payload_name = module.global_variables[ep.incoming_ray_payload.unwrap()]
251 .name
252 .clone()
253 .unwrap();
254 vec![
255 Attribute::Stage(ep.stage),
256 Attribute::IncomingRayPayload(payload_name),
257 ]
258 }
259 };
260 self.write_attributes(&attributes)?;
261 writeln!(self.out)?;
263
264 let func_ctx = back::FunctionCtx {
265 ty: back::FunctionType::EntryPoint(index as u16),
266 info: info.get_entry_point(index),
267 expressions: &ep.function.expressions,
268 named_expressions: &ep.function.named_expressions,
269 };
270 self.write_function(module, &ep.function, &func_ctx)?;
271
272 if index < module.entry_points.len() - 1 {
273 writeln!(self.out)?;
274 }
275 }
276
277 for polyfill in &self.required_polyfills {
279 writeln!(self.out)?;
280 write!(self.out, "{}", polyfill.source)?;
281 writeln!(self.out)?;
282 }
283
284 Ok(())
285 }
286
287 fn write_enable_declarations(&mut self, module: &Module) -> BackendResult {
290 #[derive(Default)]
291 struct RequiredEnabled {
292 f16: bool,
293 int16: bool,
294 dual_source_blending: bool,
295 clip_distances: bool,
296 mesh_shaders: bool,
297 primitive_index: bool,
298 cooperative_matrix: bool,
299 draw_index: bool,
300 ray_tracing_pipeline: bool,
301 per_vertex: bool,
302 binding_array: bool,
303 }
304 let mut needed = RequiredEnabled {
305 mesh_shaders: module.uses_mesh_shaders(),
306 ..Default::default()
307 };
308
309 let check_binding = |binding: &crate::Binding, needed: &mut RequiredEnabled| match *binding
310 {
311 crate::Binding::Location {
312 blend_src: Some(_), ..
313 } => {
314 needed.dual_source_blending = true;
315 }
316 crate::Binding::BuiltIn(crate::BuiltIn::ClipDistances) => {
317 needed.clip_distances = true;
318 }
319 crate::Binding::BuiltIn(crate::BuiltIn::PrimitiveIndex) => {
320 needed.primitive_index = true;
321 }
322 crate::Binding::Location {
323 per_primitive: true,
324 ..
325 } => {
326 needed.mesh_shaders = true;
327 }
328 crate::Binding::Location {
329 interpolation: Some(crate::Interpolation::PerVertex),
330 ..
331 } => {
332 needed.per_vertex = true;
333 }
334 crate::Binding::BuiltIn(crate::BuiltIn::DrawIndex) => needed.draw_index = true,
335 crate::Binding::BuiltIn(
336 crate::BuiltIn::RayInvocationId
337 | crate::BuiltIn::NumRayInvocations
338 | crate::BuiltIn::InstanceCustomData
339 | crate::BuiltIn::GeometryIndex
340 | crate::BuiltIn::WorldRayOrigin
341 | crate::BuiltIn::WorldRayDirection
342 | crate::BuiltIn::ObjectRayOrigin
343 | crate::BuiltIn::ObjectRayDirection
344 | crate::BuiltIn::RayTmin
345 | crate::BuiltIn::RayTCurrentMax
346 | crate::BuiltIn::ObjectToWorld
347 | crate::BuiltIn::WorldToObject,
348 ) => {
349 needed.ray_tracing_pipeline = true;
350 }
351 _ => {}
352 };
353
354 for (_, ty) in module.types.iter() {
356 match ty.inner {
357 TypeInner::Scalar(scalar)
358 | TypeInner::Vector { scalar, .. }
359 | TypeInner::Matrix { scalar, .. } => {
360 needed.f16 |= scalar == crate::Scalar::F16;
361 needed.int16 |= scalar == crate::Scalar::I16 || scalar == crate::Scalar::U16;
362 }
363 TypeInner::Struct { ref members, .. } => {
364 for binding in members.iter().filter_map(|m| m.binding.as_ref()) {
365 check_binding(binding, &mut needed);
366 }
367 }
368 TypeInner::CooperativeMatrix { .. } => {
369 needed.cooperative_matrix = true;
370 }
371 TypeInner::AccelerationStructure { .. } => {
372 needed.ray_tracing_pipeline = true;
373 }
374 TypeInner::BindingArray { .. } => {
375 needed.binding_array = true;
376 }
377 _ => {}
378 }
379 }
380
381 for ep in &module.entry_points {
382 if let Some(res) = ep.function.result.as_ref().and_then(|a| a.binding.as_ref()) {
383 check_binding(res, &mut needed);
384 }
385 for arg in ep
386 .function
387 .arguments
388 .iter()
389 .filter_map(|a| a.binding.as_ref())
390 {
391 check_binding(arg, &mut needed);
392 }
393 }
394
395 if module.global_variables.iter().any(|gv| {
396 gv.1.space == crate::AddressSpace::IncomingRayPayload
397 || gv.1.space == crate::AddressSpace::RayPayload
398 }) {
399 needed.ray_tracing_pipeline = true;
400 }
401
402 if module.entry_points.iter().any(|ep| {
403 matches!(
404 ep.stage,
405 ShaderStage::RayGeneration
406 | ShaderStage::AnyHit
407 | ShaderStage::ClosestHit
408 | ShaderStage::Miss
409 )
410 }) {
411 needed.ray_tracing_pipeline = true;
412 }
413
414 if module.global_variables.iter().any(|gv| {
415 gv.1.space == crate::AddressSpace::IncomingRayPayload
416 || gv.1.space == crate::AddressSpace::RayPayload
417 }) {
418 needed.ray_tracing_pipeline = true;
419 }
420
421 if module.entry_points.iter().any(|ep| {
422 matches!(
423 ep.stage,
424 ShaderStage::RayGeneration
425 | ShaderStage::AnyHit
426 | ShaderStage::ClosestHit
427 | ShaderStage::Miss
428 )
429 }) {
430 needed.ray_tracing_pipeline = true;
431 }
432
433 let mut any_written = false;
435 if needed.f16 {
436 writeln!(self.out, "enable f16;")?;
437 any_written = true;
438 }
439 if needed.int16 {
440 writeln!(self.out, "enable wgpu_int16;")?;
441 any_written = true;
442 }
443 if needed.dual_source_blending {
444 writeln!(self.out, "enable dual_source_blending;")?;
445 any_written = true;
446 }
447 if needed.clip_distances {
448 writeln!(self.out, "enable clip_distances;")?;
449 any_written = true;
450 }
451 if module.uses_mesh_shaders() {
452 writeln!(self.out, "enable wgpu_mesh_shader;")?;
453 any_written = true;
454 }
455 if needed.binding_array {
456 writeln!(self.out, "enable wgpu_binding_array;")?;
457 any_written = true;
458 }
459 if needed.draw_index {
460 writeln!(self.out, "enable draw_index;")?;
461 any_written = true;
462 }
463 if needed.primitive_index {
464 writeln!(self.out, "enable primitive_index;")?;
465 any_written = true;
466 }
467 if needed.cooperative_matrix {
468 writeln!(self.out, "enable wgpu_cooperative_matrix;")?;
469 any_written = true;
470 }
471 if needed.ray_tracing_pipeline {
472 writeln!(self.out, "enable wgpu_ray_tracing_pipeline;")?;
473 any_written = true;
474 }
475 if needed.per_vertex {
476 writeln!(self.out, "enable wgpu_per_vertex;")?;
477 any_written = true;
478 }
479 if any_written {
480 writeln!(self.out)?;
482 }
483
484 Ok(())
485 }
486
487 fn write_function(
493 &mut self,
494 module: &Module,
495 func: &crate::Function,
496 func_ctx: &back::FunctionCtx<'_>,
497 ) -> BackendResult {
498 let func_name = match func_ctx.ty {
499 back::FunctionType::EntryPoint(index) => &self.names[&NameKey::EntryPoint(index)],
500 back::FunctionType::Function(handle) => &self.names[&NameKey::Function(handle)],
501 };
502
503 write!(self.out, "fn {func_name}(")?;
505
506 for (index, arg) in func.arguments.iter().enumerate() {
508 if let Some(ref binding) = arg.binding {
510 self.write_attributes(&map_binding_to_attribute(binding))?;
511 }
512 let argument_name = &self.names[&func_ctx.argument_key(index as u32)];
514
515 write!(self.out, "{argument_name}: ")?;
516 self.write_type(module, arg.ty)?;
518 if index < func.arguments.len() - 1 {
519 write!(self.out, ", ")?;
521 }
522 }
523
524 write!(self.out, ")")?;
525
526 if let Some(ref result) = func.result {
528 write!(self.out, " -> ")?;
529 if let Some(ref binding) = result.binding {
530 self.write_attributes(&map_binding_to_attribute(binding))?;
531 }
532 self.write_type(module, result.ty)?;
533 }
534
535 write!(self.out, " {{")?;
536 writeln!(self.out)?;
537
538 for (handle, local) in func.local_variables.iter() {
540 write!(self.out, "{}", back::INDENT)?;
542
543 write!(self.out, "var {}: ", self.names[&func_ctx.name_key(handle)])?;
546
547 self.write_type(module, local.ty)?;
549
550 if let Some(init) = local.init {
552 write!(self.out, " = ")?;
555
556 self.write_expr(module, init, func_ctx)?;
559 }
560
561 writeln!(self.out, ";")?
563 }
564
565 if !func.local_variables.is_empty() {
566 writeln!(self.out)?;
567 }
568
569 for sta in func.body.iter() {
571 self.write_stmt(module, sta, func_ctx, back::Level(1))?;
573 }
574
575 writeln!(self.out, "}}")?;
576
577 self.named_expressions.clear();
578
579 Ok(())
580 }
581
582 fn write_attributes(&mut self, attributes: &[Attribute]) -> BackendResult {
584 for attribute in attributes {
585 match *attribute {
586 Attribute::Location(id) => write!(self.out, "@location({id}) ")?,
587 Attribute::BlendSrc(blend_src) => write!(self.out, "@blend_src({blend_src}) ")?,
588 Attribute::BuiltIn(builtin_attrib) => {
589 let builtin = builtin_attrib.to_wgsl_if_implemented()?;
590 write!(self.out, "@builtin({builtin}) ")?;
591 }
592 Attribute::Stage(shader_stage) => {
593 let stage_str = match shader_stage {
594 ShaderStage::Vertex => "vertex",
595 ShaderStage::Fragment => "fragment",
596 ShaderStage::Compute => "compute",
597 ShaderStage::Task => "task",
598 ShaderStage::Mesh => unreachable!(),
600 ShaderStage::RayGeneration => "ray_generation",
601 ShaderStage::AnyHit => "any_hit",
602 ShaderStage::ClosestHit => "closest_hit",
603 ShaderStage::Miss => "miss",
604 };
605
606 write!(self.out, "@{stage_str} ")?;
607 }
608 Attribute::WorkGroupSize(size) => {
609 write!(
610 self.out,
611 "@workgroup_size({}, {}, {}) ",
612 size[0], size[1], size[2]
613 )?;
614 }
615 Attribute::Binding(id) => write!(self.out, "@binding({id}) ")?,
616 Attribute::Group(id) => write!(self.out, "@group({id}) ")?,
617 Attribute::Invariant => write!(self.out, "@invariant ")?,
618 Attribute::Interpolate(interpolation, sampling) => {
619 if sampling.is_some() && sampling != Some(crate::Sampling::Center) {
620 let interpolation = interpolation
621 .unwrap_or(crate::Interpolation::Perspective)
622 .to_wgsl();
623 let sampling = sampling.unwrap_or(crate::Sampling::Center).to_wgsl();
624 write!(self.out, "@interpolate({interpolation}, {sampling}) ")?;
625 } else if interpolation.is_some()
626 && interpolation != Some(crate::Interpolation::Perspective)
627 {
628 let interpolation = interpolation
629 .unwrap_or(crate::Interpolation::Perspective)
630 .to_wgsl();
631 write!(self.out, "@interpolate({interpolation}) ")?;
632 }
633 }
634 Attribute::MeshStage(ref name) => {
635 write!(self.out, "@mesh({name}) ")?;
636 }
637 Attribute::TaskPayload(ref payload_name) => {
638 write!(self.out, "@payload({payload_name}) ")?;
639 }
640 Attribute::PerPrimitive => write!(self.out, "@per_primitive ")?,
641 Attribute::IncomingRayPayload(ref payload_name) => {
642 write!(self.out, "@incoming_payload({payload_name}) ")?;
643 }
644 };
645 }
646 Ok(())
647 }
648
649 fn write_struct(
660 &mut self,
661 module: &Module,
662 handle: Handle<crate::Type>,
663 members: &[crate::StructMember],
664 ) -> BackendResult {
665 write!(self.out, "struct {}", self.names[&NameKey::Type(handle)])?;
666 write!(self.out, " {{")?;
667 writeln!(self.out)?;
668 for (index, member) in members.iter().enumerate() {
669 write!(self.out, "{}", back::INDENT)?;
671 if let Some(ref binding) = member.binding {
672 self.write_attributes(&map_binding_to_attribute(binding))?;
673 }
674 let member_name = &self.names[&NameKey::StructMember(handle, index as u32)];
676 write!(self.out, "{member_name}: ")?;
677 self.write_type(module, member.ty)?;
678 write!(self.out, ",")?;
679 writeln!(self.out)?;
680 }
681
682 writeln!(self.out, "}}")?;
683
684 Ok(())
685 }
686
687 fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
688 let type_context = WriterTypeContext {
692 module,
693 names: &self.names,
694 };
695 type_context.write_type(ty, &mut self.out)?;
696
697 Ok(())
698 }
699
700 fn write_type_resolution(
701 &mut self,
702 module: &Module,
703 resolution: &proc::TypeResolution,
704 ) -> BackendResult {
705 let type_context = WriterTypeContext {
709 module,
710 names: &self.names,
711 };
712 type_context.write_type_resolution(resolution, &mut self.out)?;
713
714 Ok(())
715 }
716
717 fn write_stmt(
722 &mut self,
723 module: &Module,
724 stmt: &crate::Statement,
725 func_ctx: &back::FunctionCtx<'_>,
726 level: back::Level,
727 ) -> BackendResult {
728 use crate::{Expression, Statement};
729
730 match *stmt {
731 Statement::Emit(ref range) => {
732 for handle in range.clone() {
733 let info = &func_ctx.info[handle];
734 let expr_name = if let Some(name) = func_ctx.named_expressions.get(&handle) {
735 Some(self.namer.call(name))
740 } else {
741 let expr = &func_ctx.expressions[handle];
742 let min_ref_count = expr.bake_ref_count();
743 let required_baking_expr = match *expr {
745 Expression::ImageLoad { .. }
746 | Expression::ImageQuery { .. }
747 | Expression::ImageSample { .. } => true,
748 _ => false,
749 };
750 if min_ref_count <= info.ref_count || required_baking_expr {
751 Some(Baked(handle).to_string())
752 } else {
753 None
754 }
755 };
756
757 if let Some(name) = expr_name {
758 write!(self.out, "{level}")?;
759 self.start_named_expr(module, handle, func_ctx, &name)?;
760 self.write_expr(module, handle, func_ctx)?;
761 self.named_expressions.insert(handle, name);
762 writeln!(self.out, ";")?;
763 }
764 }
765 }
766 Statement::If {
768 condition,
769 ref accept,
770 ref reject,
771 } => {
772 write!(self.out, "{level}")?;
773 write!(self.out, "if ")?;
774 self.write_expr(module, condition, func_ctx)?;
775 writeln!(self.out, " {{")?;
776
777 let l2 = level.next();
778 for sta in accept {
779 self.write_stmt(module, sta, func_ctx, l2)?;
781 }
782
783 if !reject.is_empty() {
786 writeln!(self.out, "{level}}} else {{")?;
787
788 for sta in reject {
789 self.write_stmt(module, sta, func_ctx, l2)?;
791 }
792 }
793
794 writeln!(self.out, "{level}}}")?
795 }
796 Statement::Return { value } => {
797 write!(self.out, "{level}")?;
798 write!(self.out, "return")?;
799 if let Some(return_value) = value {
800 write!(self.out, " ")?;
802 self.write_expr(module, return_value, func_ctx)?;
803 }
804 writeln!(self.out, ";")?;
805 }
806 Statement::Kill => {
808 write!(self.out, "{level}")?;
809 writeln!(self.out, "discard;")?
810 }
811 Statement::Store { pointer, value } => {
812 write!(self.out, "{level}")?;
813
814 let is_atomic_pointer = func_ctx
815 .resolve_type(pointer, &module.types)
816 .is_atomic_pointer(&module.types);
817
818 if is_atomic_pointer {
819 write!(self.out, "atomicStore(")?;
820 self.write_expr(module, pointer, func_ctx)?;
821 write!(self.out, ", ")?;
822 self.write_expr(module, value, func_ctx)?;
823 write!(self.out, ")")?;
824 } else {
825 self.write_expr_with_indirection(
826 module,
827 pointer,
828 func_ctx,
829 Indirection::Reference,
830 )?;
831 write!(self.out, " = ")?;
832 self.write_expr(module, value, func_ctx)?;
833 }
834 writeln!(self.out, ";")?
835 }
836 Statement::Call {
837 function,
838 ref arguments,
839 result,
840 } => {
841 write!(self.out, "{level}")?;
842 if let Some(expr) = result {
843 let name = Baked(expr).to_string();
844 self.start_named_expr(module, expr, func_ctx, &name)?;
845 self.named_expressions.insert(expr, name);
846 }
847 let func_name = &self.names[&NameKey::Function(function)];
848 write!(self.out, "{func_name}(")?;
849 for (index, &argument) in arguments.iter().enumerate() {
850 if index != 0 {
851 write!(self.out, ", ")?;
852 }
853 self.write_expr(module, argument, func_ctx)?;
854 }
855 writeln!(self.out, ");")?
856 }
857 Statement::Atomic {
858 pointer,
859 ref fun,
860 value,
861 result,
862 } => {
863 write!(self.out, "{level}")?;
864 if let Some(result) = result {
865 let res_name = Baked(result).to_string();
866 self.start_named_expr(module, result, func_ctx, &res_name)?;
867 self.named_expressions.insert(result, res_name);
868 }
869
870 let fun_str = fun.to_wgsl();
871 write!(self.out, "atomic{fun_str}(")?;
872 self.write_expr(module, pointer, func_ctx)?;
873 if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
874 write!(self.out, ", ")?;
875 self.write_expr(module, cmp, func_ctx)?;
876 }
877 write!(self.out, ", ")?;
878 self.write_expr(module, value, func_ctx)?;
879 writeln!(self.out, ");")?
880 }
881 Statement::ImageAtomic {
882 image,
883 coordinate,
884 array_index,
885 ref fun,
886 value,
887 } => {
888 write!(self.out, "{level}")?;
889 let fun_str = fun.to_wgsl();
890 write!(self.out, "textureAtomic{fun_str}(")?;
891 self.write_expr(module, image, func_ctx)?;
892 write!(self.out, ", ")?;
893 self.write_expr(module, coordinate, func_ctx)?;
894 if let Some(array_index_expr) = array_index {
895 write!(self.out, ", ")?;
896 self.write_expr(module, array_index_expr, func_ctx)?;
897 }
898 write!(self.out, ", ")?;
899 self.write_expr(module, value, func_ctx)?;
900 writeln!(self.out, ");")?;
901 }
902 Statement::WorkGroupUniformLoad { pointer, result } => {
903 write!(self.out, "{level}")?;
904 let res_name = Baked(result).to_string();
906 self.start_named_expr(module, result, func_ctx, &res_name)?;
907 self.named_expressions.insert(result, res_name);
908 write!(self.out, "workgroupUniformLoad(")?;
909 self.write_expr(module, pointer, func_ctx)?;
910 writeln!(self.out, ");")?;
911 }
912 Statement::ImageStore {
913 image,
914 coordinate,
915 array_index,
916 value,
917 } => {
918 write!(self.out, "{level}")?;
919 write!(self.out, "textureStore(")?;
920 self.write_expr(module, image, func_ctx)?;
921 write!(self.out, ", ")?;
922 self.write_expr(module, coordinate, func_ctx)?;
923 if let Some(array_index_expr) = array_index {
924 write!(self.out, ", ")?;
925 self.write_expr(module, array_index_expr, func_ctx)?;
926 }
927 write!(self.out, ", ")?;
928 self.write_expr(module, value, func_ctx)?;
929 writeln!(self.out, ");")?;
930 }
931 Statement::Block(ref block) => {
933 write!(self.out, "{level}")?;
934 writeln!(self.out, "{{")?;
935 for sta in block.iter() {
936 self.write_stmt(module, sta, func_ctx, level.next())?
938 }
939 writeln!(self.out, "{level}}}")?
940 }
941 Statement::Switch {
942 selector,
943 ref cases,
944 } => {
945 write!(self.out, "{level}")?;
947 write!(self.out, "switch ")?;
948 self.write_expr(module, selector, func_ctx)?;
949 writeln!(self.out, " {{")?;
950
951 let l2 = level.next();
952 let mut new_case = true;
953 for case in cases {
954 if case.fall_through && !case.body.is_empty() {
955 return Err(Error::Unimplemented(
957 "fall-through switch case block".into(),
958 ));
959 }
960
961 match case.value {
962 crate::SwitchValue::I32(value) => {
963 if new_case {
964 write!(self.out, "{l2}case ")?;
965 }
966 write!(self.out, "{value}")?;
967 }
968 crate::SwitchValue::U32(value) => {
969 if new_case {
970 write!(self.out, "{l2}case ")?;
971 }
972 write!(self.out, "{value}u")?;
973 }
974 crate::SwitchValue::Default => {
975 if new_case {
976 if case.fall_through {
977 write!(self.out, "{l2}case ")?;
978 } else {
979 write!(self.out, "{l2}")?;
980 }
981 }
982 write!(self.out, "default")?;
983 }
984 }
985
986 new_case = !case.fall_through;
987
988 if case.fall_through {
989 write!(self.out, ", ")?;
990 } else {
991 writeln!(self.out, ": {{")?;
992 }
993
994 for sta in case.body.iter() {
995 self.write_stmt(module, sta, func_ctx, l2.next())?;
996 }
997
998 if !case.fall_through {
999 writeln!(self.out, "{l2}}}")?;
1000 }
1001 }
1002
1003 writeln!(self.out, "{level}}}")?
1004 }
1005 Statement::Loop {
1006 ref body,
1007 ref continuing,
1008 break_if,
1009 } => {
1010 write!(self.out, "{level}")?;
1011 writeln!(self.out, "loop {{")?;
1012
1013 let l2 = level.next();
1014 for sta in body.iter() {
1015 self.write_stmt(module, sta, func_ctx, l2)?;
1016 }
1017
1018 if !continuing.is_empty() || break_if.is_some() {
1023 writeln!(self.out, "{l2}continuing {{")?;
1024 for sta in continuing.iter() {
1025 self.write_stmt(module, sta, func_ctx, l2.next())?;
1026 }
1027
1028 if let Some(condition) = break_if {
1031 write!(self.out, "{}break if ", l2.next())?;
1033 self.write_expr(module, condition, func_ctx)?;
1034 writeln!(self.out, ";")?;
1036 }
1037
1038 writeln!(self.out, "{l2}}}")?;
1039 }
1040
1041 writeln!(self.out, "{level}}}")?
1042 }
1043 Statement::Break => {
1044 writeln!(self.out, "{level}break;")?;
1045 }
1046 Statement::Continue => {
1047 writeln!(self.out, "{level}continue;")?;
1048 }
1049 Statement::ControlBarrier(barrier) | Statement::MemoryBarrier(barrier) => {
1050 if barrier.contains(crate::Barrier::STORAGE) {
1051 writeln!(self.out, "{level}storageBarrier();")?;
1052 }
1053
1054 if barrier.contains(crate::Barrier::WORK_GROUP) {
1055 writeln!(self.out, "{level}workgroupBarrier();")?;
1056 }
1057
1058 if barrier.contains(crate::Barrier::SUB_GROUP) {
1059 writeln!(self.out, "{level}subgroupBarrier();")?;
1060 }
1061
1062 if barrier.contains(crate::Barrier::TEXTURE) {
1063 writeln!(self.out, "{level}textureBarrier();")?;
1064 }
1065 }
1066 Statement::RayQuery { .. } => unreachable!(),
1067 Statement::SubgroupBallot { result, predicate } => {
1068 write!(self.out, "{level}")?;
1069 let res_name = Baked(result).to_string();
1070 self.start_named_expr(module, result, func_ctx, &res_name)?;
1071 self.named_expressions.insert(result, res_name);
1072
1073 write!(self.out, "subgroupBallot(")?;
1074 if let Some(predicate) = predicate {
1075 self.write_expr(module, predicate, func_ctx)?;
1076 }
1077 writeln!(self.out, ");")?;
1078 }
1079 Statement::SubgroupCollectiveOperation {
1080 op,
1081 collective_op,
1082 argument,
1083 result,
1084 } => {
1085 write!(self.out, "{level}")?;
1086 let res_name = Baked(result).to_string();
1087 self.start_named_expr(module, result, func_ctx, &res_name)?;
1088 self.named_expressions.insert(result, res_name);
1089
1090 match (collective_op, op) {
1091 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
1092 write!(self.out, "subgroupAll(")?
1093 }
1094 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
1095 write!(self.out, "subgroupAny(")?
1096 }
1097 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
1098 write!(self.out, "subgroupAdd(")?
1099 }
1100 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
1101 write!(self.out, "subgroupMul(")?
1102 }
1103 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
1104 write!(self.out, "subgroupMax(")?
1105 }
1106 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
1107 write!(self.out, "subgroupMin(")?
1108 }
1109 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
1110 write!(self.out, "subgroupAnd(")?
1111 }
1112 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
1113 write!(self.out, "subgroupOr(")?
1114 }
1115 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
1116 write!(self.out, "subgroupXor(")?
1117 }
1118 (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
1119 write!(self.out, "subgroupExclusiveAdd(")?
1120 }
1121 (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
1122 write!(self.out, "subgroupExclusiveMul(")?
1123 }
1124 (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
1125 write!(self.out, "subgroupInclusiveAdd(")?
1126 }
1127 (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
1128 write!(self.out, "subgroupInclusiveMul(")?
1129 }
1130 _ => unimplemented!(),
1131 }
1132 self.write_expr(module, argument, func_ctx)?;
1133 writeln!(self.out, ");")?;
1134 }
1135 Statement::SubgroupGather {
1136 mode,
1137 argument,
1138 result,
1139 } => {
1140 write!(self.out, "{level}")?;
1141 let res_name = Baked(result).to_string();
1142 self.start_named_expr(module, result, func_ctx, &res_name)?;
1143 self.named_expressions.insert(result, res_name);
1144
1145 match mode {
1146 crate::GatherMode::BroadcastFirst => {
1147 write!(self.out, "subgroupBroadcastFirst(")?;
1148 }
1149 crate::GatherMode::Broadcast(_) => {
1150 write!(self.out, "subgroupBroadcast(")?;
1151 }
1152 crate::GatherMode::Shuffle(_) => {
1153 write!(self.out, "subgroupShuffle(")?;
1154 }
1155 crate::GatherMode::ShuffleDown(_) => {
1156 write!(self.out, "subgroupShuffleDown(")?;
1157 }
1158 crate::GatherMode::ShuffleUp(_) => {
1159 write!(self.out, "subgroupShuffleUp(")?;
1160 }
1161 crate::GatherMode::ShuffleXor(_) => {
1162 write!(self.out, "subgroupShuffleXor(")?;
1163 }
1164 crate::GatherMode::QuadBroadcast(_) => {
1165 write!(self.out, "quadBroadcast(")?;
1166 }
1167 crate::GatherMode::QuadSwap(direction) => match direction {
1168 crate::Direction::X => {
1169 write!(self.out, "quadSwapX(")?;
1170 }
1171 crate::Direction::Y => {
1172 write!(self.out, "quadSwapY(")?;
1173 }
1174 crate::Direction::Diagonal => {
1175 write!(self.out, "quadSwapDiagonal(")?;
1176 }
1177 },
1178 }
1179 self.write_expr(module, argument, func_ctx)?;
1180 match mode {
1181 crate::GatherMode::BroadcastFirst => {}
1182 crate::GatherMode::Broadcast(index)
1183 | crate::GatherMode::Shuffle(index)
1184 | crate::GatherMode::ShuffleDown(index)
1185 | crate::GatherMode::ShuffleUp(index)
1186 | crate::GatherMode::ShuffleXor(index)
1187 | crate::GatherMode::QuadBroadcast(index) => {
1188 write!(self.out, ", ")?;
1189 self.write_expr(module, index, func_ctx)?;
1190 }
1191 crate::GatherMode::QuadSwap(_) => {}
1192 }
1193 writeln!(self.out, ");")?;
1194 }
1195 Statement::CooperativeStore { target, ref data } => {
1196 let suffix = if data.row_major { "T" } else { "" };
1197 write!(self.out, "{level}coopStore{suffix}(")?;
1198 self.write_expr(module, target, func_ctx)?;
1199 write!(self.out, ", ")?;
1200 self.write_expr(module, data.pointer, func_ctx)?;
1201 write!(self.out, ", ")?;
1202 self.write_expr(module, data.stride, func_ctx)?;
1203 writeln!(self.out, ");")?
1204 }
1205 Statement::RayPipelineFunction(fun) => match fun {
1206 crate::RayPipelineFunction::TraceRay {
1207 acceleration_structure,
1208 descriptor,
1209 payload,
1210 } => {
1211 write!(self.out, "{level}traceRay(")?;
1212 self.write_expr(module, acceleration_structure, func_ctx)?;
1213 write!(self.out, ", ")?;
1214 self.write_expr(module, descriptor, func_ctx)?;
1215 write!(self.out, ", ")?;
1216 self.write_expr(module, payload, func_ctx)?;
1217 writeln!(self.out, ");")?
1218 }
1219 },
1220 }
1221
1222 Ok(())
1223 }
1224
1225 fn plain_form_indirection(
1248 &self,
1249 expr: Handle<crate::Expression>,
1250 module: &Module,
1251 func_ctx: &back::FunctionCtx<'_>,
1252 ) -> Indirection {
1253 use crate::Expression as Ex;
1254
1255 if self.named_expressions.contains_key(&expr) {
1259 return Indirection::Ordinary;
1260 }
1261
1262 match func_ctx.expressions[expr] {
1263 Ex::LocalVariable(_) => Indirection::Reference,
1264 Ex::GlobalVariable(handle) => {
1265 let global = &module.global_variables[handle];
1266 match global.space {
1267 crate::AddressSpace::Handle => Indirection::Ordinary,
1268 _ => Indirection::Reference,
1269 }
1270 }
1271 Ex::Access { base, .. } | Ex::AccessIndex { base, .. } => {
1272 let base_ty = func_ctx.resolve_type(base, &module.types);
1273 match *base_ty {
1274 TypeInner::Pointer { .. } | TypeInner::ValuePointer { .. } => {
1275 Indirection::Reference
1276 }
1277 _ => Indirection::Ordinary,
1278 }
1279 }
1280 _ => Indirection::Ordinary,
1281 }
1282 }
1283
1284 fn start_named_expr(
1285 &mut self,
1286 module: &Module,
1287 handle: Handle<crate::Expression>,
1288 func_ctx: &back::FunctionCtx,
1289 name: &str,
1290 ) -> BackendResult {
1291 write!(self.out, "let {name}")?;
1293 if self.flags.contains(WriterFlags::EXPLICIT_TYPES) {
1294 write!(self.out, ": ")?;
1295 self.write_type_resolution(module, &func_ctx.info[handle].ty)?;
1297 }
1298
1299 write!(self.out, " = ")?;
1300 Ok(())
1301 }
1302
1303 fn write_expr(
1307 &mut self,
1308 module: &Module,
1309 expr: Handle<crate::Expression>,
1310 func_ctx: &back::FunctionCtx<'_>,
1311 ) -> BackendResult {
1312 self.write_expr_with_indirection(module, expr, func_ctx, Indirection::Ordinary)
1313 }
1314
1315 fn write_expr_with_indirection(
1328 &mut self,
1329 module: &Module,
1330 expr: Handle<crate::Expression>,
1331 func_ctx: &back::FunctionCtx<'_>,
1332 requested: Indirection,
1333 ) -> BackendResult {
1334 let plain = self.plain_form_indirection(expr, module, func_ctx);
1337 log::trace!(
1338 "expression {:?}={:?} is {:?}, expected {:?}",
1339 expr,
1340 func_ctx.expressions[expr],
1341 plain,
1342 requested,
1343 );
1344 match (requested, plain) {
1345 (Indirection::Ordinary, Indirection::Reference) => {
1346 write!(self.out, "(&")?;
1347 self.write_expr_plain_form(module, expr, func_ctx, plain)?;
1348 write!(self.out, ")")?;
1349 }
1350 (Indirection::Reference, Indirection::Ordinary) => {
1351 write!(self.out, "(*")?;
1352 self.write_expr_plain_form(module, expr, func_ctx, plain)?;
1353 write!(self.out, ")")?;
1354 }
1355 (_, _) => self.write_expr_plain_form(module, expr, func_ctx, plain)?,
1356 }
1357
1358 Ok(())
1359 }
1360
1361 fn write_const_expression(
1362 &mut self,
1363 module: &Module,
1364 expr: Handle<crate::Expression>,
1365 arena: &crate::Arena<crate::Expression>,
1366 ) -> BackendResult {
1367 self.write_possibly_const_expression(module, expr, arena, |writer, expr| {
1368 writer.write_const_expression(module, expr, arena)
1369 })
1370 }
1371
1372 fn write_possibly_const_expression<E>(
1373 &mut self,
1374 module: &Module,
1375 expr: Handle<crate::Expression>,
1376 expressions: &crate::Arena<crate::Expression>,
1377 write_expression: E,
1378 ) -> BackendResult
1379 where
1380 E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult,
1381 {
1382 use crate::Expression;
1383
1384 match expressions[expr] {
1385 Expression::Literal(literal) => match literal {
1386 crate::Literal::F16(value) => write!(self.out, "{value}h")?,
1387 crate::Literal::F32(value) => write!(self.out, "{value}f")?,
1388 crate::Literal::U16(value) => write!(self.out, "u16({value})")?,
1389 crate::Literal::I16(value) => write!(self.out, "i16({value})")?,
1390 crate::Literal::U32(value) => write!(self.out, "{value}u")?,
1391 crate::Literal::I32(value) => {
1392 if value == i32::MIN {
1396 write!(self.out, "i32({value})")?;
1397 } else {
1398 write!(self.out, "{value}i")?;
1399 }
1400 }
1401 crate::Literal::Bool(value) => write!(self.out, "{value}")?,
1402 crate::Literal::F64(value) => write!(self.out, "{value:?}lf")?,
1403 crate::Literal::I64(value) => {
1404 if value == i64::MIN {
1409 write!(self.out, "i64({} - 1)", value + 1)?;
1410 } else {
1411 write!(self.out, "{value}li")?;
1412 }
1413 }
1414 crate::Literal::U64(value) => write!(self.out, "{value:?}lu")?,
1415 crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
1416 return Err(Error::Custom(
1417 "Abstract types should not appear in IR presented to backends".into(),
1418 ));
1419 }
1420 },
1421 Expression::Constant(handle) => {
1422 let constant = &module.constants[handle];
1423 if constant.name.is_some() {
1424 write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
1425 } else {
1426 self.write_const_expression(module, constant.init, &module.global_expressions)?;
1427 }
1428 }
1429 Expression::ZeroValue(ty) => {
1430 self.write_type(module, ty)?;
1431 write!(self.out, "()")?;
1432 }
1433 Expression::Compose { ty, ref components } => {
1434 self.write_type(module, ty)?;
1435 write!(self.out, "(")?;
1436 for (index, component) in components.iter().enumerate() {
1437 if index != 0 {
1438 write!(self.out, ", ")?;
1439 }
1440 write_expression(self, *component)?;
1441 }
1442 write!(self.out, ")")?
1443 }
1444 Expression::Splat { size, value } => {
1445 let size = common::vector_size_str(size);
1446 write!(self.out, "vec{size}(")?;
1447 write_expression(self, value)?;
1448 write!(self.out, ")")?;
1449 }
1450 Expression::Override(handle) => {
1451 write!(self.out, "{}", self.names[&NameKey::Override(handle)])?;
1452 }
1453 _ => unreachable!(),
1454 }
1455
1456 Ok(())
1457 }
1458
1459 fn write_expr_plain_form(
1467 &mut self,
1468 module: &Module,
1469 expr: Handle<crate::Expression>,
1470 func_ctx: &back::FunctionCtx<'_>,
1471 indirection: Indirection,
1472 ) -> BackendResult {
1473 use crate::Expression;
1474
1475 if let Some(name) = self.named_expressions.get(&expr) {
1476 write!(self.out, "{name}")?;
1477 return Ok(());
1478 }
1479
1480 let expression = &func_ctx.expressions[expr];
1481
1482 match *expression {
1491 Expression::Literal(_)
1492 | Expression::Constant(_)
1493 | Expression::ZeroValue(_)
1494 | Expression::Compose { .. }
1495 | Expression::Splat { .. } => {
1496 self.write_possibly_const_expression(
1497 module,
1498 expr,
1499 func_ctx.expressions,
1500 |writer, expr| writer.write_expr(module, expr, func_ctx),
1501 )?;
1502 }
1503 Expression::Override(handle) => {
1504 write!(self.out, "{}", self.names[&NameKey::Override(handle)])?;
1505 }
1506 Expression::FunctionArgument(pos) => {
1507 let name_key = func_ctx.argument_key(pos);
1508 let name = &self.names[&name_key];
1509 write!(self.out, "{name}")?;
1510 }
1511 Expression::Binary { op, left, right } => {
1512 write!(self.out, "(")?;
1513 self.write_expr(module, left, func_ctx)?;
1514 write!(self.out, " {} ", back::binary_operation_str(op))?;
1515 self.write_expr(module, right, func_ctx)?;
1516 write!(self.out, ")")?;
1517 }
1518 Expression::Access { base, index } => {
1519 self.write_expr_with_indirection(module, base, func_ctx, indirection)?;
1520 write!(self.out, "[")?;
1521 self.write_expr(module, index, func_ctx)?;
1522 write!(self.out, "]")?
1523 }
1524 Expression::AccessIndex { base, index } => {
1525 let base_ty_res = &func_ctx.info[base].ty;
1526 let mut resolved = base_ty_res.inner_with(&module.types);
1527
1528 self.write_expr_with_indirection(module, base, func_ctx, indirection)?;
1529
1530 let base_ty_handle = match *resolved {
1531 TypeInner::Pointer { base, space: _ } => {
1532 resolved = &module.types[base].inner;
1533 Some(base)
1534 }
1535 _ => base_ty_res.handle(),
1536 };
1537
1538 match *resolved {
1539 TypeInner::Vector { .. } => {
1540 write!(self.out, ".{}", back::COMPONENTS[index as usize])?
1542 }
1543 TypeInner::Matrix { .. }
1544 | TypeInner::Array { .. }
1545 | TypeInner::BindingArray { .. }
1546 | TypeInner::ValuePointer { .. } => write!(self.out, "[{index}]")?,
1547 TypeInner::Struct { .. } => {
1548 let ty = base_ty_handle.unwrap();
1551
1552 write!(
1553 self.out,
1554 ".{}",
1555 &self.names[&NameKey::StructMember(ty, index)]
1556 )?
1557 }
1558 ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))),
1559 }
1560 }
1561 Expression::ImageSample {
1562 image,
1563 sampler,
1564 gather: None,
1565 coordinate,
1566 array_index,
1567 offset,
1568 level,
1569 depth_ref,
1570 clamp_to_edge,
1571 } => {
1572 use crate::SampleLevel as Sl;
1573
1574 let suffix_cmp = match depth_ref {
1575 Some(_) => "Compare",
1576 None => "",
1577 };
1578 let suffix_level = match level {
1579 Sl::Auto => "",
1580 Sl::Zero if clamp_to_edge => "BaseClampToEdge",
1581 Sl::Zero | Sl::Exact(_) => "Level",
1582 Sl::Bias(_) => "Bias",
1583 Sl::Gradient { .. } => "Grad",
1584 };
1585
1586 write!(self.out, "textureSample{suffix_cmp}{suffix_level}(")?;
1587 self.write_expr(module, image, func_ctx)?;
1588 write!(self.out, ", ")?;
1589 self.write_expr(module, sampler, func_ctx)?;
1590 write!(self.out, ", ")?;
1591 self.write_expr(module, coordinate, func_ctx)?;
1592
1593 if let Some(array_index) = array_index {
1594 write!(self.out, ", ")?;
1595 self.write_expr(module, array_index, func_ctx)?;
1596 }
1597
1598 if let Some(depth_ref) = depth_ref {
1599 write!(self.out, ", ")?;
1600 self.write_expr(module, depth_ref, func_ctx)?;
1601 }
1602
1603 match level {
1604 Sl::Auto => {}
1605 Sl::Zero => {
1606 if depth_ref.is_none() && !clamp_to_edge {
1608 write!(self.out, ", 0.0")?;
1609 }
1610 }
1611 Sl::Exact(expr) => {
1612 write!(self.out, ", ")?;
1613 self.write_expr(module, expr, func_ctx)?;
1614 }
1615 Sl::Bias(expr) => {
1616 write!(self.out, ", ")?;
1617 self.write_expr(module, expr, func_ctx)?;
1618 }
1619 Sl::Gradient { x, y } => {
1620 write!(self.out, ", ")?;
1621 self.write_expr(module, x, func_ctx)?;
1622 write!(self.out, ", ")?;
1623 self.write_expr(module, y, func_ctx)?;
1624 }
1625 }
1626
1627 if let Some(offset) = offset {
1628 write!(self.out, ", ")?;
1629 self.write_const_expression(module, offset, func_ctx.expressions)?;
1630 }
1631
1632 write!(self.out, ")")?;
1633 }
1634
1635 Expression::ImageSample {
1636 image,
1637 sampler,
1638 gather: Some(component),
1639 coordinate,
1640 array_index,
1641 offset,
1642 level: _,
1643 depth_ref,
1644 clamp_to_edge: _,
1645 } => {
1646 let suffix_cmp = match depth_ref {
1647 Some(_) => "Compare",
1648 None => "",
1649 };
1650
1651 write!(self.out, "textureGather{suffix_cmp}(")?;
1652 match *func_ctx.resolve_type(image, &module.types) {
1653 TypeInner::Image {
1654 class: crate::ImageClass::Depth { multi: _ },
1655 ..
1656 } => {}
1657 _ => {
1658 write!(self.out, "{}, ", component as u8)?;
1659 }
1660 }
1661 self.write_expr(module, image, func_ctx)?;
1662 write!(self.out, ", ")?;
1663 self.write_expr(module, sampler, func_ctx)?;
1664 write!(self.out, ", ")?;
1665 self.write_expr(module, coordinate, func_ctx)?;
1666
1667 if let Some(array_index) = array_index {
1668 write!(self.out, ", ")?;
1669 self.write_expr(module, array_index, func_ctx)?;
1670 }
1671
1672 if let Some(depth_ref) = depth_ref {
1673 write!(self.out, ", ")?;
1674 self.write_expr(module, depth_ref, func_ctx)?;
1675 }
1676
1677 if let Some(offset) = offset {
1678 write!(self.out, ", ")?;
1679 self.write_const_expression(module, offset, func_ctx.expressions)?;
1680 }
1681
1682 write!(self.out, ")")?;
1683 }
1684 Expression::ImageQuery { image, query } => {
1685 use crate::ImageQuery as Iq;
1686
1687 let texture_function = match query {
1688 Iq::Size { .. } => "textureDimensions",
1689 Iq::NumLevels => "textureNumLevels",
1690 Iq::NumLayers => "textureNumLayers",
1691 Iq::NumSamples => "textureNumSamples",
1692 };
1693
1694 write!(self.out, "{texture_function}(")?;
1695 self.write_expr(module, image, func_ctx)?;
1696 if let Iq::Size { level: Some(level) } = query {
1697 write!(self.out, ", ")?;
1698 self.write_expr(module, level, func_ctx)?;
1699 };
1700 write!(self.out, ")")?;
1701 }
1702
1703 Expression::ImageLoad {
1704 image,
1705 coordinate,
1706 array_index,
1707 sample,
1708 level,
1709 } => {
1710 write!(self.out, "textureLoad(")?;
1711 self.write_expr(module, image, func_ctx)?;
1712 write!(self.out, ", ")?;
1713 self.write_expr(module, coordinate, func_ctx)?;
1714 if let Some(array_index) = array_index {
1715 write!(self.out, ", ")?;
1716 self.write_expr(module, array_index, func_ctx)?;
1717 }
1718 if let Some(index) = sample.or(level) {
1719 write!(self.out, ", ")?;
1720 self.write_expr(module, index, func_ctx)?;
1721 }
1722 write!(self.out, ")")?;
1723 }
1724 Expression::GlobalVariable(handle) => {
1725 let name = &self.names[&NameKey::GlobalVariable(handle)];
1726 write!(self.out, "{name}")?;
1727 }
1728
1729 Expression::As {
1730 expr,
1731 kind,
1732 convert,
1733 } => {
1734 let inner = func_ctx.resolve_type(expr, &module.types);
1735 match *inner {
1736 TypeInner::Matrix {
1737 columns,
1738 rows,
1739 scalar,
1740 } => {
1741 let scalar = crate::Scalar {
1742 kind,
1743 width: convert.unwrap_or(scalar.width),
1744 };
1745 let scalar_kind_str = scalar.to_wgsl_if_implemented()?;
1746 write!(
1747 self.out,
1748 "mat{}x{}<{}>",
1749 common::vector_size_str(columns),
1750 common::vector_size_str(rows),
1751 scalar_kind_str
1752 )?;
1753 }
1754 TypeInner::Vector {
1755 size,
1756 scalar: crate::Scalar { width, .. },
1757 } => {
1758 let scalar = crate::Scalar {
1759 kind,
1760 width: convert.unwrap_or(width),
1761 };
1762 let vector_size_str = common::vector_size_str(size);
1763 let scalar_kind_str = scalar.to_wgsl_if_implemented()?;
1764 if convert.is_some() {
1765 write!(self.out, "vec{vector_size_str}<{scalar_kind_str}>")?;
1766 } else {
1767 write!(self.out, "bitcast<vec{vector_size_str}<{scalar_kind_str}>>")?;
1768 }
1769 }
1770 TypeInner::Scalar(crate::Scalar { width, .. }) => {
1771 let scalar = crate::Scalar {
1772 kind,
1773 width: convert.unwrap_or(width),
1774 };
1775 let scalar_kind_str = scalar.to_wgsl_if_implemented()?;
1776 if convert.is_some() {
1777 write!(self.out, "{scalar_kind_str}")?
1778 } else {
1779 write!(self.out, "bitcast<{scalar_kind_str}>")?
1780 }
1781 }
1782 _ => {
1783 return Err(Error::Unimplemented(format!(
1784 "write_expr expression::as {inner:?}"
1785 )));
1786 }
1787 };
1788 write!(self.out, "(")?;
1789 self.write_expr(module, expr, func_ctx)?;
1790 write!(self.out, ")")?;
1791 }
1792 Expression::Load { pointer } => {
1793 let is_atomic_pointer = func_ctx
1794 .resolve_type(pointer, &module.types)
1795 .is_atomic_pointer(&module.types);
1796
1797 if is_atomic_pointer {
1798 write!(self.out, "atomicLoad(")?;
1799 self.write_expr(module, pointer, func_ctx)?;
1800 write!(self.out, ")")?;
1801 } else {
1802 self.write_expr_with_indirection(
1803 module,
1804 pointer,
1805 func_ctx,
1806 Indirection::Reference,
1807 )?;
1808 }
1809 }
1810 Expression::LocalVariable(handle) => {
1811 write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])?
1812 }
1813 Expression::ArrayLength(expr) => {
1814 write!(self.out, "arrayLength(")?;
1815 self.write_expr(module, expr, func_ctx)?;
1816 write!(self.out, ")")?;
1817 }
1818
1819 Expression::Math {
1820 fun,
1821 arg,
1822 arg1,
1823 arg2,
1824 arg3,
1825 } => {
1826 use crate::MathFunction as Mf;
1827
1828 enum Function {
1829 Regular(&'static str),
1830 InversePolyfill(InversePolyfill),
1831 }
1832
1833 let function = match fun.try_to_wgsl() {
1834 Some(name) => Function::Regular(name),
1835 None => match fun {
1836 Mf::Inverse => {
1837 let ty = func_ctx.resolve_type(arg, &module.types);
1838 let Some(overload) = InversePolyfill::find_overload(ty) else {
1839 return Err(Error::unsupported("math function", fun));
1840 };
1841
1842 Function::InversePolyfill(overload)
1843 }
1844 _ => return Err(Error::unsupported("math function", fun)),
1845 },
1846 };
1847
1848 match function {
1849 Function::Regular(fun_name) => {
1850 write!(self.out, "{fun_name}(")?;
1851 self.write_expr(module, arg, func_ctx)?;
1852 for arg in IntoIterator::into_iter([arg1, arg2, arg3]).flatten() {
1853 write!(self.out, ", ")?;
1854 self.write_expr(module, arg, func_ctx)?;
1855 }
1856 write!(self.out, ")")?
1857 }
1858 Function::InversePolyfill(inverse) => {
1859 write!(self.out, "{}(", inverse.fun_name)?;
1860 self.write_expr(module, arg, func_ctx)?;
1861 write!(self.out, ")")?;
1862 self.required_polyfills.insert(inverse);
1863 }
1864 }
1865 }
1866
1867 Expression::Swizzle {
1868 size,
1869 vector,
1870 pattern,
1871 } => {
1872 self.write_expr(module, vector, func_ctx)?;
1873 write!(self.out, ".")?;
1874 for &sc in pattern[..size as usize].iter() {
1875 self.out.write_char(back::COMPONENTS[sc as usize])?;
1876 }
1877 }
1878 Expression::Unary { op, expr } => {
1879 let unary = match op {
1880 crate::UnaryOperator::Negate => "-",
1881 crate::UnaryOperator::LogicalNot => "!",
1882 crate::UnaryOperator::BitwiseNot => "~",
1883 };
1884
1885 write!(self.out, "{unary}(")?;
1886 self.write_expr(module, expr, func_ctx)?;
1887
1888 write!(self.out, ")")?
1889 }
1890
1891 Expression::Select {
1892 condition,
1893 accept,
1894 reject,
1895 } => {
1896 write!(self.out, "select(")?;
1897 self.write_expr(module, reject, func_ctx)?;
1898 write!(self.out, ", ")?;
1899 self.write_expr(module, accept, func_ctx)?;
1900 write!(self.out, ", ")?;
1901 self.write_expr(module, condition, func_ctx)?;
1902 write!(self.out, ")")?
1903 }
1904 Expression::Derivative { axis, ctrl, expr } => {
1905 use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
1906 let op = match (axis, ctrl) {
1907 (Axis::X, Ctrl::Coarse) => "dpdxCoarse",
1908 (Axis::X, Ctrl::Fine) => "dpdxFine",
1909 (Axis::X, Ctrl::None) => "dpdx",
1910 (Axis::Y, Ctrl::Coarse) => "dpdyCoarse",
1911 (Axis::Y, Ctrl::Fine) => "dpdyFine",
1912 (Axis::Y, Ctrl::None) => "dpdy",
1913 (Axis::Width, Ctrl::Coarse) => "fwidthCoarse",
1914 (Axis::Width, Ctrl::Fine) => "fwidthFine",
1915 (Axis::Width, Ctrl::None) => "fwidth",
1916 };
1917 write!(self.out, "{op}(")?;
1918 self.write_expr(module, expr, func_ctx)?;
1919 write!(self.out, ")")?
1920 }
1921 Expression::Relational { fun, argument } => {
1922 use crate::RelationalFunction as Rf;
1923
1924 let fun_name = match fun {
1925 Rf::All => "all",
1926 Rf::Any => "any",
1927 _ => return Err(Error::UnsupportedRelationalFunction(fun)),
1928 };
1929 write!(self.out, "{fun_name}(")?;
1930
1931 self.write_expr(module, argument, func_ctx)?;
1932
1933 write!(self.out, ")")?
1934 }
1935 Expression::RayQueryGetIntersection { .. }
1937 | Expression::RayQueryVertexPositions { .. } => unreachable!(),
1938 Expression::CallResult(_)
1940 | Expression::AtomicResult { .. }
1941 | Expression::RayQueryProceedResult
1942 | Expression::SubgroupBallotResult
1943 | Expression::SubgroupOperationResult { .. }
1944 | Expression::WorkGroupUniformLoadResult { .. } => {}
1945 Expression::CooperativeLoad {
1946 columns,
1947 rows,
1948 role,
1949 ref data,
1950 } => {
1951 let suffix = if data.row_major { "T" } else { "" };
1952 let scalar = func_ctx.info[data.pointer]
1953 .ty
1954 .inner_with(&module.types)
1955 .pointer_base_type()
1956 .unwrap()
1957 .inner_with(&module.types)
1958 .scalar()
1959 .unwrap();
1960 write!(
1961 self.out,
1962 "coopLoad{suffix}<coop_mat{}x{}<{},{:?}>>(",
1963 columns as u32,
1964 rows as u32,
1965 scalar.try_to_wgsl().unwrap(),
1966 role,
1967 )?;
1968 self.write_expr(module, data.pointer, func_ctx)?;
1969 write!(self.out, ", ")?;
1970 self.write_expr(module, data.stride, func_ctx)?;
1971 write!(self.out, ")")?;
1972 }
1973 Expression::CooperativeMultiplyAdd { a, b, c } => {
1974 write!(self.out, "coopMultiplyAdd(")?;
1975 self.write_expr(module, a, func_ctx)?;
1976 write!(self.out, ", ")?;
1977 self.write_expr(module, b, func_ctx)?;
1978 write!(self.out, ", ")?;
1979 self.write_expr(module, c, func_ctx)?;
1980 write!(self.out, ")")?;
1981 }
1982 }
1983
1984 Ok(())
1985 }
1986
1987 fn write_global(
1991 &mut self,
1992 module: &Module,
1993 global: &crate::GlobalVariable,
1994 handle: Handle<crate::GlobalVariable>,
1995 ) -> BackendResult {
1996 if let Some(ref binding) = global.binding {
1998 self.write_attributes(&[
1999 Attribute::Group(binding.group),
2000 Attribute::Binding(binding.binding),
2001 ])?;
2002 writeln!(self.out)?;
2003 }
2004
2005 if global
2006 .memory_decorations
2007 .contains(crate::MemoryDecorations::COHERENT)
2008 {
2009 write!(self.out, "@coherent ")?;
2010 }
2011 if global
2012 .memory_decorations
2013 .contains(crate::MemoryDecorations::VOLATILE)
2014 {
2015 write!(self.out, "@volatile ")?;
2016 }
2017
2018 write!(self.out, "var")?;
2020 let (address, maybe_access) = address_space_str(global.space);
2021 if let Some(space) = address {
2022 write!(self.out, "<{space}")?;
2023 if let Some(access) = maybe_access {
2024 write!(self.out, ", {access}")?;
2025 }
2026 write!(self.out, ">")?;
2027 }
2028 write!(
2029 self.out,
2030 " {}: ",
2031 &self.names[&NameKey::GlobalVariable(handle)]
2032 )?;
2033
2034 self.write_type(module, global.ty)?;
2036
2037 if let Some(init) = global.init {
2039 write!(self.out, " = ")?;
2040 self.write_const_expression(module, init, &module.global_expressions)?;
2041 }
2042
2043 writeln!(self.out, ";")?;
2045
2046 Ok(())
2047 }
2048
2049 fn write_global_constant(
2054 &mut self,
2055 module: &Module,
2056 handle: Handle<crate::Constant>,
2057 ) -> BackendResult {
2058 let name = &self.names[&NameKey::Constant(handle)];
2059 write!(self.out, "const {name}: ")?;
2061 self.write_type(module, module.constants[handle].ty)?;
2062 write!(self.out, " = ")?;
2063 let init = module.constants[handle].init;
2064 self.write_const_expression(module, init, &module.global_expressions)?;
2065 writeln!(self.out, ";")?;
2066
2067 Ok(())
2068 }
2069
2070 fn write_override(
2075 &mut self,
2076 module: &Module,
2077 handle: Handle<crate::Override>,
2078 ) -> BackendResult {
2079 let override_ = &module.overrides[handle];
2080 let name = &self.names[&NameKey::Override(handle)];
2081
2082 if let Some(id) = override_.id {
2084 write!(self.out, "@id({id}) ")?;
2085 }
2086
2087 write!(self.out, "override {name}: ")?;
2089 self.write_type(module, override_.ty)?;
2090
2091 if let Some(init) = override_.init {
2093 write!(self.out, " = ")?;
2094 self.write_const_expression(module, init, &module.global_expressions)?;
2095 }
2096
2097 writeln!(self.out, ";")?;
2098
2099 Ok(())
2100 }
2101
2102 pub fn finish(self) -> W {
2104 self.out
2105 }
2106}
2107
2108struct WriterTypeContext<'m> {
2109 module: &'m Module,
2110 names: &'m crate::FastHashMap<NameKey, String>,
2111}
2112
2113impl TypeContext for WriterTypeContext<'_> {
2114 fn lookup_type(&self, handle: Handle<crate::Type>) -> &crate::Type {
2115 &self.module.types[handle]
2116 }
2117
2118 fn type_name(&self, handle: Handle<crate::Type>) -> &str {
2119 self.names[&NameKey::Type(handle)].as_str()
2120 }
2121
2122 fn write_unnamed_struct<W: Write>(&self, _: &TypeInner, _: &mut W) -> core::fmt::Result {
2123 unreachable!("the WGSL back end should always provide type handles");
2124 }
2125
2126 fn write_override<W: Write>(
2127 &self,
2128 handle: Handle<crate::Override>,
2129 out: &mut W,
2130 ) -> core::fmt::Result {
2131 write!(out, "{}", self.names[&NameKey::Override(handle)])
2132 }
2133
2134 fn write_non_wgsl_inner<W: Write>(&self, _: &TypeInner, _: &mut W) -> core::fmt::Result {
2135 unreachable!("backends should only be passed validated modules");
2136 }
2137
2138 fn write_non_wgsl_scalar<W: Write>(&self, _: crate::Scalar, _: &mut W) -> core::fmt::Result {
2139 unreachable!("backends should only be passed validated modules");
2140 }
2141}
2142
2143fn map_binding_to_attribute(binding: &crate::Binding) -> Vec<Attribute> {
2144 match *binding {
2145 crate::Binding::BuiltIn(built_in) => {
2146 if let crate::BuiltIn::Position { invariant: true } = built_in {
2147 vec![Attribute::BuiltIn(built_in), Attribute::Invariant]
2148 } else {
2149 vec![Attribute::BuiltIn(built_in)]
2150 }
2151 }
2152 crate::Binding::Location {
2153 location,
2154 interpolation,
2155 sampling,
2156 blend_src: None,
2157 per_primitive,
2158 } => {
2159 let mut attrs = vec![
2160 Attribute::Location(location),
2161 Attribute::Interpolate(interpolation, sampling),
2162 ];
2163 if per_primitive {
2164 attrs.push(Attribute::PerPrimitive);
2165 }
2166 attrs
2167 }
2168 crate::Binding::Location {
2169 location,
2170 interpolation,
2171 sampling,
2172 blend_src: Some(blend_src),
2173 per_primitive,
2174 } => {
2175 let mut attrs = vec![
2176 Attribute::Location(location),
2177 Attribute::BlendSrc(blend_src),
2178 Attribute::Interpolate(interpolation, sampling),
2179 ];
2180 if per_primitive {
2181 attrs.push(Attribute::PerPrimitive);
2182 }
2183 attrs
2184 }
2185 }
2186}