1use alloc::{
2 format,
3 string::{String, ToString},
4 vec,
5 vec::Vec,
6};
7use core::fmt::Write;
8
9use super::Error;
10use super::ToWgslIfImplemented as _;
11use crate::{back::wgsl::polyfill::InversePolyfill, common::wgsl::TypeContext};
12use crate::{
13 back::{self, Baked},
14 common::{
15 self,
16 wgsl::{address_space_str, ToWgsl, TryToWgsl},
17 },
18 proc::{self, NameKey},
19 valid, Handle, Module, ShaderStage, TypeInner,
20};
21
22type BackendResult = Result<(), Error>;
24
25enum Attribute {
27 Binding(u32),
28 BuiltIn(crate::BuiltIn),
29 Group(u32),
30 Invariant,
31 Interpolate(Option<crate::Interpolation>, Option<crate::Sampling>),
32 Location(u32),
33 BlendSrc(u32),
34 Stage(ShaderStage),
35 WorkGroupSize([u32; 3]),
36 MeshStage(String),
37 TaskPayload(String),
38 PerPrimitive,
39 IncomingRayPayload(String),
40}
41
42#[derive(Clone, Copy, Debug)]
55enum Indirection {
56 Ordinary,
62
63 Reference,
69}
70
71bitflags::bitflags! {
72 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
73 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
74 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
75 pub struct WriterFlags: u32 {
76 const EXPLICIT_TYPES = 0x1;
78 }
79}
80
81pub struct Writer<W> {
82 out: W,
83 flags: WriterFlags,
84 names: crate::FastHashMap<NameKey, String>,
85 namer: proc::Namer,
86 named_expressions: crate::NamedExpressions,
87 required_polyfills: crate::FastIndexSet<InversePolyfill>,
88}
89
90impl<W: Write> Writer<W> {
91 pub fn new(out: W, flags: WriterFlags) -> Self {
92 Writer {
93 out,
94 flags,
95 names: crate::FastHashMap::default(),
96 namer: proc::Namer::default(),
97 named_expressions: crate::NamedExpressions::default(),
98 required_polyfills: crate::FastIndexSet::default(),
99 }
100 }
101
102 fn reset(&mut self, module: &Module) {
103 self.names.clear();
104 self.namer.reset(
105 module,
106 &crate::keywords::wgsl::RESERVED_SET,
107 &crate::keywords::wgsl::BUILTIN_IDENTIFIER_SET,
108 proc::CaseInsensitiveKeywordSet::empty(),
110 &["__", "_naga"],
111 &mut self.names,
112 );
113 self.named_expressions.clear();
114 self.required_polyfills.clear();
115 }
116
117 fn is_builtin_wgsl_struct(&self, module: &Module, ty: Handle<crate::Type>) -> bool {
133 module
134 .special_types
135 .predeclared_types
136 .values()
137 .any(|t| *t == ty)
138 || Some(ty) == module.special_types.external_texture_params
139 || Some(ty) == module.special_types.external_texture_transfer_function
140 }
141
142 pub fn write(&mut self, module: &Module, info: &valid::ModuleInfo) -> BackendResult {
143 self.reset(module);
144
145 self.write_enable_declarations(module)?;
147
148 for (handle, ty) in module.types.iter() {
150 if let TypeInner::Struct { ref members, .. } = ty.inner {
151 {
152 if !self.is_builtin_wgsl_struct(module, handle) {
153 self.write_struct(module, handle, members)?;
154 writeln!(self.out)?;
155 }
156 }
157 }
158 }
159
160 let mut constants = module
162 .constants
163 .iter()
164 .filter(|&(_, c)| c.name.is_some())
165 .peekable();
166 while let Some((handle, _)) = constants.next() {
167 self.write_global_constant(module, handle)?;
168 if constants.peek().is_none() {
170 writeln!(self.out)?;
171 }
172 }
173
174 let mut overrides = module.overrides.iter().peekable();
176 while let Some((handle, _)) = overrides.next() {
177 self.write_override(module, handle)?;
178 if overrides.peek().is_none() {
180 writeln!(self.out)?;
181 }
182 }
183
184 for (ty, global) in module.global_variables.iter() {
186 self.write_global(module, global, ty)?;
187 }
188
189 if !module.global_variables.is_empty() {
190 writeln!(self.out)?;
192 }
193
194 for (handle, function) in module.functions.iter() {
196 let fun_info = &info[handle];
197
198 let func_ctx = back::FunctionCtx {
199 ty: back::FunctionType::Function(handle),
200 info: fun_info,
201 expressions: &function.expressions,
202 named_expressions: &function.named_expressions,
203 };
204
205 self.write_function(module, function, &func_ctx)?;
207
208 writeln!(self.out)?;
209 }
210
211 for (index, ep) in module.entry_points.iter().enumerate() {
213 let attributes = match ep.stage {
214 ShaderStage::Vertex | ShaderStage::Fragment => vec![Attribute::Stage(ep.stage)],
215 ShaderStage::Compute => vec![
216 Attribute::Stage(ShaderStage::Compute),
217 Attribute::WorkGroupSize(ep.workgroup_size),
218 ],
219 ShaderStage::Mesh => {
220 let mesh_output_name = module.global_variables
221 [ep.mesh_info.as_ref().unwrap().output_variable]
222 .name
223 .clone()
224 .unwrap();
225 let mut mesh_attrs = vec![
226 Attribute::MeshStage(mesh_output_name),
227 Attribute::WorkGroupSize(ep.workgroup_size),
228 ];
229 if let Some(task_payload) = ep.task_payload {
230 let payload_name =
231 module.global_variables[task_payload].name.clone().unwrap();
232 mesh_attrs.push(Attribute::TaskPayload(payload_name));
233 }
234 mesh_attrs
235 }
236 ShaderStage::Task => {
237 let payload_name = module.global_variables[ep.task_payload.unwrap()]
238 .name
239 .clone()
240 .unwrap();
241 vec![
242 Attribute::Stage(ShaderStage::Task),
243 Attribute::TaskPayload(payload_name),
244 Attribute::WorkGroupSize(ep.workgroup_size),
245 ]
246 }
247 ShaderStage::RayGeneration => vec![Attribute::Stage(ShaderStage::RayGeneration)],
248 ShaderStage::AnyHit | ShaderStage::ClosestHit | ShaderStage::Miss => {
249 let payload_name = module.global_variables[ep.incoming_ray_payload.unwrap()]
250 .name
251 .clone()
252 .unwrap();
253 vec![
254 Attribute::Stage(ep.stage),
255 Attribute::IncomingRayPayload(payload_name),
256 ]
257 }
258 };
259 self.write_attributes(&attributes)?;
260 writeln!(self.out)?;
262
263 let func_ctx = back::FunctionCtx {
264 ty: back::FunctionType::EntryPoint(index as u16),
265 info: info.get_entry_point(index),
266 expressions: &ep.function.expressions,
267 named_expressions: &ep.function.named_expressions,
268 };
269 self.write_function(module, &ep.function, &func_ctx)?;
270
271 if index < module.entry_points.len() - 1 {
272 writeln!(self.out)?;
273 }
274 }
275
276 for polyfill in &self.required_polyfills {
278 writeln!(self.out)?;
279 write!(self.out, "{}", polyfill.source)?;
280 writeln!(self.out)?;
281 }
282
283 Ok(())
284 }
285
286 fn write_enable_declarations(&mut self, module: &Module) -> BackendResult {
289 #[derive(Default)]
290 struct RequiredEnabled {
291 f16: bool,
292 dual_source_blending: bool,
293 clip_distances: bool,
294 mesh_shaders: bool,
295 primitive_index: bool,
296 cooperative_matrix: bool,
297 draw_index: bool,
298 ray_tracing_pipeline: bool,
299 }
300 let mut needed = RequiredEnabled {
301 mesh_shaders: module.uses_mesh_shaders(),
302 ..Default::default()
303 };
304
305 let check_binding = |binding: &crate::Binding, needed: &mut RequiredEnabled| match *binding
306 {
307 crate::Binding::Location {
308 blend_src: Some(_), ..
309 } => {
310 needed.dual_source_blending = true;
311 }
312 crate::Binding::BuiltIn(crate::BuiltIn::ClipDistance) => {
313 needed.clip_distances = true;
314 }
315 crate::Binding::BuiltIn(crate::BuiltIn::PrimitiveIndex) => {
316 needed.primitive_index = true;
317 }
318 crate::Binding::Location {
319 per_primitive: true,
320 ..
321 } => {
322 needed.mesh_shaders = true;
323 }
324 crate::Binding::BuiltIn(crate::BuiltIn::DrawIndex) => needed.draw_index = true,
325 crate::Binding::BuiltIn(
326 crate::BuiltIn::RayInvocationId
327 | crate::BuiltIn::NumRayInvocations
328 | crate::BuiltIn::InstanceCustomData
329 | crate::BuiltIn::GeometryIndex
330 | crate::BuiltIn::WorldRayOrigin
331 | crate::BuiltIn::WorldRayDirection
332 | crate::BuiltIn::ObjectRayOrigin
333 | crate::BuiltIn::ObjectRayDirection
334 | crate::BuiltIn::RayTmin
335 | crate::BuiltIn::RayTCurrentMax
336 | crate::BuiltIn::ObjectToWorld
337 | crate::BuiltIn::WorldToObject,
338 ) => {
339 needed.ray_tracing_pipeline = true;
340 }
341 _ => {}
342 };
343
344 for (_, ty) in module.types.iter() {
346 match ty.inner {
347 TypeInner::Scalar(scalar)
348 | TypeInner::Vector { scalar, .. }
349 | TypeInner::Matrix { scalar, .. } => {
350 needed.f16 |= scalar == crate::Scalar::F16;
351 }
352 TypeInner::Struct { ref members, .. } => {
353 for binding in members.iter().filter_map(|m| m.binding.as_ref()) {
354 check_binding(binding, &mut needed);
355 }
356 }
357 TypeInner::CooperativeMatrix { .. } => {
358 needed.cooperative_matrix = true;
359 }
360 TypeInner::AccelerationStructure { .. } => {
361 needed.ray_tracing_pipeline = true;
362 }
363 _ => {}
364 }
365 }
366
367 for ep in &module.entry_points {
368 if let Some(res) = ep.function.result.as_ref().and_then(|a| a.binding.as_ref()) {
369 check_binding(res, &mut needed);
370 }
371 for arg in ep
372 .function
373 .arguments
374 .iter()
375 .filter_map(|a| a.binding.as_ref())
376 {
377 check_binding(arg, &mut needed);
378 }
379 }
380
381 if module.global_variables.iter().any(|gv| {
382 gv.1.space == crate::AddressSpace::IncomingRayPayload
383 || gv.1.space == crate::AddressSpace::RayPayload
384 }) {
385 needed.ray_tracing_pipeline = true;
386 }
387
388 if module.entry_points.iter().any(|ep| {
389 matches!(
390 ep.stage,
391 ShaderStage::RayGeneration
392 | ShaderStage::AnyHit
393 | ShaderStage::ClosestHit
394 | ShaderStage::Miss
395 )
396 }) {
397 needed.ray_tracing_pipeline = true;
398 }
399
400 if module.global_variables.iter().any(|gv| {
401 gv.1.space == crate::AddressSpace::IncomingRayPayload
402 || gv.1.space == crate::AddressSpace::RayPayload
403 }) {
404 needed.ray_tracing_pipeline = true;
405 }
406
407 if module.entry_points.iter().any(|ep| {
408 matches!(
409 ep.stage,
410 ShaderStage::RayGeneration
411 | ShaderStage::AnyHit
412 | ShaderStage::ClosestHit
413 | ShaderStage::Miss
414 )
415 }) {
416 needed.ray_tracing_pipeline = true;
417 }
418
419 let mut any_written = false;
421 if needed.f16 {
422 writeln!(self.out, "enable f16;")?;
423 any_written = true;
424 }
425 if needed.dual_source_blending {
426 writeln!(self.out, "enable dual_source_blending;")?;
427 any_written = true;
428 }
429 if needed.clip_distances {
430 writeln!(self.out, "enable clip_distances;")?;
431 any_written = true;
432 }
433 if module.uses_mesh_shaders() {
434 writeln!(self.out, "enable wgpu_mesh_shader;")?;
435 any_written = true;
436 }
437 if needed.draw_index {
438 writeln!(self.out, "enable draw_index;")?;
439 any_written = true;
440 }
441 if needed.primitive_index {
442 writeln!(self.out, "enable primitive_index;")?;
443 any_written = true;
444 }
445 if needed.cooperative_matrix {
446 writeln!(self.out, "enable wgpu_cooperative_matrix;")?;
447 any_written = true;
448 }
449 if needed.ray_tracing_pipeline {
450 writeln!(self.out, "enable wgpu_ray_tracing_pipeline;")?;
451 any_written = true;
452 }
453 if any_written {
454 writeln!(self.out)?;
456 }
457
458 Ok(())
459 }
460
461 fn write_function(
467 &mut self,
468 module: &Module,
469 func: &crate::Function,
470 func_ctx: &back::FunctionCtx<'_>,
471 ) -> BackendResult {
472 let func_name = match func_ctx.ty {
473 back::FunctionType::EntryPoint(index) => &self.names[&NameKey::EntryPoint(index)],
474 back::FunctionType::Function(handle) => &self.names[&NameKey::Function(handle)],
475 };
476
477 write!(self.out, "fn {func_name}(")?;
479
480 for (index, arg) in func.arguments.iter().enumerate() {
482 if let Some(ref binding) = arg.binding {
484 self.write_attributes(&map_binding_to_attribute(binding))?;
485 }
486 let argument_name = &self.names[&func_ctx.argument_key(index as u32)];
488
489 write!(self.out, "{argument_name}: ")?;
490 self.write_type(module, arg.ty)?;
492 if index < func.arguments.len() - 1 {
493 write!(self.out, ", ")?;
495 }
496 }
497
498 write!(self.out, ")")?;
499
500 if let Some(ref result) = func.result {
502 write!(self.out, " -> ")?;
503 if let Some(ref binding) = result.binding {
504 self.write_attributes(&map_binding_to_attribute(binding))?;
505 }
506 self.write_type(module, result.ty)?;
507 }
508
509 write!(self.out, " {{")?;
510 writeln!(self.out)?;
511
512 for (handle, local) in func.local_variables.iter() {
514 write!(self.out, "{}", back::INDENT)?;
516
517 write!(self.out, "var {}: ", self.names[&func_ctx.name_key(handle)])?;
520
521 self.write_type(module, local.ty)?;
523
524 if let Some(init) = local.init {
526 write!(self.out, " = ")?;
529
530 self.write_expr(module, init, func_ctx)?;
533 }
534
535 writeln!(self.out, ";")?
537 }
538
539 if !func.local_variables.is_empty() {
540 writeln!(self.out)?;
541 }
542
543 for sta in func.body.iter() {
545 self.write_stmt(module, sta, func_ctx, back::Level(1))?;
547 }
548
549 writeln!(self.out, "}}")?;
550
551 self.named_expressions.clear();
552
553 Ok(())
554 }
555
556 fn write_attributes(&mut self, attributes: &[Attribute]) -> BackendResult {
558 for attribute in attributes {
559 match *attribute {
560 Attribute::Location(id) => write!(self.out, "@location({id}) ")?,
561 Attribute::BlendSrc(blend_src) => write!(self.out, "@blend_src({blend_src}) ")?,
562 Attribute::BuiltIn(builtin_attrib) => {
563 let builtin = builtin_attrib.to_wgsl_if_implemented()?;
564 write!(self.out, "@builtin({builtin}) ")?;
565 }
566 Attribute::Stage(shader_stage) => {
567 let stage_str = match shader_stage {
568 ShaderStage::Vertex => "vertex",
569 ShaderStage::Fragment => "fragment",
570 ShaderStage::Compute => "compute",
571 ShaderStage::Task => "task",
572 ShaderStage::Mesh => unreachable!(),
574 ShaderStage::RayGeneration => "ray_generation",
575 ShaderStage::AnyHit => "any_hit",
576 ShaderStage::ClosestHit => "closest_hit",
577 ShaderStage::Miss => "miss",
578 };
579
580 write!(self.out, "@{stage_str} ")?;
581 }
582 Attribute::WorkGroupSize(size) => {
583 write!(
584 self.out,
585 "@workgroup_size({}, {}, {}) ",
586 size[0], size[1], size[2]
587 )?;
588 }
589 Attribute::Binding(id) => write!(self.out, "@binding({id}) ")?,
590 Attribute::Group(id) => write!(self.out, "@group({id}) ")?,
591 Attribute::Invariant => write!(self.out, "@invariant ")?,
592 Attribute::Interpolate(interpolation, sampling) => {
593 if sampling.is_some() && sampling != Some(crate::Sampling::Center) {
594 let interpolation = interpolation
595 .unwrap_or(crate::Interpolation::Perspective)
596 .to_wgsl();
597 let sampling = sampling.unwrap_or(crate::Sampling::Center).to_wgsl();
598 write!(self.out, "@interpolate({interpolation}, {sampling}) ")?;
599 } else if interpolation.is_some()
600 && interpolation != Some(crate::Interpolation::Perspective)
601 {
602 let interpolation = interpolation
603 .unwrap_or(crate::Interpolation::Perspective)
604 .to_wgsl();
605 write!(self.out, "@interpolate({interpolation}) ")?;
606 }
607 }
608 Attribute::MeshStage(ref name) => {
609 write!(self.out, "@mesh({name}) ")?;
610 }
611 Attribute::TaskPayload(ref payload_name) => {
612 write!(self.out, "@payload({payload_name}) ")?;
613 }
614 Attribute::PerPrimitive => write!(self.out, "@per_primitive ")?,
615 Attribute::IncomingRayPayload(ref payload_name) => {
616 write!(self.out, "@incoming_payload({payload_name}) ")?;
617 }
618 };
619 }
620 Ok(())
621 }
622
623 fn write_struct(
634 &mut self,
635 module: &Module,
636 handle: Handle<crate::Type>,
637 members: &[crate::StructMember],
638 ) -> BackendResult {
639 write!(self.out, "struct {}", self.names[&NameKey::Type(handle)])?;
640 write!(self.out, " {{")?;
641 writeln!(self.out)?;
642 for (index, member) in members.iter().enumerate() {
643 write!(self.out, "{}", back::INDENT)?;
645 if let Some(ref binding) = member.binding {
646 self.write_attributes(&map_binding_to_attribute(binding))?;
647 }
648 let member_name = &self.names[&NameKey::StructMember(handle, index as u32)];
650 write!(self.out, "{member_name}: ")?;
651 self.write_type(module, member.ty)?;
652 write!(self.out, ",")?;
653 writeln!(self.out)?;
654 }
655
656 writeln!(self.out, "}}")?;
657
658 Ok(())
659 }
660
661 fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
662 let type_context = WriterTypeContext {
666 module,
667 names: &self.names,
668 };
669 type_context.write_type(ty, &mut self.out)?;
670
671 Ok(())
672 }
673
674 fn write_type_resolution(
675 &mut self,
676 module: &Module,
677 resolution: &proc::TypeResolution,
678 ) -> BackendResult {
679 let type_context = WriterTypeContext {
683 module,
684 names: &self.names,
685 };
686 type_context.write_type_resolution(resolution, &mut self.out)?;
687
688 Ok(())
689 }
690
691 fn write_stmt(
696 &mut self,
697 module: &Module,
698 stmt: &crate::Statement,
699 func_ctx: &back::FunctionCtx<'_>,
700 level: back::Level,
701 ) -> BackendResult {
702 use crate::{Expression, Statement};
703
704 match *stmt {
705 Statement::Emit(ref range) => {
706 for handle in range.clone() {
707 let info = &func_ctx.info[handle];
708 let expr_name = if let Some(name) = func_ctx.named_expressions.get(&handle) {
709 Some(self.namer.call(name))
714 } else {
715 let expr = &func_ctx.expressions[handle];
716 let min_ref_count = expr.bake_ref_count();
717 let required_baking_expr = match *expr {
719 Expression::ImageLoad { .. }
720 | Expression::ImageQuery { .. }
721 | Expression::ImageSample { .. } => true,
722 _ => false,
723 };
724 if min_ref_count <= info.ref_count || required_baking_expr {
725 Some(Baked(handle).to_string())
726 } else {
727 None
728 }
729 };
730
731 if let Some(name) = expr_name {
732 write!(self.out, "{level}")?;
733 self.start_named_expr(module, handle, func_ctx, &name)?;
734 self.write_expr(module, handle, func_ctx)?;
735 self.named_expressions.insert(handle, name);
736 writeln!(self.out, ";")?;
737 }
738 }
739 }
740 Statement::If {
742 condition,
743 ref accept,
744 ref reject,
745 } => {
746 write!(self.out, "{level}")?;
747 write!(self.out, "if ")?;
748 self.write_expr(module, condition, func_ctx)?;
749 writeln!(self.out, " {{")?;
750
751 let l2 = level.next();
752 for sta in accept {
753 self.write_stmt(module, sta, func_ctx, l2)?;
755 }
756
757 if !reject.is_empty() {
760 writeln!(self.out, "{level}}} else {{")?;
761
762 for sta in reject {
763 self.write_stmt(module, sta, func_ctx, l2)?;
765 }
766 }
767
768 writeln!(self.out, "{level}}}")?
769 }
770 Statement::Return { value } => {
771 write!(self.out, "{level}")?;
772 write!(self.out, "return")?;
773 if let Some(return_value) = value {
774 write!(self.out, " ")?;
776 self.write_expr(module, return_value, func_ctx)?;
777 }
778 writeln!(self.out, ";")?;
779 }
780 Statement::Kill => {
782 write!(self.out, "{level}")?;
783 writeln!(self.out, "discard;")?
784 }
785 Statement::Store { pointer, value } => {
786 write!(self.out, "{level}")?;
787
788 let is_atomic_pointer = func_ctx
789 .resolve_type(pointer, &module.types)
790 .is_atomic_pointer(&module.types);
791
792 if is_atomic_pointer {
793 write!(self.out, "atomicStore(")?;
794 self.write_expr(module, pointer, func_ctx)?;
795 write!(self.out, ", ")?;
796 self.write_expr(module, value, func_ctx)?;
797 write!(self.out, ")")?;
798 } else {
799 self.write_expr_with_indirection(
800 module,
801 pointer,
802 func_ctx,
803 Indirection::Reference,
804 )?;
805 write!(self.out, " = ")?;
806 self.write_expr(module, value, func_ctx)?;
807 }
808 writeln!(self.out, ";")?
809 }
810 Statement::Call {
811 function,
812 ref arguments,
813 result,
814 } => {
815 write!(self.out, "{level}")?;
816 if let Some(expr) = result {
817 let name = Baked(expr).to_string();
818 self.start_named_expr(module, expr, func_ctx, &name)?;
819 self.named_expressions.insert(expr, name);
820 }
821 let func_name = &self.names[&NameKey::Function(function)];
822 write!(self.out, "{func_name}(")?;
823 for (index, &argument) in arguments.iter().enumerate() {
824 if index != 0 {
825 write!(self.out, ", ")?;
826 }
827 self.write_expr(module, argument, func_ctx)?;
828 }
829 writeln!(self.out, ");")?
830 }
831 Statement::Atomic {
832 pointer,
833 ref fun,
834 value,
835 result,
836 } => {
837 write!(self.out, "{level}")?;
838 if let Some(result) = result {
839 let res_name = Baked(result).to_string();
840 self.start_named_expr(module, result, func_ctx, &res_name)?;
841 self.named_expressions.insert(result, res_name);
842 }
843
844 let fun_str = fun.to_wgsl();
845 write!(self.out, "atomic{fun_str}(")?;
846 self.write_expr(module, pointer, func_ctx)?;
847 if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
848 write!(self.out, ", ")?;
849 self.write_expr(module, cmp, func_ctx)?;
850 }
851 write!(self.out, ", ")?;
852 self.write_expr(module, value, func_ctx)?;
853 writeln!(self.out, ");")?
854 }
855 Statement::ImageAtomic {
856 image,
857 coordinate,
858 array_index,
859 ref fun,
860 value,
861 } => {
862 write!(self.out, "{level}")?;
863 let fun_str = fun.to_wgsl();
864 write!(self.out, "textureAtomic{fun_str}(")?;
865 self.write_expr(module, image, func_ctx)?;
866 write!(self.out, ", ")?;
867 self.write_expr(module, coordinate, func_ctx)?;
868 if let Some(array_index_expr) = array_index {
869 write!(self.out, ", ")?;
870 self.write_expr(module, array_index_expr, func_ctx)?;
871 }
872 write!(self.out, ", ")?;
873 self.write_expr(module, value, func_ctx)?;
874 writeln!(self.out, ");")?;
875 }
876 Statement::WorkGroupUniformLoad { pointer, result } => {
877 write!(self.out, "{level}")?;
878 let res_name = Baked(result).to_string();
880 self.start_named_expr(module, result, func_ctx, &res_name)?;
881 self.named_expressions.insert(result, res_name);
882 write!(self.out, "workgroupUniformLoad(")?;
883 self.write_expr(module, pointer, func_ctx)?;
884 writeln!(self.out, ");")?;
885 }
886 Statement::ImageStore {
887 image,
888 coordinate,
889 array_index,
890 value,
891 } => {
892 write!(self.out, "{level}")?;
893 write!(self.out, "textureStore(")?;
894 self.write_expr(module, image, func_ctx)?;
895 write!(self.out, ", ")?;
896 self.write_expr(module, coordinate, func_ctx)?;
897 if let Some(array_index_expr) = array_index {
898 write!(self.out, ", ")?;
899 self.write_expr(module, array_index_expr, func_ctx)?;
900 }
901 write!(self.out, ", ")?;
902 self.write_expr(module, value, func_ctx)?;
903 writeln!(self.out, ");")?;
904 }
905 Statement::Block(ref block) => {
907 write!(self.out, "{level}")?;
908 writeln!(self.out, "{{")?;
909 for sta in block.iter() {
910 self.write_stmt(module, sta, func_ctx, level.next())?
912 }
913 writeln!(self.out, "{level}}}")?
914 }
915 Statement::Switch {
916 selector,
917 ref cases,
918 } => {
919 write!(self.out, "{level}")?;
921 write!(self.out, "switch ")?;
922 self.write_expr(module, selector, func_ctx)?;
923 writeln!(self.out, " {{")?;
924
925 let l2 = level.next();
926 let mut new_case = true;
927 for case in cases {
928 if case.fall_through && !case.body.is_empty() {
929 return Err(Error::Unimplemented(
931 "fall-through switch case block".into(),
932 ));
933 }
934
935 match case.value {
936 crate::SwitchValue::I32(value) => {
937 if new_case {
938 write!(self.out, "{l2}case ")?;
939 }
940 write!(self.out, "{value}")?;
941 }
942 crate::SwitchValue::U32(value) => {
943 if new_case {
944 write!(self.out, "{l2}case ")?;
945 }
946 write!(self.out, "{value}u")?;
947 }
948 crate::SwitchValue::Default => {
949 if new_case {
950 if case.fall_through {
951 write!(self.out, "{l2}case ")?;
952 } else {
953 write!(self.out, "{l2}")?;
954 }
955 }
956 write!(self.out, "default")?;
957 }
958 }
959
960 new_case = !case.fall_through;
961
962 if case.fall_through {
963 write!(self.out, ", ")?;
964 } else {
965 writeln!(self.out, ": {{")?;
966 }
967
968 for sta in case.body.iter() {
969 self.write_stmt(module, sta, func_ctx, l2.next())?;
970 }
971
972 if !case.fall_through {
973 writeln!(self.out, "{l2}}}")?;
974 }
975 }
976
977 writeln!(self.out, "{level}}}")?
978 }
979 Statement::Loop {
980 ref body,
981 ref continuing,
982 break_if,
983 } => {
984 write!(self.out, "{level}")?;
985 writeln!(self.out, "loop {{")?;
986
987 let l2 = level.next();
988 for sta in body.iter() {
989 self.write_stmt(module, sta, func_ctx, l2)?;
990 }
991
992 if !continuing.is_empty() || break_if.is_some() {
997 writeln!(self.out, "{l2}continuing {{")?;
998 for sta in continuing.iter() {
999 self.write_stmt(module, sta, func_ctx, l2.next())?;
1000 }
1001
1002 if let Some(condition) = break_if {
1005 write!(self.out, "{}break if ", l2.next())?;
1007 self.write_expr(module, condition, func_ctx)?;
1008 writeln!(self.out, ";")?;
1010 }
1011
1012 writeln!(self.out, "{l2}}}")?;
1013 }
1014
1015 writeln!(self.out, "{level}}}")?
1016 }
1017 Statement::Break => {
1018 writeln!(self.out, "{level}break;")?;
1019 }
1020 Statement::Continue => {
1021 writeln!(self.out, "{level}continue;")?;
1022 }
1023 Statement::ControlBarrier(barrier) | Statement::MemoryBarrier(barrier) => {
1024 if barrier.contains(crate::Barrier::STORAGE) {
1025 writeln!(self.out, "{level}storageBarrier();")?;
1026 }
1027
1028 if barrier.contains(crate::Barrier::WORK_GROUP) {
1029 writeln!(self.out, "{level}workgroupBarrier();")?;
1030 }
1031
1032 if barrier.contains(crate::Barrier::SUB_GROUP) {
1033 writeln!(self.out, "{level}subgroupBarrier();")?;
1034 }
1035
1036 if barrier.contains(crate::Barrier::TEXTURE) {
1037 writeln!(self.out, "{level}textureBarrier();")?;
1038 }
1039 }
1040 Statement::RayQuery { .. } => unreachable!(),
1041 Statement::SubgroupBallot { result, predicate } => {
1042 write!(self.out, "{level}")?;
1043 let res_name = Baked(result).to_string();
1044 self.start_named_expr(module, result, func_ctx, &res_name)?;
1045 self.named_expressions.insert(result, res_name);
1046
1047 write!(self.out, "subgroupBallot(")?;
1048 if let Some(predicate) = predicate {
1049 self.write_expr(module, predicate, func_ctx)?;
1050 }
1051 writeln!(self.out, ");")?;
1052 }
1053 Statement::SubgroupCollectiveOperation {
1054 op,
1055 collective_op,
1056 argument,
1057 result,
1058 } => {
1059 write!(self.out, "{level}")?;
1060 let res_name = Baked(result).to_string();
1061 self.start_named_expr(module, result, func_ctx, &res_name)?;
1062 self.named_expressions.insert(result, res_name);
1063
1064 match (collective_op, op) {
1065 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
1066 write!(self.out, "subgroupAll(")?
1067 }
1068 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
1069 write!(self.out, "subgroupAny(")?
1070 }
1071 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
1072 write!(self.out, "subgroupAdd(")?
1073 }
1074 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
1075 write!(self.out, "subgroupMul(")?
1076 }
1077 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
1078 write!(self.out, "subgroupMax(")?
1079 }
1080 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
1081 write!(self.out, "subgroupMin(")?
1082 }
1083 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
1084 write!(self.out, "subgroupAnd(")?
1085 }
1086 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
1087 write!(self.out, "subgroupOr(")?
1088 }
1089 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
1090 write!(self.out, "subgroupXor(")?
1091 }
1092 (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
1093 write!(self.out, "subgroupExclusiveAdd(")?
1094 }
1095 (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
1096 write!(self.out, "subgroupExclusiveMul(")?
1097 }
1098 (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
1099 write!(self.out, "subgroupInclusiveAdd(")?
1100 }
1101 (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
1102 write!(self.out, "subgroupInclusiveMul(")?
1103 }
1104 _ => unimplemented!(),
1105 }
1106 self.write_expr(module, argument, func_ctx)?;
1107 writeln!(self.out, ");")?;
1108 }
1109 Statement::SubgroupGather {
1110 mode,
1111 argument,
1112 result,
1113 } => {
1114 write!(self.out, "{level}")?;
1115 let res_name = Baked(result).to_string();
1116 self.start_named_expr(module, result, func_ctx, &res_name)?;
1117 self.named_expressions.insert(result, res_name);
1118
1119 match mode {
1120 crate::GatherMode::BroadcastFirst => {
1121 write!(self.out, "subgroupBroadcastFirst(")?;
1122 }
1123 crate::GatherMode::Broadcast(_) => {
1124 write!(self.out, "subgroupBroadcast(")?;
1125 }
1126 crate::GatherMode::Shuffle(_) => {
1127 write!(self.out, "subgroupShuffle(")?;
1128 }
1129 crate::GatherMode::ShuffleDown(_) => {
1130 write!(self.out, "subgroupShuffleDown(")?;
1131 }
1132 crate::GatherMode::ShuffleUp(_) => {
1133 write!(self.out, "subgroupShuffleUp(")?;
1134 }
1135 crate::GatherMode::ShuffleXor(_) => {
1136 write!(self.out, "subgroupShuffleXor(")?;
1137 }
1138 crate::GatherMode::QuadBroadcast(_) => {
1139 write!(self.out, "quadBroadcast(")?;
1140 }
1141 crate::GatherMode::QuadSwap(direction) => match direction {
1142 crate::Direction::X => {
1143 write!(self.out, "quadSwapX(")?;
1144 }
1145 crate::Direction::Y => {
1146 write!(self.out, "quadSwapY(")?;
1147 }
1148 crate::Direction::Diagonal => {
1149 write!(self.out, "quadSwapDiagonal(")?;
1150 }
1151 },
1152 }
1153 self.write_expr(module, argument, func_ctx)?;
1154 match mode {
1155 crate::GatherMode::BroadcastFirst => {}
1156 crate::GatherMode::Broadcast(index)
1157 | crate::GatherMode::Shuffle(index)
1158 | crate::GatherMode::ShuffleDown(index)
1159 | crate::GatherMode::ShuffleUp(index)
1160 | crate::GatherMode::ShuffleXor(index)
1161 | crate::GatherMode::QuadBroadcast(index) => {
1162 write!(self.out, ", ")?;
1163 self.write_expr(module, index, func_ctx)?;
1164 }
1165 crate::GatherMode::QuadSwap(_) => {}
1166 }
1167 writeln!(self.out, ");")?;
1168 }
1169 Statement::CooperativeStore { target, ref data } => {
1170 let suffix = if data.row_major { "T" } else { "" };
1171 write!(self.out, "{level}coopStore{suffix}(")?;
1172 self.write_expr(module, target, func_ctx)?;
1173 write!(self.out, ", ")?;
1174 self.write_expr(module, data.pointer, func_ctx)?;
1175 write!(self.out, ", ")?;
1176 self.write_expr(module, data.stride, func_ctx)?;
1177 writeln!(self.out, ");")?
1178 }
1179 Statement::RayPipelineFunction(fun) => match fun {
1180 crate::RayPipelineFunction::TraceRay {
1181 acceleration_structure,
1182 descriptor,
1183 payload,
1184 } => {
1185 write!(self.out, "{level}traceRay(")?;
1186 self.write_expr(module, acceleration_structure, func_ctx)?;
1187 write!(self.out, ", ")?;
1188 self.write_expr(module, descriptor, func_ctx)?;
1189 write!(self.out, ", ")?;
1190 self.write_expr(module, payload, func_ctx)?;
1191 writeln!(self.out, ");")?
1192 }
1193 },
1194 }
1195
1196 Ok(())
1197 }
1198
1199 fn plain_form_indirection(
1222 &self,
1223 expr: Handle<crate::Expression>,
1224 module: &Module,
1225 func_ctx: &back::FunctionCtx<'_>,
1226 ) -> Indirection {
1227 use crate::Expression as Ex;
1228
1229 if self.named_expressions.contains_key(&expr) {
1233 return Indirection::Ordinary;
1234 }
1235
1236 match func_ctx.expressions[expr] {
1237 Ex::LocalVariable(_) => Indirection::Reference,
1238 Ex::GlobalVariable(handle) => {
1239 let global = &module.global_variables[handle];
1240 match global.space {
1241 crate::AddressSpace::Handle => Indirection::Ordinary,
1242 _ => Indirection::Reference,
1243 }
1244 }
1245 Ex::Access { base, .. } | Ex::AccessIndex { base, .. } => {
1246 let base_ty = func_ctx.resolve_type(base, &module.types);
1247 match *base_ty {
1248 TypeInner::Pointer { .. } | TypeInner::ValuePointer { .. } => {
1249 Indirection::Reference
1250 }
1251 _ => Indirection::Ordinary,
1252 }
1253 }
1254 _ => Indirection::Ordinary,
1255 }
1256 }
1257
1258 fn start_named_expr(
1259 &mut self,
1260 module: &Module,
1261 handle: Handle<crate::Expression>,
1262 func_ctx: &back::FunctionCtx,
1263 name: &str,
1264 ) -> BackendResult {
1265 write!(self.out, "let {name}")?;
1267 if self.flags.contains(WriterFlags::EXPLICIT_TYPES) {
1268 write!(self.out, ": ")?;
1269 self.write_type_resolution(module, &func_ctx.info[handle].ty)?;
1271 }
1272
1273 write!(self.out, " = ")?;
1274 Ok(())
1275 }
1276
1277 fn write_expr(
1281 &mut self,
1282 module: &Module,
1283 expr: Handle<crate::Expression>,
1284 func_ctx: &back::FunctionCtx<'_>,
1285 ) -> BackendResult {
1286 self.write_expr_with_indirection(module, expr, func_ctx, Indirection::Ordinary)
1287 }
1288
1289 fn write_expr_with_indirection(
1302 &mut self,
1303 module: &Module,
1304 expr: Handle<crate::Expression>,
1305 func_ctx: &back::FunctionCtx<'_>,
1306 requested: Indirection,
1307 ) -> BackendResult {
1308 let plain = self.plain_form_indirection(expr, module, func_ctx);
1311 log::trace!(
1312 "expression {:?}={:?} is {:?}, expected {:?}",
1313 expr,
1314 func_ctx.expressions[expr],
1315 plain,
1316 requested,
1317 );
1318 match (requested, plain) {
1319 (Indirection::Ordinary, Indirection::Reference) => {
1320 write!(self.out, "(&")?;
1321 self.write_expr_plain_form(module, expr, func_ctx, plain)?;
1322 write!(self.out, ")")?;
1323 }
1324 (Indirection::Reference, Indirection::Ordinary) => {
1325 write!(self.out, "(*")?;
1326 self.write_expr_plain_form(module, expr, func_ctx, plain)?;
1327 write!(self.out, ")")?;
1328 }
1329 (_, _) => self.write_expr_plain_form(module, expr, func_ctx, plain)?,
1330 }
1331
1332 Ok(())
1333 }
1334
1335 fn write_const_expression(
1336 &mut self,
1337 module: &Module,
1338 expr: Handle<crate::Expression>,
1339 arena: &crate::Arena<crate::Expression>,
1340 ) -> BackendResult {
1341 self.write_possibly_const_expression(module, expr, arena, |writer, expr| {
1342 writer.write_const_expression(module, expr, arena)
1343 })
1344 }
1345
1346 fn write_possibly_const_expression<E>(
1347 &mut self,
1348 module: &Module,
1349 expr: Handle<crate::Expression>,
1350 expressions: &crate::Arena<crate::Expression>,
1351 write_expression: E,
1352 ) -> BackendResult
1353 where
1354 E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult,
1355 {
1356 use crate::Expression;
1357
1358 match expressions[expr] {
1359 Expression::Literal(literal) => match literal {
1360 crate::Literal::F16(value) => write!(self.out, "{value}h")?,
1361 crate::Literal::F32(value) => write!(self.out, "{value}f")?,
1362 crate::Literal::U32(value) => write!(self.out, "{value}u")?,
1363 crate::Literal::I32(value) => {
1364 if value == i32::MIN {
1368 write!(self.out, "i32({value})")?;
1369 } else {
1370 write!(self.out, "{value}i")?;
1371 }
1372 }
1373 crate::Literal::Bool(value) => write!(self.out, "{value}")?,
1374 crate::Literal::F64(value) => write!(self.out, "{value:?}lf")?,
1375 crate::Literal::I64(value) => {
1376 if value == i64::MIN {
1381 write!(self.out, "i64({} - 1)", value + 1)?;
1382 } else {
1383 write!(self.out, "{value}li")?;
1384 }
1385 }
1386 crate::Literal::U64(value) => write!(self.out, "{value:?}lu")?,
1387 crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
1388 return Err(Error::Custom(
1389 "Abstract types should not appear in IR presented to backends".into(),
1390 ));
1391 }
1392 },
1393 Expression::Constant(handle) => {
1394 let constant = &module.constants[handle];
1395 if constant.name.is_some() {
1396 write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
1397 } else {
1398 self.write_const_expression(module, constant.init, &module.global_expressions)?;
1399 }
1400 }
1401 Expression::ZeroValue(ty) => {
1402 self.write_type(module, ty)?;
1403 write!(self.out, "()")?;
1404 }
1405 Expression::Compose { ty, ref components } => {
1406 self.write_type(module, ty)?;
1407 write!(self.out, "(")?;
1408 for (index, component) in components.iter().enumerate() {
1409 if index != 0 {
1410 write!(self.out, ", ")?;
1411 }
1412 write_expression(self, *component)?;
1413 }
1414 write!(self.out, ")")?
1415 }
1416 Expression::Splat { size, value } => {
1417 let size = common::vector_size_str(size);
1418 write!(self.out, "vec{size}(")?;
1419 write_expression(self, value)?;
1420 write!(self.out, ")")?;
1421 }
1422 Expression::Override(handle) => {
1423 write!(self.out, "{}", self.names[&NameKey::Override(handle)])?;
1424 }
1425 _ => unreachable!(),
1426 }
1427
1428 Ok(())
1429 }
1430
1431 fn write_expr_plain_form(
1439 &mut self,
1440 module: &Module,
1441 expr: Handle<crate::Expression>,
1442 func_ctx: &back::FunctionCtx<'_>,
1443 indirection: Indirection,
1444 ) -> BackendResult {
1445 use crate::Expression;
1446
1447 if let Some(name) = self.named_expressions.get(&expr) {
1448 write!(self.out, "{name}")?;
1449 return Ok(());
1450 }
1451
1452 let expression = &func_ctx.expressions[expr];
1453
1454 match *expression {
1463 Expression::Literal(_)
1464 | Expression::Constant(_)
1465 | Expression::ZeroValue(_)
1466 | Expression::Compose { .. }
1467 | Expression::Splat { .. } => {
1468 self.write_possibly_const_expression(
1469 module,
1470 expr,
1471 func_ctx.expressions,
1472 |writer, expr| writer.write_expr(module, expr, func_ctx),
1473 )?;
1474 }
1475 Expression::Override(handle) => {
1476 write!(self.out, "{}", self.names[&NameKey::Override(handle)])?;
1477 }
1478 Expression::FunctionArgument(pos) => {
1479 let name_key = func_ctx.argument_key(pos);
1480 let name = &self.names[&name_key];
1481 write!(self.out, "{name}")?;
1482 }
1483 Expression::Binary { op, left, right } => {
1484 write!(self.out, "(")?;
1485 self.write_expr(module, left, func_ctx)?;
1486 write!(self.out, " {} ", back::binary_operation_str(op))?;
1487 self.write_expr(module, right, func_ctx)?;
1488 write!(self.out, ")")?;
1489 }
1490 Expression::Access { base, index } => {
1491 self.write_expr_with_indirection(module, base, func_ctx, indirection)?;
1492 write!(self.out, "[")?;
1493 self.write_expr(module, index, func_ctx)?;
1494 write!(self.out, "]")?
1495 }
1496 Expression::AccessIndex { base, index } => {
1497 let base_ty_res = &func_ctx.info[base].ty;
1498 let mut resolved = base_ty_res.inner_with(&module.types);
1499
1500 self.write_expr_with_indirection(module, base, func_ctx, indirection)?;
1501
1502 let base_ty_handle = match *resolved {
1503 TypeInner::Pointer { base, space: _ } => {
1504 resolved = &module.types[base].inner;
1505 Some(base)
1506 }
1507 _ => base_ty_res.handle(),
1508 };
1509
1510 match *resolved {
1511 TypeInner::Vector { .. } => {
1512 write!(self.out, ".{}", back::COMPONENTS[index as usize])?
1514 }
1515 TypeInner::Matrix { .. }
1516 | TypeInner::Array { .. }
1517 | TypeInner::BindingArray { .. }
1518 | TypeInner::ValuePointer { .. } => write!(self.out, "[{index}]")?,
1519 TypeInner::Struct { .. } => {
1520 let ty = base_ty_handle.unwrap();
1523
1524 write!(
1525 self.out,
1526 ".{}",
1527 &self.names[&NameKey::StructMember(ty, index)]
1528 )?
1529 }
1530 ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))),
1531 }
1532 }
1533 Expression::ImageSample {
1534 image,
1535 sampler,
1536 gather: None,
1537 coordinate,
1538 array_index,
1539 offset,
1540 level,
1541 depth_ref,
1542 clamp_to_edge,
1543 } => {
1544 use crate::SampleLevel as Sl;
1545
1546 let suffix_cmp = match depth_ref {
1547 Some(_) => "Compare",
1548 None => "",
1549 };
1550 let suffix_level = match level {
1551 Sl::Auto => "",
1552 Sl::Zero if clamp_to_edge => "BaseClampToEdge",
1553 Sl::Zero | Sl::Exact(_) => "Level",
1554 Sl::Bias(_) => "Bias",
1555 Sl::Gradient { .. } => "Grad",
1556 };
1557
1558 write!(self.out, "textureSample{suffix_cmp}{suffix_level}(")?;
1559 self.write_expr(module, image, func_ctx)?;
1560 write!(self.out, ", ")?;
1561 self.write_expr(module, sampler, func_ctx)?;
1562 write!(self.out, ", ")?;
1563 self.write_expr(module, coordinate, func_ctx)?;
1564
1565 if let Some(array_index) = array_index {
1566 write!(self.out, ", ")?;
1567 self.write_expr(module, array_index, func_ctx)?;
1568 }
1569
1570 if let Some(depth_ref) = depth_ref {
1571 write!(self.out, ", ")?;
1572 self.write_expr(module, depth_ref, func_ctx)?;
1573 }
1574
1575 match level {
1576 Sl::Auto => {}
1577 Sl::Zero => {
1578 if depth_ref.is_none() && !clamp_to_edge {
1580 write!(self.out, ", 0.0")?;
1581 }
1582 }
1583 Sl::Exact(expr) => {
1584 write!(self.out, ", ")?;
1585 self.write_expr(module, expr, func_ctx)?;
1586 }
1587 Sl::Bias(expr) => {
1588 write!(self.out, ", ")?;
1589 self.write_expr(module, expr, func_ctx)?;
1590 }
1591 Sl::Gradient { x, y } => {
1592 write!(self.out, ", ")?;
1593 self.write_expr(module, x, func_ctx)?;
1594 write!(self.out, ", ")?;
1595 self.write_expr(module, y, func_ctx)?;
1596 }
1597 }
1598
1599 if let Some(offset) = offset {
1600 write!(self.out, ", ")?;
1601 self.write_const_expression(module, offset, func_ctx.expressions)?;
1602 }
1603
1604 write!(self.out, ")")?;
1605 }
1606
1607 Expression::ImageSample {
1608 image,
1609 sampler,
1610 gather: Some(component),
1611 coordinate,
1612 array_index,
1613 offset,
1614 level: _,
1615 depth_ref,
1616 clamp_to_edge: _,
1617 } => {
1618 let suffix_cmp = match depth_ref {
1619 Some(_) => "Compare",
1620 None => "",
1621 };
1622
1623 write!(self.out, "textureGather{suffix_cmp}(")?;
1624 match *func_ctx.resolve_type(image, &module.types) {
1625 TypeInner::Image {
1626 class: crate::ImageClass::Depth { multi: _ },
1627 ..
1628 } => {}
1629 _ => {
1630 write!(self.out, "{}, ", component as u8)?;
1631 }
1632 }
1633 self.write_expr(module, image, func_ctx)?;
1634 write!(self.out, ", ")?;
1635 self.write_expr(module, sampler, func_ctx)?;
1636 write!(self.out, ", ")?;
1637 self.write_expr(module, coordinate, func_ctx)?;
1638
1639 if let Some(array_index) = array_index {
1640 write!(self.out, ", ")?;
1641 self.write_expr(module, array_index, func_ctx)?;
1642 }
1643
1644 if let Some(depth_ref) = depth_ref {
1645 write!(self.out, ", ")?;
1646 self.write_expr(module, depth_ref, func_ctx)?;
1647 }
1648
1649 if let Some(offset) = offset {
1650 write!(self.out, ", ")?;
1651 self.write_const_expression(module, offset, func_ctx.expressions)?;
1652 }
1653
1654 write!(self.out, ")")?;
1655 }
1656 Expression::ImageQuery { image, query } => {
1657 use crate::ImageQuery as Iq;
1658
1659 let texture_function = match query {
1660 Iq::Size { .. } => "textureDimensions",
1661 Iq::NumLevels => "textureNumLevels",
1662 Iq::NumLayers => "textureNumLayers",
1663 Iq::NumSamples => "textureNumSamples",
1664 };
1665
1666 write!(self.out, "{texture_function}(")?;
1667 self.write_expr(module, image, func_ctx)?;
1668 if let Iq::Size { level: Some(level) } = query {
1669 write!(self.out, ", ")?;
1670 self.write_expr(module, level, func_ctx)?;
1671 };
1672 write!(self.out, ")")?;
1673 }
1674
1675 Expression::ImageLoad {
1676 image,
1677 coordinate,
1678 array_index,
1679 sample,
1680 level,
1681 } => {
1682 write!(self.out, "textureLoad(")?;
1683 self.write_expr(module, image, func_ctx)?;
1684 write!(self.out, ", ")?;
1685 self.write_expr(module, coordinate, func_ctx)?;
1686 if let Some(array_index) = array_index {
1687 write!(self.out, ", ")?;
1688 self.write_expr(module, array_index, func_ctx)?;
1689 }
1690 if let Some(index) = sample.or(level) {
1691 write!(self.out, ", ")?;
1692 self.write_expr(module, index, func_ctx)?;
1693 }
1694 write!(self.out, ")")?;
1695 }
1696 Expression::GlobalVariable(handle) => {
1697 let name = &self.names[&NameKey::GlobalVariable(handle)];
1698 write!(self.out, "{name}")?;
1699 }
1700
1701 Expression::As {
1702 expr,
1703 kind,
1704 convert,
1705 } => {
1706 let inner = func_ctx.resolve_type(expr, &module.types);
1707 match *inner {
1708 TypeInner::Matrix {
1709 columns,
1710 rows,
1711 scalar,
1712 } => {
1713 let scalar = crate::Scalar {
1714 kind,
1715 width: convert.unwrap_or(scalar.width),
1716 };
1717 let scalar_kind_str = scalar.to_wgsl_if_implemented()?;
1718 write!(
1719 self.out,
1720 "mat{}x{}<{}>",
1721 common::vector_size_str(columns),
1722 common::vector_size_str(rows),
1723 scalar_kind_str
1724 )?;
1725 }
1726 TypeInner::Vector {
1727 size,
1728 scalar: crate::Scalar { width, .. },
1729 } => {
1730 let scalar = crate::Scalar {
1731 kind,
1732 width: convert.unwrap_or(width),
1733 };
1734 let vector_size_str = common::vector_size_str(size);
1735 let scalar_kind_str = scalar.to_wgsl_if_implemented()?;
1736 if convert.is_some() {
1737 write!(self.out, "vec{vector_size_str}<{scalar_kind_str}>")?;
1738 } else {
1739 write!(self.out, "bitcast<vec{vector_size_str}<{scalar_kind_str}>>")?;
1740 }
1741 }
1742 TypeInner::Scalar(crate::Scalar { width, .. }) => {
1743 let scalar = crate::Scalar {
1744 kind,
1745 width: convert.unwrap_or(width),
1746 };
1747 let scalar_kind_str = scalar.to_wgsl_if_implemented()?;
1748 if convert.is_some() {
1749 write!(self.out, "{scalar_kind_str}")?
1750 } else {
1751 write!(self.out, "bitcast<{scalar_kind_str}>")?
1752 }
1753 }
1754 _ => {
1755 return Err(Error::Unimplemented(format!(
1756 "write_expr expression::as {inner:?}"
1757 )));
1758 }
1759 };
1760 write!(self.out, "(")?;
1761 self.write_expr(module, expr, func_ctx)?;
1762 write!(self.out, ")")?;
1763 }
1764 Expression::Load { pointer } => {
1765 let is_atomic_pointer = func_ctx
1766 .resolve_type(pointer, &module.types)
1767 .is_atomic_pointer(&module.types);
1768
1769 if is_atomic_pointer {
1770 write!(self.out, "atomicLoad(")?;
1771 self.write_expr(module, pointer, func_ctx)?;
1772 write!(self.out, ")")?;
1773 } else {
1774 self.write_expr_with_indirection(
1775 module,
1776 pointer,
1777 func_ctx,
1778 Indirection::Reference,
1779 )?;
1780 }
1781 }
1782 Expression::LocalVariable(handle) => {
1783 write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])?
1784 }
1785 Expression::ArrayLength(expr) => {
1786 write!(self.out, "arrayLength(")?;
1787 self.write_expr(module, expr, func_ctx)?;
1788 write!(self.out, ")")?;
1789 }
1790
1791 Expression::Math {
1792 fun,
1793 arg,
1794 arg1,
1795 arg2,
1796 arg3,
1797 } => {
1798 use crate::MathFunction as Mf;
1799
1800 enum Function {
1801 Regular(&'static str),
1802 InversePolyfill(InversePolyfill),
1803 }
1804
1805 let function = match fun.try_to_wgsl() {
1806 Some(name) => Function::Regular(name),
1807 None => match fun {
1808 Mf::Inverse => {
1809 let ty = func_ctx.resolve_type(arg, &module.types);
1810 let Some(overload) = InversePolyfill::find_overload(ty) else {
1811 return Err(Error::unsupported("math function", fun));
1812 };
1813
1814 Function::InversePolyfill(overload)
1815 }
1816 _ => return Err(Error::unsupported("math function", fun)),
1817 },
1818 };
1819
1820 match function {
1821 Function::Regular(fun_name) => {
1822 write!(self.out, "{fun_name}(")?;
1823 self.write_expr(module, arg, func_ctx)?;
1824 for arg in IntoIterator::into_iter([arg1, arg2, arg3]).flatten() {
1825 write!(self.out, ", ")?;
1826 self.write_expr(module, arg, func_ctx)?;
1827 }
1828 write!(self.out, ")")?
1829 }
1830 Function::InversePolyfill(inverse) => {
1831 write!(self.out, "{}(", inverse.fun_name)?;
1832 self.write_expr(module, arg, func_ctx)?;
1833 write!(self.out, ")")?;
1834 self.required_polyfills.insert(inverse);
1835 }
1836 }
1837 }
1838
1839 Expression::Swizzle {
1840 size,
1841 vector,
1842 pattern,
1843 } => {
1844 self.write_expr(module, vector, func_ctx)?;
1845 write!(self.out, ".")?;
1846 for &sc in pattern[..size as usize].iter() {
1847 self.out.write_char(back::COMPONENTS[sc as usize])?;
1848 }
1849 }
1850 Expression::Unary { op, expr } => {
1851 let unary = match op {
1852 crate::UnaryOperator::Negate => "-",
1853 crate::UnaryOperator::LogicalNot => "!",
1854 crate::UnaryOperator::BitwiseNot => "~",
1855 };
1856
1857 write!(self.out, "{unary}(")?;
1858 self.write_expr(module, expr, func_ctx)?;
1859
1860 write!(self.out, ")")?
1861 }
1862
1863 Expression::Select {
1864 condition,
1865 accept,
1866 reject,
1867 } => {
1868 write!(self.out, "select(")?;
1869 self.write_expr(module, reject, func_ctx)?;
1870 write!(self.out, ", ")?;
1871 self.write_expr(module, accept, func_ctx)?;
1872 write!(self.out, ", ")?;
1873 self.write_expr(module, condition, func_ctx)?;
1874 write!(self.out, ")")?
1875 }
1876 Expression::Derivative { axis, ctrl, expr } => {
1877 use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
1878 let op = match (axis, ctrl) {
1879 (Axis::X, Ctrl::Coarse) => "dpdxCoarse",
1880 (Axis::X, Ctrl::Fine) => "dpdxFine",
1881 (Axis::X, Ctrl::None) => "dpdx",
1882 (Axis::Y, Ctrl::Coarse) => "dpdyCoarse",
1883 (Axis::Y, Ctrl::Fine) => "dpdyFine",
1884 (Axis::Y, Ctrl::None) => "dpdy",
1885 (Axis::Width, Ctrl::Coarse) => "fwidthCoarse",
1886 (Axis::Width, Ctrl::Fine) => "fwidthFine",
1887 (Axis::Width, Ctrl::None) => "fwidth",
1888 };
1889 write!(self.out, "{op}(")?;
1890 self.write_expr(module, expr, func_ctx)?;
1891 write!(self.out, ")")?
1892 }
1893 Expression::Relational { fun, argument } => {
1894 use crate::RelationalFunction as Rf;
1895
1896 let fun_name = match fun {
1897 Rf::All => "all",
1898 Rf::Any => "any",
1899 _ => return Err(Error::UnsupportedRelationalFunction(fun)),
1900 };
1901 write!(self.out, "{fun_name}(")?;
1902
1903 self.write_expr(module, argument, func_ctx)?;
1904
1905 write!(self.out, ")")?
1906 }
1907 Expression::RayQueryGetIntersection { .. }
1909 | Expression::RayQueryVertexPositions { .. } => unreachable!(),
1910 Expression::CallResult(_)
1912 | Expression::AtomicResult { .. }
1913 | Expression::RayQueryProceedResult
1914 | Expression::SubgroupBallotResult
1915 | Expression::SubgroupOperationResult { .. }
1916 | Expression::WorkGroupUniformLoadResult { .. } => {}
1917 Expression::CooperativeLoad {
1918 columns,
1919 rows,
1920 role,
1921 ref data,
1922 } => {
1923 let suffix = if data.row_major { "T" } else { "" };
1924 let scalar = func_ctx.info[data.pointer]
1925 .ty
1926 .inner_with(&module.types)
1927 .pointer_base_type()
1928 .unwrap()
1929 .inner_with(&module.types)
1930 .scalar()
1931 .unwrap();
1932 write!(
1933 self.out,
1934 "coopLoad{suffix}<coop_mat{}x{}<{},{:?}>>(",
1935 columns as u32,
1936 rows as u32,
1937 scalar.try_to_wgsl().unwrap(),
1938 role,
1939 )?;
1940 self.write_expr(module, data.pointer, func_ctx)?;
1941 write!(self.out, ", ")?;
1942 self.write_expr(module, data.stride, func_ctx)?;
1943 write!(self.out, ")")?;
1944 }
1945 Expression::CooperativeMultiplyAdd { a, b, c } => {
1946 write!(self.out, "coopMultiplyAdd(")?;
1947 self.write_expr(module, a, func_ctx)?;
1948 write!(self.out, ", ")?;
1949 self.write_expr(module, b, func_ctx)?;
1950 write!(self.out, ", ")?;
1951 self.write_expr(module, c, func_ctx)?;
1952 write!(self.out, ")")?;
1953 }
1954 }
1955
1956 Ok(())
1957 }
1958
1959 fn write_global(
1963 &mut self,
1964 module: &Module,
1965 global: &crate::GlobalVariable,
1966 handle: Handle<crate::GlobalVariable>,
1967 ) -> BackendResult {
1968 if let Some(ref binding) = global.binding {
1970 self.write_attributes(&[
1971 Attribute::Group(binding.group),
1972 Attribute::Binding(binding.binding),
1973 ])?;
1974 writeln!(self.out)?;
1975 }
1976
1977 write!(self.out, "var")?;
1979 let (address, maybe_access) = address_space_str(global.space);
1980 if let Some(space) = address {
1981 write!(self.out, "<{space}")?;
1982 if let Some(access) = maybe_access {
1983 write!(self.out, ", {access}")?;
1984 }
1985 write!(self.out, ">")?;
1986 }
1987 write!(
1988 self.out,
1989 " {}: ",
1990 &self.names[&NameKey::GlobalVariable(handle)]
1991 )?;
1992
1993 self.write_type(module, global.ty)?;
1995
1996 if let Some(init) = global.init {
1998 write!(self.out, " = ")?;
1999 self.write_const_expression(module, init, &module.global_expressions)?;
2000 }
2001
2002 writeln!(self.out, ";")?;
2004
2005 Ok(())
2006 }
2007
2008 fn write_global_constant(
2013 &mut self,
2014 module: &Module,
2015 handle: Handle<crate::Constant>,
2016 ) -> BackendResult {
2017 let name = &self.names[&NameKey::Constant(handle)];
2018 write!(self.out, "const {name}: ")?;
2020 self.write_type(module, module.constants[handle].ty)?;
2021 write!(self.out, " = ")?;
2022 let init = module.constants[handle].init;
2023 self.write_const_expression(module, init, &module.global_expressions)?;
2024 writeln!(self.out, ";")?;
2025
2026 Ok(())
2027 }
2028
2029 fn write_override(
2034 &mut self,
2035 module: &Module,
2036 handle: Handle<crate::Override>,
2037 ) -> BackendResult {
2038 let override_ = &module.overrides[handle];
2039 let name = &self.names[&NameKey::Override(handle)];
2040
2041 if let Some(id) = override_.id {
2043 write!(self.out, "@id({id}) ")?;
2044 }
2045
2046 write!(self.out, "override {name}: ")?;
2048 self.write_type(module, override_.ty)?;
2049
2050 if let Some(init) = override_.init {
2052 write!(self.out, " = ")?;
2053 self.write_const_expression(module, init, &module.global_expressions)?;
2054 }
2055
2056 writeln!(self.out, ";")?;
2057
2058 Ok(())
2059 }
2060
2061 pub fn finish(self) -> W {
2063 self.out
2064 }
2065}
2066
2067struct WriterTypeContext<'m> {
2068 module: &'m Module,
2069 names: &'m crate::FastHashMap<NameKey, String>,
2070}
2071
2072impl TypeContext for WriterTypeContext<'_> {
2073 fn lookup_type(&self, handle: Handle<crate::Type>) -> &crate::Type {
2074 &self.module.types[handle]
2075 }
2076
2077 fn type_name(&self, handle: Handle<crate::Type>) -> &str {
2078 self.names[&NameKey::Type(handle)].as_str()
2079 }
2080
2081 fn write_unnamed_struct<W: Write>(&self, _: &TypeInner, _: &mut W) -> core::fmt::Result {
2082 unreachable!("the WGSL back end should always provide type handles");
2083 }
2084
2085 fn write_override<W: Write>(
2086 &self,
2087 handle: Handle<crate::Override>,
2088 out: &mut W,
2089 ) -> core::fmt::Result {
2090 write!(out, "{}", self.names[&NameKey::Override(handle)])
2091 }
2092
2093 fn write_non_wgsl_inner<W: Write>(&self, _: &TypeInner, _: &mut W) -> core::fmt::Result {
2094 unreachable!("backends should only be passed validated modules");
2095 }
2096
2097 fn write_non_wgsl_scalar<W: Write>(&self, _: crate::Scalar, _: &mut W) -> core::fmt::Result {
2098 unreachable!("backends should only be passed validated modules");
2099 }
2100}
2101
2102fn map_binding_to_attribute(binding: &crate::Binding) -> Vec<Attribute> {
2103 match *binding {
2104 crate::Binding::BuiltIn(built_in) => {
2105 if let crate::BuiltIn::Position { invariant: true } = built_in {
2106 vec![Attribute::BuiltIn(built_in), Attribute::Invariant]
2107 } else {
2108 vec![Attribute::BuiltIn(built_in)]
2109 }
2110 }
2111 crate::Binding::Location {
2112 location,
2113 interpolation,
2114 sampling,
2115 blend_src: None,
2116 per_primitive,
2117 } => {
2118 let mut attrs = vec![
2119 Attribute::Location(location),
2120 Attribute::Interpolate(interpolation, sampling),
2121 ];
2122 if per_primitive {
2123 attrs.push(Attribute::PerPrimitive);
2124 }
2125 attrs
2126 }
2127 crate::Binding::Location {
2128 location,
2129 interpolation,
2130 sampling,
2131 blend_src: Some(blend_src),
2132 per_primitive,
2133 } => {
2134 let mut attrs = vec![
2135 Attribute::Location(location),
2136 Attribute::BlendSrc(blend_src),
2137 Attribute::Interpolate(interpolation, sampling),
2138 ];
2139 if per_primitive {
2140 attrs.push(Attribute::PerPrimitive);
2141 }
2142 attrs
2143 }
2144 }
2145}