1use alloc::{
2 format,
3 string::{String, ToString},
4 vec,
5 vec::Vec,
6};
7use core::fmt::Write;
8
9use super::Error;
10use super::ToWgslIfImplemented as _;
11use crate::{back::wgsl::polyfill::InversePolyfill, common::wgsl::TypeContext};
12use crate::{
13 back::{self, Baked},
14 common::{
15 self,
16 wgsl::{address_space_str, ToWgsl, TryToWgsl},
17 },
18 proc::{self, NameKey},
19 valid, Handle, Module, ShaderStage, TypeInner,
20};
21
22type BackendResult = Result<(), Error>;
24
25enum Attribute {
27 Binding(u32),
28 BuiltIn(crate::BuiltIn),
29 Group(u32),
30 Invariant,
31 Interpolate(Option<crate::Interpolation>, Option<crate::Sampling>),
32 Location(u32),
33 BlendSrc(u32),
34 Stage(ShaderStage),
35 WorkGroupSize([u32; 3]),
36 MeshStage(String),
37 TaskPayload(String),
38 PerPrimitive,
39 IncomingRayPayload(String),
40}
41
42#[derive(Clone, Copy, Debug)]
55enum Indirection {
56 Ordinary,
62
63 Reference,
69}
70
71bitflags::bitflags! {
72 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
73 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
74 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
75 pub struct WriterFlags: u32 {
76 const EXPLICIT_TYPES = 0x1;
78 }
79}
80
81pub struct Writer<W> {
82 out: W,
83 flags: WriterFlags,
84 names: crate::FastHashMap<NameKey, String>,
85 namer: proc::Namer,
86 named_expressions: crate::NamedExpressions,
87 required_polyfills: crate::FastIndexSet<InversePolyfill>,
88}
89
90impl<W: Write> Writer<W> {
91 pub fn new(out: W, flags: WriterFlags) -> Self {
92 Writer {
93 out,
94 flags,
95 names: crate::FastHashMap::default(),
96 namer: proc::Namer::default(),
97 named_expressions: crate::NamedExpressions::default(),
98 required_polyfills: crate::FastIndexSet::default(),
99 }
100 }
101
102 fn reset(&mut self, module: &Module) {
103 self.names.clear();
104 self.namer.reset(
105 module,
106 &crate::keywords::wgsl::RESERVED_SET,
107 &crate::keywords::wgsl::BUILTIN_IDENTIFIER_SET,
108 proc::CaseInsensitiveKeywordSet::empty(),
110 &["__", "_naga"],
111 &mut self.names,
112 );
113 self.named_expressions.clear();
114 self.required_polyfills.clear();
115 }
116
117 fn is_builtin_wgsl_struct(&self, module: &Module, ty: Handle<crate::Type>) -> bool {
133 module
134 .special_types
135 .predeclared_types
136 .values()
137 .any(|t| *t == ty)
138 || Some(ty) == module.special_types.external_texture_params
139 || Some(ty) == module.special_types.external_texture_transfer_function
140 }
141
142 pub fn write(&mut self, module: &Module, info: &valid::ModuleInfo) -> BackendResult {
143 self.reset(module);
144
145 self.write_enable_declarations(module)?;
147
148 for (handle, ty) in module.types.iter() {
150 if let TypeInner::Struct { ref members, .. } = ty.inner {
151 {
152 if !self.is_builtin_wgsl_struct(module, handle) {
153 self.write_struct(module, handle, members)?;
154 writeln!(self.out)?;
155 }
156 }
157 }
158 }
159
160 let mut constants = module
162 .constants
163 .iter()
164 .filter(|&(_, c)| c.name.is_some())
165 .peekable();
166 while let Some((handle, _)) = constants.next() {
167 self.write_global_constant(module, handle)?;
168 if constants.peek().is_none() {
170 writeln!(self.out)?;
171 }
172 }
173
174 let mut overrides = module.overrides.iter().peekable();
176 while let Some((handle, _)) = overrides.next() {
177 self.write_override(module, handle)?;
178 if overrides.peek().is_none() {
180 writeln!(self.out)?;
181 }
182 }
183
184 for (ty, global) in module.global_variables.iter() {
186 self.write_global(module, global, ty)?;
187 }
188
189 if !module.global_variables.is_empty() {
190 writeln!(self.out)?;
192 }
193
194 for (handle, function) in module.functions.iter() {
196 let fun_info = &info[handle];
197
198 let func_ctx = back::FunctionCtx {
199 ty: back::FunctionType::Function(handle),
200 info: fun_info,
201 expressions: &function.expressions,
202 named_expressions: &function.named_expressions,
203 };
204
205 self.write_function(module, function, &func_ctx)?;
207
208 writeln!(self.out)?;
209 }
210
211 for (index, ep) in module.entry_points.iter().enumerate() {
213 let attributes = match ep.stage {
214 ShaderStage::Vertex | ShaderStage::Fragment => vec![Attribute::Stage(ep.stage)],
215 ShaderStage::Compute => vec![
216 Attribute::Stage(ShaderStage::Compute),
217 Attribute::WorkGroupSize(ep.workgroup_size),
218 ],
219 ShaderStage::Mesh => {
220 let mesh_output_name = module.global_variables
221 [ep.mesh_info.as_ref().unwrap().output_variable]
222 .name
223 .clone()
224 .unwrap();
225 let mut mesh_attrs = vec![
226 Attribute::MeshStage(mesh_output_name),
227 Attribute::WorkGroupSize(ep.workgroup_size),
228 ];
229 if let Some(task_payload) = ep.task_payload {
230 let payload_name =
231 module.global_variables[task_payload].name.clone().unwrap();
232 mesh_attrs.push(Attribute::TaskPayload(payload_name));
233 }
234 mesh_attrs
235 }
236 ShaderStage::Task => {
237 let payload_name = module.global_variables[ep.task_payload.unwrap()]
238 .name
239 .clone()
240 .unwrap();
241 vec![
242 Attribute::Stage(ShaderStage::Task),
243 Attribute::TaskPayload(payload_name),
244 Attribute::WorkGroupSize(ep.workgroup_size),
245 ]
246 }
247 ShaderStage::RayGeneration => vec![Attribute::Stage(ShaderStage::RayGeneration)],
248 ShaderStage::AnyHit | ShaderStage::ClosestHit | ShaderStage::Miss => {
249 let payload_name = module.global_variables[ep.incoming_ray_payload.unwrap()]
250 .name
251 .clone()
252 .unwrap();
253 vec![
254 Attribute::Stage(ep.stage),
255 Attribute::IncomingRayPayload(payload_name),
256 ]
257 }
258 };
259 self.write_attributes(&attributes)?;
260 writeln!(self.out)?;
262
263 let func_ctx = back::FunctionCtx {
264 ty: back::FunctionType::EntryPoint(index as u16),
265 info: info.get_entry_point(index),
266 expressions: &ep.function.expressions,
267 named_expressions: &ep.function.named_expressions,
268 };
269 self.write_function(module, &ep.function, &func_ctx)?;
270
271 if index < module.entry_points.len() - 1 {
272 writeln!(self.out)?;
273 }
274 }
275
276 for polyfill in &self.required_polyfills {
278 writeln!(self.out)?;
279 write!(self.out, "{}", polyfill.source)?;
280 writeln!(self.out)?;
281 }
282
283 Ok(())
284 }
285
286 fn write_enable_declarations(&mut self, module: &Module) -> BackendResult {
289 #[derive(Default)]
290 struct RequiredEnabled {
291 f16: bool,
292 dual_source_blending: bool,
293 clip_distances: bool,
294 mesh_shaders: bool,
295 primitive_index: bool,
296 cooperative_matrix: bool,
297 draw_index: bool,
298 ray_tracing_pipeline: bool,
299 per_vertex: bool,
300 binding_array: bool,
301 }
302 let mut needed = RequiredEnabled {
303 mesh_shaders: module.uses_mesh_shaders(),
304 ..Default::default()
305 };
306
307 let check_binding = |binding: &crate::Binding, needed: &mut RequiredEnabled| match *binding
308 {
309 crate::Binding::Location {
310 blend_src: Some(_), ..
311 } => {
312 needed.dual_source_blending = true;
313 }
314 crate::Binding::BuiltIn(crate::BuiltIn::ClipDistances) => {
315 needed.clip_distances = true;
316 }
317 crate::Binding::BuiltIn(crate::BuiltIn::PrimitiveIndex) => {
318 needed.primitive_index = true;
319 }
320 crate::Binding::Location {
321 per_primitive: true,
322 ..
323 } => {
324 needed.mesh_shaders = true;
325 }
326 crate::Binding::Location {
327 interpolation: Some(crate::Interpolation::PerVertex),
328 ..
329 } => {
330 needed.per_vertex = true;
331 }
332 crate::Binding::BuiltIn(crate::BuiltIn::DrawIndex) => needed.draw_index = true,
333 crate::Binding::BuiltIn(
334 crate::BuiltIn::RayInvocationId
335 | crate::BuiltIn::NumRayInvocations
336 | crate::BuiltIn::InstanceCustomData
337 | crate::BuiltIn::GeometryIndex
338 | crate::BuiltIn::WorldRayOrigin
339 | crate::BuiltIn::WorldRayDirection
340 | crate::BuiltIn::ObjectRayOrigin
341 | crate::BuiltIn::ObjectRayDirection
342 | crate::BuiltIn::RayTmin
343 | crate::BuiltIn::RayTCurrentMax
344 | crate::BuiltIn::ObjectToWorld
345 | crate::BuiltIn::WorldToObject,
346 ) => {
347 needed.ray_tracing_pipeline = true;
348 }
349 _ => {}
350 };
351
352 for (_, ty) in module.types.iter() {
354 match ty.inner {
355 TypeInner::Scalar(scalar)
356 | TypeInner::Vector { scalar, .. }
357 | TypeInner::Matrix { scalar, .. } => {
358 needed.f16 |= scalar == crate::Scalar::F16;
359 }
360 TypeInner::Struct { ref members, .. } => {
361 for binding in members.iter().filter_map(|m| m.binding.as_ref()) {
362 check_binding(binding, &mut needed);
363 }
364 }
365 TypeInner::CooperativeMatrix { .. } => {
366 needed.cooperative_matrix = true;
367 }
368 TypeInner::AccelerationStructure { .. } => {
369 needed.ray_tracing_pipeline = true;
370 }
371 TypeInner::BindingArray { .. } => {
372 needed.binding_array = true;
373 }
374 _ => {}
375 }
376 }
377
378 for ep in &module.entry_points {
379 if let Some(res) = ep.function.result.as_ref().and_then(|a| a.binding.as_ref()) {
380 check_binding(res, &mut needed);
381 }
382 for arg in ep
383 .function
384 .arguments
385 .iter()
386 .filter_map(|a| a.binding.as_ref())
387 {
388 check_binding(arg, &mut needed);
389 }
390 }
391
392 if module.global_variables.iter().any(|gv| {
393 gv.1.space == crate::AddressSpace::IncomingRayPayload
394 || gv.1.space == crate::AddressSpace::RayPayload
395 }) {
396 needed.ray_tracing_pipeline = true;
397 }
398
399 if module.entry_points.iter().any(|ep| {
400 matches!(
401 ep.stage,
402 ShaderStage::RayGeneration
403 | ShaderStage::AnyHit
404 | ShaderStage::ClosestHit
405 | ShaderStage::Miss
406 )
407 }) {
408 needed.ray_tracing_pipeline = true;
409 }
410
411 if module.global_variables.iter().any(|gv| {
412 gv.1.space == crate::AddressSpace::IncomingRayPayload
413 || gv.1.space == crate::AddressSpace::RayPayload
414 }) {
415 needed.ray_tracing_pipeline = true;
416 }
417
418 if module.entry_points.iter().any(|ep| {
419 matches!(
420 ep.stage,
421 ShaderStage::RayGeneration
422 | ShaderStage::AnyHit
423 | ShaderStage::ClosestHit
424 | ShaderStage::Miss
425 )
426 }) {
427 needed.ray_tracing_pipeline = true;
428 }
429
430 let mut any_written = false;
432 if needed.f16 {
433 writeln!(self.out, "enable f16;")?;
434 any_written = true;
435 }
436 if needed.dual_source_blending {
437 writeln!(self.out, "enable dual_source_blending;")?;
438 any_written = true;
439 }
440 if needed.clip_distances {
441 writeln!(self.out, "enable clip_distances;")?;
442 any_written = true;
443 }
444 if module.uses_mesh_shaders() {
445 writeln!(self.out, "enable wgpu_mesh_shader;")?;
446 any_written = true;
447 }
448 if needed.binding_array {
449 writeln!(self.out, "enable wgpu_binding_array;")?;
450 any_written = true;
451 }
452 if needed.draw_index {
453 writeln!(self.out, "enable draw_index;")?;
454 any_written = true;
455 }
456 if needed.primitive_index {
457 writeln!(self.out, "enable primitive_index;")?;
458 any_written = true;
459 }
460 if needed.cooperative_matrix {
461 writeln!(self.out, "enable wgpu_cooperative_matrix;")?;
462 any_written = true;
463 }
464 if needed.ray_tracing_pipeline {
465 writeln!(self.out, "enable wgpu_ray_tracing_pipeline;")?;
466 any_written = true;
467 }
468 if needed.per_vertex {
469 writeln!(self.out, "enable wgpu_per_vertex;")?;
470 any_written = true;
471 }
472 if any_written {
473 writeln!(self.out)?;
475 }
476
477 Ok(())
478 }
479
480 fn write_function(
486 &mut self,
487 module: &Module,
488 func: &crate::Function,
489 func_ctx: &back::FunctionCtx<'_>,
490 ) -> BackendResult {
491 let func_name = match func_ctx.ty {
492 back::FunctionType::EntryPoint(index) => &self.names[&NameKey::EntryPoint(index)],
493 back::FunctionType::Function(handle) => &self.names[&NameKey::Function(handle)],
494 };
495
496 write!(self.out, "fn {func_name}(")?;
498
499 for (index, arg) in func.arguments.iter().enumerate() {
501 if let Some(ref binding) = arg.binding {
503 self.write_attributes(&map_binding_to_attribute(binding))?;
504 }
505 let argument_name = &self.names[&func_ctx.argument_key(index as u32)];
507
508 write!(self.out, "{argument_name}: ")?;
509 self.write_type(module, arg.ty)?;
511 if index < func.arguments.len() - 1 {
512 write!(self.out, ", ")?;
514 }
515 }
516
517 write!(self.out, ")")?;
518
519 if let Some(ref result) = func.result {
521 write!(self.out, " -> ")?;
522 if let Some(ref binding) = result.binding {
523 self.write_attributes(&map_binding_to_attribute(binding))?;
524 }
525 self.write_type(module, result.ty)?;
526 }
527
528 write!(self.out, " {{")?;
529 writeln!(self.out)?;
530
531 for (handle, local) in func.local_variables.iter() {
533 write!(self.out, "{}", back::INDENT)?;
535
536 write!(self.out, "var {}: ", self.names[&func_ctx.name_key(handle)])?;
539
540 self.write_type(module, local.ty)?;
542
543 if let Some(init) = local.init {
545 write!(self.out, " = ")?;
548
549 self.write_expr(module, init, func_ctx)?;
552 }
553
554 writeln!(self.out, ";")?
556 }
557
558 if !func.local_variables.is_empty() {
559 writeln!(self.out)?;
560 }
561
562 for sta in func.body.iter() {
564 self.write_stmt(module, sta, func_ctx, back::Level(1))?;
566 }
567
568 writeln!(self.out, "}}")?;
569
570 self.named_expressions.clear();
571
572 Ok(())
573 }
574
575 fn write_attributes(&mut self, attributes: &[Attribute]) -> BackendResult {
577 for attribute in attributes {
578 match *attribute {
579 Attribute::Location(id) => write!(self.out, "@location({id}) ")?,
580 Attribute::BlendSrc(blend_src) => write!(self.out, "@blend_src({blend_src}) ")?,
581 Attribute::BuiltIn(builtin_attrib) => {
582 let builtin = builtin_attrib.to_wgsl_if_implemented()?;
583 write!(self.out, "@builtin({builtin}) ")?;
584 }
585 Attribute::Stage(shader_stage) => {
586 let stage_str = match shader_stage {
587 ShaderStage::Vertex => "vertex",
588 ShaderStage::Fragment => "fragment",
589 ShaderStage::Compute => "compute",
590 ShaderStage::Task => "task",
591 ShaderStage::Mesh => unreachable!(),
593 ShaderStage::RayGeneration => "ray_generation",
594 ShaderStage::AnyHit => "any_hit",
595 ShaderStage::ClosestHit => "closest_hit",
596 ShaderStage::Miss => "miss",
597 };
598
599 write!(self.out, "@{stage_str} ")?;
600 }
601 Attribute::WorkGroupSize(size) => {
602 write!(
603 self.out,
604 "@workgroup_size({}, {}, {}) ",
605 size[0], size[1], size[2]
606 )?;
607 }
608 Attribute::Binding(id) => write!(self.out, "@binding({id}) ")?,
609 Attribute::Group(id) => write!(self.out, "@group({id}) ")?,
610 Attribute::Invariant => write!(self.out, "@invariant ")?,
611 Attribute::Interpolate(interpolation, sampling) => {
612 if sampling.is_some() && sampling != Some(crate::Sampling::Center) {
613 let interpolation = interpolation
614 .unwrap_or(crate::Interpolation::Perspective)
615 .to_wgsl();
616 let sampling = sampling.unwrap_or(crate::Sampling::Center).to_wgsl();
617 write!(self.out, "@interpolate({interpolation}, {sampling}) ")?;
618 } else if interpolation.is_some()
619 && interpolation != Some(crate::Interpolation::Perspective)
620 {
621 let interpolation = interpolation
622 .unwrap_or(crate::Interpolation::Perspective)
623 .to_wgsl();
624 write!(self.out, "@interpolate({interpolation}) ")?;
625 }
626 }
627 Attribute::MeshStage(ref name) => {
628 write!(self.out, "@mesh({name}) ")?;
629 }
630 Attribute::TaskPayload(ref payload_name) => {
631 write!(self.out, "@payload({payload_name}) ")?;
632 }
633 Attribute::PerPrimitive => write!(self.out, "@per_primitive ")?,
634 Attribute::IncomingRayPayload(ref payload_name) => {
635 write!(self.out, "@incoming_payload({payload_name}) ")?;
636 }
637 };
638 }
639 Ok(())
640 }
641
642 fn write_struct(
653 &mut self,
654 module: &Module,
655 handle: Handle<crate::Type>,
656 members: &[crate::StructMember],
657 ) -> BackendResult {
658 write!(self.out, "struct {}", self.names[&NameKey::Type(handle)])?;
659 write!(self.out, " {{")?;
660 writeln!(self.out)?;
661 for (index, member) in members.iter().enumerate() {
662 write!(self.out, "{}", back::INDENT)?;
664 if let Some(ref binding) = member.binding {
665 self.write_attributes(&map_binding_to_attribute(binding))?;
666 }
667 let member_name = &self.names[&NameKey::StructMember(handle, index as u32)];
669 write!(self.out, "{member_name}: ")?;
670 self.write_type(module, member.ty)?;
671 write!(self.out, ",")?;
672 writeln!(self.out)?;
673 }
674
675 writeln!(self.out, "}}")?;
676
677 Ok(())
678 }
679
680 fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
681 let type_context = WriterTypeContext {
685 module,
686 names: &self.names,
687 };
688 type_context.write_type(ty, &mut self.out)?;
689
690 Ok(())
691 }
692
693 fn write_type_resolution(
694 &mut self,
695 module: &Module,
696 resolution: &proc::TypeResolution,
697 ) -> BackendResult {
698 let type_context = WriterTypeContext {
702 module,
703 names: &self.names,
704 };
705 type_context.write_type_resolution(resolution, &mut self.out)?;
706
707 Ok(())
708 }
709
710 fn write_stmt(
715 &mut self,
716 module: &Module,
717 stmt: &crate::Statement,
718 func_ctx: &back::FunctionCtx<'_>,
719 level: back::Level,
720 ) -> BackendResult {
721 use crate::{Expression, Statement};
722
723 match *stmt {
724 Statement::Emit(ref range) => {
725 for handle in range.clone() {
726 let info = &func_ctx.info[handle];
727 let expr_name = if let Some(name) = func_ctx.named_expressions.get(&handle) {
728 Some(self.namer.call(name))
733 } else {
734 let expr = &func_ctx.expressions[handle];
735 let min_ref_count = expr.bake_ref_count();
736 let required_baking_expr = match *expr {
738 Expression::ImageLoad { .. }
739 | Expression::ImageQuery { .. }
740 | Expression::ImageSample { .. } => true,
741 _ => false,
742 };
743 if min_ref_count <= info.ref_count || required_baking_expr {
744 Some(Baked(handle).to_string())
745 } else {
746 None
747 }
748 };
749
750 if let Some(name) = expr_name {
751 write!(self.out, "{level}")?;
752 self.start_named_expr(module, handle, func_ctx, &name)?;
753 self.write_expr(module, handle, func_ctx)?;
754 self.named_expressions.insert(handle, name);
755 writeln!(self.out, ";")?;
756 }
757 }
758 }
759 Statement::If {
761 condition,
762 ref accept,
763 ref reject,
764 } => {
765 write!(self.out, "{level}")?;
766 write!(self.out, "if ")?;
767 self.write_expr(module, condition, func_ctx)?;
768 writeln!(self.out, " {{")?;
769
770 let l2 = level.next();
771 for sta in accept {
772 self.write_stmt(module, sta, func_ctx, l2)?;
774 }
775
776 if !reject.is_empty() {
779 writeln!(self.out, "{level}}} else {{")?;
780
781 for sta in reject {
782 self.write_stmt(module, sta, func_ctx, l2)?;
784 }
785 }
786
787 writeln!(self.out, "{level}}}")?
788 }
789 Statement::Return { value } => {
790 write!(self.out, "{level}")?;
791 write!(self.out, "return")?;
792 if let Some(return_value) = value {
793 write!(self.out, " ")?;
795 self.write_expr(module, return_value, func_ctx)?;
796 }
797 writeln!(self.out, ";")?;
798 }
799 Statement::Kill => {
801 write!(self.out, "{level}")?;
802 writeln!(self.out, "discard;")?
803 }
804 Statement::Store { pointer, value } => {
805 write!(self.out, "{level}")?;
806
807 let is_atomic_pointer = func_ctx
808 .resolve_type(pointer, &module.types)
809 .is_atomic_pointer(&module.types);
810
811 if is_atomic_pointer {
812 write!(self.out, "atomicStore(")?;
813 self.write_expr(module, pointer, func_ctx)?;
814 write!(self.out, ", ")?;
815 self.write_expr(module, value, func_ctx)?;
816 write!(self.out, ")")?;
817 } else {
818 self.write_expr_with_indirection(
819 module,
820 pointer,
821 func_ctx,
822 Indirection::Reference,
823 )?;
824 write!(self.out, " = ")?;
825 self.write_expr(module, value, func_ctx)?;
826 }
827 writeln!(self.out, ";")?
828 }
829 Statement::Call {
830 function,
831 ref arguments,
832 result,
833 } => {
834 write!(self.out, "{level}")?;
835 if let Some(expr) = result {
836 let name = Baked(expr).to_string();
837 self.start_named_expr(module, expr, func_ctx, &name)?;
838 self.named_expressions.insert(expr, name);
839 }
840 let func_name = &self.names[&NameKey::Function(function)];
841 write!(self.out, "{func_name}(")?;
842 for (index, &argument) in arguments.iter().enumerate() {
843 if index != 0 {
844 write!(self.out, ", ")?;
845 }
846 self.write_expr(module, argument, func_ctx)?;
847 }
848 writeln!(self.out, ");")?
849 }
850 Statement::Atomic {
851 pointer,
852 ref fun,
853 value,
854 result,
855 } => {
856 write!(self.out, "{level}")?;
857 if let Some(result) = result {
858 let res_name = Baked(result).to_string();
859 self.start_named_expr(module, result, func_ctx, &res_name)?;
860 self.named_expressions.insert(result, res_name);
861 }
862
863 let fun_str = fun.to_wgsl();
864 write!(self.out, "atomic{fun_str}(")?;
865 self.write_expr(module, pointer, func_ctx)?;
866 if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
867 write!(self.out, ", ")?;
868 self.write_expr(module, cmp, func_ctx)?;
869 }
870 write!(self.out, ", ")?;
871 self.write_expr(module, value, func_ctx)?;
872 writeln!(self.out, ");")?
873 }
874 Statement::ImageAtomic {
875 image,
876 coordinate,
877 array_index,
878 ref fun,
879 value,
880 } => {
881 write!(self.out, "{level}")?;
882 let fun_str = fun.to_wgsl();
883 write!(self.out, "textureAtomic{fun_str}(")?;
884 self.write_expr(module, image, func_ctx)?;
885 write!(self.out, ", ")?;
886 self.write_expr(module, coordinate, func_ctx)?;
887 if let Some(array_index_expr) = array_index {
888 write!(self.out, ", ")?;
889 self.write_expr(module, array_index_expr, func_ctx)?;
890 }
891 write!(self.out, ", ")?;
892 self.write_expr(module, value, func_ctx)?;
893 writeln!(self.out, ");")?;
894 }
895 Statement::WorkGroupUniformLoad { pointer, result } => {
896 write!(self.out, "{level}")?;
897 let res_name = Baked(result).to_string();
899 self.start_named_expr(module, result, func_ctx, &res_name)?;
900 self.named_expressions.insert(result, res_name);
901 write!(self.out, "workgroupUniformLoad(")?;
902 self.write_expr(module, pointer, func_ctx)?;
903 writeln!(self.out, ");")?;
904 }
905 Statement::ImageStore {
906 image,
907 coordinate,
908 array_index,
909 value,
910 } => {
911 write!(self.out, "{level}")?;
912 write!(self.out, "textureStore(")?;
913 self.write_expr(module, image, func_ctx)?;
914 write!(self.out, ", ")?;
915 self.write_expr(module, coordinate, func_ctx)?;
916 if let Some(array_index_expr) = array_index {
917 write!(self.out, ", ")?;
918 self.write_expr(module, array_index_expr, func_ctx)?;
919 }
920 write!(self.out, ", ")?;
921 self.write_expr(module, value, func_ctx)?;
922 writeln!(self.out, ");")?;
923 }
924 Statement::Block(ref block) => {
926 write!(self.out, "{level}")?;
927 writeln!(self.out, "{{")?;
928 for sta in block.iter() {
929 self.write_stmt(module, sta, func_ctx, level.next())?
931 }
932 writeln!(self.out, "{level}}}")?
933 }
934 Statement::Switch {
935 selector,
936 ref cases,
937 } => {
938 write!(self.out, "{level}")?;
940 write!(self.out, "switch ")?;
941 self.write_expr(module, selector, func_ctx)?;
942 writeln!(self.out, " {{")?;
943
944 let l2 = level.next();
945 let mut new_case = true;
946 for case in cases {
947 if case.fall_through && !case.body.is_empty() {
948 return Err(Error::Unimplemented(
950 "fall-through switch case block".into(),
951 ));
952 }
953
954 match case.value {
955 crate::SwitchValue::I32(value) => {
956 if new_case {
957 write!(self.out, "{l2}case ")?;
958 }
959 write!(self.out, "{value}")?;
960 }
961 crate::SwitchValue::U32(value) => {
962 if new_case {
963 write!(self.out, "{l2}case ")?;
964 }
965 write!(self.out, "{value}u")?;
966 }
967 crate::SwitchValue::Default => {
968 if new_case {
969 if case.fall_through {
970 write!(self.out, "{l2}case ")?;
971 } else {
972 write!(self.out, "{l2}")?;
973 }
974 }
975 write!(self.out, "default")?;
976 }
977 }
978
979 new_case = !case.fall_through;
980
981 if case.fall_through {
982 write!(self.out, ", ")?;
983 } else {
984 writeln!(self.out, ": {{")?;
985 }
986
987 for sta in case.body.iter() {
988 self.write_stmt(module, sta, func_ctx, l2.next())?;
989 }
990
991 if !case.fall_through {
992 writeln!(self.out, "{l2}}}")?;
993 }
994 }
995
996 writeln!(self.out, "{level}}}")?
997 }
998 Statement::Loop {
999 ref body,
1000 ref continuing,
1001 break_if,
1002 } => {
1003 write!(self.out, "{level}")?;
1004 writeln!(self.out, "loop {{")?;
1005
1006 let l2 = level.next();
1007 for sta in body.iter() {
1008 self.write_stmt(module, sta, func_ctx, l2)?;
1009 }
1010
1011 if !continuing.is_empty() || break_if.is_some() {
1016 writeln!(self.out, "{l2}continuing {{")?;
1017 for sta in continuing.iter() {
1018 self.write_stmt(module, sta, func_ctx, l2.next())?;
1019 }
1020
1021 if let Some(condition) = break_if {
1024 write!(self.out, "{}break if ", l2.next())?;
1026 self.write_expr(module, condition, func_ctx)?;
1027 writeln!(self.out, ";")?;
1029 }
1030
1031 writeln!(self.out, "{l2}}}")?;
1032 }
1033
1034 writeln!(self.out, "{level}}}")?
1035 }
1036 Statement::Break => {
1037 writeln!(self.out, "{level}break;")?;
1038 }
1039 Statement::Continue => {
1040 writeln!(self.out, "{level}continue;")?;
1041 }
1042 Statement::ControlBarrier(barrier) | Statement::MemoryBarrier(barrier) => {
1043 if barrier.contains(crate::Barrier::STORAGE) {
1044 writeln!(self.out, "{level}storageBarrier();")?;
1045 }
1046
1047 if barrier.contains(crate::Barrier::WORK_GROUP) {
1048 writeln!(self.out, "{level}workgroupBarrier();")?;
1049 }
1050
1051 if barrier.contains(crate::Barrier::SUB_GROUP) {
1052 writeln!(self.out, "{level}subgroupBarrier();")?;
1053 }
1054
1055 if barrier.contains(crate::Barrier::TEXTURE) {
1056 writeln!(self.out, "{level}textureBarrier();")?;
1057 }
1058 }
1059 Statement::RayQuery { .. } => unreachable!(),
1060 Statement::SubgroupBallot { result, predicate } => {
1061 write!(self.out, "{level}")?;
1062 let res_name = Baked(result).to_string();
1063 self.start_named_expr(module, result, func_ctx, &res_name)?;
1064 self.named_expressions.insert(result, res_name);
1065
1066 write!(self.out, "subgroupBallot(")?;
1067 if let Some(predicate) = predicate {
1068 self.write_expr(module, predicate, func_ctx)?;
1069 }
1070 writeln!(self.out, ");")?;
1071 }
1072 Statement::SubgroupCollectiveOperation {
1073 op,
1074 collective_op,
1075 argument,
1076 result,
1077 } => {
1078 write!(self.out, "{level}")?;
1079 let res_name = Baked(result).to_string();
1080 self.start_named_expr(module, result, func_ctx, &res_name)?;
1081 self.named_expressions.insert(result, res_name);
1082
1083 match (collective_op, op) {
1084 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
1085 write!(self.out, "subgroupAll(")?
1086 }
1087 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
1088 write!(self.out, "subgroupAny(")?
1089 }
1090 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
1091 write!(self.out, "subgroupAdd(")?
1092 }
1093 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
1094 write!(self.out, "subgroupMul(")?
1095 }
1096 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
1097 write!(self.out, "subgroupMax(")?
1098 }
1099 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
1100 write!(self.out, "subgroupMin(")?
1101 }
1102 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
1103 write!(self.out, "subgroupAnd(")?
1104 }
1105 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
1106 write!(self.out, "subgroupOr(")?
1107 }
1108 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
1109 write!(self.out, "subgroupXor(")?
1110 }
1111 (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
1112 write!(self.out, "subgroupExclusiveAdd(")?
1113 }
1114 (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
1115 write!(self.out, "subgroupExclusiveMul(")?
1116 }
1117 (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
1118 write!(self.out, "subgroupInclusiveAdd(")?
1119 }
1120 (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
1121 write!(self.out, "subgroupInclusiveMul(")?
1122 }
1123 _ => unimplemented!(),
1124 }
1125 self.write_expr(module, argument, func_ctx)?;
1126 writeln!(self.out, ");")?;
1127 }
1128 Statement::SubgroupGather {
1129 mode,
1130 argument,
1131 result,
1132 } => {
1133 write!(self.out, "{level}")?;
1134 let res_name = Baked(result).to_string();
1135 self.start_named_expr(module, result, func_ctx, &res_name)?;
1136 self.named_expressions.insert(result, res_name);
1137
1138 match mode {
1139 crate::GatherMode::BroadcastFirst => {
1140 write!(self.out, "subgroupBroadcastFirst(")?;
1141 }
1142 crate::GatherMode::Broadcast(_) => {
1143 write!(self.out, "subgroupBroadcast(")?;
1144 }
1145 crate::GatherMode::Shuffle(_) => {
1146 write!(self.out, "subgroupShuffle(")?;
1147 }
1148 crate::GatherMode::ShuffleDown(_) => {
1149 write!(self.out, "subgroupShuffleDown(")?;
1150 }
1151 crate::GatherMode::ShuffleUp(_) => {
1152 write!(self.out, "subgroupShuffleUp(")?;
1153 }
1154 crate::GatherMode::ShuffleXor(_) => {
1155 write!(self.out, "subgroupShuffleXor(")?;
1156 }
1157 crate::GatherMode::QuadBroadcast(_) => {
1158 write!(self.out, "quadBroadcast(")?;
1159 }
1160 crate::GatherMode::QuadSwap(direction) => match direction {
1161 crate::Direction::X => {
1162 write!(self.out, "quadSwapX(")?;
1163 }
1164 crate::Direction::Y => {
1165 write!(self.out, "quadSwapY(")?;
1166 }
1167 crate::Direction::Diagonal => {
1168 write!(self.out, "quadSwapDiagonal(")?;
1169 }
1170 },
1171 }
1172 self.write_expr(module, argument, func_ctx)?;
1173 match mode {
1174 crate::GatherMode::BroadcastFirst => {}
1175 crate::GatherMode::Broadcast(index)
1176 | crate::GatherMode::Shuffle(index)
1177 | crate::GatherMode::ShuffleDown(index)
1178 | crate::GatherMode::ShuffleUp(index)
1179 | crate::GatherMode::ShuffleXor(index)
1180 | crate::GatherMode::QuadBroadcast(index) => {
1181 write!(self.out, ", ")?;
1182 self.write_expr(module, index, func_ctx)?;
1183 }
1184 crate::GatherMode::QuadSwap(_) => {}
1185 }
1186 writeln!(self.out, ");")?;
1187 }
1188 Statement::CooperativeStore { target, ref data } => {
1189 let suffix = if data.row_major { "T" } else { "" };
1190 write!(self.out, "{level}coopStore{suffix}(")?;
1191 self.write_expr(module, target, func_ctx)?;
1192 write!(self.out, ", ")?;
1193 self.write_expr(module, data.pointer, func_ctx)?;
1194 write!(self.out, ", ")?;
1195 self.write_expr(module, data.stride, func_ctx)?;
1196 writeln!(self.out, ");")?
1197 }
1198 Statement::RayPipelineFunction(fun) => match fun {
1199 crate::RayPipelineFunction::TraceRay {
1200 acceleration_structure,
1201 descriptor,
1202 payload,
1203 } => {
1204 write!(self.out, "{level}traceRay(")?;
1205 self.write_expr(module, acceleration_structure, func_ctx)?;
1206 write!(self.out, ", ")?;
1207 self.write_expr(module, descriptor, func_ctx)?;
1208 write!(self.out, ", ")?;
1209 self.write_expr(module, payload, func_ctx)?;
1210 writeln!(self.out, ");")?
1211 }
1212 },
1213 }
1214
1215 Ok(())
1216 }
1217
1218 fn plain_form_indirection(
1241 &self,
1242 expr: Handle<crate::Expression>,
1243 module: &Module,
1244 func_ctx: &back::FunctionCtx<'_>,
1245 ) -> Indirection {
1246 use crate::Expression as Ex;
1247
1248 if self.named_expressions.contains_key(&expr) {
1252 return Indirection::Ordinary;
1253 }
1254
1255 match func_ctx.expressions[expr] {
1256 Ex::LocalVariable(_) => Indirection::Reference,
1257 Ex::GlobalVariable(handle) => {
1258 let global = &module.global_variables[handle];
1259 match global.space {
1260 crate::AddressSpace::Handle => Indirection::Ordinary,
1261 _ => Indirection::Reference,
1262 }
1263 }
1264 Ex::Access { base, .. } | Ex::AccessIndex { base, .. } => {
1265 let base_ty = func_ctx.resolve_type(base, &module.types);
1266 match *base_ty {
1267 TypeInner::Pointer { .. } | TypeInner::ValuePointer { .. } => {
1268 Indirection::Reference
1269 }
1270 _ => Indirection::Ordinary,
1271 }
1272 }
1273 _ => Indirection::Ordinary,
1274 }
1275 }
1276
1277 fn start_named_expr(
1278 &mut self,
1279 module: &Module,
1280 handle: Handle<crate::Expression>,
1281 func_ctx: &back::FunctionCtx,
1282 name: &str,
1283 ) -> BackendResult {
1284 write!(self.out, "let {name}")?;
1286 if self.flags.contains(WriterFlags::EXPLICIT_TYPES) {
1287 write!(self.out, ": ")?;
1288 self.write_type_resolution(module, &func_ctx.info[handle].ty)?;
1290 }
1291
1292 write!(self.out, " = ")?;
1293 Ok(())
1294 }
1295
1296 fn write_expr(
1300 &mut self,
1301 module: &Module,
1302 expr: Handle<crate::Expression>,
1303 func_ctx: &back::FunctionCtx<'_>,
1304 ) -> BackendResult {
1305 self.write_expr_with_indirection(module, expr, func_ctx, Indirection::Ordinary)
1306 }
1307
1308 fn write_expr_with_indirection(
1321 &mut self,
1322 module: &Module,
1323 expr: Handle<crate::Expression>,
1324 func_ctx: &back::FunctionCtx<'_>,
1325 requested: Indirection,
1326 ) -> BackendResult {
1327 let plain = self.plain_form_indirection(expr, module, func_ctx);
1330 log::trace!(
1331 "expression {:?}={:?} is {:?}, expected {:?}",
1332 expr,
1333 func_ctx.expressions[expr],
1334 plain,
1335 requested,
1336 );
1337 match (requested, plain) {
1338 (Indirection::Ordinary, Indirection::Reference) => {
1339 write!(self.out, "(&")?;
1340 self.write_expr_plain_form(module, expr, func_ctx, plain)?;
1341 write!(self.out, ")")?;
1342 }
1343 (Indirection::Reference, Indirection::Ordinary) => {
1344 write!(self.out, "(*")?;
1345 self.write_expr_plain_form(module, expr, func_ctx, plain)?;
1346 write!(self.out, ")")?;
1347 }
1348 (_, _) => self.write_expr_plain_form(module, expr, func_ctx, plain)?,
1349 }
1350
1351 Ok(())
1352 }
1353
1354 fn write_const_expression(
1355 &mut self,
1356 module: &Module,
1357 expr: Handle<crate::Expression>,
1358 arena: &crate::Arena<crate::Expression>,
1359 ) -> BackendResult {
1360 self.write_possibly_const_expression(module, expr, arena, |writer, expr| {
1361 writer.write_const_expression(module, expr, arena)
1362 })
1363 }
1364
1365 fn write_possibly_const_expression<E>(
1366 &mut self,
1367 module: &Module,
1368 expr: Handle<crate::Expression>,
1369 expressions: &crate::Arena<crate::Expression>,
1370 write_expression: E,
1371 ) -> BackendResult
1372 where
1373 E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult,
1374 {
1375 use crate::Expression;
1376
1377 match expressions[expr] {
1378 Expression::Literal(literal) => match literal {
1379 crate::Literal::F16(value) => write!(self.out, "{value}h")?,
1380 crate::Literal::F32(value) => write!(self.out, "{value}f")?,
1381 crate::Literal::U32(value) => write!(self.out, "{value}u")?,
1382 crate::Literal::I32(value) => {
1383 if value == i32::MIN {
1387 write!(self.out, "i32({value})")?;
1388 } else {
1389 write!(self.out, "{value}i")?;
1390 }
1391 }
1392 crate::Literal::Bool(value) => write!(self.out, "{value}")?,
1393 crate::Literal::F64(value) => write!(self.out, "{value:?}lf")?,
1394 crate::Literal::I64(value) => {
1395 if value == i64::MIN {
1400 write!(self.out, "i64({} - 1)", value + 1)?;
1401 } else {
1402 write!(self.out, "{value}li")?;
1403 }
1404 }
1405 crate::Literal::U64(value) => write!(self.out, "{value:?}lu")?,
1406 crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
1407 return Err(Error::Custom(
1408 "Abstract types should not appear in IR presented to backends".into(),
1409 ));
1410 }
1411 },
1412 Expression::Constant(handle) => {
1413 let constant = &module.constants[handle];
1414 if constant.name.is_some() {
1415 write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
1416 } else {
1417 self.write_const_expression(module, constant.init, &module.global_expressions)?;
1418 }
1419 }
1420 Expression::ZeroValue(ty) => {
1421 self.write_type(module, ty)?;
1422 write!(self.out, "()")?;
1423 }
1424 Expression::Compose { ty, ref components } => {
1425 self.write_type(module, ty)?;
1426 write!(self.out, "(")?;
1427 for (index, component) in components.iter().enumerate() {
1428 if index != 0 {
1429 write!(self.out, ", ")?;
1430 }
1431 write_expression(self, *component)?;
1432 }
1433 write!(self.out, ")")?
1434 }
1435 Expression::Splat { size, value } => {
1436 let size = common::vector_size_str(size);
1437 write!(self.out, "vec{size}(")?;
1438 write_expression(self, value)?;
1439 write!(self.out, ")")?;
1440 }
1441 Expression::Override(handle) => {
1442 write!(self.out, "{}", self.names[&NameKey::Override(handle)])?;
1443 }
1444 _ => unreachable!(),
1445 }
1446
1447 Ok(())
1448 }
1449
1450 fn write_expr_plain_form(
1458 &mut self,
1459 module: &Module,
1460 expr: Handle<crate::Expression>,
1461 func_ctx: &back::FunctionCtx<'_>,
1462 indirection: Indirection,
1463 ) -> BackendResult {
1464 use crate::Expression;
1465
1466 if let Some(name) = self.named_expressions.get(&expr) {
1467 write!(self.out, "{name}")?;
1468 return Ok(());
1469 }
1470
1471 let expression = &func_ctx.expressions[expr];
1472
1473 match *expression {
1482 Expression::Literal(_)
1483 | Expression::Constant(_)
1484 | Expression::ZeroValue(_)
1485 | Expression::Compose { .. }
1486 | Expression::Splat { .. } => {
1487 self.write_possibly_const_expression(
1488 module,
1489 expr,
1490 func_ctx.expressions,
1491 |writer, expr| writer.write_expr(module, expr, func_ctx),
1492 )?;
1493 }
1494 Expression::Override(handle) => {
1495 write!(self.out, "{}", self.names[&NameKey::Override(handle)])?;
1496 }
1497 Expression::FunctionArgument(pos) => {
1498 let name_key = func_ctx.argument_key(pos);
1499 let name = &self.names[&name_key];
1500 write!(self.out, "{name}")?;
1501 }
1502 Expression::Binary { op, left, right } => {
1503 write!(self.out, "(")?;
1504 self.write_expr(module, left, func_ctx)?;
1505 write!(self.out, " {} ", back::binary_operation_str(op))?;
1506 self.write_expr(module, right, func_ctx)?;
1507 write!(self.out, ")")?;
1508 }
1509 Expression::Access { base, index } => {
1510 self.write_expr_with_indirection(module, base, func_ctx, indirection)?;
1511 write!(self.out, "[")?;
1512 self.write_expr(module, index, func_ctx)?;
1513 write!(self.out, "]")?
1514 }
1515 Expression::AccessIndex { base, index } => {
1516 let base_ty_res = &func_ctx.info[base].ty;
1517 let mut resolved = base_ty_res.inner_with(&module.types);
1518
1519 self.write_expr_with_indirection(module, base, func_ctx, indirection)?;
1520
1521 let base_ty_handle = match *resolved {
1522 TypeInner::Pointer { base, space: _ } => {
1523 resolved = &module.types[base].inner;
1524 Some(base)
1525 }
1526 _ => base_ty_res.handle(),
1527 };
1528
1529 match *resolved {
1530 TypeInner::Vector { .. } => {
1531 write!(self.out, ".{}", back::COMPONENTS[index as usize])?
1533 }
1534 TypeInner::Matrix { .. }
1535 | TypeInner::Array { .. }
1536 | TypeInner::BindingArray { .. }
1537 | TypeInner::ValuePointer { .. } => write!(self.out, "[{index}]")?,
1538 TypeInner::Struct { .. } => {
1539 let ty = base_ty_handle.unwrap();
1542
1543 write!(
1544 self.out,
1545 ".{}",
1546 &self.names[&NameKey::StructMember(ty, index)]
1547 )?
1548 }
1549 ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))),
1550 }
1551 }
1552 Expression::ImageSample {
1553 image,
1554 sampler,
1555 gather: None,
1556 coordinate,
1557 array_index,
1558 offset,
1559 level,
1560 depth_ref,
1561 clamp_to_edge,
1562 } => {
1563 use crate::SampleLevel as Sl;
1564
1565 let suffix_cmp = match depth_ref {
1566 Some(_) => "Compare",
1567 None => "",
1568 };
1569 let suffix_level = match level {
1570 Sl::Auto => "",
1571 Sl::Zero if clamp_to_edge => "BaseClampToEdge",
1572 Sl::Zero | Sl::Exact(_) => "Level",
1573 Sl::Bias(_) => "Bias",
1574 Sl::Gradient { .. } => "Grad",
1575 };
1576
1577 write!(self.out, "textureSample{suffix_cmp}{suffix_level}(")?;
1578 self.write_expr(module, image, func_ctx)?;
1579 write!(self.out, ", ")?;
1580 self.write_expr(module, sampler, func_ctx)?;
1581 write!(self.out, ", ")?;
1582 self.write_expr(module, coordinate, func_ctx)?;
1583
1584 if let Some(array_index) = array_index {
1585 write!(self.out, ", ")?;
1586 self.write_expr(module, array_index, func_ctx)?;
1587 }
1588
1589 if let Some(depth_ref) = depth_ref {
1590 write!(self.out, ", ")?;
1591 self.write_expr(module, depth_ref, func_ctx)?;
1592 }
1593
1594 match level {
1595 Sl::Auto => {}
1596 Sl::Zero => {
1597 if depth_ref.is_none() && !clamp_to_edge {
1599 write!(self.out, ", 0.0")?;
1600 }
1601 }
1602 Sl::Exact(expr) => {
1603 write!(self.out, ", ")?;
1604 self.write_expr(module, expr, func_ctx)?;
1605 }
1606 Sl::Bias(expr) => {
1607 write!(self.out, ", ")?;
1608 self.write_expr(module, expr, func_ctx)?;
1609 }
1610 Sl::Gradient { x, y } => {
1611 write!(self.out, ", ")?;
1612 self.write_expr(module, x, func_ctx)?;
1613 write!(self.out, ", ")?;
1614 self.write_expr(module, y, func_ctx)?;
1615 }
1616 }
1617
1618 if let Some(offset) = offset {
1619 write!(self.out, ", ")?;
1620 self.write_const_expression(module, offset, func_ctx.expressions)?;
1621 }
1622
1623 write!(self.out, ")")?;
1624 }
1625
1626 Expression::ImageSample {
1627 image,
1628 sampler,
1629 gather: Some(component),
1630 coordinate,
1631 array_index,
1632 offset,
1633 level: _,
1634 depth_ref,
1635 clamp_to_edge: _,
1636 } => {
1637 let suffix_cmp = match depth_ref {
1638 Some(_) => "Compare",
1639 None => "",
1640 };
1641
1642 write!(self.out, "textureGather{suffix_cmp}(")?;
1643 match *func_ctx.resolve_type(image, &module.types) {
1644 TypeInner::Image {
1645 class: crate::ImageClass::Depth { multi: _ },
1646 ..
1647 } => {}
1648 _ => {
1649 write!(self.out, "{}, ", component as u8)?;
1650 }
1651 }
1652 self.write_expr(module, image, func_ctx)?;
1653 write!(self.out, ", ")?;
1654 self.write_expr(module, sampler, func_ctx)?;
1655 write!(self.out, ", ")?;
1656 self.write_expr(module, coordinate, func_ctx)?;
1657
1658 if let Some(array_index) = array_index {
1659 write!(self.out, ", ")?;
1660 self.write_expr(module, array_index, func_ctx)?;
1661 }
1662
1663 if let Some(depth_ref) = depth_ref {
1664 write!(self.out, ", ")?;
1665 self.write_expr(module, depth_ref, func_ctx)?;
1666 }
1667
1668 if let Some(offset) = offset {
1669 write!(self.out, ", ")?;
1670 self.write_const_expression(module, offset, func_ctx.expressions)?;
1671 }
1672
1673 write!(self.out, ")")?;
1674 }
1675 Expression::ImageQuery { image, query } => {
1676 use crate::ImageQuery as Iq;
1677
1678 let texture_function = match query {
1679 Iq::Size { .. } => "textureDimensions",
1680 Iq::NumLevels => "textureNumLevels",
1681 Iq::NumLayers => "textureNumLayers",
1682 Iq::NumSamples => "textureNumSamples",
1683 };
1684
1685 write!(self.out, "{texture_function}(")?;
1686 self.write_expr(module, image, func_ctx)?;
1687 if let Iq::Size { level: Some(level) } = query {
1688 write!(self.out, ", ")?;
1689 self.write_expr(module, level, func_ctx)?;
1690 };
1691 write!(self.out, ")")?;
1692 }
1693
1694 Expression::ImageLoad {
1695 image,
1696 coordinate,
1697 array_index,
1698 sample,
1699 level,
1700 } => {
1701 write!(self.out, "textureLoad(")?;
1702 self.write_expr(module, image, func_ctx)?;
1703 write!(self.out, ", ")?;
1704 self.write_expr(module, coordinate, func_ctx)?;
1705 if let Some(array_index) = array_index {
1706 write!(self.out, ", ")?;
1707 self.write_expr(module, array_index, func_ctx)?;
1708 }
1709 if let Some(index) = sample.or(level) {
1710 write!(self.out, ", ")?;
1711 self.write_expr(module, index, func_ctx)?;
1712 }
1713 write!(self.out, ")")?;
1714 }
1715 Expression::GlobalVariable(handle) => {
1716 let name = &self.names[&NameKey::GlobalVariable(handle)];
1717 write!(self.out, "{name}")?;
1718 }
1719
1720 Expression::As {
1721 expr,
1722 kind,
1723 convert,
1724 } => {
1725 let inner = func_ctx.resolve_type(expr, &module.types);
1726 match *inner {
1727 TypeInner::Matrix {
1728 columns,
1729 rows,
1730 scalar,
1731 } => {
1732 let scalar = crate::Scalar {
1733 kind,
1734 width: convert.unwrap_or(scalar.width),
1735 };
1736 let scalar_kind_str = scalar.to_wgsl_if_implemented()?;
1737 write!(
1738 self.out,
1739 "mat{}x{}<{}>",
1740 common::vector_size_str(columns),
1741 common::vector_size_str(rows),
1742 scalar_kind_str
1743 )?;
1744 }
1745 TypeInner::Vector {
1746 size,
1747 scalar: crate::Scalar { width, .. },
1748 } => {
1749 let scalar = crate::Scalar {
1750 kind,
1751 width: convert.unwrap_or(width),
1752 };
1753 let vector_size_str = common::vector_size_str(size);
1754 let scalar_kind_str = scalar.to_wgsl_if_implemented()?;
1755 if convert.is_some() {
1756 write!(self.out, "vec{vector_size_str}<{scalar_kind_str}>")?;
1757 } else {
1758 write!(self.out, "bitcast<vec{vector_size_str}<{scalar_kind_str}>>")?;
1759 }
1760 }
1761 TypeInner::Scalar(crate::Scalar { width, .. }) => {
1762 let scalar = crate::Scalar {
1763 kind,
1764 width: convert.unwrap_or(width),
1765 };
1766 let scalar_kind_str = scalar.to_wgsl_if_implemented()?;
1767 if convert.is_some() {
1768 write!(self.out, "{scalar_kind_str}")?
1769 } else {
1770 write!(self.out, "bitcast<{scalar_kind_str}>")?
1771 }
1772 }
1773 _ => {
1774 return Err(Error::Unimplemented(format!(
1775 "write_expr expression::as {inner:?}"
1776 )));
1777 }
1778 };
1779 write!(self.out, "(")?;
1780 self.write_expr(module, expr, func_ctx)?;
1781 write!(self.out, ")")?;
1782 }
1783 Expression::Load { pointer } => {
1784 let is_atomic_pointer = func_ctx
1785 .resolve_type(pointer, &module.types)
1786 .is_atomic_pointer(&module.types);
1787
1788 if is_atomic_pointer {
1789 write!(self.out, "atomicLoad(")?;
1790 self.write_expr(module, pointer, func_ctx)?;
1791 write!(self.out, ")")?;
1792 } else {
1793 self.write_expr_with_indirection(
1794 module,
1795 pointer,
1796 func_ctx,
1797 Indirection::Reference,
1798 )?;
1799 }
1800 }
1801 Expression::LocalVariable(handle) => {
1802 write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])?
1803 }
1804 Expression::ArrayLength(expr) => {
1805 write!(self.out, "arrayLength(")?;
1806 self.write_expr(module, expr, func_ctx)?;
1807 write!(self.out, ")")?;
1808 }
1809
1810 Expression::Math {
1811 fun,
1812 arg,
1813 arg1,
1814 arg2,
1815 arg3,
1816 } => {
1817 use crate::MathFunction as Mf;
1818
1819 enum Function {
1820 Regular(&'static str),
1821 InversePolyfill(InversePolyfill),
1822 }
1823
1824 let function = match fun.try_to_wgsl() {
1825 Some(name) => Function::Regular(name),
1826 None => match fun {
1827 Mf::Inverse => {
1828 let ty = func_ctx.resolve_type(arg, &module.types);
1829 let Some(overload) = InversePolyfill::find_overload(ty) else {
1830 return Err(Error::unsupported("math function", fun));
1831 };
1832
1833 Function::InversePolyfill(overload)
1834 }
1835 _ => return Err(Error::unsupported("math function", fun)),
1836 },
1837 };
1838
1839 match function {
1840 Function::Regular(fun_name) => {
1841 write!(self.out, "{fun_name}(")?;
1842 self.write_expr(module, arg, func_ctx)?;
1843 for arg in IntoIterator::into_iter([arg1, arg2, arg3]).flatten() {
1844 write!(self.out, ", ")?;
1845 self.write_expr(module, arg, func_ctx)?;
1846 }
1847 write!(self.out, ")")?
1848 }
1849 Function::InversePolyfill(inverse) => {
1850 write!(self.out, "{}(", inverse.fun_name)?;
1851 self.write_expr(module, arg, func_ctx)?;
1852 write!(self.out, ")")?;
1853 self.required_polyfills.insert(inverse);
1854 }
1855 }
1856 }
1857
1858 Expression::Swizzle {
1859 size,
1860 vector,
1861 pattern,
1862 } => {
1863 self.write_expr(module, vector, func_ctx)?;
1864 write!(self.out, ".")?;
1865 for &sc in pattern[..size as usize].iter() {
1866 self.out.write_char(back::COMPONENTS[sc as usize])?;
1867 }
1868 }
1869 Expression::Unary { op, expr } => {
1870 let unary = match op {
1871 crate::UnaryOperator::Negate => "-",
1872 crate::UnaryOperator::LogicalNot => "!",
1873 crate::UnaryOperator::BitwiseNot => "~",
1874 };
1875
1876 write!(self.out, "{unary}(")?;
1877 self.write_expr(module, expr, func_ctx)?;
1878
1879 write!(self.out, ")")?
1880 }
1881
1882 Expression::Select {
1883 condition,
1884 accept,
1885 reject,
1886 } => {
1887 write!(self.out, "select(")?;
1888 self.write_expr(module, reject, func_ctx)?;
1889 write!(self.out, ", ")?;
1890 self.write_expr(module, accept, func_ctx)?;
1891 write!(self.out, ", ")?;
1892 self.write_expr(module, condition, func_ctx)?;
1893 write!(self.out, ")")?
1894 }
1895 Expression::Derivative { axis, ctrl, expr } => {
1896 use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
1897 let op = match (axis, ctrl) {
1898 (Axis::X, Ctrl::Coarse) => "dpdxCoarse",
1899 (Axis::X, Ctrl::Fine) => "dpdxFine",
1900 (Axis::X, Ctrl::None) => "dpdx",
1901 (Axis::Y, Ctrl::Coarse) => "dpdyCoarse",
1902 (Axis::Y, Ctrl::Fine) => "dpdyFine",
1903 (Axis::Y, Ctrl::None) => "dpdy",
1904 (Axis::Width, Ctrl::Coarse) => "fwidthCoarse",
1905 (Axis::Width, Ctrl::Fine) => "fwidthFine",
1906 (Axis::Width, Ctrl::None) => "fwidth",
1907 };
1908 write!(self.out, "{op}(")?;
1909 self.write_expr(module, expr, func_ctx)?;
1910 write!(self.out, ")")?
1911 }
1912 Expression::Relational { fun, argument } => {
1913 use crate::RelationalFunction as Rf;
1914
1915 let fun_name = match fun {
1916 Rf::All => "all",
1917 Rf::Any => "any",
1918 _ => return Err(Error::UnsupportedRelationalFunction(fun)),
1919 };
1920 write!(self.out, "{fun_name}(")?;
1921
1922 self.write_expr(module, argument, func_ctx)?;
1923
1924 write!(self.out, ")")?
1925 }
1926 Expression::RayQueryGetIntersection { .. }
1928 | Expression::RayQueryVertexPositions { .. } => unreachable!(),
1929 Expression::CallResult(_)
1931 | Expression::AtomicResult { .. }
1932 | Expression::RayQueryProceedResult
1933 | Expression::SubgroupBallotResult
1934 | Expression::SubgroupOperationResult { .. }
1935 | Expression::WorkGroupUniformLoadResult { .. } => {}
1936 Expression::CooperativeLoad {
1937 columns,
1938 rows,
1939 role,
1940 ref data,
1941 } => {
1942 let suffix = if data.row_major { "T" } else { "" };
1943 let scalar = func_ctx.info[data.pointer]
1944 .ty
1945 .inner_with(&module.types)
1946 .pointer_base_type()
1947 .unwrap()
1948 .inner_with(&module.types)
1949 .scalar()
1950 .unwrap();
1951 write!(
1952 self.out,
1953 "coopLoad{suffix}<coop_mat{}x{}<{},{:?}>>(",
1954 columns as u32,
1955 rows as u32,
1956 scalar.try_to_wgsl().unwrap(),
1957 role,
1958 )?;
1959 self.write_expr(module, data.pointer, func_ctx)?;
1960 write!(self.out, ", ")?;
1961 self.write_expr(module, data.stride, func_ctx)?;
1962 write!(self.out, ")")?;
1963 }
1964 Expression::CooperativeMultiplyAdd { a, b, c } => {
1965 write!(self.out, "coopMultiplyAdd(")?;
1966 self.write_expr(module, a, func_ctx)?;
1967 write!(self.out, ", ")?;
1968 self.write_expr(module, b, func_ctx)?;
1969 write!(self.out, ", ")?;
1970 self.write_expr(module, c, func_ctx)?;
1971 write!(self.out, ")")?;
1972 }
1973 }
1974
1975 Ok(())
1976 }
1977
1978 fn write_global(
1982 &mut self,
1983 module: &Module,
1984 global: &crate::GlobalVariable,
1985 handle: Handle<crate::GlobalVariable>,
1986 ) -> BackendResult {
1987 if let Some(ref binding) = global.binding {
1989 self.write_attributes(&[
1990 Attribute::Group(binding.group),
1991 Attribute::Binding(binding.binding),
1992 ])?;
1993 writeln!(self.out)?;
1994 }
1995
1996 if global
1997 .memory_decorations
1998 .contains(crate::MemoryDecorations::COHERENT)
1999 {
2000 write!(self.out, "@coherent ")?;
2001 }
2002 if global
2003 .memory_decorations
2004 .contains(crate::MemoryDecorations::VOLATILE)
2005 {
2006 write!(self.out, "@volatile ")?;
2007 }
2008
2009 write!(self.out, "var")?;
2011 let (address, maybe_access) = address_space_str(global.space);
2012 if let Some(space) = address {
2013 write!(self.out, "<{space}")?;
2014 if let Some(access) = maybe_access {
2015 write!(self.out, ", {access}")?;
2016 }
2017 write!(self.out, ">")?;
2018 }
2019 write!(
2020 self.out,
2021 " {}: ",
2022 &self.names[&NameKey::GlobalVariable(handle)]
2023 )?;
2024
2025 self.write_type(module, global.ty)?;
2027
2028 if let Some(init) = global.init {
2030 write!(self.out, " = ")?;
2031 self.write_const_expression(module, init, &module.global_expressions)?;
2032 }
2033
2034 writeln!(self.out, ";")?;
2036
2037 Ok(())
2038 }
2039
2040 fn write_global_constant(
2045 &mut self,
2046 module: &Module,
2047 handle: Handle<crate::Constant>,
2048 ) -> BackendResult {
2049 let name = &self.names[&NameKey::Constant(handle)];
2050 write!(self.out, "const {name}: ")?;
2052 self.write_type(module, module.constants[handle].ty)?;
2053 write!(self.out, " = ")?;
2054 let init = module.constants[handle].init;
2055 self.write_const_expression(module, init, &module.global_expressions)?;
2056 writeln!(self.out, ";")?;
2057
2058 Ok(())
2059 }
2060
2061 fn write_override(
2066 &mut self,
2067 module: &Module,
2068 handle: Handle<crate::Override>,
2069 ) -> BackendResult {
2070 let override_ = &module.overrides[handle];
2071 let name = &self.names[&NameKey::Override(handle)];
2072
2073 if let Some(id) = override_.id {
2075 write!(self.out, "@id({id}) ")?;
2076 }
2077
2078 write!(self.out, "override {name}: ")?;
2080 self.write_type(module, override_.ty)?;
2081
2082 if let Some(init) = override_.init {
2084 write!(self.out, " = ")?;
2085 self.write_const_expression(module, init, &module.global_expressions)?;
2086 }
2087
2088 writeln!(self.out, ";")?;
2089
2090 Ok(())
2091 }
2092
2093 pub fn finish(self) -> W {
2095 self.out
2096 }
2097}
2098
2099struct WriterTypeContext<'m> {
2100 module: &'m Module,
2101 names: &'m crate::FastHashMap<NameKey, String>,
2102}
2103
2104impl TypeContext for WriterTypeContext<'_> {
2105 fn lookup_type(&self, handle: Handle<crate::Type>) -> &crate::Type {
2106 &self.module.types[handle]
2107 }
2108
2109 fn type_name(&self, handle: Handle<crate::Type>) -> &str {
2110 self.names[&NameKey::Type(handle)].as_str()
2111 }
2112
2113 fn write_unnamed_struct<W: Write>(&self, _: &TypeInner, _: &mut W) -> core::fmt::Result {
2114 unreachable!("the WGSL back end should always provide type handles");
2115 }
2116
2117 fn write_override<W: Write>(
2118 &self,
2119 handle: Handle<crate::Override>,
2120 out: &mut W,
2121 ) -> core::fmt::Result {
2122 write!(out, "{}", self.names[&NameKey::Override(handle)])
2123 }
2124
2125 fn write_non_wgsl_inner<W: Write>(&self, _: &TypeInner, _: &mut W) -> core::fmt::Result {
2126 unreachable!("backends should only be passed validated modules");
2127 }
2128
2129 fn write_non_wgsl_scalar<W: Write>(&self, _: crate::Scalar, _: &mut W) -> core::fmt::Result {
2130 unreachable!("backends should only be passed validated modules");
2131 }
2132}
2133
2134fn map_binding_to_attribute(binding: &crate::Binding) -> Vec<Attribute> {
2135 match *binding {
2136 crate::Binding::BuiltIn(built_in) => {
2137 if let crate::BuiltIn::Position { invariant: true } = built_in {
2138 vec![Attribute::BuiltIn(built_in), Attribute::Invariant]
2139 } else {
2140 vec![Attribute::BuiltIn(built_in)]
2141 }
2142 }
2143 crate::Binding::Location {
2144 location,
2145 interpolation,
2146 sampling,
2147 blend_src: None,
2148 per_primitive,
2149 } => {
2150 let mut attrs = vec![
2151 Attribute::Location(location),
2152 Attribute::Interpolate(interpolation, sampling),
2153 ];
2154 if per_primitive {
2155 attrs.push(Attribute::PerPrimitive);
2156 }
2157 attrs
2158 }
2159 crate::Binding::Location {
2160 location,
2161 interpolation,
2162 sampling,
2163 blend_src: Some(blend_src),
2164 per_primitive,
2165 } => {
2166 let mut attrs = vec![
2167 Attribute::Location(location),
2168 Attribute::BlendSrc(blend_src),
2169 Attribute::Interpolate(interpolation, sampling),
2170 ];
2171 if per_primitive {
2172 attrs.push(Attribute::PerPrimitive);
2173 }
2174 attrs
2175 }
2176 }
2177}