naga/back/wgsl/
writer.rs

1use alloc::{
2    format,
3    string::{String, ToString},
4    vec,
5    vec::Vec,
6};
7use core::fmt::Write;
8
9use super::Error;
10use super::ToWgslIfImplemented as _;
11use crate::{back::wgsl::polyfill::InversePolyfill, common::wgsl::TypeContext};
12use crate::{
13    back::{self, Baked},
14    common::{
15        self,
16        wgsl::{address_space_str, ToWgsl, TryToWgsl},
17    },
18    proc::{self, NameKey},
19    valid, Handle, Module, ShaderStage, TypeInner,
20};
21
22/// Shorthand result used internally by the backend
23type BackendResult = Result<(), Error>;
24
25/// WGSL [attribute](https://gpuweb.github.io/gpuweb/wgsl/#attributes)
26enum Attribute {
27    Binding(u32),
28    BuiltIn(crate::BuiltIn),
29    Group(u32),
30    Invariant,
31    Interpolate(Option<crate::Interpolation>, Option<crate::Sampling>),
32    Location(u32),
33    BlendSrc(u32),
34    Stage(ShaderStage),
35    WorkGroupSize([u32; 3]),
36    MeshStage(String),
37    TaskPayload(String),
38    PerPrimitive,
39}
40
41/// The WGSL form that `write_expr_with_indirection` should use to render a Naga
42/// expression.
43///
44/// Sometimes a Naga `Expression` alone doesn't provide enough information to
45/// choose the right rendering for it in WGSL. For example, one natural WGSL
46/// rendering of a Naga `LocalVariable(x)` expression might be `&x`, since
47/// `LocalVariable` produces a pointer to the local variable's storage. But when
48/// rendering a `Store` statement, the `pointer` operand must be the left hand
49/// side of a WGSL assignment, so the proper rendering is `x`.
50///
51/// The caller of `write_expr_with_indirection` must provide an `Expected` value
52/// to indicate how ambiguous expressions should be rendered.
53#[derive(Clone, Copy, Debug)]
54enum Indirection {
55    /// Render pointer-construction expressions as WGSL `ptr`-typed expressions.
56    ///
57    /// This is the right choice for most cases. Whenever a Naga pointer
58    /// expression is not the `pointer` operand of a `Load` or `Store`, it
59    /// must be a WGSL pointer expression.
60    Ordinary,
61
62    /// Render pointer-construction expressions as WGSL reference-typed
63    /// expressions.
64    ///
65    /// For example, this is the right choice for the `pointer` operand when
66    /// rendering a `Store` statement as a WGSL assignment.
67    Reference,
68}
69
70bitflags::bitflags! {
71    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
72    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
73    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
74    pub struct WriterFlags: u32 {
75        /// Always annotate the type information instead of inferring.
76        const EXPLICIT_TYPES = 0x1;
77    }
78}
79
80pub struct Writer<W> {
81    out: W,
82    flags: WriterFlags,
83    names: crate::FastHashMap<NameKey, String>,
84    namer: proc::Namer,
85    named_expressions: crate::NamedExpressions,
86    required_polyfills: crate::FastIndexSet<InversePolyfill>,
87}
88
89impl<W: Write> Writer<W> {
90    pub fn new(out: W, flags: WriterFlags) -> Self {
91        Writer {
92            out,
93            flags,
94            names: crate::FastHashMap::default(),
95            namer: proc::Namer::default(),
96            named_expressions: crate::NamedExpressions::default(),
97            required_polyfills: crate::FastIndexSet::default(),
98        }
99    }
100
101    fn reset(&mut self, module: &Module) {
102        self.names.clear();
103        self.namer.reset(
104            module,
105            &crate::keywords::wgsl::RESERVED_SET,
106            // an identifier must not start with two underscore
107            proc::CaseInsensitiveKeywordSet::empty(),
108            &["__", "_naga"],
109            &mut self.names,
110        );
111        self.named_expressions.clear();
112        self.required_polyfills.clear();
113    }
114
115    /// Determine if `ty` is the Naga IR presentation of a WGSL builtin type.
116    ///
117    /// Return true if `ty` refers to the Naga IR form of a WGSL builtin type
118    /// like `__atomic_compare_exchange_result`.
119    ///
120    /// Even though the module may use the type, the WGSL backend should avoid
121    /// emitting a definition for it, since it is [predeclared] in WGSL.
122    ///
123    /// This also covers types like [`NagaExternalTextureParams`], which other
124    /// backends use to lower WGSL constructs like external textures to their
125    /// implementations. WGSL can express these directly, so the types need not
126    /// be emitted.
127    ///
128    /// [predeclared]: https://www.w3.org/TR/WGSL/#predeclared
129    /// [`NagaExternalTextureParams`]: crate::ir::SpecialTypes::external_texture_params
130    fn is_builtin_wgsl_struct(&self, module: &Module, ty: Handle<crate::Type>) -> bool {
131        module
132            .special_types
133            .predeclared_types
134            .values()
135            .any(|t| *t == ty)
136            || Some(ty) == module.special_types.external_texture_params
137            || Some(ty) == module.special_types.external_texture_transfer_function
138    }
139
140    pub fn write(&mut self, module: &Module, info: &valid::ModuleInfo) -> BackendResult {
141        self.reset(module);
142
143        // Write all `enable` declarations
144        self.write_enable_declarations(module)?;
145
146        // Write all structs
147        for (handle, ty) in module.types.iter() {
148            if let TypeInner::Struct { ref members, .. } = ty.inner {
149                {
150                    if !self.is_builtin_wgsl_struct(module, handle) {
151                        self.write_struct(module, handle, members)?;
152                        writeln!(self.out)?;
153                    }
154                }
155            }
156        }
157
158        // Write all named constants
159        let mut constants = module
160            .constants
161            .iter()
162            .filter(|&(_, c)| c.name.is_some())
163            .peekable();
164        while let Some((handle, _)) = constants.next() {
165            self.write_global_constant(module, handle)?;
166            // Add extra newline for readability on last iteration
167            if constants.peek().is_none() {
168                writeln!(self.out)?;
169            }
170        }
171
172        // Write all overrides
173        let mut overrides = module.overrides.iter().peekable();
174        while let Some((handle, _)) = overrides.next() {
175            self.write_override(module, handle)?;
176            // Add extra newline for readability on last iteration
177            if overrides.peek().is_none() {
178                writeln!(self.out)?;
179            }
180        }
181
182        // Write all globals
183        for (ty, global) in module.global_variables.iter() {
184            self.write_global(module, global, ty)?;
185        }
186
187        if !module.global_variables.is_empty() {
188            // Add extra newline for readability
189            writeln!(self.out)?;
190        }
191
192        // Write all regular functions
193        for (handle, function) in module.functions.iter() {
194            let fun_info = &info[handle];
195
196            let func_ctx = back::FunctionCtx {
197                ty: back::FunctionType::Function(handle),
198                info: fun_info,
199                expressions: &function.expressions,
200                named_expressions: &function.named_expressions,
201            };
202
203            // Write the function
204            self.write_function(module, function, &func_ctx)?;
205
206            writeln!(self.out)?;
207        }
208
209        // Write all entry points
210        for (index, ep) in module.entry_points.iter().enumerate() {
211            let attributes = match ep.stage {
212                ShaderStage::Vertex | ShaderStage::Fragment => vec![Attribute::Stage(ep.stage)],
213                ShaderStage::Compute => vec![
214                    Attribute::Stage(ShaderStage::Compute),
215                    Attribute::WorkGroupSize(ep.workgroup_size),
216                ],
217                ShaderStage::Mesh => {
218                    let mesh_output_name = module.global_variables
219                        [ep.mesh_info.as_ref().unwrap().output_variable]
220                        .name
221                        .clone()
222                        .unwrap();
223                    let mut mesh_attrs = vec![
224                        Attribute::MeshStage(mesh_output_name),
225                        Attribute::WorkGroupSize(ep.workgroup_size),
226                    ];
227                    if ep.task_payload.is_some() {
228                        let payload_name = module.global_variables[ep.task_payload.unwrap()]
229                            .name
230                            .clone()
231                            .unwrap();
232                        mesh_attrs.push(Attribute::TaskPayload(payload_name));
233                    }
234                    mesh_attrs
235                }
236                ShaderStage::Task => {
237                    let payload_name = module.global_variables[ep.task_payload.unwrap()]
238                        .name
239                        .clone()
240                        .unwrap();
241                    vec![
242                        Attribute::Stage(ShaderStage::Task),
243                        Attribute::TaskPayload(payload_name),
244                        Attribute::WorkGroupSize(ep.workgroup_size),
245                    ]
246                }
247            };
248            self.write_attributes(&attributes)?;
249            // Add a newline after attribute
250            writeln!(self.out)?;
251
252            let func_ctx = back::FunctionCtx {
253                ty: back::FunctionType::EntryPoint(index as u16),
254                info: info.get_entry_point(index),
255                expressions: &ep.function.expressions,
256                named_expressions: &ep.function.named_expressions,
257            };
258            self.write_function(module, &ep.function, &func_ctx)?;
259
260            if index < module.entry_points.len() - 1 {
261                writeln!(self.out)?;
262            }
263        }
264
265        // Write any polyfills that were required.
266        for polyfill in &self.required_polyfills {
267            writeln!(self.out)?;
268            write!(self.out, "{}", polyfill.source)?;
269            writeln!(self.out)?;
270        }
271
272        Ok(())
273    }
274
275    /// Helper method which writes all the `enable` declarations
276    /// needed for a module.
277    fn write_enable_declarations(&mut self, module: &Module) -> BackendResult {
278        let mut needs_f16 = false;
279        let mut needs_dual_source_blending = false;
280        let mut needs_clip_distances = false;
281        let mut needs_mesh_shaders = false;
282        let mut needs_cooperative_matrix = false;
283
284        // Determine which `enable` declarations are needed
285        for (_, ty) in module.types.iter() {
286            match ty.inner {
287                TypeInner::Scalar(scalar)
288                | TypeInner::Vector { scalar, .. }
289                | TypeInner::Matrix { scalar, .. } => {
290                    needs_f16 |= scalar == crate::Scalar::F16;
291                }
292                TypeInner::Struct { ref members, .. } => {
293                    for binding in members.iter().filter_map(|m| m.binding.as_ref()) {
294                        match *binding {
295                            crate::Binding::Location {
296                                blend_src: Some(_), ..
297                            } => {
298                                needs_dual_source_blending = true;
299                            }
300                            crate::Binding::BuiltIn(crate::BuiltIn::ClipDistance) => {
301                                needs_clip_distances = true;
302                            }
303                            crate::Binding::Location {
304                                per_primitive: true,
305                                ..
306                            } => {
307                                needs_mesh_shaders = true;
308                            }
309                            crate::Binding::BuiltIn(
310                                crate::BuiltIn::MeshTaskSize
311                                | crate::BuiltIn::CullPrimitive
312                                | crate::BuiltIn::PointIndex
313                                | crate::BuiltIn::LineIndices
314                                | crate::BuiltIn::TriangleIndices
315                                | crate::BuiltIn::VertexCount
316                                | crate::BuiltIn::Vertices
317                                | crate::BuiltIn::PrimitiveCount
318                                | crate::BuiltIn::Primitives,
319                            ) => {
320                                needs_mesh_shaders = true;
321                            }
322                            _ => {}
323                        }
324                    }
325                }
326                TypeInner::CooperativeMatrix { .. } => {
327                    needs_cooperative_matrix = true;
328                }
329                _ => {}
330            }
331        }
332
333        if module
334            .entry_points
335            .iter()
336            .any(|ep| matches!(ep.stage, ShaderStage::Mesh | ShaderStage::Task))
337        {
338            needs_mesh_shaders = true;
339        }
340
341        if module
342            .global_variables
343            .iter()
344            .any(|gv| gv.1.space == crate::AddressSpace::TaskPayload)
345        {
346            needs_mesh_shaders = true;
347        }
348
349        // Write required declarations
350        let mut any_written = false;
351        if needs_f16 {
352            writeln!(self.out, "enable f16;")?;
353            any_written = true;
354        }
355        if needs_dual_source_blending {
356            writeln!(self.out, "enable dual_source_blending;")?;
357            any_written = true;
358        }
359        if needs_clip_distances {
360            writeln!(self.out, "enable clip_distances;")?;
361            any_written = true;
362        }
363        if needs_mesh_shaders {
364            writeln!(self.out, "enable wgpu_mesh_shader;")?;
365            any_written = true;
366        }
367        if needs_cooperative_matrix {
368            writeln!(self.out, "enable wgpu_cooperative_matrix;")?;
369            any_written = true;
370        }
371        if any_written {
372            // Empty line for readability
373            writeln!(self.out)?;
374        }
375
376        Ok(())
377    }
378
379    /// Helper method used to write
380    /// [functions](https://gpuweb.github.io/gpuweb/wgsl/#functions)
381    ///
382    /// # Notes
383    /// Ends in a newline
384    fn write_function(
385        &mut self,
386        module: &Module,
387        func: &crate::Function,
388        func_ctx: &back::FunctionCtx<'_>,
389    ) -> BackendResult {
390        let func_name = match func_ctx.ty {
391            back::FunctionType::EntryPoint(index) => &self.names[&NameKey::EntryPoint(index)],
392            back::FunctionType::Function(handle) => &self.names[&NameKey::Function(handle)],
393        };
394
395        // Write function name
396        write!(self.out, "fn {func_name}(")?;
397
398        // Write function arguments
399        for (index, arg) in func.arguments.iter().enumerate() {
400            // Write argument attribute if a binding is present
401            if let Some(ref binding) = arg.binding {
402                self.write_attributes(&map_binding_to_attribute(binding))?;
403            }
404            // Write argument name
405            let argument_name = &self.names[&func_ctx.argument_key(index as u32)];
406
407            write!(self.out, "{argument_name}: ")?;
408            // Write argument type
409            self.write_type(module, arg.ty)?;
410            if index < func.arguments.len() - 1 {
411                // Add a separator between args
412                write!(self.out, ", ")?;
413            }
414        }
415
416        write!(self.out, ")")?;
417
418        // Write function return type
419        if let Some(ref result) = func.result {
420            write!(self.out, " -> ")?;
421            if let Some(ref binding) = result.binding {
422                self.write_attributes(&map_binding_to_attribute(binding))?;
423            }
424            self.write_type(module, result.ty)?;
425        }
426
427        write!(self.out, " {{")?;
428        writeln!(self.out)?;
429
430        // Write function local variables
431        for (handle, local) in func.local_variables.iter() {
432            // Write indentation (only for readability)
433            write!(self.out, "{}", back::INDENT)?;
434
435            // Write the local name
436            // The leading space is important
437            write!(self.out, "var {}: ", self.names[&func_ctx.name_key(handle)])?;
438
439            // Write the local type
440            self.write_type(module, local.ty)?;
441
442            // Write the local initializer if needed
443            if let Some(init) = local.init {
444                // Put the equal signal only if there's a initializer
445                // The leading and trailing spaces aren't needed but help with readability
446                write!(self.out, " = ")?;
447
448                // Write the constant
449                // `write_constant` adds no trailing or leading space/newline
450                self.write_expr(module, init, func_ctx)?;
451            }
452
453            // Finish the local with `;` and add a newline (only for readability)
454            writeln!(self.out, ";")?
455        }
456
457        if !func.local_variables.is_empty() {
458            writeln!(self.out)?;
459        }
460
461        // Write the function body (statement list)
462        for sta in func.body.iter() {
463            // The indentation should always be 1 when writing the function body
464            self.write_stmt(module, sta, func_ctx, back::Level(1))?;
465        }
466
467        writeln!(self.out, "}}")?;
468
469        self.named_expressions.clear();
470
471        Ok(())
472    }
473
474    /// Helper method to write a attribute
475    fn write_attributes(&mut self, attributes: &[Attribute]) -> BackendResult {
476        for attribute in attributes {
477            match *attribute {
478                Attribute::Location(id) => write!(self.out, "@location({id}) ")?,
479                Attribute::BlendSrc(blend_src) => write!(self.out, "@blend_src({blend_src}) ")?,
480                Attribute::BuiltIn(builtin_attrib) => {
481                    let builtin = builtin_attrib.to_wgsl_if_implemented()?;
482                    write!(self.out, "@builtin({builtin}) ")?;
483                }
484                Attribute::Stage(shader_stage) => {
485                    let stage_str = match shader_stage {
486                        ShaderStage::Vertex => "vertex",
487                        ShaderStage::Fragment => "fragment",
488                        ShaderStage::Compute => "compute",
489                        ShaderStage::Task => "task",
490                        //Handled by another variant in the Attribute enum, so this code should never be hit.
491                        ShaderStage::Mesh => unreachable!(),
492                    };
493
494                    write!(self.out, "@{stage_str} ")?;
495                }
496                Attribute::WorkGroupSize(size) => {
497                    write!(
498                        self.out,
499                        "@workgroup_size({}, {}, {}) ",
500                        size[0], size[1], size[2]
501                    )?;
502                }
503                Attribute::Binding(id) => write!(self.out, "@binding({id}) ")?,
504                Attribute::Group(id) => write!(self.out, "@group({id}) ")?,
505                Attribute::Invariant => write!(self.out, "@invariant ")?,
506                Attribute::Interpolate(interpolation, sampling) => {
507                    if sampling.is_some() && sampling != Some(crate::Sampling::Center) {
508                        let interpolation = interpolation
509                            .unwrap_or(crate::Interpolation::Perspective)
510                            .to_wgsl();
511                        let sampling = sampling.unwrap_or(crate::Sampling::Center).to_wgsl();
512                        write!(self.out, "@interpolate({interpolation}, {sampling}) ")?;
513                    } else if interpolation.is_some()
514                        && interpolation != Some(crate::Interpolation::Perspective)
515                    {
516                        let interpolation = interpolation
517                            .unwrap_or(crate::Interpolation::Perspective)
518                            .to_wgsl();
519                        write!(self.out, "@interpolate({interpolation}) ")?;
520                    }
521                }
522                Attribute::MeshStage(ref name) => {
523                    write!(self.out, "@mesh({name}) ")?;
524                }
525                Attribute::TaskPayload(ref payload_name) => {
526                    write!(self.out, "@payload({payload_name}) ")?;
527                }
528                Attribute::PerPrimitive => write!(self.out, "@per_primitive ")?,
529            };
530        }
531        Ok(())
532    }
533
534    /// Helper method used to write structs
535    /// Write the full declaration of a struct type.
536    ///
537    /// Write out a definition of the struct type referred to by
538    /// `handle` in `module`. The output will be an instance of the
539    /// `struct_decl` production in the WGSL grammar.
540    ///
541    /// Use `members` as the list of `handle`'s members. (This
542    /// function is usually called after matching a `TypeInner`, so
543    /// the callers already have the members at hand.)
544    fn write_struct(
545        &mut self,
546        module: &Module,
547        handle: Handle<crate::Type>,
548        members: &[crate::StructMember],
549    ) -> BackendResult {
550        write!(self.out, "struct {}", self.names[&NameKey::Type(handle)])?;
551        write!(self.out, " {{")?;
552        writeln!(self.out)?;
553        for (index, member) in members.iter().enumerate() {
554            // The indentation is only for readability
555            write!(self.out, "{}", back::INDENT)?;
556            if let Some(ref binding) = member.binding {
557                self.write_attributes(&map_binding_to_attribute(binding))?;
558            }
559            // Write struct member name and type
560            let member_name = &self.names[&NameKey::StructMember(handle, index as u32)];
561            write!(self.out, "{member_name}: ")?;
562            self.write_type(module, member.ty)?;
563            write!(self.out, ",")?;
564            writeln!(self.out)?;
565        }
566
567        writeln!(self.out, "}}")?;
568
569        Ok(())
570    }
571
572    fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
573        // This actually can't be factored out into a nice constructor method,
574        // because the borrow checker needs to be able to see that the borrows
575        // of `self.names` and `self.out` are disjoint.
576        let type_context = WriterTypeContext {
577            module,
578            names: &self.names,
579        };
580        type_context.write_type(ty, &mut self.out)?;
581
582        Ok(())
583    }
584
585    fn write_type_resolution(
586        &mut self,
587        module: &Module,
588        resolution: &proc::TypeResolution,
589    ) -> BackendResult {
590        // This actually can't be factored out into a nice constructor method,
591        // because the borrow checker needs to be able to see that the borrows
592        // of `self.names` and `self.out` are disjoint.
593        let type_context = WriterTypeContext {
594            module,
595            names: &self.names,
596        };
597        type_context.write_type_resolution(resolution, &mut self.out)?;
598
599        Ok(())
600    }
601
602    /// Helper method used to write statements
603    ///
604    /// # Notes
605    /// Always adds a newline
606    fn write_stmt(
607        &mut self,
608        module: &Module,
609        stmt: &crate::Statement,
610        func_ctx: &back::FunctionCtx<'_>,
611        level: back::Level,
612    ) -> BackendResult {
613        use crate::{Expression, Statement};
614
615        match *stmt {
616            Statement::Emit(ref range) => {
617                for handle in range.clone() {
618                    let info = &func_ctx.info[handle];
619                    let expr_name = if let Some(name) = func_ctx.named_expressions.get(&handle) {
620                        // Front end provides names for all variables at the start of writing.
621                        // But we write them to step by step. We need to recache them
622                        // Otherwise, we could accidentally write variable name instead of full expression.
623                        // Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords.
624                        Some(self.namer.call(name))
625                    } else {
626                        let expr = &func_ctx.expressions[handle];
627                        let min_ref_count = expr.bake_ref_count();
628                        // Forcefully creating baking expressions in some cases to help with readability
629                        let required_baking_expr = match *expr {
630                            Expression::ImageLoad { .. }
631                            | Expression::ImageQuery { .. }
632                            | Expression::ImageSample { .. } => true,
633                            _ => false,
634                        };
635                        if min_ref_count <= info.ref_count || required_baking_expr {
636                            Some(Baked(handle).to_string())
637                        } else {
638                            None
639                        }
640                    };
641
642                    if let Some(name) = expr_name {
643                        write!(self.out, "{level}")?;
644                        self.start_named_expr(module, handle, func_ctx, &name)?;
645                        self.write_expr(module, handle, func_ctx)?;
646                        self.named_expressions.insert(handle, name);
647                        writeln!(self.out, ";")?;
648                    }
649                }
650            }
651            // TODO: copy-paste from glsl-out
652            Statement::If {
653                condition,
654                ref accept,
655                ref reject,
656            } => {
657                write!(self.out, "{level}")?;
658                write!(self.out, "if ")?;
659                self.write_expr(module, condition, func_ctx)?;
660                writeln!(self.out, " {{")?;
661
662                let l2 = level.next();
663                for sta in accept {
664                    // Increase indentation to help with readability
665                    self.write_stmt(module, sta, func_ctx, l2)?;
666                }
667
668                // If there are no statements in the reject block we skip writing it
669                // This is only for readability
670                if !reject.is_empty() {
671                    writeln!(self.out, "{level}}} else {{")?;
672
673                    for sta in reject {
674                        // Increase indentation to help with readability
675                        self.write_stmt(module, sta, func_ctx, l2)?;
676                    }
677                }
678
679                writeln!(self.out, "{level}}}")?
680            }
681            Statement::Return { value } => {
682                write!(self.out, "{level}")?;
683                write!(self.out, "return")?;
684                if let Some(return_value) = value {
685                    // The leading space is important
686                    write!(self.out, " ")?;
687                    self.write_expr(module, return_value, func_ctx)?;
688                }
689                writeln!(self.out, ";")?;
690            }
691            // TODO: copy-paste from glsl-out
692            Statement::Kill => {
693                write!(self.out, "{level}")?;
694                writeln!(self.out, "discard;")?
695            }
696            Statement::Store { pointer, value } => {
697                write!(self.out, "{level}")?;
698
699                let is_atomic_pointer = func_ctx
700                    .resolve_type(pointer, &module.types)
701                    .is_atomic_pointer(&module.types);
702
703                if is_atomic_pointer {
704                    write!(self.out, "atomicStore(")?;
705                    self.write_expr(module, pointer, func_ctx)?;
706                    write!(self.out, ", ")?;
707                    self.write_expr(module, value, func_ctx)?;
708                    write!(self.out, ")")?;
709                } else {
710                    self.write_expr_with_indirection(
711                        module,
712                        pointer,
713                        func_ctx,
714                        Indirection::Reference,
715                    )?;
716                    write!(self.out, " = ")?;
717                    self.write_expr(module, value, func_ctx)?;
718                }
719                writeln!(self.out, ";")?
720            }
721            Statement::Call {
722                function,
723                ref arguments,
724                result,
725            } => {
726                write!(self.out, "{level}")?;
727                if let Some(expr) = result {
728                    let name = Baked(expr).to_string();
729                    self.start_named_expr(module, expr, func_ctx, &name)?;
730                    self.named_expressions.insert(expr, name);
731                }
732                let func_name = &self.names[&NameKey::Function(function)];
733                write!(self.out, "{func_name}(")?;
734                for (index, &argument) in arguments.iter().enumerate() {
735                    if index != 0 {
736                        write!(self.out, ", ")?;
737                    }
738                    self.write_expr(module, argument, func_ctx)?;
739                }
740                writeln!(self.out, ");")?
741            }
742            Statement::Atomic {
743                pointer,
744                ref fun,
745                value,
746                result,
747            } => {
748                write!(self.out, "{level}")?;
749                if let Some(result) = result {
750                    let res_name = Baked(result).to_string();
751                    self.start_named_expr(module, result, func_ctx, &res_name)?;
752                    self.named_expressions.insert(result, res_name);
753                }
754
755                let fun_str = fun.to_wgsl();
756                write!(self.out, "atomic{fun_str}(")?;
757                self.write_expr(module, pointer, func_ctx)?;
758                if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
759                    write!(self.out, ", ")?;
760                    self.write_expr(module, cmp, func_ctx)?;
761                }
762                write!(self.out, ", ")?;
763                self.write_expr(module, value, func_ctx)?;
764                writeln!(self.out, ");")?
765            }
766            Statement::ImageAtomic {
767                image,
768                coordinate,
769                array_index,
770                ref fun,
771                value,
772            } => {
773                write!(self.out, "{level}")?;
774                let fun_str = fun.to_wgsl();
775                write!(self.out, "textureAtomic{fun_str}(")?;
776                self.write_expr(module, image, func_ctx)?;
777                write!(self.out, ", ")?;
778                self.write_expr(module, coordinate, func_ctx)?;
779                if let Some(array_index_expr) = array_index {
780                    write!(self.out, ", ")?;
781                    self.write_expr(module, array_index_expr, func_ctx)?;
782                }
783                write!(self.out, ", ")?;
784                self.write_expr(module, value, func_ctx)?;
785                writeln!(self.out, ");")?;
786            }
787            Statement::WorkGroupUniformLoad { pointer, result } => {
788                write!(self.out, "{level}")?;
789                // TODO: Obey named expressions here.
790                let res_name = Baked(result).to_string();
791                self.start_named_expr(module, result, func_ctx, &res_name)?;
792                self.named_expressions.insert(result, res_name);
793                write!(self.out, "workgroupUniformLoad(")?;
794                self.write_expr(module, pointer, func_ctx)?;
795                writeln!(self.out, ");")?;
796            }
797            Statement::ImageStore {
798                image,
799                coordinate,
800                array_index,
801                value,
802            } => {
803                write!(self.out, "{level}")?;
804                write!(self.out, "textureStore(")?;
805                self.write_expr(module, image, func_ctx)?;
806                write!(self.out, ", ")?;
807                self.write_expr(module, coordinate, func_ctx)?;
808                if let Some(array_index_expr) = array_index {
809                    write!(self.out, ", ")?;
810                    self.write_expr(module, array_index_expr, func_ctx)?;
811                }
812                write!(self.out, ", ")?;
813                self.write_expr(module, value, func_ctx)?;
814                writeln!(self.out, ");")?;
815            }
816            // TODO: copy-paste from glsl-out
817            Statement::Block(ref block) => {
818                write!(self.out, "{level}")?;
819                writeln!(self.out, "{{")?;
820                for sta in block.iter() {
821                    // Increase the indentation to help with readability
822                    self.write_stmt(module, sta, func_ctx, level.next())?
823                }
824                writeln!(self.out, "{level}}}")?
825            }
826            Statement::Switch {
827                selector,
828                ref cases,
829            } => {
830                // Start the switch
831                write!(self.out, "{level}")?;
832                write!(self.out, "switch ")?;
833                self.write_expr(module, selector, func_ctx)?;
834                writeln!(self.out, " {{")?;
835
836                let l2 = level.next();
837                let mut new_case = true;
838                for case in cases {
839                    if case.fall_through && !case.body.is_empty() {
840                        // TODO: we could do the same workaround as we did for the HLSL backend
841                        return Err(Error::Unimplemented(
842                            "fall-through switch case block".into(),
843                        ));
844                    }
845
846                    match case.value {
847                        crate::SwitchValue::I32(value) => {
848                            if new_case {
849                                write!(self.out, "{l2}case ")?;
850                            }
851                            write!(self.out, "{value}")?;
852                        }
853                        crate::SwitchValue::U32(value) => {
854                            if new_case {
855                                write!(self.out, "{l2}case ")?;
856                            }
857                            write!(self.out, "{value}u")?;
858                        }
859                        crate::SwitchValue::Default => {
860                            if new_case {
861                                if case.fall_through {
862                                    write!(self.out, "{l2}case ")?;
863                                } else {
864                                    write!(self.out, "{l2}")?;
865                                }
866                            }
867                            write!(self.out, "default")?;
868                        }
869                    }
870
871                    new_case = !case.fall_through;
872
873                    if case.fall_through {
874                        write!(self.out, ", ")?;
875                    } else {
876                        writeln!(self.out, ": {{")?;
877                    }
878
879                    for sta in case.body.iter() {
880                        self.write_stmt(module, sta, func_ctx, l2.next())?;
881                    }
882
883                    if !case.fall_through {
884                        writeln!(self.out, "{l2}}}")?;
885                    }
886                }
887
888                writeln!(self.out, "{level}}}")?
889            }
890            Statement::Loop {
891                ref body,
892                ref continuing,
893                break_if,
894            } => {
895                write!(self.out, "{level}")?;
896                writeln!(self.out, "loop {{")?;
897
898                let l2 = level.next();
899                for sta in body.iter() {
900                    self.write_stmt(module, sta, func_ctx, l2)?;
901                }
902
903                // The continuing is optional so we don't need to write it if
904                // it is empty, but the `break if` counts as a continuing statement
905                // so even if `continuing` is empty we must generate it if a
906                // `break if` exists
907                if !continuing.is_empty() || break_if.is_some() {
908                    writeln!(self.out, "{l2}continuing {{")?;
909                    for sta in continuing.iter() {
910                        self.write_stmt(module, sta, func_ctx, l2.next())?;
911                    }
912
913                    // The `break if` is always the last
914                    // statement of the `continuing` block
915                    if let Some(condition) = break_if {
916                        // The trailing space is important
917                        write!(self.out, "{}break if ", l2.next())?;
918                        self.write_expr(module, condition, func_ctx)?;
919                        // Close the `break if` statement
920                        writeln!(self.out, ";")?;
921                    }
922
923                    writeln!(self.out, "{l2}}}")?;
924                }
925
926                writeln!(self.out, "{level}}}")?
927            }
928            Statement::Break => {
929                writeln!(self.out, "{level}break;")?;
930            }
931            Statement::Continue => {
932                writeln!(self.out, "{level}continue;")?;
933            }
934            Statement::ControlBarrier(barrier) | Statement::MemoryBarrier(barrier) => {
935                if barrier.contains(crate::Barrier::STORAGE) {
936                    writeln!(self.out, "{level}storageBarrier();")?;
937                }
938
939                if barrier.contains(crate::Barrier::WORK_GROUP) {
940                    writeln!(self.out, "{level}workgroupBarrier();")?;
941                }
942
943                if barrier.contains(crate::Barrier::SUB_GROUP) {
944                    writeln!(self.out, "{level}subgroupBarrier();")?;
945                }
946
947                if barrier.contains(crate::Barrier::TEXTURE) {
948                    writeln!(self.out, "{level}textureBarrier();")?;
949                }
950            }
951            Statement::RayQuery { .. } => unreachable!(),
952            Statement::SubgroupBallot { result, predicate } => {
953                write!(self.out, "{level}")?;
954                let res_name = Baked(result).to_string();
955                self.start_named_expr(module, result, func_ctx, &res_name)?;
956                self.named_expressions.insert(result, res_name);
957
958                write!(self.out, "subgroupBallot(")?;
959                if let Some(predicate) = predicate {
960                    self.write_expr(module, predicate, func_ctx)?;
961                }
962                writeln!(self.out, ");")?;
963            }
964            Statement::SubgroupCollectiveOperation {
965                op,
966                collective_op,
967                argument,
968                result,
969            } => {
970                write!(self.out, "{level}")?;
971                let res_name = Baked(result).to_string();
972                self.start_named_expr(module, result, func_ctx, &res_name)?;
973                self.named_expressions.insert(result, res_name);
974
975                match (collective_op, op) {
976                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
977                        write!(self.out, "subgroupAll(")?
978                    }
979                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
980                        write!(self.out, "subgroupAny(")?
981                    }
982                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
983                        write!(self.out, "subgroupAdd(")?
984                    }
985                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
986                        write!(self.out, "subgroupMul(")?
987                    }
988                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
989                        write!(self.out, "subgroupMax(")?
990                    }
991                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
992                        write!(self.out, "subgroupMin(")?
993                    }
994                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
995                        write!(self.out, "subgroupAnd(")?
996                    }
997                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
998                        write!(self.out, "subgroupOr(")?
999                    }
1000                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
1001                        write!(self.out, "subgroupXor(")?
1002                    }
1003                    (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
1004                        write!(self.out, "subgroupExclusiveAdd(")?
1005                    }
1006                    (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
1007                        write!(self.out, "subgroupExclusiveMul(")?
1008                    }
1009                    (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
1010                        write!(self.out, "subgroupInclusiveAdd(")?
1011                    }
1012                    (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
1013                        write!(self.out, "subgroupInclusiveMul(")?
1014                    }
1015                    _ => unimplemented!(),
1016                }
1017                self.write_expr(module, argument, func_ctx)?;
1018                writeln!(self.out, ");")?;
1019            }
1020            Statement::SubgroupGather {
1021                mode,
1022                argument,
1023                result,
1024            } => {
1025                write!(self.out, "{level}")?;
1026                let res_name = Baked(result).to_string();
1027                self.start_named_expr(module, result, func_ctx, &res_name)?;
1028                self.named_expressions.insert(result, res_name);
1029
1030                match mode {
1031                    crate::GatherMode::BroadcastFirst => {
1032                        write!(self.out, "subgroupBroadcastFirst(")?;
1033                    }
1034                    crate::GatherMode::Broadcast(_) => {
1035                        write!(self.out, "subgroupBroadcast(")?;
1036                    }
1037                    crate::GatherMode::Shuffle(_) => {
1038                        write!(self.out, "subgroupShuffle(")?;
1039                    }
1040                    crate::GatherMode::ShuffleDown(_) => {
1041                        write!(self.out, "subgroupShuffleDown(")?;
1042                    }
1043                    crate::GatherMode::ShuffleUp(_) => {
1044                        write!(self.out, "subgroupShuffleUp(")?;
1045                    }
1046                    crate::GatherMode::ShuffleXor(_) => {
1047                        write!(self.out, "subgroupShuffleXor(")?;
1048                    }
1049                    crate::GatherMode::QuadBroadcast(_) => {
1050                        write!(self.out, "quadBroadcast(")?;
1051                    }
1052                    crate::GatherMode::QuadSwap(direction) => match direction {
1053                        crate::Direction::X => {
1054                            write!(self.out, "quadSwapX(")?;
1055                        }
1056                        crate::Direction::Y => {
1057                            write!(self.out, "quadSwapY(")?;
1058                        }
1059                        crate::Direction::Diagonal => {
1060                            write!(self.out, "quadSwapDiagonal(")?;
1061                        }
1062                    },
1063                }
1064                self.write_expr(module, argument, func_ctx)?;
1065                match mode {
1066                    crate::GatherMode::BroadcastFirst => {}
1067                    crate::GatherMode::Broadcast(index)
1068                    | crate::GatherMode::Shuffle(index)
1069                    | crate::GatherMode::ShuffleDown(index)
1070                    | crate::GatherMode::ShuffleUp(index)
1071                    | crate::GatherMode::ShuffleXor(index)
1072                    | crate::GatherMode::QuadBroadcast(index) => {
1073                        write!(self.out, ", ")?;
1074                        self.write_expr(module, index, func_ctx)?;
1075                    }
1076                    crate::GatherMode::QuadSwap(_) => {}
1077                }
1078                writeln!(self.out, ");")?;
1079            }
1080            Statement::CooperativeStore { target, ref data } => {
1081                let suffix = if data.row_major { "T" } else { "" };
1082                write!(self.out, "{level}coopStore{suffix}(")?;
1083                self.write_expr(module, target, func_ctx)?;
1084                write!(self.out, ", ")?;
1085                self.write_expr(module, data.pointer, func_ctx)?;
1086                write!(self.out, ", ")?;
1087                self.write_expr(module, data.stride, func_ctx)?;
1088                writeln!(self.out, ");")?
1089            }
1090        }
1091
1092        Ok(())
1093    }
1094
1095    /// Return the sort of indirection that `expr`'s plain form evaluates to.
1096    ///
1097    /// An expression's 'plain form' is the most general rendition of that
1098    /// expression into WGSL, lacking `&` or `*` operators:
1099    ///
1100    /// - The plain form of `LocalVariable(x)` is simply `x`, which is a reference
1101    ///   to the local variable's storage.
1102    ///
1103    /// - The plain form of `GlobalVariable(g)` is simply `g`, which is usually a
1104    ///   reference to the global variable's storage. However, globals in the
1105    ///   `Handle` address space are immutable, and `GlobalVariable` expressions for
1106    ///   those produce the value directly, not a pointer to it. Such
1107    ///   `GlobalVariable` expressions are `Ordinary`.
1108    ///
1109    /// - `Access` and `AccessIndex` are `Reference` when their `base` operand is a
1110    ///   pointer. If they are applied directly to a composite value, they are
1111    ///   `Ordinary`.
1112    ///
1113    /// Note that `FunctionArgument` expressions are never `Reference`, even when
1114    /// the argument's type is `Pointer`. `FunctionArgument` always evaluates to the
1115    /// argument's value directly, so any pointer it produces is merely the value
1116    /// passed by the caller.
1117    fn plain_form_indirection(
1118        &self,
1119        expr: Handle<crate::Expression>,
1120        module: &Module,
1121        func_ctx: &back::FunctionCtx<'_>,
1122    ) -> Indirection {
1123        use crate::Expression as Ex;
1124
1125        // Named expressions are `let` expressions, which apply the Load Rule,
1126        // so if their type is a Naga pointer, then that must be a WGSL pointer
1127        // as well.
1128        if self.named_expressions.contains_key(&expr) {
1129            return Indirection::Ordinary;
1130        }
1131
1132        match func_ctx.expressions[expr] {
1133            Ex::LocalVariable(_) => Indirection::Reference,
1134            Ex::GlobalVariable(handle) => {
1135                let global = &module.global_variables[handle];
1136                match global.space {
1137                    crate::AddressSpace::Handle => Indirection::Ordinary,
1138                    _ => Indirection::Reference,
1139                }
1140            }
1141            Ex::Access { base, .. } | Ex::AccessIndex { base, .. } => {
1142                let base_ty = func_ctx.resolve_type(base, &module.types);
1143                match *base_ty {
1144                    TypeInner::Pointer { .. } | TypeInner::ValuePointer { .. } => {
1145                        Indirection::Reference
1146                    }
1147                    _ => Indirection::Ordinary,
1148                }
1149            }
1150            _ => Indirection::Ordinary,
1151        }
1152    }
1153
1154    fn start_named_expr(
1155        &mut self,
1156        module: &Module,
1157        handle: Handle<crate::Expression>,
1158        func_ctx: &back::FunctionCtx,
1159        name: &str,
1160    ) -> BackendResult {
1161        // Write variable name
1162        write!(self.out, "let {name}")?;
1163        if self.flags.contains(WriterFlags::EXPLICIT_TYPES) {
1164            write!(self.out, ": ")?;
1165            // Write variable type
1166            self.write_type_resolution(module, &func_ctx.info[handle].ty)?;
1167        }
1168
1169        write!(self.out, " = ")?;
1170        Ok(())
1171    }
1172
1173    /// Write the ordinary WGSL form of `expr`.
1174    ///
1175    /// See `write_expr_with_indirection` for details.
1176    fn write_expr(
1177        &mut self,
1178        module: &Module,
1179        expr: Handle<crate::Expression>,
1180        func_ctx: &back::FunctionCtx<'_>,
1181    ) -> BackendResult {
1182        self.write_expr_with_indirection(module, expr, func_ctx, Indirection::Ordinary)
1183    }
1184
1185    /// Write `expr` as a WGSL expression with the requested indirection.
1186    ///
1187    /// In terms of the WGSL grammar, the resulting expression is a
1188    /// `singular_expression`. It may be parenthesized. This makes it suitable
1189    /// for use as the operand of a unary or binary operator without worrying
1190    /// about precedence.
1191    ///
1192    /// This does not produce newlines or indentation.
1193    ///
1194    /// The `requested` argument indicates (roughly) whether Naga
1195    /// `Pointer`-valued expressions represent WGSL references or pointers. See
1196    /// `Indirection` for details.
1197    fn write_expr_with_indirection(
1198        &mut self,
1199        module: &Module,
1200        expr: Handle<crate::Expression>,
1201        func_ctx: &back::FunctionCtx<'_>,
1202        requested: Indirection,
1203    ) -> BackendResult {
1204        // If the plain form of the expression is not what we need, emit the
1205        // operator necessary to correct that.
1206        let plain = self.plain_form_indirection(expr, module, func_ctx);
1207        log::trace!(
1208            "expression {:?}={:?} is {:?}, expected {:?}",
1209            expr,
1210            func_ctx.expressions[expr],
1211            plain,
1212            requested,
1213        );
1214        match (requested, plain) {
1215            (Indirection::Ordinary, Indirection::Reference) => {
1216                write!(self.out, "(&")?;
1217                self.write_expr_plain_form(module, expr, func_ctx, plain)?;
1218                write!(self.out, ")")?;
1219            }
1220            (Indirection::Reference, Indirection::Ordinary) => {
1221                write!(self.out, "(*")?;
1222                self.write_expr_plain_form(module, expr, func_ctx, plain)?;
1223                write!(self.out, ")")?;
1224            }
1225            (_, _) => self.write_expr_plain_form(module, expr, func_ctx, plain)?,
1226        }
1227
1228        Ok(())
1229    }
1230
1231    fn write_const_expression(
1232        &mut self,
1233        module: &Module,
1234        expr: Handle<crate::Expression>,
1235        arena: &crate::Arena<crate::Expression>,
1236    ) -> BackendResult {
1237        self.write_possibly_const_expression(module, expr, arena, |writer, expr| {
1238            writer.write_const_expression(module, expr, arena)
1239        })
1240    }
1241
1242    fn write_possibly_const_expression<E>(
1243        &mut self,
1244        module: &Module,
1245        expr: Handle<crate::Expression>,
1246        expressions: &crate::Arena<crate::Expression>,
1247        write_expression: E,
1248    ) -> BackendResult
1249    where
1250        E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult,
1251    {
1252        use crate::Expression;
1253
1254        match expressions[expr] {
1255            Expression::Literal(literal) => match literal {
1256                crate::Literal::F16(value) => write!(self.out, "{value}h")?,
1257                crate::Literal::F32(value) => write!(self.out, "{value}f")?,
1258                crate::Literal::U32(value) => write!(self.out, "{value}u")?,
1259                crate::Literal::I32(value) => {
1260                    // `-2147483648i` is not valid WGSL. The most negative `i32`
1261                    // value can only be expressed in WGSL using AbstractInt and
1262                    // a unary negation operator.
1263                    if value == i32::MIN {
1264                        write!(self.out, "i32({value})")?;
1265                    } else {
1266                        write!(self.out, "{value}i")?;
1267                    }
1268                }
1269                crate::Literal::Bool(value) => write!(self.out, "{value}")?,
1270                crate::Literal::F64(value) => write!(self.out, "{value:?}lf")?,
1271                crate::Literal::I64(value) => {
1272                    // `-9223372036854775808li` is not valid WGSL. Nor can we simply use the
1273                    // AbstractInt trick above, as AbstractInt also cannot represent
1274                    // `9223372036854775808`. Instead construct the second most negative
1275                    // AbstractInt, subtract one from it, then cast to i64.
1276                    if value == i64::MIN {
1277                        write!(self.out, "i64({} - 1)", value + 1)?;
1278                    } else {
1279                        write!(self.out, "{value}li")?;
1280                    }
1281                }
1282                crate::Literal::U64(value) => write!(self.out, "{value:?}lu")?,
1283                crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
1284                    return Err(Error::Custom(
1285                        "Abstract types should not appear in IR presented to backends".into(),
1286                    ));
1287                }
1288            },
1289            Expression::Constant(handle) => {
1290                let constant = &module.constants[handle];
1291                if constant.name.is_some() {
1292                    write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
1293                } else {
1294                    self.write_const_expression(module, constant.init, &module.global_expressions)?;
1295                }
1296            }
1297            Expression::ZeroValue(ty) => {
1298                self.write_type(module, ty)?;
1299                write!(self.out, "()")?;
1300            }
1301            Expression::Compose { ty, ref components } => {
1302                self.write_type(module, ty)?;
1303                write!(self.out, "(")?;
1304                for (index, component) in components.iter().enumerate() {
1305                    if index != 0 {
1306                        write!(self.out, ", ")?;
1307                    }
1308                    write_expression(self, *component)?;
1309                }
1310                write!(self.out, ")")?
1311            }
1312            Expression::Splat { size, value } => {
1313                let size = common::vector_size_str(size);
1314                write!(self.out, "vec{size}(")?;
1315                write_expression(self, value)?;
1316                write!(self.out, ")")?;
1317            }
1318            Expression::Override(handle) => {
1319                write!(self.out, "{}", self.names[&NameKey::Override(handle)])?;
1320            }
1321            _ => unreachable!(),
1322        }
1323
1324        Ok(())
1325    }
1326
1327    /// Write the 'plain form' of `expr`.
1328    ///
1329    /// An expression's 'plain form' is the most general rendition of that
1330    /// expression into WGSL, lacking `&` or `*` operators. The plain forms of
1331    /// `LocalVariable(x)` and `GlobalVariable(g)` are simply `x` and `g`. Such
1332    /// Naga expressions represent both WGSL pointers and references; it's the
1333    /// caller's responsibility to distinguish those cases appropriately.
1334    fn write_expr_plain_form(
1335        &mut self,
1336        module: &Module,
1337        expr: Handle<crate::Expression>,
1338        func_ctx: &back::FunctionCtx<'_>,
1339        indirection: Indirection,
1340    ) -> BackendResult {
1341        use crate::Expression;
1342
1343        if let Some(name) = self.named_expressions.get(&expr) {
1344            write!(self.out, "{name}")?;
1345            return Ok(());
1346        }
1347
1348        let expression = &func_ctx.expressions[expr];
1349
1350        // Write the plain WGSL form of a Naga expression.
1351        //
1352        // The plain form of `LocalVariable` and `GlobalVariable` expressions is
1353        // simply the variable name; `*` and `&` operators are never emitted.
1354        //
1355        // The plain form of `Access` and `AccessIndex` expressions are WGSL
1356        // `postfix_expression` forms for member/component access and
1357        // subscripting.
1358        match *expression {
1359            Expression::Literal(_)
1360            | Expression::Constant(_)
1361            | Expression::ZeroValue(_)
1362            | Expression::Compose { .. }
1363            | Expression::Splat { .. } => {
1364                self.write_possibly_const_expression(
1365                    module,
1366                    expr,
1367                    func_ctx.expressions,
1368                    |writer, expr| writer.write_expr(module, expr, func_ctx),
1369                )?;
1370            }
1371            Expression::Override(handle) => {
1372                write!(self.out, "{}", self.names[&NameKey::Override(handle)])?;
1373            }
1374            Expression::FunctionArgument(pos) => {
1375                let name_key = func_ctx.argument_key(pos);
1376                let name = &self.names[&name_key];
1377                write!(self.out, "{name}")?;
1378            }
1379            Expression::Binary { op, left, right } => {
1380                write!(self.out, "(")?;
1381                self.write_expr(module, left, func_ctx)?;
1382                write!(self.out, " {} ", back::binary_operation_str(op))?;
1383                self.write_expr(module, right, func_ctx)?;
1384                write!(self.out, ")")?;
1385            }
1386            Expression::Access { base, index } => {
1387                self.write_expr_with_indirection(module, base, func_ctx, indirection)?;
1388                write!(self.out, "[")?;
1389                self.write_expr(module, index, func_ctx)?;
1390                write!(self.out, "]")?
1391            }
1392            Expression::AccessIndex { base, index } => {
1393                let base_ty_res = &func_ctx.info[base].ty;
1394                let mut resolved = base_ty_res.inner_with(&module.types);
1395
1396                self.write_expr_with_indirection(module, base, func_ctx, indirection)?;
1397
1398                let base_ty_handle = match *resolved {
1399                    TypeInner::Pointer { base, space: _ } => {
1400                        resolved = &module.types[base].inner;
1401                        Some(base)
1402                    }
1403                    _ => base_ty_res.handle(),
1404                };
1405
1406                match *resolved {
1407                    TypeInner::Vector { .. } => {
1408                        // Write vector access as a swizzle
1409                        write!(self.out, ".{}", back::COMPONENTS[index as usize])?
1410                    }
1411                    TypeInner::Matrix { .. }
1412                    | TypeInner::Array { .. }
1413                    | TypeInner::BindingArray { .. }
1414                    | TypeInner::ValuePointer { .. } => write!(self.out, "[{index}]")?,
1415                    TypeInner::Struct { .. } => {
1416                        // This will never panic in case the type is a `Struct`, this is not true
1417                        // for other types so we can only check while inside this match arm
1418                        let ty = base_ty_handle.unwrap();
1419
1420                        write!(
1421                            self.out,
1422                            ".{}",
1423                            &self.names[&NameKey::StructMember(ty, index)]
1424                        )?
1425                    }
1426                    ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))),
1427                }
1428            }
1429            Expression::ImageSample {
1430                image,
1431                sampler,
1432                gather: None,
1433                coordinate,
1434                array_index,
1435                offset,
1436                level,
1437                depth_ref,
1438                clamp_to_edge,
1439            } => {
1440                use crate::SampleLevel as Sl;
1441
1442                let suffix_cmp = match depth_ref {
1443                    Some(_) => "Compare",
1444                    None => "",
1445                };
1446                let suffix_level = match level {
1447                    Sl::Auto => "",
1448                    Sl::Zero if clamp_to_edge => "BaseClampToEdge",
1449                    Sl::Zero | Sl::Exact(_) => "Level",
1450                    Sl::Bias(_) => "Bias",
1451                    Sl::Gradient { .. } => "Grad",
1452                };
1453
1454                write!(self.out, "textureSample{suffix_cmp}{suffix_level}(")?;
1455                self.write_expr(module, image, func_ctx)?;
1456                write!(self.out, ", ")?;
1457                self.write_expr(module, sampler, func_ctx)?;
1458                write!(self.out, ", ")?;
1459                self.write_expr(module, coordinate, func_ctx)?;
1460
1461                if let Some(array_index) = array_index {
1462                    write!(self.out, ", ")?;
1463                    self.write_expr(module, array_index, func_ctx)?;
1464                }
1465
1466                if let Some(depth_ref) = depth_ref {
1467                    write!(self.out, ", ")?;
1468                    self.write_expr(module, depth_ref, func_ctx)?;
1469                }
1470
1471                match level {
1472                    Sl::Auto => {}
1473                    Sl::Zero => {
1474                        // Level 0 is implied for depth comparison and BaseClampToEdge
1475                        if depth_ref.is_none() && !clamp_to_edge {
1476                            write!(self.out, ", 0.0")?;
1477                        }
1478                    }
1479                    Sl::Exact(expr) => {
1480                        write!(self.out, ", ")?;
1481                        self.write_expr(module, expr, func_ctx)?;
1482                    }
1483                    Sl::Bias(expr) => {
1484                        write!(self.out, ", ")?;
1485                        self.write_expr(module, expr, func_ctx)?;
1486                    }
1487                    Sl::Gradient { x, y } => {
1488                        write!(self.out, ", ")?;
1489                        self.write_expr(module, x, func_ctx)?;
1490                        write!(self.out, ", ")?;
1491                        self.write_expr(module, y, func_ctx)?;
1492                    }
1493                }
1494
1495                if let Some(offset) = offset {
1496                    write!(self.out, ", ")?;
1497                    self.write_const_expression(module, offset, func_ctx.expressions)?;
1498                }
1499
1500                write!(self.out, ")")?;
1501            }
1502
1503            Expression::ImageSample {
1504                image,
1505                sampler,
1506                gather: Some(component),
1507                coordinate,
1508                array_index,
1509                offset,
1510                level: _,
1511                depth_ref,
1512                clamp_to_edge: _,
1513            } => {
1514                let suffix_cmp = match depth_ref {
1515                    Some(_) => "Compare",
1516                    None => "",
1517                };
1518
1519                write!(self.out, "textureGather{suffix_cmp}(")?;
1520                match *func_ctx.resolve_type(image, &module.types) {
1521                    TypeInner::Image {
1522                        class: crate::ImageClass::Depth { multi: _ },
1523                        ..
1524                    } => {}
1525                    _ => {
1526                        write!(self.out, "{}, ", component as u8)?;
1527                    }
1528                }
1529                self.write_expr(module, image, func_ctx)?;
1530                write!(self.out, ", ")?;
1531                self.write_expr(module, sampler, func_ctx)?;
1532                write!(self.out, ", ")?;
1533                self.write_expr(module, coordinate, func_ctx)?;
1534
1535                if let Some(array_index) = array_index {
1536                    write!(self.out, ", ")?;
1537                    self.write_expr(module, array_index, func_ctx)?;
1538                }
1539
1540                if let Some(depth_ref) = depth_ref {
1541                    write!(self.out, ", ")?;
1542                    self.write_expr(module, depth_ref, func_ctx)?;
1543                }
1544
1545                if let Some(offset) = offset {
1546                    write!(self.out, ", ")?;
1547                    self.write_const_expression(module, offset, func_ctx.expressions)?;
1548                }
1549
1550                write!(self.out, ")")?;
1551            }
1552            Expression::ImageQuery { image, query } => {
1553                use crate::ImageQuery as Iq;
1554
1555                let texture_function = match query {
1556                    Iq::Size { .. } => "textureDimensions",
1557                    Iq::NumLevels => "textureNumLevels",
1558                    Iq::NumLayers => "textureNumLayers",
1559                    Iq::NumSamples => "textureNumSamples",
1560                };
1561
1562                write!(self.out, "{texture_function}(")?;
1563                self.write_expr(module, image, func_ctx)?;
1564                if let Iq::Size { level: Some(level) } = query {
1565                    write!(self.out, ", ")?;
1566                    self.write_expr(module, level, func_ctx)?;
1567                };
1568                write!(self.out, ")")?;
1569            }
1570
1571            Expression::ImageLoad {
1572                image,
1573                coordinate,
1574                array_index,
1575                sample,
1576                level,
1577            } => {
1578                write!(self.out, "textureLoad(")?;
1579                self.write_expr(module, image, func_ctx)?;
1580                write!(self.out, ", ")?;
1581                self.write_expr(module, coordinate, func_ctx)?;
1582                if let Some(array_index) = array_index {
1583                    write!(self.out, ", ")?;
1584                    self.write_expr(module, array_index, func_ctx)?;
1585                }
1586                if let Some(index) = sample.or(level) {
1587                    write!(self.out, ", ")?;
1588                    self.write_expr(module, index, func_ctx)?;
1589                }
1590                write!(self.out, ")")?;
1591            }
1592            Expression::GlobalVariable(handle) => {
1593                let name = &self.names[&NameKey::GlobalVariable(handle)];
1594                write!(self.out, "{name}")?;
1595            }
1596
1597            Expression::As {
1598                expr,
1599                kind,
1600                convert,
1601            } => {
1602                let inner = func_ctx.resolve_type(expr, &module.types);
1603                match *inner {
1604                    TypeInner::Matrix {
1605                        columns,
1606                        rows,
1607                        scalar,
1608                    } => {
1609                        let scalar = crate::Scalar {
1610                            kind,
1611                            width: convert.unwrap_or(scalar.width),
1612                        };
1613                        let scalar_kind_str = scalar.to_wgsl_if_implemented()?;
1614                        write!(
1615                            self.out,
1616                            "mat{}x{}<{}>",
1617                            common::vector_size_str(columns),
1618                            common::vector_size_str(rows),
1619                            scalar_kind_str
1620                        )?;
1621                    }
1622                    TypeInner::Vector {
1623                        size,
1624                        scalar: crate::Scalar { width, .. },
1625                    } => {
1626                        let scalar = crate::Scalar {
1627                            kind,
1628                            width: convert.unwrap_or(width),
1629                        };
1630                        let vector_size_str = common::vector_size_str(size);
1631                        let scalar_kind_str = scalar.to_wgsl_if_implemented()?;
1632                        if convert.is_some() {
1633                            write!(self.out, "vec{vector_size_str}<{scalar_kind_str}>")?;
1634                        } else {
1635                            write!(self.out, "bitcast<vec{vector_size_str}<{scalar_kind_str}>>")?;
1636                        }
1637                    }
1638                    TypeInner::Scalar(crate::Scalar { width, .. }) => {
1639                        let scalar = crate::Scalar {
1640                            kind,
1641                            width: convert.unwrap_or(width),
1642                        };
1643                        let scalar_kind_str = scalar.to_wgsl_if_implemented()?;
1644                        if convert.is_some() {
1645                            write!(self.out, "{scalar_kind_str}")?
1646                        } else {
1647                            write!(self.out, "bitcast<{scalar_kind_str}>")?
1648                        }
1649                    }
1650                    _ => {
1651                        return Err(Error::Unimplemented(format!(
1652                            "write_expr expression::as {inner:?}"
1653                        )));
1654                    }
1655                };
1656                write!(self.out, "(")?;
1657                self.write_expr(module, expr, func_ctx)?;
1658                write!(self.out, ")")?;
1659            }
1660            Expression::Load { pointer } => {
1661                let is_atomic_pointer = func_ctx
1662                    .resolve_type(pointer, &module.types)
1663                    .is_atomic_pointer(&module.types);
1664
1665                if is_atomic_pointer {
1666                    write!(self.out, "atomicLoad(")?;
1667                    self.write_expr(module, pointer, func_ctx)?;
1668                    write!(self.out, ")")?;
1669                } else {
1670                    self.write_expr_with_indirection(
1671                        module,
1672                        pointer,
1673                        func_ctx,
1674                        Indirection::Reference,
1675                    )?;
1676                }
1677            }
1678            Expression::LocalVariable(handle) => {
1679                write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])?
1680            }
1681            Expression::ArrayLength(expr) => {
1682                write!(self.out, "arrayLength(")?;
1683                self.write_expr(module, expr, func_ctx)?;
1684                write!(self.out, ")")?;
1685            }
1686
1687            Expression::Math {
1688                fun,
1689                arg,
1690                arg1,
1691                arg2,
1692                arg3,
1693            } => {
1694                use crate::MathFunction as Mf;
1695
1696                enum Function {
1697                    Regular(&'static str),
1698                    InversePolyfill(InversePolyfill),
1699                }
1700
1701                let function = match fun.try_to_wgsl() {
1702                    Some(name) => Function::Regular(name),
1703                    None => match fun {
1704                        Mf::Inverse => {
1705                            let ty = func_ctx.resolve_type(arg, &module.types);
1706                            let Some(overload) = InversePolyfill::find_overload(ty) else {
1707                                return Err(Error::unsupported("math function", fun));
1708                            };
1709
1710                            Function::InversePolyfill(overload)
1711                        }
1712                        _ => return Err(Error::unsupported("math function", fun)),
1713                    },
1714                };
1715
1716                match function {
1717                    Function::Regular(fun_name) => {
1718                        write!(self.out, "{fun_name}(")?;
1719                        self.write_expr(module, arg, func_ctx)?;
1720                        for arg in IntoIterator::into_iter([arg1, arg2, arg3]).flatten() {
1721                            write!(self.out, ", ")?;
1722                            self.write_expr(module, arg, func_ctx)?;
1723                        }
1724                        write!(self.out, ")")?
1725                    }
1726                    Function::InversePolyfill(inverse) => {
1727                        write!(self.out, "{}(", inverse.fun_name)?;
1728                        self.write_expr(module, arg, func_ctx)?;
1729                        write!(self.out, ")")?;
1730                        self.required_polyfills.insert(inverse);
1731                    }
1732                }
1733            }
1734
1735            Expression::Swizzle {
1736                size,
1737                vector,
1738                pattern,
1739            } => {
1740                self.write_expr(module, vector, func_ctx)?;
1741                write!(self.out, ".")?;
1742                for &sc in pattern[..size as usize].iter() {
1743                    self.out.write_char(back::COMPONENTS[sc as usize])?;
1744                }
1745            }
1746            Expression::Unary { op, expr } => {
1747                let unary = match op {
1748                    crate::UnaryOperator::Negate => "-",
1749                    crate::UnaryOperator::LogicalNot => "!",
1750                    crate::UnaryOperator::BitwiseNot => "~",
1751                };
1752
1753                write!(self.out, "{unary}(")?;
1754                self.write_expr(module, expr, func_ctx)?;
1755
1756                write!(self.out, ")")?
1757            }
1758
1759            Expression::Select {
1760                condition,
1761                accept,
1762                reject,
1763            } => {
1764                write!(self.out, "select(")?;
1765                self.write_expr(module, reject, func_ctx)?;
1766                write!(self.out, ", ")?;
1767                self.write_expr(module, accept, func_ctx)?;
1768                write!(self.out, ", ")?;
1769                self.write_expr(module, condition, func_ctx)?;
1770                write!(self.out, ")")?
1771            }
1772            Expression::Derivative { axis, ctrl, expr } => {
1773                use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
1774                let op = match (axis, ctrl) {
1775                    (Axis::X, Ctrl::Coarse) => "dpdxCoarse",
1776                    (Axis::X, Ctrl::Fine) => "dpdxFine",
1777                    (Axis::X, Ctrl::None) => "dpdx",
1778                    (Axis::Y, Ctrl::Coarse) => "dpdyCoarse",
1779                    (Axis::Y, Ctrl::Fine) => "dpdyFine",
1780                    (Axis::Y, Ctrl::None) => "dpdy",
1781                    (Axis::Width, Ctrl::Coarse) => "fwidthCoarse",
1782                    (Axis::Width, Ctrl::Fine) => "fwidthFine",
1783                    (Axis::Width, Ctrl::None) => "fwidth",
1784                };
1785                write!(self.out, "{op}(")?;
1786                self.write_expr(module, expr, func_ctx)?;
1787                write!(self.out, ")")?
1788            }
1789            Expression::Relational { fun, argument } => {
1790                use crate::RelationalFunction as Rf;
1791
1792                let fun_name = match fun {
1793                    Rf::All => "all",
1794                    Rf::Any => "any",
1795                    _ => return Err(Error::UnsupportedRelationalFunction(fun)),
1796                };
1797                write!(self.out, "{fun_name}(")?;
1798
1799                self.write_expr(module, argument, func_ctx)?;
1800
1801                write!(self.out, ")")?
1802            }
1803            // Not supported yet
1804            Expression::RayQueryGetIntersection { .. }
1805            | Expression::RayQueryVertexPositions { .. } => unreachable!(),
1806            // Nothing to do here, since call expression already cached
1807            Expression::CallResult(_)
1808            | Expression::AtomicResult { .. }
1809            | Expression::RayQueryProceedResult
1810            | Expression::SubgroupBallotResult
1811            | Expression::SubgroupOperationResult { .. }
1812            | Expression::WorkGroupUniformLoadResult { .. } => {}
1813            Expression::CooperativeLoad {
1814                columns,
1815                rows,
1816                role,
1817                ref data,
1818            } => {
1819                let suffix = if data.row_major { "T" } else { "" };
1820                let scalar = func_ctx.info[data.pointer]
1821                    .ty
1822                    .inner_with(&module.types)
1823                    .pointer_base_type()
1824                    .unwrap()
1825                    .inner_with(&module.types)
1826                    .scalar()
1827                    .unwrap();
1828                write!(
1829                    self.out,
1830                    "coopLoad{suffix}<coop_mat{}x{}<{},{:?}>>(",
1831                    columns as u32,
1832                    rows as u32,
1833                    scalar.try_to_wgsl().unwrap(),
1834                    role,
1835                )?;
1836                self.write_expr(module, data.pointer, func_ctx)?;
1837                write!(self.out, ", ")?;
1838                self.write_expr(module, data.stride, func_ctx)?;
1839                write!(self.out, ")")?;
1840            }
1841            Expression::CooperativeMultiplyAdd { a, b, c } => {
1842                write!(self.out, "coopMultiplyAdd(")?;
1843                self.write_expr(module, a, func_ctx)?;
1844                write!(self.out, ", ")?;
1845                self.write_expr(module, b, func_ctx)?;
1846                write!(self.out, ", ")?;
1847                self.write_expr(module, c, func_ctx)?;
1848                write!(self.out, ")")?;
1849            }
1850        }
1851
1852        Ok(())
1853    }
1854
1855    /// Helper method used to write global variables
1856    /// # Notes
1857    /// Always adds a newline
1858    fn write_global(
1859        &mut self,
1860        module: &Module,
1861        global: &crate::GlobalVariable,
1862        handle: Handle<crate::GlobalVariable>,
1863    ) -> BackendResult {
1864        // Write group and binding attributes if present
1865        if let Some(ref binding) = global.binding {
1866            self.write_attributes(&[
1867                Attribute::Group(binding.group),
1868                Attribute::Binding(binding.binding),
1869            ])?;
1870            writeln!(self.out)?;
1871        }
1872
1873        // First write global name and address space if supported
1874        write!(self.out, "var")?;
1875        let (address, maybe_access) = address_space_str(global.space);
1876        if let Some(space) = address {
1877            write!(self.out, "<{space}")?;
1878            if let Some(access) = maybe_access {
1879                write!(self.out, ", {access}")?;
1880            }
1881            write!(self.out, ">")?;
1882        }
1883        write!(
1884            self.out,
1885            " {}: ",
1886            &self.names[&NameKey::GlobalVariable(handle)]
1887        )?;
1888
1889        // Write global type
1890        self.write_type(module, global.ty)?;
1891
1892        // Write initializer
1893        if let Some(init) = global.init {
1894            write!(self.out, " = ")?;
1895            self.write_const_expression(module, init, &module.global_expressions)?;
1896        }
1897
1898        // End with semicolon
1899        writeln!(self.out, ";")?;
1900
1901        Ok(())
1902    }
1903
1904    /// Helper method used to write global constants
1905    ///
1906    /// # Notes
1907    /// Ends in a newline
1908    fn write_global_constant(
1909        &mut self,
1910        module: &Module,
1911        handle: Handle<crate::Constant>,
1912    ) -> BackendResult {
1913        let name = &self.names[&NameKey::Constant(handle)];
1914        // First write only constant name
1915        write!(self.out, "const {name}: ")?;
1916        self.write_type(module, module.constants[handle].ty)?;
1917        write!(self.out, " = ")?;
1918        let init = module.constants[handle].init;
1919        self.write_const_expression(module, init, &module.global_expressions)?;
1920        writeln!(self.out, ";")?;
1921
1922        Ok(())
1923    }
1924
1925    /// Helper method used to write overrides
1926    ///
1927    /// # Notes
1928    /// Ends in a newline
1929    fn write_override(
1930        &mut self,
1931        module: &Module,
1932        handle: Handle<crate::Override>,
1933    ) -> BackendResult {
1934        let override_ = &module.overrides[handle];
1935        let name = &self.names[&NameKey::Override(handle)];
1936
1937        // Write @id attribute if present
1938        if let Some(id) = override_.id {
1939            write!(self.out, "@id({id}) ")?;
1940        }
1941
1942        // Write override declaration
1943        write!(self.out, "override {name}: ")?;
1944        self.write_type(module, override_.ty)?;
1945
1946        // Write initializer if present
1947        if let Some(init) = override_.init {
1948            write!(self.out, " = ")?;
1949            self.write_const_expression(module, init, &module.global_expressions)?;
1950        }
1951
1952        writeln!(self.out, ";")?;
1953
1954        Ok(())
1955    }
1956
1957    // See https://github.com/rust-lang/rust-clippy/issues/4979.
1958    #[allow(clippy::missing_const_for_fn)]
1959    pub fn finish(self) -> W {
1960        self.out
1961    }
1962}
1963
1964struct WriterTypeContext<'m> {
1965    module: &'m Module,
1966    names: &'m crate::FastHashMap<NameKey, String>,
1967}
1968
1969impl TypeContext for WriterTypeContext<'_> {
1970    fn lookup_type(&self, handle: Handle<crate::Type>) -> &crate::Type {
1971        &self.module.types[handle]
1972    }
1973
1974    fn type_name(&self, handle: Handle<crate::Type>) -> &str {
1975        self.names[&NameKey::Type(handle)].as_str()
1976    }
1977
1978    fn write_unnamed_struct<W: Write>(&self, _: &TypeInner, _: &mut W) -> core::fmt::Result {
1979        unreachable!("the WGSL back end should always provide type handles");
1980    }
1981
1982    fn write_override<W: Write>(
1983        &self,
1984        handle: Handle<crate::Override>,
1985        out: &mut W,
1986    ) -> core::fmt::Result {
1987        write!(out, "{}", self.names[&NameKey::Override(handle)])
1988    }
1989
1990    fn write_non_wgsl_inner<W: Write>(&self, _: &TypeInner, _: &mut W) -> core::fmt::Result {
1991        unreachable!("backends should only be passed validated modules");
1992    }
1993
1994    fn write_non_wgsl_scalar<W: Write>(&self, _: crate::Scalar, _: &mut W) -> core::fmt::Result {
1995        unreachable!("backends should only be passed validated modules");
1996    }
1997}
1998
1999fn map_binding_to_attribute(binding: &crate::Binding) -> Vec<Attribute> {
2000    match *binding {
2001        crate::Binding::BuiltIn(built_in) => {
2002            if let crate::BuiltIn::Position { invariant: true } = built_in {
2003                vec![Attribute::BuiltIn(built_in), Attribute::Invariant]
2004            } else {
2005                vec![Attribute::BuiltIn(built_in)]
2006            }
2007        }
2008        crate::Binding::Location {
2009            location,
2010            interpolation,
2011            sampling,
2012            blend_src: None,
2013            per_primitive,
2014        } => {
2015            let mut attrs = vec![
2016                Attribute::Location(location),
2017                Attribute::Interpolate(interpolation, sampling),
2018            ];
2019            if per_primitive {
2020                attrs.push(Attribute::PerPrimitive);
2021            }
2022            attrs
2023        }
2024        crate::Binding::Location {
2025            location,
2026            interpolation,
2027            sampling,
2028            blend_src: Some(blend_src),
2029            per_primitive,
2030        } => {
2031            let mut attrs = vec![
2032                Attribute::Location(location),
2033                Attribute::BlendSrc(blend_src),
2034                Attribute::Interpolate(interpolation, sampling),
2035            ];
2036            if per_primitive {
2037                attrs.push(Attribute::PerPrimitive);
2038            }
2039            attrs
2040        }
2041    }
2042}