1use alloc::{
2 format,
3 string::{String, ToString},
4 vec,
5 vec::Vec,
6};
7use core::fmt::Write;
8
9use super::Error;
10use super::ToWgslIfImplemented as _;
11use crate::{back::wgsl::polyfill::InversePolyfill, common::wgsl::TypeContext};
12use crate::{
13 back::{self, Baked},
14 common::{
15 self,
16 wgsl::{address_space_str, ToWgsl, TryToWgsl},
17 },
18 proc::{self, NameKey},
19 valid, Handle, Module, ShaderStage, TypeInner,
20};
21
22type BackendResult = Result<(), Error>;
24
25enum Attribute {
27 Binding(u32),
28 BuiltIn(crate::BuiltIn),
29 Group(u32),
30 Invariant,
31 Interpolate(Option<crate::Interpolation>, Option<crate::Sampling>),
32 Location(u32),
33 BlendSrc(u32),
34 Stage(ShaderStage),
35 WorkGroupSize([u32; 3]),
36 MeshStage(String),
37 TaskPayload(String),
38 PerPrimitive,
39 IncomingRayPayload(String),
40}
41
42#[derive(Clone, Copy, Debug)]
55enum Indirection {
56 Ordinary,
62
63 Reference,
69}
70
71bitflags::bitflags! {
72 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
73 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
74 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
75 pub struct WriterFlags: u32 {
76 const EXPLICIT_TYPES = 0x1;
78 }
79}
80
81pub struct Writer<W> {
82 out: W,
83 flags: WriterFlags,
84 names: crate::FastHashMap<NameKey, String>,
85 namer: proc::Namer,
86 named_expressions: crate::NamedExpressions,
87 required_polyfills: crate::FastIndexSet<InversePolyfill>,
88}
89
90impl<W: Write> Writer<W> {
91 pub fn new(out: W, flags: WriterFlags) -> Self {
92 Writer {
93 out,
94 flags,
95 names: crate::FastHashMap::default(),
96 namer: proc::Namer::default(),
97 named_expressions: crate::NamedExpressions::default(),
98 required_polyfills: crate::FastIndexSet::default(),
99 }
100 }
101
102 fn reset(&mut self, module: &Module) {
103 self.names.clear();
104 self.namer.reset(
105 module,
106 &crate::keywords::wgsl::RESERVED_SET,
107 &crate::keywords::wgsl::BUILTIN_IDENTIFIER_SET,
108 proc::CaseInsensitiveKeywordSet::empty(),
110 &["__", "_naga"],
111 &mut self.names,
112 );
113 self.named_expressions.clear();
114 self.required_polyfills.clear();
115 }
116
117 fn is_builtin_wgsl_struct(&self, module: &Module, ty: Handle<crate::Type>) -> bool {
133 module
134 .special_types
135 .predeclared_types
136 .values()
137 .any(|t| *t == ty)
138 || Some(ty) == module.special_types.external_texture_params
139 || Some(ty) == module.special_types.external_texture_transfer_function
140 }
141
142 pub fn write(&mut self, module: &Module, info: &valid::ModuleInfo) -> BackendResult {
143 self.reset(module);
144
145 self.write_enable_declarations(module)?;
147
148 for (handle, ty) in module.types.iter() {
150 if let TypeInner::Struct { ref members, .. } = ty.inner {
151 {
152 if !self.is_builtin_wgsl_struct(module, handle) {
153 self.write_struct(module, handle, members)?;
154 writeln!(self.out)?;
155 }
156 }
157 }
158 }
159
160 let mut constants = module
162 .constants
163 .iter()
164 .filter(|&(_, c)| c.name.is_some())
165 .peekable();
166 while let Some((handle, _)) = constants.next() {
167 self.write_global_constant(module, handle)?;
168 if constants.peek().is_none() {
170 writeln!(self.out)?;
171 }
172 }
173
174 let mut overrides = module.overrides.iter().peekable();
176 while let Some((handle, _)) = overrides.next() {
177 self.write_override(module, handle)?;
178 if overrides.peek().is_none() {
180 writeln!(self.out)?;
181 }
182 }
183
184 for (ty, global) in module.global_variables.iter() {
186 self.write_global(module, global, ty)?;
187 }
188
189 if !module.global_variables.is_empty() {
190 writeln!(self.out)?;
192 }
193
194 for (handle, function) in module.functions.iter() {
196 let fun_info = &info[handle];
197
198 let func_ctx = back::FunctionCtx {
199 ty: back::FunctionType::Function(handle),
200 info: fun_info,
201 expressions: &function.expressions,
202 named_expressions: &function.named_expressions,
203 };
204
205 self.write_function(module, function, &func_ctx)?;
207
208 writeln!(self.out)?;
209 }
210
211 for (index, ep) in module.entry_points.iter().enumerate() {
213 let attributes = match ep.stage {
214 ShaderStage::Vertex | ShaderStage::Fragment => vec![Attribute::Stage(ep.stage)],
215 ShaderStage::Compute => vec![
216 Attribute::Stage(ShaderStage::Compute),
217 Attribute::WorkGroupSize(ep.workgroup_size),
218 ],
219 ShaderStage::Mesh => {
220 let mesh_output_name = module.global_variables
221 [ep.mesh_info.as_ref().unwrap().output_variable]
222 .name
223 .clone()
224 .unwrap();
225 let mut mesh_attrs = vec![
226 Attribute::MeshStage(mesh_output_name),
227 Attribute::WorkGroupSize(ep.workgroup_size),
228 ];
229 if let Some(task_payload) = ep.task_payload {
230 let payload_name =
231 module.global_variables[task_payload].name.clone().unwrap();
232 mesh_attrs.push(Attribute::TaskPayload(payload_name));
233 }
234 mesh_attrs
235 }
236 ShaderStage::Task => {
237 let payload_name = module.global_variables[ep.task_payload.unwrap()]
238 .name
239 .clone()
240 .unwrap();
241 vec![
242 Attribute::Stage(ShaderStage::Task),
243 Attribute::TaskPayload(payload_name),
244 Attribute::WorkGroupSize(ep.workgroup_size),
245 ]
246 }
247 ShaderStage::RayGeneration => vec![Attribute::Stage(ShaderStage::RayGeneration)],
248 ShaderStage::AnyHit | ShaderStage::ClosestHit | ShaderStage::Miss => {
249 let payload_name = module.global_variables[ep.incoming_ray_payload.unwrap()]
250 .name
251 .clone()
252 .unwrap();
253 vec![
254 Attribute::Stage(ep.stage),
255 Attribute::IncomingRayPayload(payload_name),
256 ]
257 }
258 };
259 self.write_attributes(&attributes)?;
260 writeln!(self.out)?;
262
263 let func_ctx = back::FunctionCtx {
264 ty: back::FunctionType::EntryPoint(index as u16),
265 info: info.get_entry_point(index),
266 expressions: &ep.function.expressions,
267 named_expressions: &ep.function.named_expressions,
268 };
269 self.write_function(module, &ep.function, &func_ctx)?;
270
271 if index < module.entry_points.len() - 1 {
272 writeln!(self.out)?;
273 }
274 }
275
276 for polyfill in &self.required_polyfills {
278 writeln!(self.out)?;
279 write!(self.out, "{}", polyfill.source)?;
280 writeln!(self.out)?;
281 }
282
283 Ok(())
284 }
285
286 fn write_enable_declarations(&mut self, module: &Module) -> BackendResult {
289 #[derive(Default)]
290 struct RequiredEnabled {
291 f16: bool,
292 int16: bool,
293 dual_source_blending: bool,
294 clip_distances: bool,
295 mesh_shaders: bool,
296 primitive_index: bool,
297 cooperative_matrix: bool,
298 draw_index: bool,
299 ray_tracing_pipeline: bool,
300 per_vertex: bool,
301 binding_array: bool,
302 }
303 let mut needed = RequiredEnabled {
304 mesh_shaders: module.uses_mesh_shaders(),
305 ..Default::default()
306 };
307
308 let check_binding = |binding: &crate::Binding, needed: &mut RequiredEnabled| match *binding
309 {
310 crate::Binding::Location {
311 blend_src: Some(_), ..
312 } => {
313 needed.dual_source_blending = true;
314 }
315 crate::Binding::BuiltIn(crate::BuiltIn::ClipDistances) => {
316 needed.clip_distances = true;
317 }
318 crate::Binding::BuiltIn(crate::BuiltIn::PrimitiveIndex) => {
319 needed.primitive_index = true;
320 }
321 crate::Binding::Location {
322 per_primitive: true,
323 ..
324 } => {
325 needed.mesh_shaders = true;
326 }
327 crate::Binding::Location {
328 interpolation: Some(crate::Interpolation::PerVertex),
329 ..
330 } => {
331 needed.per_vertex = true;
332 }
333 crate::Binding::BuiltIn(crate::BuiltIn::DrawIndex) => needed.draw_index = true,
334 crate::Binding::BuiltIn(
335 crate::BuiltIn::RayInvocationId
336 | crate::BuiltIn::NumRayInvocations
337 | crate::BuiltIn::InstanceCustomData
338 | crate::BuiltIn::GeometryIndex
339 | crate::BuiltIn::WorldRayOrigin
340 | crate::BuiltIn::WorldRayDirection
341 | crate::BuiltIn::ObjectRayOrigin
342 | crate::BuiltIn::ObjectRayDirection
343 | crate::BuiltIn::RayTmin
344 | crate::BuiltIn::RayTCurrentMax
345 | crate::BuiltIn::ObjectToWorld
346 | crate::BuiltIn::WorldToObject,
347 ) => {
348 needed.ray_tracing_pipeline = true;
349 }
350 _ => {}
351 };
352
353 for (_, ty) in module.types.iter() {
355 match ty.inner {
356 TypeInner::Scalar(scalar)
357 | TypeInner::Vector { scalar, .. }
358 | TypeInner::Matrix { scalar, .. } => {
359 needed.f16 |= scalar == crate::Scalar::F16;
360 needed.int16 |= scalar == crate::Scalar::I16 || scalar == crate::Scalar::U16;
361 }
362 TypeInner::Struct { ref members, .. } => {
363 for binding in members.iter().filter_map(|m| m.binding.as_ref()) {
364 check_binding(binding, &mut needed);
365 }
366 }
367 TypeInner::CooperativeMatrix { .. } => {
368 needed.cooperative_matrix = true;
369 }
370 TypeInner::AccelerationStructure { .. } => {
371 needed.ray_tracing_pipeline = true;
372 }
373 TypeInner::BindingArray { .. } => {
374 needed.binding_array = true;
375 }
376 _ => {}
377 }
378 }
379
380 for ep in &module.entry_points {
381 if let Some(res) = ep.function.result.as_ref().and_then(|a| a.binding.as_ref()) {
382 check_binding(res, &mut needed);
383 }
384 for arg in ep
385 .function
386 .arguments
387 .iter()
388 .filter_map(|a| a.binding.as_ref())
389 {
390 check_binding(arg, &mut needed);
391 }
392 }
393
394 if module.global_variables.iter().any(|gv| {
395 gv.1.space == crate::AddressSpace::IncomingRayPayload
396 || gv.1.space == crate::AddressSpace::RayPayload
397 }) {
398 needed.ray_tracing_pipeline = true;
399 }
400
401 if module.entry_points.iter().any(|ep| {
402 matches!(
403 ep.stage,
404 ShaderStage::RayGeneration
405 | ShaderStage::AnyHit
406 | ShaderStage::ClosestHit
407 | ShaderStage::Miss
408 )
409 }) {
410 needed.ray_tracing_pipeline = true;
411 }
412
413 if module.global_variables.iter().any(|gv| {
414 gv.1.space == crate::AddressSpace::IncomingRayPayload
415 || gv.1.space == crate::AddressSpace::RayPayload
416 }) {
417 needed.ray_tracing_pipeline = true;
418 }
419
420 if module.entry_points.iter().any(|ep| {
421 matches!(
422 ep.stage,
423 ShaderStage::RayGeneration
424 | ShaderStage::AnyHit
425 | ShaderStage::ClosestHit
426 | ShaderStage::Miss
427 )
428 }) {
429 needed.ray_tracing_pipeline = true;
430 }
431
432 let mut any_written = false;
434 if needed.f16 {
435 writeln!(self.out, "enable f16;")?;
436 any_written = true;
437 }
438 if needed.int16 {
439 writeln!(self.out, "enable wgpu_int16;")?;
440 any_written = true;
441 }
442 if needed.dual_source_blending {
443 writeln!(self.out, "enable dual_source_blending;")?;
444 any_written = true;
445 }
446 if needed.clip_distances {
447 writeln!(self.out, "enable clip_distances;")?;
448 any_written = true;
449 }
450 if module.uses_mesh_shaders() {
451 writeln!(self.out, "enable wgpu_mesh_shader;")?;
452 any_written = true;
453 }
454 if needed.binding_array {
455 writeln!(self.out, "enable wgpu_binding_array;")?;
456 any_written = true;
457 }
458 if needed.draw_index {
459 writeln!(self.out, "enable draw_index;")?;
460 any_written = true;
461 }
462 if needed.primitive_index {
463 writeln!(self.out, "enable primitive_index;")?;
464 any_written = true;
465 }
466 if needed.cooperative_matrix {
467 writeln!(self.out, "enable wgpu_cooperative_matrix;")?;
468 any_written = true;
469 }
470 if needed.ray_tracing_pipeline {
471 writeln!(self.out, "enable wgpu_ray_tracing_pipeline;")?;
472 any_written = true;
473 }
474 if needed.per_vertex {
475 writeln!(self.out, "enable wgpu_per_vertex;")?;
476 any_written = true;
477 }
478 if any_written {
479 writeln!(self.out)?;
481 }
482
483 Ok(())
484 }
485
486 fn write_function(
492 &mut self,
493 module: &Module,
494 func: &crate::Function,
495 func_ctx: &back::FunctionCtx<'_>,
496 ) -> BackendResult {
497 let func_name = match func_ctx.ty {
498 back::FunctionType::EntryPoint(index) => &self.names[&NameKey::EntryPoint(index)],
499 back::FunctionType::Function(handle) => &self.names[&NameKey::Function(handle)],
500 };
501
502 write!(self.out, "fn {func_name}(")?;
504
505 for (index, arg) in func.arguments.iter().enumerate() {
507 if let Some(ref binding) = arg.binding {
509 self.write_attributes(&map_binding_to_attribute(binding))?;
510 }
511 let argument_name = &self.names[&func_ctx.argument_key(index as u32)];
513
514 write!(self.out, "{argument_name}: ")?;
515 self.write_type(module, arg.ty)?;
517 if index < func.arguments.len() - 1 {
518 write!(self.out, ", ")?;
520 }
521 }
522
523 write!(self.out, ")")?;
524
525 if let Some(ref result) = func.result {
527 write!(self.out, " -> ")?;
528 if let Some(ref binding) = result.binding {
529 self.write_attributes(&map_binding_to_attribute(binding))?;
530 }
531 self.write_type(module, result.ty)?;
532 }
533
534 write!(self.out, " {{")?;
535 writeln!(self.out)?;
536
537 for (handle, local) in func.local_variables.iter() {
539 write!(self.out, "{}", back::INDENT)?;
541
542 write!(self.out, "var {}: ", self.names[&func_ctx.name_key(handle)])?;
545
546 self.write_type(module, local.ty)?;
548
549 if let Some(init) = local.init {
551 write!(self.out, " = ")?;
554
555 self.write_expr(module, init, func_ctx)?;
558 }
559
560 writeln!(self.out, ";")?
562 }
563
564 if !func.local_variables.is_empty() {
565 writeln!(self.out)?;
566 }
567
568 for sta in func.body.iter() {
570 self.write_stmt(module, sta, func_ctx, back::Level(1))?;
572 }
573
574 writeln!(self.out, "}}")?;
575
576 self.named_expressions.clear();
577
578 Ok(())
579 }
580
581 fn write_attributes(&mut self, attributes: &[Attribute]) -> BackendResult {
583 for attribute in attributes {
584 match *attribute {
585 Attribute::Location(id) => write!(self.out, "@location({id}) ")?,
586 Attribute::BlendSrc(blend_src) => write!(self.out, "@blend_src({blend_src}) ")?,
587 Attribute::BuiltIn(builtin_attrib) => {
588 let builtin = builtin_attrib.to_wgsl_if_implemented()?;
589 write!(self.out, "@builtin({builtin}) ")?;
590 }
591 Attribute::Stage(shader_stage) => {
592 let stage_str = match shader_stage {
593 ShaderStage::Vertex => "vertex",
594 ShaderStage::Fragment => "fragment",
595 ShaderStage::Compute => "compute",
596 ShaderStage::Task => "task",
597 ShaderStage::Mesh => unreachable!(),
599 ShaderStage::RayGeneration => "ray_generation",
600 ShaderStage::AnyHit => "any_hit",
601 ShaderStage::ClosestHit => "closest_hit",
602 ShaderStage::Miss => "miss",
603 };
604
605 write!(self.out, "@{stage_str} ")?;
606 }
607 Attribute::WorkGroupSize(size) => {
608 write!(
609 self.out,
610 "@workgroup_size({}, {}, {}) ",
611 size[0], size[1], size[2]
612 )?;
613 }
614 Attribute::Binding(id) => write!(self.out, "@binding({id}) ")?,
615 Attribute::Group(id) => write!(self.out, "@group({id}) ")?,
616 Attribute::Invariant => write!(self.out, "@invariant ")?,
617 Attribute::Interpolate(interpolation, sampling) => {
618 if sampling.is_some() && sampling != Some(crate::Sampling::Center) {
619 let interpolation = interpolation
620 .unwrap_or(crate::Interpolation::Perspective)
621 .to_wgsl();
622 let sampling = sampling.unwrap_or(crate::Sampling::Center).to_wgsl();
623 write!(self.out, "@interpolate({interpolation}, {sampling}) ")?;
624 } else if interpolation.is_some()
625 && interpolation != Some(crate::Interpolation::Perspective)
626 {
627 let interpolation = interpolation
628 .unwrap_or(crate::Interpolation::Perspective)
629 .to_wgsl();
630 write!(self.out, "@interpolate({interpolation}) ")?;
631 }
632 }
633 Attribute::MeshStage(ref name) => {
634 write!(self.out, "@mesh({name}) ")?;
635 }
636 Attribute::TaskPayload(ref payload_name) => {
637 write!(self.out, "@payload({payload_name}) ")?;
638 }
639 Attribute::PerPrimitive => write!(self.out, "@per_primitive ")?,
640 Attribute::IncomingRayPayload(ref payload_name) => {
641 write!(self.out, "@incoming_payload({payload_name}) ")?;
642 }
643 };
644 }
645 Ok(())
646 }
647
648 fn write_struct(
659 &mut self,
660 module: &Module,
661 handle: Handle<crate::Type>,
662 members: &[crate::StructMember],
663 ) -> BackendResult {
664 write!(self.out, "struct {}", self.names[&NameKey::Type(handle)])?;
665 write!(self.out, " {{")?;
666 writeln!(self.out)?;
667 for (index, member) in members.iter().enumerate() {
668 write!(self.out, "{}", back::INDENT)?;
670 if let Some(ref binding) = member.binding {
671 self.write_attributes(&map_binding_to_attribute(binding))?;
672 }
673 let member_name = &self.names[&NameKey::StructMember(handle, index as u32)];
675 write!(self.out, "{member_name}: ")?;
676 self.write_type(module, member.ty)?;
677 write!(self.out, ",")?;
678 writeln!(self.out)?;
679 }
680
681 writeln!(self.out, "}}")?;
682
683 Ok(())
684 }
685
686 fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
687 let type_context = WriterTypeContext {
691 module,
692 names: &self.names,
693 };
694 type_context.write_type(ty, &mut self.out)?;
695
696 Ok(())
697 }
698
699 fn write_type_resolution(
700 &mut self,
701 module: &Module,
702 resolution: &proc::TypeResolution,
703 ) -> BackendResult {
704 let type_context = WriterTypeContext {
708 module,
709 names: &self.names,
710 };
711 type_context.write_type_resolution(resolution, &mut self.out)?;
712
713 Ok(())
714 }
715
716 fn write_stmt(
721 &mut self,
722 module: &Module,
723 stmt: &crate::Statement,
724 func_ctx: &back::FunctionCtx<'_>,
725 level: back::Level,
726 ) -> BackendResult {
727 use crate::{Expression, Statement};
728
729 match *stmt {
730 Statement::Emit(ref range) => {
731 for handle in range.clone() {
732 let info = &func_ctx.info[handle];
733 let expr_name = if let Some(name) = func_ctx.named_expressions.get(&handle) {
734 Some(self.namer.call(name))
739 } else {
740 let expr = &func_ctx.expressions[handle];
741 let min_ref_count = expr.bake_ref_count();
742 let required_baking_expr = match *expr {
744 Expression::ImageLoad { .. }
745 | Expression::ImageQuery { .. }
746 | Expression::ImageSample { .. } => true,
747 _ => false,
748 };
749 if min_ref_count <= info.ref_count || required_baking_expr {
750 Some(Baked(handle).to_string())
751 } else {
752 None
753 }
754 };
755
756 if let Some(name) = expr_name {
757 write!(self.out, "{level}")?;
758 self.start_named_expr(module, handle, func_ctx, &name)?;
759 self.write_expr(module, handle, func_ctx)?;
760 self.named_expressions.insert(handle, name);
761 writeln!(self.out, ";")?;
762 }
763 }
764 }
765 Statement::If {
767 condition,
768 ref accept,
769 ref reject,
770 } => {
771 write!(self.out, "{level}")?;
772 write!(self.out, "if ")?;
773 self.write_expr(module, condition, func_ctx)?;
774 writeln!(self.out, " {{")?;
775
776 let l2 = level.next();
777 for sta in accept {
778 self.write_stmt(module, sta, func_ctx, l2)?;
780 }
781
782 if !reject.is_empty() {
785 writeln!(self.out, "{level}}} else {{")?;
786
787 for sta in reject {
788 self.write_stmt(module, sta, func_ctx, l2)?;
790 }
791 }
792
793 writeln!(self.out, "{level}}}")?
794 }
795 Statement::Return { value } => {
796 write!(self.out, "{level}")?;
797 write!(self.out, "return")?;
798 if let Some(return_value) = value {
799 write!(self.out, " ")?;
801 self.write_expr(module, return_value, func_ctx)?;
802 }
803 writeln!(self.out, ";")?;
804 }
805 Statement::Kill => {
807 write!(self.out, "{level}")?;
808 writeln!(self.out, "discard;")?
809 }
810 Statement::Store { pointer, value } => {
811 write!(self.out, "{level}")?;
812
813 let is_atomic_pointer = func_ctx
814 .resolve_type(pointer, &module.types)
815 .is_atomic_pointer(&module.types);
816
817 if is_atomic_pointer {
818 write!(self.out, "atomicStore(")?;
819 self.write_expr(module, pointer, func_ctx)?;
820 write!(self.out, ", ")?;
821 self.write_expr(module, value, func_ctx)?;
822 write!(self.out, ")")?;
823 } else {
824 self.write_expr_with_indirection(
825 module,
826 pointer,
827 func_ctx,
828 Indirection::Reference,
829 )?;
830 write!(self.out, " = ")?;
831 self.write_expr(module, value, func_ctx)?;
832 }
833 writeln!(self.out, ";")?
834 }
835 Statement::Call {
836 function,
837 ref arguments,
838 result,
839 } => {
840 write!(self.out, "{level}")?;
841 if let Some(expr) = result {
842 let name = Baked(expr).to_string();
843 self.start_named_expr(module, expr, func_ctx, &name)?;
844 self.named_expressions.insert(expr, name);
845 }
846 let func_name = &self.names[&NameKey::Function(function)];
847 write!(self.out, "{func_name}(")?;
848 for (index, &argument) in arguments.iter().enumerate() {
849 if index != 0 {
850 write!(self.out, ", ")?;
851 }
852 self.write_expr(module, argument, func_ctx)?;
853 }
854 writeln!(self.out, ");")?
855 }
856 Statement::Atomic {
857 pointer,
858 ref fun,
859 value,
860 result,
861 } => {
862 write!(self.out, "{level}")?;
863 if let Some(result) = result {
864 let res_name = Baked(result).to_string();
865 self.start_named_expr(module, result, func_ctx, &res_name)?;
866 self.named_expressions.insert(result, res_name);
867 }
868
869 let fun_str = fun.to_wgsl();
870 write!(self.out, "atomic{fun_str}(")?;
871 self.write_expr(module, pointer, func_ctx)?;
872 if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
873 write!(self.out, ", ")?;
874 self.write_expr(module, cmp, func_ctx)?;
875 }
876 write!(self.out, ", ")?;
877 self.write_expr(module, value, func_ctx)?;
878 writeln!(self.out, ");")?
879 }
880 Statement::ImageAtomic {
881 image,
882 coordinate,
883 array_index,
884 ref fun,
885 value,
886 } => {
887 write!(self.out, "{level}")?;
888 let fun_str = fun.to_wgsl();
889 write!(self.out, "textureAtomic{fun_str}(")?;
890 self.write_expr(module, image, func_ctx)?;
891 write!(self.out, ", ")?;
892 self.write_expr(module, coordinate, func_ctx)?;
893 if let Some(array_index_expr) = array_index {
894 write!(self.out, ", ")?;
895 self.write_expr(module, array_index_expr, func_ctx)?;
896 }
897 write!(self.out, ", ")?;
898 self.write_expr(module, value, func_ctx)?;
899 writeln!(self.out, ");")?;
900 }
901 Statement::WorkGroupUniformLoad { pointer, result } => {
902 write!(self.out, "{level}")?;
903 let res_name = Baked(result).to_string();
905 self.start_named_expr(module, result, func_ctx, &res_name)?;
906 self.named_expressions.insert(result, res_name);
907 write!(self.out, "workgroupUniformLoad(")?;
908 self.write_expr(module, pointer, func_ctx)?;
909 writeln!(self.out, ");")?;
910 }
911 Statement::ImageStore {
912 image,
913 coordinate,
914 array_index,
915 value,
916 } => {
917 write!(self.out, "{level}")?;
918 write!(self.out, "textureStore(")?;
919 self.write_expr(module, image, func_ctx)?;
920 write!(self.out, ", ")?;
921 self.write_expr(module, coordinate, func_ctx)?;
922 if let Some(array_index_expr) = array_index {
923 write!(self.out, ", ")?;
924 self.write_expr(module, array_index_expr, func_ctx)?;
925 }
926 write!(self.out, ", ")?;
927 self.write_expr(module, value, func_ctx)?;
928 writeln!(self.out, ");")?;
929 }
930 Statement::Block(ref block) => {
932 write!(self.out, "{level}")?;
933 writeln!(self.out, "{{")?;
934 for sta in block.iter() {
935 self.write_stmt(module, sta, func_ctx, level.next())?
937 }
938 writeln!(self.out, "{level}}}")?
939 }
940 Statement::Switch {
941 selector,
942 ref cases,
943 } => {
944 write!(self.out, "{level}")?;
946 write!(self.out, "switch ")?;
947 self.write_expr(module, selector, func_ctx)?;
948 writeln!(self.out, " {{")?;
949
950 let l2 = level.next();
951 let mut new_case = true;
952 for case in cases {
953 if case.fall_through && !case.body.is_empty() {
954 return Err(Error::Unimplemented(
956 "fall-through switch case block".into(),
957 ));
958 }
959
960 match case.value {
961 crate::SwitchValue::I32(value) => {
962 if new_case {
963 write!(self.out, "{l2}case ")?;
964 }
965 write!(self.out, "{value}")?;
966 }
967 crate::SwitchValue::U32(value) => {
968 if new_case {
969 write!(self.out, "{l2}case ")?;
970 }
971 write!(self.out, "{value}u")?;
972 }
973 crate::SwitchValue::Default => {
974 if new_case {
975 if case.fall_through {
976 write!(self.out, "{l2}case ")?;
977 } else {
978 write!(self.out, "{l2}")?;
979 }
980 }
981 write!(self.out, "default")?;
982 }
983 }
984
985 new_case = !case.fall_through;
986
987 if case.fall_through {
988 write!(self.out, ", ")?;
989 } else {
990 writeln!(self.out, ": {{")?;
991 }
992
993 for sta in case.body.iter() {
994 self.write_stmt(module, sta, func_ctx, l2.next())?;
995 }
996
997 if !case.fall_through {
998 writeln!(self.out, "{l2}}}")?;
999 }
1000 }
1001
1002 writeln!(self.out, "{level}}}")?
1003 }
1004 Statement::Loop {
1005 ref body,
1006 ref continuing,
1007 break_if,
1008 } => {
1009 write!(self.out, "{level}")?;
1010 writeln!(self.out, "loop {{")?;
1011
1012 let l2 = level.next();
1013 for sta in body.iter() {
1014 self.write_stmt(module, sta, func_ctx, l2)?;
1015 }
1016
1017 if !continuing.is_empty() || break_if.is_some() {
1022 writeln!(self.out, "{l2}continuing {{")?;
1023 for sta in continuing.iter() {
1024 self.write_stmt(module, sta, func_ctx, l2.next())?;
1025 }
1026
1027 if let Some(condition) = break_if {
1030 write!(self.out, "{}break if ", l2.next())?;
1032 self.write_expr(module, condition, func_ctx)?;
1033 writeln!(self.out, ";")?;
1035 }
1036
1037 writeln!(self.out, "{l2}}}")?;
1038 }
1039
1040 writeln!(self.out, "{level}}}")?
1041 }
1042 Statement::Break => {
1043 writeln!(self.out, "{level}break;")?;
1044 }
1045 Statement::Continue => {
1046 writeln!(self.out, "{level}continue;")?;
1047 }
1048 Statement::ControlBarrier(barrier) | Statement::MemoryBarrier(barrier) => {
1049 if barrier.contains(crate::Barrier::STORAGE) {
1050 writeln!(self.out, "{level}storageBarrier();")?;
1051 }
1052
1053 if barrier.contains(crate::Barrier::WORK_GROUP) {
1054 writeln!(self.out, "{level}workgroupBarrier();")?;
1055 }
1056
1057 if barrier.contains(crate::Barrier::SUB_GROUP) {
1058 writeln!(self.out, "{level}subgroupBarrier();")?;
1059 }
1060
1061 if barrier.contains(crate::Barrier::TEXTURE) {
1062 writeln!(self.out, "{level}textureBarrier();")?;
1063 }
1064 }
1065 Statement::RayQuery { .. } => unreachable!(),
1066 Statement::SubgroupBallot { result, predicate } => {
1067 write!(self.out, "{level}")?;
1068 let res_name = Baked(result).to_string();
1069 self.start_named_expr(module, result, func_ctx, &res_name)?;
1070 self.named_expressions.insert(result, res_name);
1071
1072 write!(self.out, "subgroupBallot(")?;
1073 if let Some(predicate) = predicate {
1074 self.write_expr(module, predicate, func_ctx)?;
1075 }
1076 writeln!(self.out, ");")?;
1077 }
1078 Statement::SubgroupCollectiveOperation {
1079 op,
1080 collective_op,
1081 argument,
1082 result,
1083 } => {
1084 write!(self.out, "{level}")?;
1085 let res_name = Baked(result).to_string();
1086 self.start_named_expr(module, result, func_ctx, &res_name)?;
1087 self.named_expressions.insert(result, res_name);
1088
1089 match (collective_op, op) {
1090 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
1091 write!(self.out, "subgroupAll(")?
1092 }
1093 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
1094 write!(self.out, "subgroupAny(")?
1095 }
1096 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
1097 write!(self.out, "subgroupAdd(")?
1098 }
1099 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
1100 write!(self.out, "subgroupMul(")?
1101 }
1102 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
1103 write!(self.out, "subgroupMax(")?
1104 }
1105 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
1106 write!(self.out, "subgroupMin(")?
1107 }
1108 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
1109 write!(self.out, "subgroupAnd(")?
1110 }
1111 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
1112 write!(self.out, "subgroupOr(")?
1113 }
1114 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
1115 write!(self.out, "subgroupXor(")?
1116 }
1117 (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
1118 write!(self.out, "subgroupExclusiveAdd(")?
1119 }
1120 (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
1121 write!(self.out, "subgroupExclusiveMul(")?
1122 }
1123 (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
1124 write!(self.out, "subgroupInclusiveAdd(")?
1125 }
1126 (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
1127 write!(self.out, "subgroupInclusiveMul(")?
1128 }
1129 _ => unimplemented!(),
1130 }
1131 self.write_expr(module, argument, func_ctx)?;
1132 writeln!(self.out, ");")?;
1133 }
1134 Statement::SubgroupGather {
1135 mode,
1136 argument,
1137 result,
1138 } => {
1139 write!(self.out, "{level}")?;
1140 let res_name = Baked(result).to_string();
1141 self.start_named_expr(module, result, func_ctx, &res_name)?;
1142 self.named_expressions.insert(result, res_name);
1143
1144 match mode {
1145 crate::GatherMode::BroadcastFirst => {
1146 write!(self.out, "subgroupBroadcastFirst(")?;
1147 }
1148 crate::GatherMode::Broadcast(_) => {
1149 write!(self.out, "subgroupBroadcast(")?;
1150 }
1151 crate::GatherMode::Shuffle(_) => {
1152 write!(self.out, "subgroupShuffle(")?;
1153 }
1154 crate::GatherMode::ShuffleDown(_) => {
1155 write!(self.out, "subgroupShuffleDown(")?;
1156 }
1157 crate::GatherMode::ShuffleUp(_) => {
1158 write!(self.out, "subgroupShuffleUp(")?;
1159 }
1160 crate::GatherMode::ShuffleXor(_) => {
1161 write!(self.out, "subgroupShuffleXor(")?;
1162 }
1163 crate::GatherMode::QuadBroadcast(_) => {
1164 write!(self.out, "quadBroadcast(")?;
1165 }
1166 crate::GatherMode::QuadSwap(direction) => match direction {
1167 crate::Direction::X => {
1168 write!(self.out, "quadSwapX(")?;
1169 }
1170 crate::Direction::Y => {
1171 write!(self.out, "quadSwapY(")?;
1172 }
1173 crate::Direction::Diagonal => {
1174 write!(self.out, "quadSwapDiagonal(")?;
1175 }
1176 },
1177 }
1178 self.write_expr(module, argument, func_ctx)?;
1179 match mode {
1180 crate::GatherMode::BroadcastFirst => {}
1181 crate::GatherMode::Broadcast(index)
1182 | crate::GatherMode::Shuffle(index)
1183 | crate::GatherMode::ShuffleDown(index)
1184 | crate::GatherMode::ShuffleUp(index)
1185 | crate::GatherMode::ShuffleXor(index)
1186 | crate::GatherMode::QuadBroadcast(index) => {
1187 write!(self.out, ", ")?;
1188 self.write_expr(module, index, func_ctx)?;
1189 }
1190 crate::GatherMode::QuadSwap(_) => {}
1191 }
1192 writeln!(self.out, ");")?;
1193 }
1194 Statement::CooperativeStore { target, ref data } => {
1195 let suffix = if data.row_major { "T" } else { "" };
1196 write!(self.out, "{level}coopStore{suffix}(")?;
1197 self.write_expr(module, target, func_ctx)?;
1198 write!(self.out, ", ")?;
1199 self.write_expr(module, data.pointer, func_ctx)?;
1200 write!(self.out, ", ")?;
1201 self.write_expr(module, data.stride, func_ctx)?;
1202 writeln!(self.out, ");")?
1203 }
1204 Statement::RayPipelineFunction(fun) => match fun {
1205 crate::RayPipelineFunction::TraceRay {
1206 acceleration_structure,
1207 descriptor,
1208 payload,
1209 } => {
1210 write!(self.out, "{level}traceRay(")?;
1211 self.write_expr(module, acceleration_structure, func_ctx)?;
1212 write!(self.out, ", ")?;
1213 self.write_expr(module, descriptor, func_ctx)?;
1214 write!(self.out, ", ")?;
1215 self.write_expr(module, payload, func_ctx)?;
1216 writeln!(self.out, ");")?
1217 }
1218 },
1219 }
1220
1221 Ok(())
1222 }
1223
1224 fn plain_form_indirection(
1247 &self,
1248 expr: Handle<crate::Expression>,
1249 module: &Module,
1250 func_ctx: &back::FunctionCtx<'_>,
1251 ) -> Indirection {
1252 use crate::Expression as Ex;
1253
1254 if self.named_expressions.contains_key(&expr) {
1258 return Indirection::Ordinary;
1259 }
1260
1261 match func_ctx.expressions[expr] {
1262 Ex::LocalVariable(_) => Indirection::Reference,
1263 Ex::GlobalVariable(handle) => {
1264 let global = &module.global_variables[handle];
1265 match global.space {
1266 crate::AddressSpace::Handle => Indirection::Ordinary,
1267 _ => Indirection::Reference,
1268 }
1269 }
1270 Ex::Access { base, .. } | Ex::AccessIndex { base, .. } => {
1271 let base_ty = func_ctx.resolve_type(base, &module.types);
1272 match *base_ty {
1273 TypeInner::Pointer { .. } | TypeInner::ValuePointer { .. } => {
1274 Indirection::Reference
1275 }
1276 _ => Indirection::Ordinary,
1277 }
1278 }
1279 _ => Indirection::Ordinary,
1280 }
1281 }
1282
1283 fn start_named_expr(
1284 &mut self,
1285 module: &Module,
1286 handle: Handle<crate::Expression>,
1287 func_ctx: &back::FunctionCtx,
1288 name: &str,
1289 ) -> BackendResult {
1290 write!(self.out, "let {name}")?;
1292 if self.flags.contains(WriterFlags::EXPLICIT_TYPES) {
1293 write!(self.out, ": ")?;
1294 self.write_type_resolution(module, &func_ctx.info[handle].ty)?;
1296 }
1297
1298 write!(self.out, " = ")?;
1299 Ok(())
1300 }
1301
1302 fn write_expr(
1306 &mut self,
1307 module: &Module,
1308 expr: Handle<crate::Expression>,
1309 func_ctx: &back::FunctionCtx<'_>,
1310 ) -> BackendResult {
1311 self.write_expr_with_indirection(module, expr, func_ctx, Indirection::Ordinary)
1312 }
1313
1314 fn write_expr_with_indirection(
1327 &mut self,
1328 module: &Module,
1329 expr: Handle<crate::Expression>,
1330 func_ctx: &back::FunctionCtx<'_>,
1331 requested: Indirection,
1332 ) -> BackendResult {
1333 let plain = self.plain_form_indirection(expr, module, func_ctx);
1336 log::trace!(
1337 "expression {:?}={:?} is {:?}, expected {:?}",
1338 expr,
1339 func_ctx.expressions[expr],
1340 plain,
1341 requested,
1342 );
1343 match (requested, plain) {
1344 (Indirection::Ordinary, Indirection::Reference) => {
1345 write!(self.out, "(&")?;
1346 self.write_expr_plain_form(module, expr, func_ctx, plain)?;
1347 write!(self.out, ")")?;
1348 }
1349 (Indirection::Reference, Indirection::Ordinary) => {
1350 write!(self.out, "(*")?;
1351 self.write_expr_plain_form(module, expr, func_ctx, plain)?;
1352 write!(self.out, ")")?;
1353 }
1354 (_, _) => self.write_expr_plain_form(module, expr, func_ctx, plain)?,
1355 }
1356
1357 Ok(())
1358 }
1359
1360 fn write_const_expression(
1361 &mut self,
1362 module: &Module,
1363 expr: Handle<crate::Expression>,
1364 arena: &crate::Arena<crate::Expression>,
1365 ) -> BackendResult {
1366 self.write_possibly_const_expression(module, expr, arena, |writer, expr| {
1367 writer.write_const_expression(module, expr, arena)
1368 })
1369 }
1370
1371 fn write_possibly_const_expression<E>(
1372 &mut self,
1373 module: &Module,
1374 expr: Handle<crate::Expression>,
1375 expressions: &crate::Arena<crate::Expression>,
1376 write_expression: E,
1377 ) -> BackendResult
1378 where
1379 E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult,
1380 {
1381 use crate::Expression;
1382
1383 match expressions[expr] {
1384 Expression::Literal(literal) => match literal {
1385 crate::Literal::F16(value) => write!(self.out, "{value}h")?,
1386 crate::Literal::F32(value) => write!(self.out, "{value}f")?,
1387 crate::Literal::U16(value) => write!(self.out, "u16({value})")?,
1388 crate::Literal::I16(value) => write!(self.out, "i16({value})")?,
1389 crate::Literal::U32(value) => write!(self.out, "{value}u")?,
1390 crate::Literal::I32(value) => {
1391 if value == i32::MIN {
1395 write!(self.out, "i32({value})")?;
1396 } else {
1397 write!(self.out, "{value}i")?;
1398 }
1399 }
1400 crate::Literal::Bool(value) => write!(self.out, "{value}")?,
1401 crate::Literal::F64(value) => write!(self.out, "{value:?}lf")?,
1402 crate::Literal::I64(value) => {
1403 if value == i64::MIN {
1408 write!(self.out, "i64({} - 1)", value + 1)?;
1409 } else {
1410 write!(self.out, "{value}li")?;
1411 }
1412 }
1413 crate::Literal::U64(value) => write!(self.out, "{value:?}lu")?,
1414 crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
1415 return Err(Error::Custom(
1416 "Abstract types should not appear in IR presented to backends".into(),
1417 ));
1418 }
1419 },
1420 Expression::Constant(handle) => {
1421 let constant = &module.constants[handle];
1422 if constant.name.is_some() {
1423 write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
1424 } else {
1425 self.write_const_expression(module, constant.init, &module.global_expressions)?;
1426 }
1427 }
1428 Expression::ZeroValue(ty) => {
1429 self.write_type(module, ty)?;
1430 write!(self.out, "()")?;
1431 }
1432 Expression::Compose { ty, ref components } => {
1433 self.write_type(module, ty)?;
1434 write!(self.out, "(")?;
1435 for (index, component) in components.iter().enumerate() {
1436 if index != 0 {
1437 write!(self.out, ", ")?;
1438 }
1439 write_expression(self, *component)?;
1440 }
1441 write!(self.out, ")")?
1442 }
1443 Expression::Splat { size, value } => {
1444 let size = common::vector_size_str(size);
1445 write!(self.out, "vec{size}(")?;
1446 write_expression(self, value)?;
1447 write!(self.out, ")")?;
1448 }
1449 Expression::Override(handle) => {
1450 write!(self.out, "{}", self.names[&NameKey::Override(handle)])?;
1451 }
1452 _ => unreachable!(),
1453 }
1454
1455 Ok(())
1456 }
1457
1458 fn write_expr_plain_form(
1466 &mut self,
1467 module: &Module,
1468 expr: Handle<crate::Expression>,
1469 func_ctx: &back::FunctionCtx<'_>,
1470 indirection: Indirection,
1471 ) -> BackendResult {
1472 use crate::Expression;
1473
1474 if let Some(name) = self.named_expressions.get(&expr) {
1475 write!(self.out, "{name}")?;
1476 return Ok(());
1477 }
1478
1479 let expression = &func_ctx.expressions[expr];
1480
1481 match *expression {
1490 Expression::Literal(_)
1491 | Expression::Constant(_)
1492 | Expression::ZeroValue(_)
1493 | Expression::Compose { .. }
1494 | Expression::Splat { .. } => {
1495 self.write_possibly_const_expression(
1496 module,
1497 expr,
1498 func_ctx.expressions,
1499 |writer, expr| writer.write_expr(module, expr, func_ctx),
1500 )?;
1501 }
1502 Expression::Override(handle) => {
1503 write!(self.out, "{}", self.names[&NameKey::Override(handle)])?;
1504 }
1505 Expression::FunctionArgument(pos) => {
1506 let name_key = func_ctx.argument_key(pos);
1507 let name = &self.names[&name_key];
1508 write!(self.out, "{name}")?;
1509 }
1510 Expression::Binary { op, left, right } => {
1511 write!(self.out, "(")?;
1512 self.write_expr(module, left, func_ctx)?;
1513 write!(self.out, " {} ", back::binary_operation_str(op))?;
1514 self.write_expr(module, right, func_ctx)?;
1515 write!(self.out, ")")?;
1516 }
1517 Expression::Access { base, index } => {
1518 self.write_expr_with_indirection(module, base, func_ctx, indirection)?;
1519 write!(self.out, "[")?;
1520 self.write_expr(module, index, func_ctx)?;
1521 write!(self.out, "]")?
1522 }
1523 Expression::AccessIndex { base, index } => {
1524 let base_ty_res = &func_ctx.info[base].ty;
1525 let mut resolved = base_ty_res.inner_with(&module.types);
1526
1527 self.write_expr_with_indirection(module, base, func_ctx, indirection)?;
1528
1529 let base_ty_handle = match *resolved {
1530 TypeInner::Pointer { base, space: _ } => {
1531 resolved = &module.types[base].inner;
1532 Some(base)
1533 }
1534 _ => base_ty_res.handle(),
1535 };
1536
1537 match *resolved {
1538 TypeInner::Vector { .. } => {
1539 write!(self.out, ".{}", back::COMPONENTS[index as usize])?
1541 }
1542 TypeInner::Matrix { .. }
1543 | TypeInner::Array { .. }
1544 | TypeInner::BindingArray { .. }
1545 | TypeInner::ValuePointer { .. } => write!(self.out, "[{index}]")?,
1546 TypeInner::Struct { .. } => {
1547 let ty = base_ty_handle.unwrap();
1550
1551 write!(
1552 self.out,
1553 ".{}",
1554 &self.names[&NameKey::StructMember(ty, index)]
1555 )?
1556 }
1557 ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))),
1558 }
1559 }
1560 Expression::ImageSample {
1561 image,
1562 sampler,
1563 gather: None,
1564 coordinate,
1565 array_index,
1566 offset,
1567 level,
1568 depth_ref,
1569 clamp_to_edge,
1570 } => {
1571 use crate::SampleLevel as Sl;
1572
1573 let suffix_cmp = match depth_ref {
1574 Some(_) => "Compare",
1575 None => "",
1576 };
1577 let suffix_level = match level {
1578 Sl::Auto => "",
1579 Sl::Zero if clamp_to_edge => "BaseClampToEdge",
1580 Sl::Zero | Sl::Exact(_) => "Level",
1581 Sl::Bias(_) => "Bias",
1582 Sl::Gradient { .. } => "Grad",
1583 };
1584
1585 write!(self.out, "textureSample{suffix_cmp}{suffix_level}(")?;
1586 self.write_expr(module, image, func_ctx)?;
1587 write!(self.out, ", ")?;
1588 self.write_expr(module, sampler, func_ctx)?;
1589 write!(self.out, ", ")?;
1590 self.write_expr(module, coordinate, func_ctx)?;
1591
1592 if let Some(array_index) = array_index {
1593 write!(self.out, ", ")?;
1594 self.write_expr(module, array_index, func_ctx)?;
1595 }
1596
1597 if let Some(depth_ref) = depth_ref {
1598 write!(self.out, ", ")?;
1599 self.write_expr(module, depth_ref, func_ctx)?;
1600 }
1601
1602 match level {
1603 Sl::Auto => {}
1604 Sl::Zero => {
1605 if depth_ref.is_none() && !clamp_to_edge {
1607 write!(self.out, ", 0.0")?;
1608 }
1609 }
1610 Sl::Exact(expr) => {
1611 write!(self.out, ", ")?;
1612 self.write_expr(module, expr, func_ctx)?;
1613 }
1614 Sl::Bias(expr) => {
1615 write!(self.out, ", ")?;
1616 self.write_expr(module, expr, func_ctx)?;
1617 }
1618 Sl::Gradient { x, y } => {
1619 write!(self.out, ", ")?;
1620 self.write_expr(module, x, func_ctx)?;
1621 write!(self.out, ", ")?;
1622 self.write_expr(module, y, func_ctx)?;
1623 }
1624 }
1625
1626 if let Some(offset) = offset {
1627 write!(self.out, ", ")?;
1628 self.write_const_expression(module, offset, func_ctx.expressions)?;
1629 }
1630
1631 write!(self.out, ")")?;
1632 }
1633
1634 Expression::ImageSample {
1635 image,
1636 sampler,
1637 gather: Some(component),
1638 coordinate,
1639 array_index,
1640 offset,
1641 level: _,
1642 depth_ref,
1643 clamp_to_edge: _,
1644 } => {
1645 let suffix_cmp = match depth_ref {
1646 Some(_) => "Compare",
1647 None => "",
1648 };
1649
1650 write!(self.out, "textureGather{suffix_cmp}(")?;
1651 match *func_ctx.resolve_type(image, &module.types) {
1652 TypeInner::Image {
1653 class: crate::ImageClass::Depth { multi: _ },
1654 ..
1655 } => {}
1656 _ => {
1657 write!(self.out, "{}, ", component as u8)?;
1658 }
1659 }
1660 self.write_expr(module, image, func_ctx)?;
1661 write!(self.out, ", ")?;
1662 self.write_expr(module, sampler, func_ctx)?;
1663 write!(self.out, ", ")?;
1664 self.write_expr(module, coordinate, func_ctx)?;
1665
1666 if let Some(array_index) = array_index {
1667 write!(self.out, ", ")?;
1668 self.write_expr(module, array_index, func_ctx)?;
1669 }
1670
1671 if let Some(depth_ref) = depth_ref {
1672 write!(self.out, ", ")?;
1673 self.write_expr(module, depth_ref, func_ctx)?;
1674 }
1675
1676 if let Some(offset) = offset {
1677 write!(self.out, ", ")?;
1678 self.write_const_expression(module, offset, func_ctx.expressions)?;
1679 }
1680
1681 write!(self.out, ")")?;
1682 }
1683 Expression::ImageQuery { image, query } => {
1684 use crate::ImageQuery as Iq;
1685
1686 let texture_function = match query {
1687 Iq::Size { .. } => "textureDimensions",
1688 Iq::NumLevels => "textureNumLevels",
1689 Iq::NumLayers => "textureNumLayers",
1690 Iq::NumSamples => "textureNumSamples",
1691 };
1692
1693 write!(self.out, "{texture_function}(")?;
1694 self.write_expr(module, image, func_ctx)?;
1695 if let Iq::Size { level: Some(level) } = query {
1696 write!(self.out, ", ")?;
1697 self.write_expr(module, level, func_ctx)?;
1698 };
1699 write!(self.out, ")")?;
1700 }
1701
1702 Expression::ImageLoad {
1703 image,
1704 coordinate,
1705 array_index,
1706 sample,
1707 level,
1708 } => {
1709 write!(self.out, "textureLoad(")?;
1710 self.write_expr(module, image, func_ctx)?;
1711 write!(self.out, ", ")?;
1712 self.write_expr(module, coordinate, func_ctx)?;
1713 if let Some(array_index) = array_index {
1714 write!(self.out, ", ")?;
1715 self.write_expr(module, array_index, func_ctx)?;
1716 }
1717 if let Some(index) = sample.or(level) {
1718 write!(self.out, ", ")?;
1719 self.write_expr(module, index, func_ctx)?;
1720 }
1721 write!(self.out, ")")?;
1722 }
1723 Expression::GlobalVariable(handle) => {
1724 let name = &self.names[&NameKey::GlobalVariable(handle)];
1725 write!(self.out, "{name}")?;
1726 }
1727
1728 Expression::As {
1729 expr,
1730 kind,
1731 convert,
1732 } => {
1733 let inner = func_ctx.resolve_type(expr, &module.types);
1734 match *inner {
1735 TypeInner::Matrix {
1736 columns,
1737 rows,
1738 scalar,
1739 } => {
1740 let scalar = crate::Scalar {
1741 kind,
1742 width: convert.unwrap_or(scalar.width),
1743 };
1744 let scalar_kind_str = scalar.to_wgsl_if_implemented()?;
1745 write!(
1746 self.out,
1747 "mat{}x{}<{}>",
1748 common::vector_size_str(columns),
1749 common::vector_size_str(rows),
1750 scalar_kind_str
1751 )?;
1752 }
1753 TypeInner::Vector {
1754 size,
1755 scalar: crate::Scalar { width, .. },
1756 } => {
1757 let scalar = crate::Scalar {
1758 kind,
1759 width: convert.unwrap_or(width),
1760 };
1761 let vector_size_str = common::vector_size_str(size);
1762 let scalar_kind_str = scalar.to_wgsl_if_implemented()?;
1763 if convert.is_some() {
1764 write!(self.out, "vec{vector_size_str}<{scalar_kind_str}>")?;
1765 } else {
1766 write!(self.out, "bitcast<vec{vector_size_str}<{scalar_kind_str}>>")?;
1767 }
1768 }
1769 TypeInner::Scalar(crate::Scalar { width, .. }) => {
1770 let scalar = crate::Scalar {
1771 kind,
1772 width: convert.unwrap_or(width),
1773 };
1774 let scalar_kind_str = scalar.to_wgsl_if_implemented()?;
1775 if convert.is_some() {
1776 write!(self.out, "{scalar_kind_str}")?
1777 } else {
1778 write!(self.out, "bitcast<{scalar_kind_str}>")?
1779 }
1780 }
1781 _ => {
1782 return Err(Error::Unimplemented(format!(
1783 "write_expr expression::as {inner:?}"
1784 )));
1785 }
1786 };
1787 write!(self.out, "(")?;
1788 self.write_expr(module, expr, func_ctx)?;
1789 write!(self.out, ")")?;
1790 }
1791 Expression::Load { pointer } => {
1792 let is_atomic_pointer = func_ctx
1793 .resolve_type(pointer, &module.types)
1794 .is_atomic_pointer(&module.types);
1795
1796 if is_atomic_pointer {
1797 write!(self.out, "atomicLoad(")?;
1798 self.write_expr(module, pointer, func_ctx)?;
1799 write!(self.out, ")")?;
1800 } else {
1801 self.write_expr_with_indirection(
1802 module,
1803 pointer,
1804 func_ctx,
1805 Indirection::Reference,
1806 )?;
1807 }
1808 }
1809 Expression::LocalVariable(handle) => {
1810 write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])?
1811 }
1812 Expression::ArrayLength(expr) => {
1813 write!(self.out, "arrayLength(")?;
1814 self.write_expr(module, expr, func_ctx)?;
1815 write!(self.out, ")")?;
1816 }
1817
1818 Expression::Math {
1819 fun,
1820 arg,
1821 arg1,
1822 arg2,
1823 arg3,
1824 } => {
1825 use crate::MathFunction as Mf;
1826
1827 enum Function {
1828 Regular(&'static str),
1829 InversePolyfill(InversePolyfill),
1830 }
1831
1832 let function = match fun.try_to_wgsl() {
1833 Some(name) => Function::Regular(name),
1834 None => match fun {
1835 Mf::Inverse => {
1836 let ty = func_ctx.resolve_type(arg, &module.types);
1837 let Some(overload) = InversePolyfill::find_overload(ty) else {
1838 return Err(Error::unsupported("math function", fun));
1839 };
1840
1841 Function::InversePolyfill(overload)
1842 }
1843 _ => return Err(Error::unsupported("math function", fun)),
1844 },
1845 };
1846
1847 match function {
1848 Function::Regular(fun_name) => {
1849 write!(self.out, "{fun_name}(")?;
1850 self.write_expr(module, arg, func_ctx)?;
1851 for arg in IntoIterator::into_iter([arg1, arg2, arg3]).flatten() {
1852 write!(self.out, ", ")?;
1853 self.write_expr(module, arg, func_ctx)?;
1854 }
1855 write!(self.out, ")")?
1856 }
1857 Function::InversePolyfill(inverse) => {
1858 write!(self.out, "{}(", inverse.fun_name)?;
1859 self.write_expr(module, arg, func_ctx)?;
1860 write!(self.out, ")")?;
1861 self.required_polyfills.insert(inverse);
1862 }
1863 }
1864 }
1865
1866 Expression::Swizzle {
1867 size,
1868 vector,
1869 pattern,
1870 } => {
1871 self.write_expr(module, vector, func_ctx)?;
1872 write!(self.out, ".")?;
1873 for &sc in pattern[..size as usize].iter() {
1874 self.out.write_char(back::COMPONENTS[sc as usize])?;
1875 }
1876 }
1877 Expression::Unary { op, expr } => {
1878 let unary = match op {
1879 crate::UnaryOperator::Negate => "-",
1880 crate::UnaryOperator::LogicalNot => "!",
1881 crate::UnaryOperator::BitwiseNot => "~",
1882 };
1883
1884 write!(self.out, "{unary}(")?;
1885 self.write_expr(module, expr, func_ctx)?;
1886
1887 write!(self.out, ")")?
1888 }
1889
1890 Expression::Select {
1891 condition,
1892 accept,
1893 reject,
1894 } => {
1895 write!(self.out, "select(")?;
1896 self.write_expr(module, reject, func_ctx)?;
1897 write!(self.out, ", ")?;
1898 self.write_expr(module, accept, func_ctx)?;
1899 write!(self.out, ", ")?;
1900 self.write_expr(module, condition, func_ctx)?;
1901 write!(self.out, ")")?
1902 }
1903 Expression::Derivative { axis, ctrl, expr } => {
1904 use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
1905 let op = match (axis, ctrl) {
1906 (Axis::X, Ctrl::Coarse) => "dpdxCoarse",
1907 (Axis::X, Ctrl::Fine) => "dpdxFine",
1908 (Axis::X, Ctrl::None) => "dpdx",
1909 (Axis::Y, Ctrl::Coarse) => "dpdyCoarse",
1910 (Axis::Y, Ctrl::Fine) => "dpdyFine",
1911 (Axis::Y, Ctrl::None) => "dpdy",
1912 (Axis::Width, Ctrl::Coarse) => "fwidthCoarse",
1913 (Axis::Width, Ctrl::Fine) => "fwidthFine",
1914 (Axis::Width, Ctrl::None) => "fwidth",
1915 };
1916 write!(self.out, "{op}(")?;
1917 self.write_expr(module, expr, func_ctx)?;
1918 write!(self.out, ")")?
1919 }
1920 Expression::Relational { fun, argument } => {
1921 use crate::RelationalFunction as Rf;
1922
1923 let fun_name = match fun {
1924 Rf::All => "all",
1925 Rf::Any => "any",
1926 _ => return Err(Error::UnsupportedRelationalFunction(fun)),
1927 };
1928 write!(self.out, "{fun_name}(")?;
1929
1930 self.write_expr(module, argument, func_ctx)?;
1931
1932 write!(self.out, ")")?
1933 }
1934 Expression::RayQueryGetIntersection { .. }
1936 | Expression::RayQueryVertexPositions { .. } => unreachable!(),
1937 Expression::CallResult(_)
1939 | Expression::AtomicResult { .. }
1940 | Expression::RayQueryProceedResult
1941 | Expression::SubgroupBallotResult
1942 | Expression::SubgroupOperationResult { .. }
1943 | Expression::WorkGroupUniformLoadResult { .. } => {}
1944 Expression::CooperativeLoad {
1945 columns,
1946 rows,
1947 role,
1948 ref data,
1949 } => {
1950 let suffix = if data.row_major { "T" } else { "" };
1951 let scalar = func_ctx.info[data.pointer]
1952 .ty
1953 .inner_with(&module.types)
1954 .pointer_base_type()
1955 .unwrap()
1956 .inner_with(&module.types)
1957 .scalar()
1958 .unwrap();
1959 write!(
1960 self.out,
1961 "coopLoad{suffix}<coop_mat{}x{}<{},{:?}>>(",
1962 columns as u32,
1963 rows as u32,
1964 scalar.try_to_wgsl().unwrap(),
1965 role,
1966 )?;
1967 self.write_expr(module, data.pointer, func_ctx)?;
1968 write!(self.out, ", ")?;
1969 self.write_expr(module, data.stride, func_ctx)?;
1970 write!(self.out, ")")?;
1971 }
1972 Expression::CooperativeMultiplyAdd { a, b, c } => {
1973 write!(self.out, "coopMultiplyAdd(")?;
1974 self.write_expr(module, a, func_ctx)?;
1975 write!(self.out, ", ")?;
1976 self.write_expr(module, b, func_ctx)?;
1977 write!(self.out, ", ")?;
1978 self.write_expr(module, c, func_ctx)?;
1979 write!(self.out, ")")?;
1980 }
1981 }
1982
1983 Ok(())
1984 }
1985
1986 fn write_global(
1990 &mut self,
1991 module: &Module,
1992 global: &crate::GlobalVariable,
1993 handle: Handle<crate::GlobalVariable>,
1994 ) -> BackendResult {
1995 if let Some(ref binding) = global.binding {
1997 self.write_attributes(&[
1998 Attribute::Group(binding.group),
1999 Attribute::Binding(binding.binding),
2000 ])?;
2001 writeln!(self.out)?;
2002 }
2003
2004 if global
2005 .memory_decorations
2006 .contains(crate::MemoryDecorations::COHERENT)
2007 {
2008 write!(self.out, "@coherent ")?;
2009 }
2010 if global
2011 .memory_decorations
2012 .contains(crate::MemoryDecorations::VOLATILE)
2013 {
2014 write!(self.out, "@volatile ")?;
2015 }
2016
2017 write!(self.out, "var")?;
2019 let (address, maybe_access) = address_space_str(global.space);
2020 if let Some(space) = address {
2021 write!(self.out, "<{space}")?;
2022 if let Some(access) = maybe_access {
2023 write!(self.out, ", {access}")?;
2024 }
2025 write!(self.out, ">")?;
2026 }
2027 write!(
2028 self.out,
2029 " {}: ",
2030 &self.names[&NameKey::GlobalVariable(handle)]
2031 )?;
2032
2033 self.write_type(module, global.ty)?;
2035
2036 if let Some(init) = global.init {
2038 write!(self.out, " = ")?;
2039 self.write_const_expression(module, init, &module.global_expressions)?;
2040 }
2041
2042 writeln!(self.out, ";")?;
2044
2045 Ok(())
2046 }
2047
2048 fn write_global_constant(
2053 &mut self,
2054 module: &Module,
2055 handle: Handle<crate::Constant>,
2056 ) -> BackendResult {
2057 let name = &self.names[&NameKey::Constant(handle)];
2058 write!(self.out, "const {name}: ")?;
2060 self.write_type(module, module.constants[handle].ty)?;
2061 write!(self.out, " = ")?;
2062 let init = module.constants[handle].init;
2063 self.write_const_expression(module, init, &module.global_expressions)?;
2064 writeln!(self.out, ";")?;
2065
2066 Ok(())
2067 }
2068
2069 fn write_override(
2074 &mut self,
2075 module: &Module,
2076 handle: Handle<crate::Override>,
2077 ) -> BackendResult {
2078 let override_ = &module.overrides[handle];
2079 let name = &self.names[&NameKey::Override(handle)];
2080
2081 if let Some(id) = override_.id {
2083 write!(self.out, "@id({id}) ")?;
2084 }
2085
2086 write!(self.out, "override {name}: ")?;
2088 self.write_type(module, override_.ty)?;
2089
2090 if let Some(init) = override_.init {
2092 write!(self.out, " = ")?;
2093 self.write_const_expression(module, init, &module.global_expressions)?;
2094 }
2095
2096 writeln!(self.out, ";")?;
2097
2098 Ok(())
2099 }
2100
2101 pub fn finish(self) -> W {
2103 self.out
2104 }
2105}
2106
2107struct WriterTypeContext<'m> {
2108 module: &'m Module,
2109 names: &'m crate::FastHashMap<NameKey, String>,
2110}
2111
2112impl TypeContext for WriterTypeContext<'_> {
2113 fn lookup_type(&self, handle: Handle<crate::Type>) -> &crate::Type {
2114 &self.module.types[handle]
2115 }
2116
2117 fn type_name(&self, handle: Handle<crate::Type>) -> &str {
2118 self.names[&NameKey::Type(handle)].as_str()
2119 }
2120
2121 fn write_unnamed_struct<W: Write>(&self, _: &TypeInner, _: &mut W) -> core::fmt::Result {
2122 unreachable!("the WGSL back end should always provide type handles");
2123 }
2124
2125 fn write_override<W: Write>(
2126 &self,
2127 handle: Handle<crate::Override>,
2128 out: &mut W,
2129 ) -> core::fmt::Result {
2130 write!(out, "{}", self.names[&NameKey::Override(handle)])
2131 }
2132
2133 fn write_non_wgsl_inner<W: Write>(&self, _: &TypeInner, _: &mut W) -> core::fmt::Result {
2134 unreachable!("backends should only be passed validated modules");
2135 }
2136
2137 fn write_non_wgsl_scalar<W: Write>(&self, _: crate::Scalar, _: &mut W) -> core::fmt::Result {
2138 unreachable!("backends should only be passed validated modules");
2139 }
2140}
2141
2142fn map_binding_to_attribute(binding: &crate::Binding) -> Vec<Attribute> {
2143 match *binding {
2144 crate::Binding::BuiltIn(built_in) => {
2145 if let crate::BuiltIn::Position { invariant: true } = built_in {
2146 vec![Attribute::BuiltIn(built_in), Attribute::Invariant]
2147 } else {
2148 vec![Attribute::BuiltIn(built_in)]
2149 }
2150 }
2151 crate::Binding::Location {
2152 location,
2153 interpolation,
2154 sampling,
2155 blend_src: None,
2156 per_primitive,
2157 } => {
2158 let mut attrs = vec![
2159 Attribute::Location(location),
2160 Attribute::Interpolate(interpolation, sampling),
2161 ];
2162 if per_primitive {
2163 attrs.push(Attribute::PerPrimitive);
2164 }
2165 attrs
2166 }
2167 crate::Binding::Location {
2168 location,
2169 interpolation,
2170 sampling,
2171 blend_src: Some(blend_src),
2172 per_primitive,
2173 } => {
2174 let mut attrs = vec![
2175 Attribute::Location(location),
2176 Attribute::BlendSrc(blend_src),
2177 Attribute::Interpolate(interpolation, sampling),
2178 ];
2179 if per_primitive {
2180 attrs.push(Attribute::PerPrimitive);
2181 }
2182 attrs
2183 }
2184 }
2185}