naga/back/wgsl/
writer.rs

1use alloc::{
2    format,
3    string::{String, ToString},
4    vec,
5    vec::Vec,
6};
7use core::fmt::Write;
8
9use super::Error;
10use super::ToWgslIfImplemented as _;
11use crate::{back::wgsl::polyfill::InversePolyfill, common::wgsl::TypeContext};
12use crate::{
13    back::{self, Baked},
14    common::{
15        self,
16        wgsl::{address_space_str, ToWgsl, TryToWgsl},
17    },
18    proc::{self, NameKey},
19    valid, Handle, Module, ShaderStage, TypeInner,
20};
21
22/// Shorthand result used internally by the backend
23type BackendResult = Result<(), Error>;
24
25/// WGSL [attribute](https://gpuweb.github.io/gpuweb/wgsl/#attributes)
26enum Attribute {
27    Binding(u32),
28    BuiltIn(crate::BuiltIn),
29    Group(u32),
30    Invariant,
31    Interpolate(Option<crate::Interpolation>, Option<crate::Sampling>),
32    Location(u32),
33    BlendSrc(u32),
34    Stage(ShaderStage),
35    WorkGroupSize([u32; 3]),
36    MeshStage(String),
37    TaskPayload(String),
38    PerPrimitive,
39    IncomingRayPayload(String),
40}
41
42/// The WGSL form that `write_expr_with_indirection` should use to render a Naga
43/// expression.
44///
45/// Sometimes a Naga `Expression` alone doesn't provide enough information to
46/// choose the right rendering for it in WGSL. For example, one natural WGSL
47/// rendering of a Naga `LocalVariable(x)` expression might be `&x`, since
48/// `LocalVariable` produces a pointer to the local variable's storage. But when
49/// rendering a `Store` statement, the `pointer` operand must be the left hand
50/// side of a WGSL assignment, so the proper rendering is `x`.
51///
52/// The caller of `write_expr_with_indirection` must provide an `Expected` value
53/// to indicate how ambiguous expressions should be rendered.
54#[derive(Clone, Copy, Debug)]
55enum Indirection {
56    /// Render pointer-construction expressions as WGSL `ptr`-typed expressions.
57    ///
58    /// This is the right choice for most cases. Whenever a Naga pointer
59    /// expression is not the `pointer` operand of a `Load` or `Store`, it
60    /// must be a WGSL pointer expression.
61    Ordinary,
62
63    /// Render pointer-construction expressions as WGSL reference-typed
64    /// expressions.
65    ///
66    /// For example, this is the right choice for the `pointer` operand when
67    /// rendering a `Store` statement as a WGSL assignment.
68    Reference,
69}
70
71bitflags::bitflags! {
72    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
73    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
74    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
75    pub struct WriterFlags: u32 {
76        /// Always annotate the type information instead of inferring.
77        const EXPLICIT_TYPES = 0x1;
78    }
79}
80
81pub struct Writer<W> {
82    out: W,
83    flags: WriterFlags,
84    names: crate::FastHashMap<NameKey, String>,
85    namer: proc::Namer,
86    named_expressions: crate::NamedExpressions,
87    required_polyfills: crate::FastIndexSet<InversePolyfill>,
88}
89
90impl<W: Write> Writer<W> {
91    pub fn new(out: W, flags: WriterFlags) -> Self {
92        Writer {
93            out,
94            flags,
95            names: crate::FastHashMap::default(),
96            namer: proc::Namer::default(),
97            named_expressions: crate::NamedExpressions::default(),
98            required_polyfills: crate::FastIndexSet::default(),
99        }
100    }
101
102    fn reset(&mut self, module: &Module) {
103        self.names.clear();
104        self.namer.reset(
105            module,
106            &crate::keywords::wgsl::RESERVED_SET,
107            &crate::keywords::wgsl::BUILTIN_IDENTIFIER_SET,
108            // an identifier must not start with two underscore
109            proc::CaseInsensitiveKeywordSet::empty(),
110            &["__", "_naga"],
111            &mut self.names,
112        );
113        self.named_expressions.clear();
114        self.required_polyfills.clear();
115    }
116
117    /// Determine if `ty` is the Naga IR presentation of a WGSL builtin type.
118    ///
119    /// Return true if `ty` refers to the Naga IR form of a WGSL builtin type
120    /// like `__atomic_compare_exchange_result`.
121    ///
122    /// Even though the module may use the type, the WGSL backend should avoid
123    /// emitting a definition for it, since it is [predeclared] in WGSL.
124    ///
125    /// This also covers types like [`NagaExternalTextureParams`], which other
126    /// backends use to lower WGSL constructs like external textures to their
127    /// implementations. WGSL can express these directly, so the types need not
128    /// be emitted.
129    ///
130    /// [predeclared]: https://www.w3.org/TR/WGSL/#predeclared
131    /// [`NagaExternalTextureParams`]: crate::ir::SpecialTypes::external_texture_params
132    fn is_builtin_wgsl_struct(&self, module: &Module, ty: Handle<crate::Type>) -> bool {
133        module
134            .special_types
135            .predeclared_types
136            .values()
137            .any(|t| *t == ty)
138            || Some(ty) == module.special_types.external_texture_params
139            || Some(ty) == module.special_types.external_texture_transfer_function
140    }
141
142    pub fn write(&mut self, module: &Module, info: &valid::ModuleInfo) -> BackendResult {
143        self.reset(module);
144
145        // Write all `enable` declarations
146        self.write_enable_declarations(module)?;
147
148        // Write all structs
149        for (handle, ty) in module.types.iter() {
150            if let TypeInner::Struct { ref members, .. } = ty.inner {
151                {
152                    if !self.is_builtin_wgsl_struct(module, handle) {
153                        self.write_struct(module, handle, members)?;
154                        writeln!(self.out)?;
155                    }
156                }
157            }
158        }
159
160        // Write all named constants
161        let mut constants = module
162            .constants
163            .iter()
164            .filter(|&(_, c)| c.name.is_some())
165            .peekable();
166        while let Some((handle, _)) = constants.next() {
167            self.write_global_constant(module, handle)?;
168            // Add extra newline for readability on last iteration
169            if constants.peek().is_none() {
170                writeln!(self.out)?;
171            }
172        }
173
174        // Write all overrides
175        let mut overrides = module.overrides.iter().peekable();
176        while let Some((handle, _)) = overrides.next() {
177            self.write_override(module, handle)?;
178            // Add extra newline for readability on last iteration
179            if overrides.peek().is_none() {
180                writeln!(self.out)?;
181            }
182        }
183
184        // Write all globals
185        for (ty, global) in module.global_variables.iter() {
186            self.write_global(module, global, ty)?;
187        }
188
189        if !module.global_variables.is_empty() {
190            // Add extra newline for readability
191            writeln!(self.out)?;
192        }
193
194        // Write all regular functions
195        for (handle, function) in module.functions.iter() {
196            let fun_info = &info[handle];
197
198            let func_ctx = back::FunctionCtx {
199                ty: back::FunctionType::Function(handle),
200                info: fun_info,
201                expressions: &function.expressions,
202                named_expressions: &function.named_expressions,
203            };
204
205            // Write the function
206            self.write_function(module, function, &func_ctx)?;
207
208            writeln!(self.out)?;
209        }
210
211        // Write all entry points
212        for (index, ep) in module.entry_points.iter().enumerate() {
213            let attributes = match ep.stage {
214                ShaderStage::Vertex | ShaderStage::Fragment => vec![Attribute::Stage(ep.stage)],
215                ShaderStage::Compute => vec![
216                    Attribute::Stage(ShaderStage::Compute),
217                    Attribute::WorkGroupSize(ep.workgroup_size),
218                ],
219                ShaderStage::Mesh => {
220                    let mesh_output_name = module.global_variables
221                        [ep.mesh_info.as_ref().unwrap().output_variable]
222                        .name
223                        .clone()
224                        .unwrap();
225                    let mut mesh_attrs = vec![
226                        Attribute::MeshStage(mesh_output_name),
227                        Attribute::WorkGroupSize(ep.workgroup_size),
228                    ];
229                    if let Some(task_payload) = ep.task_payload {
230                        let payload_name =
231                            module.global_variables[task_payload].name.clone().unwrap();
232                        mesh_attrs.push(Attribute::TaskPayload(payload_name));
233                    }
234                    mesh_attrs
235                }
236                ShaderStage::Task => {
237                    let payload_name = module.global_variables[ep.task_payload.unwrap()]
238                        .name
239                        .clone()
240                        .unwrap();
241                    vec![
242                        Attribute::Stage(ShaderStage::Task),
243                        Attribute::TaskPayload(payload_name),
244                        Attribute::WorkGroupSize(ep.workgroup_size),
245                    ]
246                }
247                ShaderStage::RayGeneration => vec![Attribute::Stage(ShaderStage::RayGeneration)],
248                ShaderStage::AnyHit | ShaderStage::ClosestHit | ShaderStage::Miss => {
249                    let payload_name = module.global_variables[ep.incoming_ray_payload.unwrap()]
250                        .name
251                        .clone()
252                        .unwrap();
253                    vec![
254                        Attribute::Stage(ep.stage),
255                        Attribute::IncomingRayPayload(payload_name),
256                    ]
257                }
258            };
259            self.write_attributes(&attributes)?;
260            // Add a newline after attribute
261            writeln!(self.out)?;
262
263            let func_ctx = back::FunctionCtx {
264                ty: back::FunctionType::EntryPoint(index as u16),
265                info: info.get_entry_point(index),
266                expressions: &ep.function.expressions,
267                named_expressions: &ep.function.named_expressions,
268            };
269            self.write_function(module, &ep.function, &func_ctx)?;
270
271            if index < module.entry_points.len() - 1 {
272                writeln!(self.out)?;
273            }
274        }
275
276        // Write any polyfills that were required.
277        for polyfill in &self.required_polyfills {
278            writeln!(self.out)?;
279            write!(self.out, "{}", polyfill.source)?;
280            writeln!(self.out)?;
281        }
282
283        Ok(())
284    }
285
286    /// Helper method which writes all the `enable` declarations
287    /// needed for a module.
288    fn write_enable_declarations(&mut self, module: &Module) -> BackendResult {
289        #[derive(Default)]
290        struct RequiredEnabled {
291            f16: bool,
292            dual_source_blending: bool,
293            clip_distances: bool,
294            mesh_shaders: bool,
295            primitive_index: bool,
296            cooperative_matrix: bool,
297            draw_index: bool,
298            ray_tracing_pipeline: bool,
299            per_vertex: bool,
300            binding_array: bool,
301        }
302        let mut needed = RequiredEnabled {
303            mesh_shaders: module.uses_mesh_shaders(),
304            ..Default::default()
305        };
306
307        let check_binding = |binding: &crate::Binding, needed: &mut RequiredEnabled| match *binding
308        {
309            crate::Binding::Location {
310                blend_src: Some(_), ..
311            } => {
312                needed.dual_source_blending = true;
313            }
314            crate::Binding::BuiltIn(crate::BuiltIn::ClipDistances) => {
315                needed.clip_distances = true;
316            }
317            crate::Binding::BuiltIn(crate::BuiltIn::PrimitiveIndex) => {
318                needed.primitive_index = true;
319            }
320            crate::Binding::Location {
321                per_primitive: true,
322                ..
323            } => {
324                needed.mesh_shaders = true;
325            }
326            crate::Binding::Location {
327                interpolation: Some(crate::Interpolation::PerVertex),
328                ..
329            } => {
330                needed.per_vertex = true;
331            }
332            crate::Binding::BuiltIn(crate::BuiltIn::DrawIndex) => needed.draw_index = true,
333            crate::Binding::BuiltIn(
334                crate::BuiltIn::RayInvocationId
335                | crate::BuiltIn::NumRayInvocations
336                | crate::BuiltIn::InstanceCustomData
337                | crate::BuiltIn::GeometryIndex
338                | crate::BuiltIn::WorldRayOrigin
339                | crate::BuiltIn::WorldRayDirection
340                | crate::BuiltIn::ObjectRayOrigin
341                | crate::BuiltIn::ObjectRayDirection
342                | crate::BuiltIn::RayTmin
343                | crate::BuiltIn::RayTCurrentMax
344                | crate::BuiltIn::ObjectToWorld
345                | crate::BuiltIn::WorldToObject,
346            ) => {
347                needed.ray_tracing_pipeline = true;
348            }
349            _ => {}
350        };
351
352        // Determine which `enable` declarations are needed
353        for (_, ty) in module.types.iter() {
354            match ty.inner {
355                TypeInner::Scalar(scalar)
356                | TypeInner::Vector { scalar, .. }
357                | TypeInner::Matrix { scalar, .. } => {
358                    needed.f16 |= scalar == crate::Scalar::F16;
359                }
360                TypeInner::Struct { ref members, .. } => {
361                    for binding in members.iter().filter_map(|m| m.binding.as_ref()) {
362                        check_binding(binding, &mut needed);
363                    }
364                }
365                TypeInner::CooperativeMatrix { .. } => {
366                    needed.cooperative_matrix = true;
367                }
368                TypeInner::AccelerationStructure { .. } => {
369                    needed.ray_tracing_pipeline = true;
370                }
371                TypeInner::BindingArray { .. } => {
372                    needed.binding_array = true;
373                }
374                _ => {}
375            }
376        }
377
378        for ep in &module.entry_points {
379            if let Some(res) = ep.function.result.as_ref().and_then(|a| a.binding.as_ref()) {
380                check_binding(res, &mut needed);
381            }
382            for arg in ep
383                .function
384                .arguments
385                .iter()
386                .filter_map(|a| a.binding.as_ref())
387            {
388                check_binding(arg, &mut needed);
389            }
390        }
391
392        if module.global_variables.iter().any(|gv| {
393            gv.1.space == crate::AddressSpace::IncomingRayPayload
394                || gv.1.space == crate::AddressSpace::RayPayload
395        }) {
396            needed.ray_tracing_pipeline = true;
397        }
398
399        if module.entry_points.iter().any(|ep| {
400            matches!(
401                ep.stage,
402                ShaderStage::RayGeneration
403                    | ShaderStage::AnyHit
404                    | ShaderStage::ClosestHit
405                    | ShaderStage::Miss
406            )
407        }) {
408            needed.ray_tracing_pipeline = true;
409        }
410
411        if module.global_variables.iter().any(|gv| {
412            gv.1.space == crate::AddressSpace::IncomingRayPayload
413                || gv.1.space == crate::AddressSpace::RayPayload
414        }) {
415            needed.ray_tracing_pipeline = true;
416        }
417
418        if module.entry_points.iter().any(|ep| {
419            matches!(
420                ep.stage,
421                ShaderStage::RayGeneration
422                    | ShaderStage::AnyHit
423                    | ShaderStage::ClosestHit
424                    | ShaderStage::Miss
425            )
426        }) {
427            needed.ray_tracing_pipeline = true;
428        }
429
430        // Write required declarations
431        let mut any_written = false;
432        if needed.f16 {
433            writeln!(self.out, "enable f16;")?;
434            any_written = true;
435        }
436        if needed.dual_source_blending {
437            writeln!(self.out, "enable dual_source_blending;")?;
438            any_written = true;
439        }
440        if needed.clip_distances {
441            writeln!(self.out, "enable clip_distances;")?;
442            any_written = true;
443        }
444        if module.uses_mesh_shaders() {
445            writeln!(self.out, "enable wgpu_mesh_shader;")?;
446            any_written = true;
447        }
448        if needed.binding_array {
449            writeln!(self.out, "enable wgpu_binding_array;")?;
450            any_written = true;
451        }
452        if needed.draw_index {
453            writeln!(self.out, "enable draw_index;")?;
454            any_written = true;
455        }
456        if needed.primitive_index {
457            writeln!(self.out, "enable primitive_index;")?;
458            any_written = true;
459        }
460        if needed.cooperative_matrix {
461            writeln!(self.out, "enable wgpu_cooperative_matrix;")?;
462            any_written = true;
463        }
464        if needed.ray_tracing_pipeline {
465            writeln!(self.out, "enable wgpu_ray_tracing_pipeline;")?;
466            any_written = true;
467        }
468        if needed.per_vertex {
469            writeln!(self.out, "enable wgpu_per_vertex;")?;
470            any_written = true;
471        }
472        if any_written {
473            // Empty line for readability
474            writeln!(self.out)?;
475        }
476
477        Ok(())
478    }
479
480    /// Helper method used to write
481    /// [functions](https://gpuweb.github.io/gpuweb/wgsl/#functions)
482    ///
483    /// # Notes
484    /// Ends in a newline
485    fn write_function(
486        &mut self,
487        module: &Module,
488        func: &crate::Function,
489        func_ctx: &back::FunctionCtx<'_>,
490    ) -> BackendResult {
491        let func_name = match func_ctx.ty {
492            back::FunctionType::EntryPoint(index) => &self.names[&NameKey::EntryPoint(index)],
493            back::FunctionType::Function(handle) => &self.names[&NameKey::Function(handle)],
494        };
495
496        // Write function name
497        write!(self.out, "fn {func_name}(")?;
498
499        // Write function arguments
500        for (index, arg) in func.arguments.iter().enumerate() {
501            // Write argument attribute if a binding is present
502            if let Some(ref binding) = arg.binding {
503                self.write_attributes(&map_binding_to_attribute(binding))?;
504            }
505            // Write argument name
506            let argument_name = &self.names[&func_ctx.argument_key(index as u32)];
507
508            write!(self.out, "{argument_name}: ")?;
509            // Write argument type
510            self.write_type(module, arg.ty)?;
511            if index < func.arguments.len() - 1 {
512                // Add a separator between args
513                write!(self.out, ", ")?;
514            }
515        }
516
517        write!(self.out, ")")?;
518
519        // Write function return type
520        if let Some(ref result) = func.result {
521            write!(self.out, " -> ")?;
522            if let Some(ref binding) = result.binding {
523                self.write_attributes(&map_binding_to_attribute(binding))?;
524            }
525            self.write_type(module, result.ty)?;
526        }
527
528        write!(self.out, " {{")?;
529        writeln!(self.out)?;
530
531        // Write function local variables
532        for (handle, local) in func.local_variables.iter() {
533            // Write indentation (only for readability)
534            write!(self.out, "{}", back::INDENT)?;
535
536            // Write the local name
537            // The leading space is important
538            write!(self.out, "var {}: ", self.names[&func_ctx.name_key(handle)])?;
539
540            // Write the local type
541            self.write_type(module, local.ty)?;
542
543            // Write the local initializer if needed
544            if let Some(init) = local.init {
545                // Put the equal signal only if there's a initializer
546                // The leading and trailing spaces aren't needed but help with readability
547                write!(self.out, " = ")?;
548
549                // Write the constant
550                // `write_constant` adds no trailing or leading space/newline
551                self.write_expr(module, init, func_ctx)?;
552            }
553
554            // Finish the local with `;` and add a newline (only for readability)
555            writeln!(self.out, ";")?
556        }
557
558        if !func.local_variables.is_empty() {
559            writeln!(self.out)?;
560        }
561
562        // Write the function body (statement list)
563        for sta in func.body.iter() {
564            // The indentation should always be 1 when writing the function body
565            self.write_stmt(module, sta, func_ctx, back::Level(1))?;
566        }
567
568        writeln!(self.out, "}}")?;
569
570        self.named_expressions.clear();
571
572        Ok(())
573    }
574
575    /// Helper method to write a attribute
576    fn write_attributes(&mut self, attributes: &[Attribute]) -> BackendResult {
577        for attribute in attributes {
578            match *attribute {
579                Attribute::Location(id) => write!(self.out, "@location({id}) ")?,
580                Attribute::BlendSrc(blend_src) => write!(self.out, "@blend_src({blend_src}) ")?,
581                Attribute::BuiltIn(builtin_attrib) => {
582                    let builtin = builtin_attrib.to_wgsl_if_implemented()?;
583                    write!(self.out, "@builtin({builtin}) ")?;
584                }
585                Attribute::Stage(shader_stage) => {
586                    let stage_str = match shader_stage {
587                        ShaderStage::Vertex => "vertex",
588                        ShaderStage::Fragment => "fragment",
589                        ShaderStage::Compute => "compute",
590                        ShaderStage::Task => "task",
591                        //Handled by another variant in the Attribute enum, so this code should never be hit.
592                        ShaderStage::Mesh => unreachable!(),
593                        ShaderStage::RayGeneration => "ray_generation",
594                        ShaderStage::AnyHit => "any_hit",
595                        ShaderStage::ClosestHit => "closest_hit",
596                        ShaderStage::Miss => "miss",
597                    };
598
599                    write!(self.out, "@{stage_str} ")?;
600                }
601                Attribute::WorkGroupSize(size) => {
602                    write!(
603                        self.out,
604                        "@workgroup_size({}, {}, {}) ",
605                        size[0], size[1], size[2]
606                    )?;
607                }
608                Attribute::Binding(id) => write!(self.out, "@binding({id}) ")?,
609                Attribute::Group(id) => write!(self.out, "@group({id}) ")?,
610                Attribute::Invariant => write!(self.out, "@invariant ")?,
611                Attribute::Interpolate(interpolation, sampling) => {
612                    if sampling.is_some() && sampling != Some(crate::Sampling::Center) {
613                        let interpolation = interpolation
614                            .unwrap_or(crate::Interpolation::Perspective)
615                            .to_wgsl();
616                        let sampling = sampling.unwrap_or(crate::Sampling::Center).to_wgsl();
617                        write!(self.out, "@interpolate({interpolation}, {sampling}) ")?;
618                    } else if interpolation.is_some()
619                        && interpolation != Some(crate::Interpolation::Perspective)
620                    {
621                        let interpolation = interpolation
622                            .unwrap_or(crate::Interpolation::Perspective)
623                            .to_wgsl();
624                        write!(self.out, "@interpolate({interpolation}) ")?;
625                    }
626                }
627                Attribute::MeshStage(ref name) => {
628                    write!(self.out, "@mesh({name}) ")?;
629                }
630                Attribute::TaskPayload(ref payload_name) => {
631                    write!(self.out, "@payload({payload_name}) ")?;
632                }
633                Attribute::PerPrimitive => write!(self.out, "@per_primitive ")?,
634                Attribute::IncomingRayPayload(ref payload_name) => {
635                    write!(self.out, "@incoming_payload({payload_name}) ")?;
636                }
637            };
638        }
639        Ok(())
640    }
641
642    /// Helper method used to write structs
643    /// Write the full declaration of a struct type.
644    ///
645    /// Write out a definition of the struct type referred to by
646    /// `handle` in `module`. The output will be an instance of the
647    /// `struct_decl` production in the WGSL grammar.
648    ///
649    /// Use `members` as the list of `handle`'s members. (This
650    /// function is usually called after matching a `TypeInner`, so
651    /// the callers already have the members at hand.)
652    fn write_struct(
653        &mut self,
654        module: &Module,
655        handle: Handle<crate::Type>,
656        members: &[crate::StructMember],
657    ) -> BackendResult {
658        write!(self.out, "struct {}", self.names[&NameKey::Type(handle)])?;
659        write!(self.out, " {{")?;
660        writeln!(self.out)?;
661        for (index, member) in members.iter().enumerate() {
662            // The indentation is only for readability
663            write!(self.out, "{}", back::INDENT)?;
664            if let Some(ref binding) = member.binding {
665                self.write_attributes(&map_binding_to_attribute(binding))?;
666            }
667            // Write struct member name and type
668            let member_name = &self.names[&NameKey::StructMember(handle, index as u32)];
669            write!(self.out, "{member_name}: ")?;
670            self.write_type(module, member.ty)?;
671            write!(self.out, ",")?;
672            writeln!(self.out)?;
673        }
674
675        writeln!(self.out, "}}")?;
676
677        Ok(())
678    }
679
680    fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
681        // This actually can't be factored out into a nice constructor method,
682        // because the borrow checker needs to be able to see that the borrows
683        // of `self.names` and `self.out` are disjoint.
684        let type_context = WriterTypeContext {
685            module,
686            names: &self.names,
687        };
688        type_context.write_type(ty, &mut self.out)?;
689
690        Ok(())
691    }
692
693    fn write_type_resolution(
694        &mut self,
695        module: &Module,
696        resolution: &proc::TypeResolution,
697    ) -> BackendResult {
698        // This actually can't be factored out into a nice constructor method,
699        // because the borrow checker needs to be able to see that the borrows
700        // of `self.names` and `self.out` are disjoint.
701        let type_context = WriterTypeContext {
702            module,
703            names: &self.names,
704        };
705        type_context.write_type_resolution(resolution, &mut self.out)?;
706
707        Ok(())
708    }
709
710    /// Helper method used to write statements
711    ///
712    /// # Notes
713    /// Always adds a newline
714    fn write_stmt(
715        &mut self,
716        module: &Module,
717        stmt: &crate::Statement,
718        func_ctx: &back::FunctionCtx<'_>,
719        level: back::Level,
720    ) -> BackendResult {
721        use crate::{Expression, Statement};
722
723        match *stmt {
724            Statement::Emit(ref range) => {
725                for handle in range.clone() {
726                    let info = &func_ctx.info[handle];
727                    let expr_name = if let Some(name) = func_ctx.named_expressions.get(&handle) {
728                        // Front end provides names for all variables at the start of writing.
729                        // But we write them to step by step. We need to recache them
730                        // Otherwise, we could accidentally write variable name instead of full expression.
731                        // Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords.
732                        Some(self.namer.call(name))
733                    } else {
734                        let expr = &func_ctx.expressions[handle];
735                        let min_ref_count = expr.bake_ref_count();
736                        // Forcefully creating baking expressions in some cases to help with readability
737                        let required_baking_expr = match *expr {
738                            Expression::ImageLoad { .. }
739                            | Expression::ImageQuery { .. }
740                            | Expression::ImageSample { .. } => true,
741                            _ => false,
742                        };
743                        if min_ref_count <= info.ref_count || required_baking_expr {
744                            Some(Baked(handle).to_string())
745                        } else {
746                            None
747                        }
748                    };
749
750                    if let Some(name) = expr_name {
751                        write!(self.out, "{level}")?;
752                        self.start_named_expr(module, handle, func_ctx, &name)?;
753                        self.write_expr(module, handle, func_ctx)?;
754                        self.named_expressions.insert(handle, name);
755                        writeln!(self.out, ";")?;
756                    }
757                }
758            }
759            // TODO: copy-paste from glsl-out
760            Statement::If {
761                condition,
762                ref accept,
763                ref reject,
764            } => {
765                write!(self.out, "{level}")?;
766                write!(self.out, "if ")?;
767                self.write_expr(module, condition, func_ctx)?;
768                writeln!(self.out, " {{")?;
769
770                let l2 = level.next();
771                for sta in accept {
772                    // Increase indentation to help with readability
773                    self.write_stmt(module, sta, func_ctx, l2)?;
774                }
775
776                // If there are no statements in the reject block we skip writing it
777                // This is only for readability
778                if !reject.is_empty() {
779                    writeln!(self.out, "{level}}} else {{")?;
780
781                    for sta in reject {
782                        // Increase indentation to help with readability
783                        self.write_stmt(module, sta, func_ctx, l2)?;
784                    }
785                }
786
787                writeln!(self.out, "{level}}}")?
788            }
789            Statement::Return { value } => {
790                write!(self.out, "{level}")?;
791                write!(self.out, "return")?;
792                if let Some(return_value) = value {
793                    // The leading space is important
794                    write!(self.out, " ")?;
795                    self.write_expr(module, return_value, func_ctx)?;
796                }
797                writeln!(self.out, ";")?;
798            }
799            // TODO: copy-paste from glsl-out
800            Statement::Kill => {
801                write!(self.out, "{level}")?;
802                writeln!(self.out, "discard;")?
803            }
804            Statement::Store { pointer, value } => {
805                write!(self.out, "{level}")?;
806
807                let is_atomic_pointer = func_ctx
808                    .resolve_type(pointer, &module.types)
809                    .is_atomic_pointer(&module.types);
810
811                if is_atomic_pointer {
812                    write!(self.out, "atomicStore(")?;
813                    self.write_expr(module, pointer, func_ctx)?;
814                    write!(self.out, ", ")?;
815                    self.write_expr(module, value, func_ctx)?;
816                    write!(self.out, ")")?;
817                } else {
818                    self.write_expr_with_indirection(
819                        module,
820                        pointer,
821                        func_ctx,
822                        Indirection::Reference,
823                    )?;
824                    write!(self.out, " = ")?;
825                    self.write_expr(module, value, func_ctx)?;
826                }
827                writeln!(self.out, ";")?
828            }
829            Statement::Call {
830                function,
831                ref arguments,
832                result,
833            } => {
834                write!(self.out, "{level}")?;
835                if let Some(expr) = result {
836                    let name = Baked(expr).to_string();
837                    self.start_named_expr(module, expr, func_ctx, &name)?;
838                    self.named_expressions.insert(expr, name);
839                }
840                let func_name = &self.names[&NameKey::Function(function)];
841                write!(self.out, "{func_name}(")?;
842                for (index, &argument) in arguments.iter().enumerate() {
843                    if index != 0 {
844                        write!(self.out, ", ")?;
845                    }
846                    self.write_expr(module, argument, func_ctx)?;
847                }
848                writeln!(self.out, ");")?
849            }
850            Statement::Atomic {
851                pointer,
852                ref fun,
853                value,
854                result,
855            } => {
856                write!(self.out, "{level}")?;
857                if let Some(result) = result {
858                    let res_name = Baked(result).to_string();
859                    self.start_named_expr(module, result, func_ctx, &res_name)?;
860                    self.named_expressions.insert(result, res_name);
861                }
862
863                let fun_str = fun.to_wgsl();
864                write!(self.out, "atomic{fun_str}(")?;
865                self.write_expr(module, pointer, func_ctx)?;
866                if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
867                    write!(self.out, ", ")?;
868                    self.write_expr(module, cmp, func_ctx)?;
869                }
870                write!(self.out, ", ")?;
871                self.write_expr(module, value, func_ctx)?;
872                writeln!(self.out, ");")?
873            }
874            Statement::ImageAtomic {
875                image,
876                coordinate,
877                array_index,
878                ref fun,
879                value,
880            } => {
881                write!(self.out, "{level}")?;
882                let fun_str = fun.to_wgsl();
883                write!(self.out, "textureAtomic{fun_str}(")?;
884                self.write_expr(module, image, func_ctx)?;
885                write!(self.out, ", ")?;
886                self.write_expr(module, coordinate, func_ctx)?;
887                if let Some(array_index_expr) = array_index {
888                    write!(self.out, ", ")?;
889                    self.write_expr(module, array_index_expr, func_ctx)?;
890                }
891                write!(self.out, ", ")?;
892                self.write_expr(module, value, func_ctx)?;
893                writeln!(self.out, ");")?;
894            }
895            Statement::WorkGroupUniformLoad { pointer, result } => {
896                write!(self.out, "{level}")?;
897                // TODO: Obey named expressions here.
898                let res_name = Baked(result).to_string();
899                self.start_named_expr(module, result, func_ctx, &res_name)?;
900                self.named_expressions.insert(result, res_name);
901                write!(self.out, "workgroupUniformLoad(")?;
902                self.write_expr(module, pointer, func_ctx)?;
903                writeln!(self.out, ");")?;
904            }
905            Statement::ImageStore {
906                image,
907                coordinate,
908                array_index,
909                value,
910            } => {
911                write!(self.out, "{level}")?;
912                write!(self.out, "textureStore(")?;
913                self.write_expr(module, image, func_ctx)?;
914                write!(self.out, ", ")?;
915                self.write_expr(module, coordinate, func_ctx)?;
916                if let Some(array_index_expr) = array_index {
917                    write!(self.out, ", ")?;
918                    self.write_expr(module, array_index_expr, func_ctx)?;
919                }
920                write!(self.out, ", ")?;
921                self.write_expr(module, value, func_ctx)?;
922                writeln!(self.out, ");")?;
923            }
924            // TODO: copy-paste from glsl-out
925            Statement::Block(ref block) => {
926                write!(self.out, "{level}")?;
927                writeln!(self.out, "{{")?;
928                for sta in block.iter() {
929                    // Increase the indentation to help with readability
930                    self.write_stmt(module, sta, func_ctx, level.next())?
931                }
932                writeln!(self.out, "{level}}}")?
933            }
934            Statement::Switch {
935                selector,
936                ref cases,
937            } => {
938                // Start the switch
939                write!(self.out, "{level}")?;
940                write!(self.out, "switch ")?;
941                self.write_expr(module, selector, func_ctx)?;
942                writeln!(self.out, " {{")?;
943
944                let l2 = level.next();
945                let mut new_case = true;
946                for case in cases {
947                    if case.fall_through && !case.body.is_empty() {
948                        // TODO: we could do the same workaround as we did for the HLSL backend
949                        return Err(Error::Unimplemented(
950                            "fall-through switch case block".into(),
951                        ));
952                    }
953
954                    match case.value {
955                        crate::SwitchValue::I32(value) => {
956                            if new_case {
957                                write!(self.out, "{l2}case ")?;
958                            }
959                            write!(self.out, "{value}")?;
960                        }
961                        crate::SwitchValue::U32(value) => {
962                            if new_case {
963                                write!(self.out, "{l2}case ")?;
964                            }
965                            write!(self.out, "{value}u")?;
966                        }
967                        crate::SwitchValue::Default => {
968                            if new_case {
969                                if case.fall_through {
970                                    write!(self.out, "{l2}case ")?;
971                                } else {
972                                    write!(self.out, "{l2}")?;
973                                }
974                            }
975                            write!(self.out, "default")?;
976                        }
977                    }
978
979                    new_case = !case.fall_through;
980
981                    if case.fall_through {
982                        write!(self.out, ", ")?;
983                    } else {
984                        writeln!(self.out, ": {{")?;
985                    }
986
987                    for sta in case.body.iter() {
988                        self.write_stmt(module, sta, func_ctx, l2.next())?;
989                    }
990
991                    if !case.fall_through {
992                        writeln!(self.out, "{l2}}}")?;
993                    }
994                }
995
996                writeln!(self.out, "{level}}}")?
997            }
998            Statement::Loop {
999                ref body,
1000                ref continuing,
1001                break_if,
1002            } => {
1003                write!(self.out, "{level}")?;
1004                writeln!(self.out, "loop {{")?;
1005
1006                let l2 = level.next();
1007                for sta in body.iter() {
1008                    self.write_stmt(module, sta, func_ctx, l2)?;
1009                }
1010
1011                // The continuing is optional so we don't need to write it if
1012                // it is empty, but the `break if` counts as a continuing statement
1013                // so even if `continuing` is empty we must generate it if a
1014                // `break if` exists
1015                if !continuing.is_empty() || break_if.is_some() {
1016                    writeln!(self.out, "{l2}continuing {{")?;
1017                    for sta in continuing.iter() {
1018                        self.write_stmt(module, sta, func_ctx, l2.next())?;
1019                    }
1020
1021                    // The `break if` is always the last
1022                    // statement of the `continuing` block
1023                    if let Some(condition) = break_if {
1024                        // The trailing space is important
1025                        write!(self.out, "{}break if ", l2.next())?;
1026                        self.write_expr(module, condition, func_ctx)?;
1027                        // Close the `break if` statement
1028                        writeln!(self.out, ";")?;
1029                    }
1030
1031                    writeln!(self.out, "{l2}}}")?;
1032                }
1033
1034                writeln!(self.out, "{level}}}")?
1035            }
1036            Statement::Break => {
1037                writeln!(self.out, "{level}break;")?;
1038            }
1039            Statement::Continue => {
1040                writeln!(self.out, "{level}continue;")?;
1041            }
1042            Statement::ControlBarrier(barrier) | Statement::MemoryBarrier(barrier) => {
1043                if barrier.contains(crate::Barrier::STORAGE) {
1044                    writeln!(self.out, "{level}storageBarrier();")?;
1045                }
1046
1047                if barrier.contains(crate::Barrier::WORK_GROUP) {
1048                    writeln!(self.out, "{level}workgroupBarrier();")?;
1049                }
1050
1051                if barrier.contains(crate::Barrier::SUB_GROUP) {
1052                    writeln!(self.out, "{level}subgroupBarrier();")?;
1053                }
1054
1055                if barrier.contains(crate::Barrier::TEXTURE) {
1056                    writeln!(self.out, "{level}textureBarrier();")?;
1057                }
1058            }
1059            Statement::RayQuery { .. } => unreachable!(),
1060            Statement::SubgroupBallot { result, predicate } => {
1061                write!(self.out, "{level}")?;
1062                let res_name = Baked(result).to_string();
1063                self.start_named_expr(module, result, func_ctx, &res_name)?;
1064                self.named_expressions.insert(result, res_name);
1065
1066                write!(self.out, "subgroupBallot(")?;
1067                if let Some(predicate) = predicate {
1068                    self.write_expr(module, predicate, func_ctx)?;
1069                }
1070                writeln!(self.out, ");")?;
1071            }
1072            Statement::SubgroupCollectiveOperation {
1073                op,
1074                collective_op,
1075                argument,
1076                result,
1077            } => {
1078                write!(self.out, "{level}")?;
1079                let res_name = Baked(result).to_string();
1080                self.start_named_expr(module, result, func_ctx, &res_name)?;
1081                self.named_expressions.insert(result, res_name);
1082
1083                match (collective_op, op) {
1084                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
1085                        write!(self.out, "subgroupAll(")?
1086                    }
1087                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
1088                        write!(self.out, "subgroupAny(")?
1089                    }
1090                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
1091                        write!(self.out, "subgroupAdd(")?
1092                    }
1093                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
1094                        write!(self.out, "subgroupMul(")?
1095                    }
1096                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
1097                        write!(self.out, "subgroupMax(")?
1098                    }
1099                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
1100                        write!(self.out, "subgroupMin(")?
1101                    }
1102                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
1103                        write!(self.out, "subgroupAnd(")?
1104                    }
1105                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
1106                        write!(self.out, "subgroupOr(")?
1107                    }
1108                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
1109                        write!(self.out, "subgroupXor(")?
1110                    }
1111                    (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
1112                        write!(self.out, "subgroupExclusiveAdd(")?
1113                    }
1114                    (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
1115                        write!(self.out, "subgroupExclusiveMul(")?
1116                    }
1117                    (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
1118                        write!(self.out, "subgroupInclusiveAdd(")?
1119                    }
1120                    (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
1121                        write!(self.out, "subgroupInclusiveMul(")?
1122                    }
1123                    _ => unimplemented!(),
1124                }
1125                self.write_expr(module, argument, func_ctx)?;
1126                writeln!(self.out, ");")?;
1127            }
1128            Statement::SubgroupGather {
1129                mode,
1130                argument,
1131                result,
1132            } => {
1133                write!(self.out, "{level}")?;
1134                let res_name = Baked(result).to_string();
1135                self.start_named_expr(module, result, func_ctx, &res_name)?;
1136                self.named_expressions.insert(result, res_name);
1137
1138                match mode {
1139                    crate::GatherMode::BroadcastFirst => {
1140                        write!(self.out, "subgroupBroadcastFirst(")?;
1141                    }
1142                    crate::GatherMode::Broadcast(_) => {
1143                        write!(self.out, "subgroupBroadcast(")?;
1144                    }
1145                    crate::GatherMode::Shuffle(_) => {
1146                        write!(self.out, "subgroupShuffle(")?;
1147                    }
1148                    crate::GatherMode::ShuffleDown(_) => {
1149                        write!(self.out, "subgroupShuffleDown(")?;
1150                    }
1151                    crate::GatherMode::ShuffleUp(_) => {
1152                        write!(self.out, "subgroupShuffleUp(")?;
1153                    }
1154                    crate::GatherMode::ShuffleXor(_) => {
1155                        write!(self.out, "subgroupShuffleXor(")?;
1156                    }
1157                    crate::GatherMode::QuadBroadcast(_) => {
1158                        write!(self.out, "quadBroadcast(")?;
1159                    }
1160                    crate::GatherMode::QuadSwap(direction) => match direction {
1161                        crate::Direction::X => {
1162                            write!(self.out, "quadSwapX(")?;
1163                        }
1164                        crate::Direction::Y => {
1165                            write!(self.out, "quadSwapY(")?;
1166                        }
1167                        crate::Direction::Diagonal => {
1168                            write!(self.out, "quadSwapDiagonal(")?;
1169                        }
1170                    },
1171                }
1172                self.write_expr(module, argument, func_ctx)?;
1173                match mode {
1174                    crate::GatherMode::BroadcastFirst => {}
1175                    crate::GatherMode::Broadcast(index)
1176                    | crate::GatherMode::Shuffle(index)
1177                    | crate::GatherMode::ShuffleDown(index)
1178                    | crate::GatherMode::ShuffleUp(index)
1179                    | crate::GatherMode::ShuffleXor(index)
1180                    | crate::GatherMode::QuadBroadcast(index) => {
1181                        write!(self.out, ", ")?;
1182                        self.write_expr(module, index, func_ctx)?;
1183                    }
1184                    crate::GatherMode::QuadSwap(_) => {}
1185                }
1186                writeln!(self.out, ");")?;
1187            }
1188            Statement::CooperativeStore { target, ref data } => {
1189                let suffix = if data.row_major { "T" } else { "" };
1190                write!(self.out, "{level}coopStore{suffix}(")?;
1191                self.write_expr(module, target, func_ctx)?;
1192                write!(self.out, ", ")?;
1193                self.write_expr(module, data.pointer, func_ctx)?;
1194                write!(self.out, ", ")?;
1195                self.write_expr(module, data.stride, func_ctx)?;
1196                writeln!(self.out, ");")?
1197            }
1198            Statement::RayPipelineFunction(fun) => match fun {
1199                crate::RayPipelineFunction::TraceRay {
1200                    acceleration_structure,
1201                    descriptor,
1202                    payload,
1203                } => {
1204                    write!(self.out, "{level}traceRay(")?;
1205                    self.write_expr(module, acceleration_structure, func_ctx)?;
1206                    write!(self.out, ", ")?;
1207                    self.write_expr(module, descriptor, func_ctx)?;
1208                    write!(self.out, ", ")?;
1209                    self.write_expr(module, payload, func_ctx)?;
1210                    writeln!(self.out, ");")?
1211                }
1212            },
1213        }
1214
1215        Ok(())
1216    }
1217
1218    /// Return the sort of indirection that `expr`'s plain form evaluates to.
1219    ///
1220    /// An expression's 'plain form' is the most general rendition of that
1221    /// expression into WGSL, lacking `&` or `*` operators:
1222    ///
1223    /// - The plain form of `LocalVariable(x)` is simply `x`, which is a reference
1224    ///   to the local variable's storage.
1225    ///
1226    /// - The plain form of `GlobalVariable(g)` is simply `g`, which is usually a
1227    ///   reference to the global variable's storage. However, globals in the
1228    ///   `Handle` address space are immutable, and `GlobalVariable` expressions for
1229    ///   those produce the value directly, not a pointer to it. Such
1230    ///   `GlobalVariable` expressions are `Ordinary`.
1231    ///
1232    /// - `Access` and `AccessIndex` are `Reference` when their `base` operand is a
1233    ///   pointer. If they are applied directly to a composite value, they are
1234    ///   `Ordinary`.
1235    ///
1236    /// Note that `FunctionArgument` expressions are never `Reference`, even when
1237    /// the argument's type is `Pointer`. `FunctionArgument` always evaluates to the
1238    /// argument's value directly, so any pointer it produces is merely the value
1239    /// passed by the caller.
1240    fn plain_form_indirection(
1241        &self,
1242        expr: Handle<crate::Expression>,
1243        module: &Module,
1244        func_ctx: &back::FunctionCtx<'_>,
1245    ) -> Indirection {
1246        use crate::Expression as Ex;
1247
1248        // Named expressions are `let` expressions, which apply the Load Rule,
1249        // so if their type is a Naga pointer, then that must be a WGSL pointer
1250        // as well.
1251        if self.named_expressions.contains_key(&expr) {
1252            return Indirection::Ordinary;
1253        }
1254
1255        match func_ctx.expressions[expr] {
1256            Ex::LocalVariable(_) => Indirection::Reference,
1257            Ex::GlobalVariable(handle) => {
1258                let global = &module.global_variables[handle];
1259                match global.space {
1260                    crate::AddressSpace::Handle => Indirection::Ordinary,
1261                    _ => Indirection::Reference,
1262                }
1263            }
1264            Ex::Access { base, .. } | Ex::AccessIndex { base, .. } => {
1265                let base_ty = func_ctx.resolve_type(base, &module.types);
1266                match *base_ty {
1267                    TypeInner::Pointer { .. } | TypeInner::ValuePointer { .. } => {
1268                        Indirection::Reference
1269                    }
1270                    _ => Indirection::Ordinary,
1271                }
1272            }
1273            _ => Indirection::Ordinary,
1274        }
1275    }
1276
1277    fn start_named_expr(
1278        &mut self,
1279        module: &Module,
1280        handle: Handle<crate::Expression>,
1281        func_ctx: &back::FunctionCtx,
1282        name: &str,
1283    ) -> BackendResult {
1284        // Write variable name
1285        write!(self.out, "let {name}")?;
1286        if self.flags.contains(WriterFlags::EXPLICIT_TYPES) {
1287            write!(self.out, ": ")?;
1288            // Write variable type
1289            self.write_type_resolution(module, &func_ctx.info[handle].ty)?;
1290        }
1291
1292        write!(self.out, " = ")?;
1293        Ok(())
1294    }
1295
1296    /// Write the ordinary WGSL form of `expr`.
1297    ///
1298    /// See `write_expr_with_indirection` for details.
1299    fn write_expr(
1300        &mut self,
1301        module: &Module,
1302        expr: Handle<crate::Expression>,
1303        func_ctx: &back::FunctionCtx<'_>,
1304    ) -> BackendResult {
1305        self.write_expr_with_indirection(module, expr, func_ctx, Indirection::Ordinary)
1306    }
1307
1308    /// Write `expr` as a WGSL expression with the requested indirection.
1309    ///
1310    /// In terms of the WGSL grammar, the resulting expression is a
1311    /// `singular_expression`. It may be parenthesized. This makes it suitable
1312    /// for use as the operand of a unary or binary operator without worrying
1313    /// about precedence.
1314    ///
1315    /// This does not produce newlines or indentation.
1316    ///
1317    /// The `requested` argument indicates (roughly) whether Naga
1318    /// `Pointer`-valued expressions represent WGSL references or pointers. See
1319    /// `Indirection` for details.
1320    fn write_expr_with_indirection(
1321        &mut self,
1322        module: &Module,
1323        expr: Handle<crate::Expression>,
1324        func_ctx: &back::FunctionCtx<'_>,
1325        requested: Indirection,
1326    ) -> BackendResult {
1327        // If the plain form of the expression is not what we need, emit the
1328        // operator necessary to correct that.
1329        let plain = self.plain_form_indirection(expr, module, func_ctx);
1330        log::trace!(
1331            "expression {:?}={:?} is {:?}, expected {:?}",
1332            expr,
1333            func_ctx.expressions[expr],
1334            plain,
1335            requested,
1336        );
1337        match (requested, plain) {
1338            (Indirection::Ordinary, Indirection::Reference) => {
1339                write!(self.out, "(&")?;
1340                self.write_expr_plain_form(module, expr, func_ctx, plain)?;
1341                write!(self.out, ")")?;
1342            }
1343            (Indirection::Reference, Indirection::Ordinary) => {
1344                write!(self.out, "(*")?;
1345                self.write_expr_plain_form(module, expr, func_ctx, plain)?;
1346                write!(self.out, ")")?;
1347            }
1348            (_, _) => self.write_expr_plain_form(module, expr, func_ctx, plain)?,
1349        }
1350
1351        Ok(())
1352    }
1353
1354    fn write_const_expression(
1355        &mut self,
1356        module: &Module,
1357        expr: Handle<crate::Expression>,
1358        arena: &crate::Arena<crate::Expression>,
1359    ) -> BackendResult {
1360        self.write_possibly_const_expression(module, expr, arena, |writer, expr| {
1361            writer.write_const_expression(module, expr, arena)
1362        })
1363    }
1364
1365    fn write_possibly_const_expression<E>(
1366        &mut self,
1367        module: &Module,
1368        expr: Handle<crate::Expression>,
1369        expressions: &crate::Arena<crate::Expression>,
1370        write_expression: E,
1371    ) -> BackendResult
1372    where
1373        E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult,
1374    {
1375        use crate::Expression;
1376
1377        match expressions[expr] {
1378            Expression::Literal(literal) => match literal {
1379                crate::Literal::F16(value) => write!(self.out, "{value}h")?,
1380                crate::Literal::F32(value) => write!(self.out, "{value}f")?,
1381                crate::Literal::U32(value) => write!(self.out, "{value}u")?,
1382                crate::Literal::I32(value) => {
1383                    // `-2147483648i` is not valid WGSL. The most negative `i32`
1384                    // value can only be expressed in WGSL using AbstractInt and
1385                    // a unary negation operator.
1386                    if value == i32::MIN {
1387                        write!(self.out, "i32({value})")?;
1388                    } else {
1389                        write!(self.out, "{value}i")?;
1390                    }
1391                }
1392                crate::Literal::Bool(value) => write!(self.out, "{value}")?,
1393                crate::Literal::F64(value) => write!(self.out, "{value:?}lf")?,
1394                crate::Literal::I64(value) => {
1395                    // `-9223372036854775808li` is not valid WGSL. Nor can we simply use the
1396                    // AbstractInt trick above, as AbstractInt also cannot represent
1397                    // `9223372036854775808`. Instead construct the second most negative
1398                    // AbstractInt, subtract one from it, then cast to i64.
1399                    if value == i64::MIN {
1400                        write!(self.out, "i64({} - 1)", value + 1)?;
1401                    } else {
1402                        write!(self.out, "{value}li")?;
1403                    }
1404                }
1405                crate::Literal::U64(value) => write!(self.out, "{value:?}lu")?,
1406                crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
1407                    return Err(Error::Custom(
1408                        "Abstract types should not appear in IR presented to backends".into(),
1409                    ));
1410                }
1411            },
1412            Expression::Constant(handle) => {
1413                let constant = &module.constants[handle];
1414                if constant.name.is_some() {
1415                    write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
1416                } else {
1417                    self.write_const_expression(module, constant.init, &module.global_expressions)?;
1418                }
1419            }
1420            Expression::ZeroValue(ty) => {
1421                self.write_type(module, ty)?;
1422                write!(self.out, "()")?;
1423            }
1424            Expression::Compose { ty, ref components } => {
1425                self.write_type(module, ty)?;
1426                write!(self.out, "(")?;
1427                for (index, component) in components.iter().enumerate() {
1428                    if index != 0 {
1429                        write!(self.out, ", ")?;
1430                    }
1431                    write_expression(self, *component)?;
1432                }
1433                write!(self.out, ")")?
1434            }
1435            Expression::Splat { size, value } => {
1436                let size = common::vector_size_str(size);
1437                write!(self.out, "vec{size}(")?;
1438                write_expression(self, value)?;
1439                write!(self.out, ")")?;
1440            }
1441            Expression::Override(handle) => {
1442                write!(self.out, "{}", self.names[&NameKey::Override(handle)])?;
1443            }
1444            _ => unreachable!(),
1445        }
1446
1447        Ok(())
1448    }
1449
1450    /// Write the 'plain form' of `expr`.
1451    ///
1452    /// An expression's 'plain form' is the most general rendition of that
1453    /// expression into WGSL, lacking `&` or `*` operators. The plain forms of
1454    /// `LocalVariable(x)` and `GlobalVariable(g)` are simply `x` and `g`. Such
1455    /// Naga expressions represent both WGSL pointers and references; it's the
1456    /// caller's responsibility to distinguish those cases appropriately.
1457    fn write_expr_plain_form(
1458        &mut self,
1459        module: &Module,
1460        expr: Handle<crate::Expression>,
1461        func_ctx: &back::FunctionCtx<'_>,
1462        indirection: Indirection,
1463    ) -> BackendResult {
1464        use crate::Expression;
1465
1466        if let Some(name) = self.named_expressions.get(&expr) {
1467            write!(self.out, "{name}")?;
1468            return Ok(());
1469        }
1470
1471        let expression = &func_ctx.expressions[expr];
1472
1473        // Write the plain WGSL form of a Naga expression.
1474        //
1475        // The plain form of `LocalVariable` and `GlobalVariable` expressions is
1476        // simply the variable name; `*` and `&` operators are never emitted.
1477        //
1478        // The plain form of `Access` and `AccessIndex` expressions are WGSL
1479        // `postfix_expression` forms for member/component access and
1480        // subscripting.
1481        match *expression {
1482            Expression::Literal(_)
1483            | Expression::Constant(_)
1484            | Expression::ZeroValue(_)
1485            | Expression::Compose { .. }
1486            | Expression::Splat { .. } => {
1487                self.write_possibly_const_expression(
1488                    module,
1489                    expr,
1490                    func_ctx.expressions,
1491                    |writer, expr| writer.write_expr(module, expr, func_ctx),
1492                )?;
1493            }
1494            Expression::Override(handle) => {
1495                write!(self.out, "{}", self.names[&NameKey::Override(handle)])?;
1496            }
1497            Expression::FunctionArgument(pos) => {
1498                let name_key = func_ctx.argument_key(pos);
1499                let name = &self.names[&name_key];
1500                write!(self.out, "{name}")?;
1501            }
1502            Expression::Binary { op, left, right } => {
1503                write!(self.out, "(")?;
1504                self.write_expr(module, left, func_ctx)?;
1505                write!(self.out, " {} ", back::binary_operation_str(op))?;
1506                self.write_expr(module, right, func_ctx)?;
1507                write!(self.out, ")")?;
1508            }
1509            Expression::Access { base, index } => {
1510                self.write_expr_with_indirection(module, base, func_ctx, indirection)?;
1511                write!(self.out, "[")?;
1512                self.write_expr(module, index, func_ctx)?;
1513                write!(self.out, "]")?
1514            }
1515            Expression::AccessIndex { base, index } => {
1516                let base_ty_res = &func_ctx.info[base].ty;
1517                let mut resolved = base_ty_res.inner_with(&module.types);
1518
1519                self.write_expr_with_indirection(module, base, func_ctx, indirection)?;
1520
1521                let base_ty_handle = match *resolved {
1522                    TypeInner::Pointer { base, space: _ } => {
1523                        resolved = &module.types[base].inner;
1524                        Some(base)
1525                    }
1526                    _ => base_ty_res.handle(),
1527                };
1528
1529                match *resolved {
1530                    TypeInner::Vector { .. } => {
1531                        // Write vector access as a swizzle
1532                        write!(self.out, ".{}", back::COMPONENTS[index as usize])?
1533                    }
1534                    TypeInner::Matrix { .. }
1535                    | TypeInner::Array { .. }
1536                    | TypeInner::BindingArray { .. }
1537                    | TypeInner::ValuePointer { .. } => write!(self.out, "[{index}]")?,
1538                    TypeInner::Struct { .. } => {
1539                        // This will never panic in case the type is a `Struct`, this is not true
1540                        // for other types so we can only check while inside this match arm
1541                        let ty = base_ty_handle.unwrap();
1542
1543                        write!(
1544                            self.out,
1545                            ".{}",
1546                            &self.names[&NameKey::StructMember(ty, index)]
1547                        )?
1548                    }
1549                    ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))),
1550                }
1551            }
1552            Expression::ImageSample {
1553                image,
1554                sampler,
1555                gather: None,
1556                coordinate,
1557                array_index,
1558                offset,
1559                level,
1560                depth_ref,
1561                clamp_to_edge,
1562            } => {
1563                use crate::SampleLevel as Sl;
1564
1565                let suffix_cmp = match depth_ref {
1566                    Some(_) => "Compare",
1567                    None => "",
1568                };
1569                let suffix_level = match level {
1570                    Sl::Auto => "",
1571                    Sl::Zero if clamp_to_edge => "BaseClampToEdge",
1572                    Sl::Zero | Sl::Exact(_) => "Level",
1573                    Sl::Bias(_) => "Bias",
1574                    Sl::Gradient { .. } => "Grad",
1575                };
1576
1577                write!(self.out, "textureSample{suffix_cmp}{suffix_level}(")?;
1578                self.write_expr(module, image, func_ctx)?;
1579                write!(self.out, ", ")?;
1580                self.write_expr(module, sampler, func_ctx)?;
1581                write!(self.out, ", ")?;
1582                self.write_expr(module, coordinate, func_ctx)?;
1583
1584                if let Some(array_index) = array_index {
1585                    write!(self.out, ", ")?;
1586                    self.write_expr(module, array_index, func_ctx)?;
1587                }
1588
1589                if let Some(depth_ref) = depth_ref {
1590                    write!(self.out, ", ")?;
1591                    self.write_expr(module, depth_ref, func_ctx)?;
1592                }
1593
1594                match level {
1595                    Sl::Auto => {}
1596                    Sl::Zero => {
1597                        // Level 0 is implied for depth comparison and BaseClampToEdge
1598                        if depth_ref.is_none() && !clamp_to_edge {
1599                            write!(self.out, ", 0.0")?;
1600                        }
1601                    }
1602                    Sl::Exact(expr) => {
1603                        write!(self.out, ", ")?;
1604                        self.write_expr(module, expr, func_ctx)?;
1605                    }
1606                    Sl::Bias(expr) => {
1607                        write!(self.out, ", ")?;
1608                        self.write_expr(module, expr, func_ctx)?;
1609                    }
1610                    Sl::Gradient { x, y } => {
1611                        write!(self.out, ", ")?;
1612                        self.write_expr(module, x, func_ctx)?;
1613                        write!(self.out, ", ")?;
1614                        self.write_expr(module, y, func_ctx)?;
1615                    }
1616                }
1617
1618                if let Some(offset) = offset {
1619                    write!(self.out, ", ")?;
1620                    self.write_const_expression(module, offset, func_ctx.expressions)?;
1621                }
1622
1623                write!(self.out, ")")?;
1624            }
1625
1626            Expression::ImageSample {
1627                image,
1628                sampler,
1629                gather: Some(component),
1630                coordinate,
1631                array_index,
1632                offset,
1633                level: _,
1634                depth_ref,
1635                clamp_to_edge: _,
1636            } => {
1637                let suffix_cmp = match depth_ref {
1638                    Some(_) => "Compare",
1639                    None => "",
1640                };
1641
1642                write!(self.out, "textureGather{suffix_cmp}(")?;
1643                match *func_ctx.resolve_type(image, &module.types) {
1644                    TypeInner::Image {
1645                        class: crate::ImageClass::Depth { multi: _ },
1646                        ..
1647                    } => {}
1648                    _ => {
1649                        write!(self.out, "{}, ", component as u8)?;
1650                    }
1651                }
1652                self.write_expr(module, image, func_ctx)?;
1653                write!(self.out, ", ")?;
1654                self.write_expr(module, sampler, func_ctx)?;
1655                write!(self.out, ", ")?;
1656                self.write_expr(module, coordinate, func_ctx)?;
1657
1658                if let Some(array_index) = array_index {
1659                    write!(self.out, ", ")?;
1660                    self.write_expr(module, array_index, func_ctx)?;
1661                }
1662
1663                if let Some(depth_ref) = depth_ref {
1664                    write!(self.out, ", ")?;
1665                    self.write_expr(module, depth_ref, func_ctx)?;
1666                }
1667
1668                if let Some(offset) = offset {
1669                    write!(self.out, ", ")?;
1670                    self.write_const_expression(module, offset, func_ctx.expressions)?;
1671                }
1672
1673                write!(self.out, ")")?;
1674            }
1675            Expression::ImageQuery { image, query } => {
1676                use crate::ImageQuery as Iq;
1677
1678                let texture_function = match query {
1679                    Iq::Size { .. } => "textureDimensions",
1680                    Iq::NumLevels => "textureNumLevels",
1681                    Iq::NumLayers => "textureNumLayers",
1682                    Iq::NumSamples => "textureNumSamples",
1683                };
1684
1685                write!(self.out, "{texture_function}(")?;
1686                self.write_expr(module, image, func_ctx)?;
1687                if let Iq::Size { level: Some(level) } = query {
1688                    write!(self.out, ", ")?;
1689                    self.write_expr(module, level, func_ctx)?;
1690                };
1691                write!(self.out, ")")?;
1692            }
1693
1694            Expression::ImageLoad {
1695                image,
1696                coordinate,
1697                array_index,
1698                sample,
1699                level,
1700            } => {
1701                write!(self.out, "textureLoad(")?;
1702                self.write_expr(module, image, func_ctx)?;
1703                write!(self.out, ", ")?;
1704                self.write_expr(module, coordinate, func_ctx)?;
1705                if let Some(array_index) = array_index {
1706                    write!(self.out, ", ")?;
1707                    self.write_expr(module, array_index, func_ctx)?;
1708                }
1709                if let Some(index) = sample.or(level) {
1710                    write!(self.out, ", ")?;
1711                    self.write_expr(module, index, func_ctx)?;
1712                }
1713                write!(self.out, ")")?;
1714            }
1715            Expression::GlobalVariable(handle) => {
1716                let name = &self.names[&NameKey::GlobalVariable(handle)];
1717                write!(self.out, "{name}")?;
1718            }
1719
1720            Expression::As {
1721                expr,
1722                kind,
1723                convert,
1724            } => {
1725                let inner = func_ctx.resolve_type(expr, &module.types);
1726                match *inner {
1727                    TypeInner::Matrix {
1728                        columns,
1729                        rows,
1730                        scalar,
1731                    } => {
1732                        let scalar = crate::Scalar {
1733                            kind,
1734                            width: convert.unwrap_or(scalar.width),
1735                        };
1736                        let scalar_kind_str = scalar.to_wgsl_if_implemented()?;
1737                        write!(
1738                            self.out,
1739                            "mat{}x{}<{}>",
1740                            common::vector_size_str(columns),
1741                            common::vector_size_str(rows),
1742                            scalar_kind_str
1743                        )?;
1744                    }
1745                    TypeInner::Vector {
1746                        size,
1747                        scalar: crate::Scalar { width, .. },
1748                    } => {
1749                        let scalar = crate::Scalar {
1750                            kind,
1751                            width: convert.unwrap_or(width),
1752                        };
1753                        let vector_size_str = common::vector_size_str(size);
1754                        let scalar_kind_str = scalar.to_wgsl_if_implemented()?;
1755                        if convert.is_some() {
1756                            write!(self.out, "vec{vector_size_str}<{scalar_kind_str}>")?;
1757                        } else {
1758                            write!(self.out, "bitcast<vec{vector_size_str}<{scalar_kind_str}>>")?;
1759                        }
1760                    }
1761                    TypeInner::Scalar(crate::Scalar { width, .. }) => {
1762                        let scalar = crate::Scalar {
1763                            kind,
1764                            width: convert.unwrap_or(width),
1765                        };
1766                        let scalar_kind_str = scalar.to_wgsl_if_implemented()?;
1767                        if convert.is_some() {
1768                            write!(self.out, "{scalar_kind_str}")?
1769                        } else {
1770                            write!(self.out, "bitcast<{scalar_kind_str}>")?
1771                        }
1772                    }
1773                    _ => {
1774                        return Err(Error::Unimplemented(format!(
1775                            "write_expr expression::as {inner:?}"
1776                        )));
1777                    }
1778                };
1779                write!(self.out, "(")?;
1780                self.write_expr(module, expr, func_ctx)?;
1781                write!(self.out, ")")?;
1782            }
1783            Expression::Load { pointer } => {
1784                let is_atomic_pointer = func_ctx
1785                    .resolve_type(pointer, &module.types)
1786                    .is_atomic_pointer(&module.types);
1787
1788                if is_atomic_pointer {
1789                    write!(self.out, "atomicLoad(")?;
1790                    self.write_expr(module, pointer, func_ctx)?;
1791                    write!(self.out, ")")?;
1792                } else {
1793                    self.write_expr_with_indirection(
1794                        module,
1795                        pointer,
1796                        func_ctx,
1797                        Indirection::Reference,
1798                    )?;
1799                }
1800            }
1801            Expression::LocalVariable(handle) => {
1802                write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])?
1803            }
1804            Expression::ArrayLength(expr) => {
1805                write!(self.out, "arrayLength(")?;
1806                self.write_expr(module, expr, func_ctx)?;
1807                write!(self.out, ")")?;
1808            }
1809
1810            Expression::Math {
1811                fun,
1812                arg,
1813                arg1,
1814                arg2,
1815                arg3,
1816            } => {
1817                use crate::MathFunction as Mf;
1818
1819                enum Function {
1820                    Regular(&'static str),
1821                    InversePolyfill(InversePolyfill),
1822                }
1823
1824                let function = match fun.try_to_wgsl() {
1825                    Some(name) => Function::Regular(name),
1826                    None => match fun {
1827                        Mf::Inverse => {
1828                            let ty = func_ctx.resolve_type(arg, &module.types);
1829                            let Some(overload) = InversePolyfill::find_overload(ty) else {
1830                                return Err(Error::unsupported("math function", fun));
1831                            };
1832
1833                            Function::InversePolyfill(overload)
1834                        }
1835                        _ => return Err(Error::unsupported("math function", fun)),
1836                    },
1837                };
1838
1839                match function {
1840                    Function::Regular(fun_name) => {
1841                        write!(self.out, "{fun_name}(")?;
1842                        self.write_expr(module, arg, func_ctx)?;
1843                        for arg in IntoIterator::into_iter([arg1, arg2, arg3]).flatten() {
1844                            write!(self.out, ", ")?;
1845                            self.write_expr(module, arg, func_ctx)?;
1846                        }
1847                        write!(self.out, ")")?
1848                    }
1849                    Function::InversePolyfill(inverse) => {
1850                        write!(self.out, "{}(", inverse.fun_name)?;
1851                        self.write_expr(module, arg, func_ctx)?;
1852                        write!(self.out, ")")?;
1853                        self.required_polyfills.insert(inverse);
1854                    }
1855                }
1856            }
1857
1858            Expression::Swizzle {
1859                size,
1860                vector,
1861                pattern,
1862            } => {
1863                self.write_expr(module, vector, func_ctx)?;
1864                write!(self.out, ".")?;
1865                for &sc in pattern[..size as usize].iter() {
1866                    self.out.write_char(back::COMPONENTS[sc as usize])?;
1867                }
1868            }
1869            Expression::Unary { op, expr } => {
1870                let unary = match op {
1871                    crate::UnaryOperator::Negate => "-",
1872                    crate::UnaryOperator::LogicalNot => "!",
1873                    crate::UnaryOperator::BitwiseNot => "~",
1874                };
1875
1876                write!(self.out, "{unary}(")?;
1877                self.write_expr(module, expr, func_ctx)?;
1878
1879                write!(self.out, ")")?
1880            }
1881
1882            Expression::Select {
1883                condition,
1884                accept,
1885                reject,
1886            } => {
1887                write!(self.out, "select(")?;
1888                self.write_expr(module, reject, func_ctx)?;
1889                write!(self.out, ", ")?;
1890                self.write_expr(module, accept, func_ctx)?;
1891                write!(self.out, ", ")?;
1892                self.write_expr(module, condition, func_ctx)?;
1893                write!(self.out, ")")?
1894            }
1895            Expression::Derivative { axis, ctrl, expr } => {
1896                use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
1897                let op = match (axis, ctrl) {
1898                    (Axis::X, Ctrl::Coarse) => "dpdxCoarse",
1899                    (Axis::X, Ctrl::Fine) => "dpdxFine",
1900                    (Axis::X, Ctrl::None) => "dpdx",
1901                    (Axis::Y, Ctrl::Coarse) => "dpdyCoarse",
1902                    (Axis::Y, Ctrl::Fine) => "dpdyFine",
1903                    (Axis::Y, Ctrl::None) => "dpdy",
1904                    (Axis::Width, Ctrl::Coarse) => "fwidthCoarse",
1905                    (Axis::Width, Ctrl::Fine) => "fwidthFine",
1906                    (Axis::Width, Ctrl::None) => "fwidth",
1907                };
1908                write!(self.out, "{op}(")?;
1909                self.write_expr(module, expr, func_ctx)?;
1910                write!(self.out, ")")?
1911            }
1912            Expression::Relational { fun, argument } => {
1913                use crate::RelationalFunction as Rf;
1914
1915                let fun_name = match fun {
1916                    Rf::All => "all",
1917                    Rf::Any => "any",
1918                    _ => return Err(Error::UnsupportedRelationalFunction(fun)),
1919                };
1920                write!(self.out, "{fun_name}(")?;
1921
1922                self.write_expr(module, argument, func_ctx)?;
1923
1924                write!(self.out, ")")?
1925            }
1926            // Not supported yet
1927            Expression::RayQueryGetIntersection { .. }
1928            | Expression::RayQueryVertexPositions { .. } => unreachable!(),
1929            // Nothing to do here, since call expression already cached
1930            Expression::CallResult(_)
1931            | Expression::AtomicResult { .. }
1932            | Expression::RayQueryProceedResult
1933            | Expression::SubgroupBallotResult
1934            | Expression::SubgroupOperationResult { .. }
1935            | Expression::WorkGroupUniformLoadResult { .. } => {}
1936            Expression::CooperativeLoad {
1937                columns,
1938                rows,
1939                role,
1940                ref data,
1941            } => {
1942                let suffix = if data.row_major { "T" } else { "" };
1943                let scalar = func_ctx.info[data.pointer]
1944                    .ty
1945                    .inner_with(&module.types)
1946                    .pointer_base_type()
1947                    .unwrap()
1948                    .inner_with(&module.types)
1949                    .scalar()
1950                    .unwrap();
1951                write!(
1952                    self.out,
1953                    "coopLoad{suffix}<coop_mat{}x{}<{},{:?}>>(",
1954                    columns as u32,
1955                    rows as u32,
1956                    scalar.try_to_wgsl().unwrap(),
1957                    role,
1958                )?;
1959                self.write_expr(module, data.pointer, func_ctx)?;
1960                write!(self.out, ", ")?;
1961                self.write_expr(module, data.stride, func_ctx)?;
1962                write!(self.out, ")")?;
1963            }
1964            Expression::CooperativeMultiplyAdd { a, b, c } => {
1965                write!(self.out, "coopMultiplyAdd(")?;
1966                self.write_expr(module, a, func_ctx)?;
1967                write!(self.out, ", ")?;
1968                self.write_expr(module, b, func_ctx)?;
1969                write!(self.out, ", ")?;
1970                self.write_expr(module, c, func_ctx)?;
1971                write!(self.out, ")")?;
1972            }
1973        }
1974
1975        Ok(())
1976    }
1977
1978    /// Helper method used to write global variables
1979    /// # Notes
1980    /// Always adds a newline
1981    fn write_global(
1982        &mut self,
1983        module: &Module,
1984        global: &crate::GlobalVariable,
1985        handle: Handle<crate::GlobalVariable>,
1986    ) -> BackendResult {
1987        // Write group and binding attributes if present
1988        if let Some(ref binding) = global.binding {
1989            self.write_attributes(&[
1990                Attribute::Group(binding.group),
1991                Attribute::Binding(binding.binding),
1992            ])?;
1993            writeln!(self.out)?;
1994        }
1995
1996        if global
1997            .memory_decorations
1998            .contains(crate::MemoryDecorations::COHERENT)
1999        {
2000            write!(self.out, "@coherent ")?;
2001        }
2002        if global
2003            .memory_decorations
2004            .contains(crate::MemoryDecorations::VOLATILE)
2005        {
2006            write!(self.out, "@volatile ")?;
2007        }
2008
2009        // First write global name and address space if supported
2010        write!(self.out, "var")?;
2011        let (address, maybe_access) = address_space_str(global.space);
2012        if let Some(space) = address {
2013            write!(self.out, "<{space}")?;
2014            if let Some(access) = maybe_access {
2015                write!(self.out, ", {access}")?;
2016            }
2017            write!(self.out, ">")?;
2018        }
2019        write!(
2020            self.out,
2021            " {}: ",
2022            &self.names[&NameKey::GlobalVariable(handle)]
2023        )?;
2024
2025        // Write global type
2026        self.write_type(module, global.ty)?;
2027
2028        // Write initializer
2029        if let Some(init) = global.init {
2030            write!(self.out, " = ")?;
2031            self.write_const_expression(module, init, &module.global_expressions)?;
2032        }
2033
2034        // End with semicolon
2035        writeln!(self.out, ";")?;
2036
2037        Ok(())
2038    }
2039
2040    /// Helper method used to write global constants
2041    ///
2042    /// # Notes
2043    /// Ends in a newline
2044    fn write_global_constant(
2045        &mut self,
2046        module: &Module,
2047        handle: Handle<crate::Constant>,
2048    ) -> BackendResult {
2049        let name = &self.names[&NameKey::Constant(handle)];
2050        // First write only constant name
2051        write!(self.out, "const {name}: ")?;
2052        self.write_type(module, module.constants[handle].ty)?;
2053        write!(self.out, " = ")?;
2054        let init = module.constants[handle].init;
2055        self.write_const_expression(module, init, &module.global_expressions)?;
2056        writeln!(self.out, ";")?;
2057
2058        Ok(())
2059    }
2060
2061    /// Helper method used to write overrides
2062    ///
2063    /// # Notes
2064    /// Ends in a newline
2065    fn write_override(
2066        &mut self,
2067        module: &Module,
2068        handle: Handle<crate::Override>,
2069    ) -> BackendResult {
2070        let override_ = &module.overrides[handle];
2071        let name = &self.names[&NameKey::Override(handle)];
2072
2073        // Write @id attribute if present
2074        if let Some(id) = override_.id {
2075            write!(self.out, "@id({id}) ")?;
2076        }
2077
2078        // Write override declaration
2079        write!(self.out, "override {name}: ")?;
2080        self.write_type(module, override_.ty)?;
2081
2082        // Write initializer if present
2083        if let Some(init) = override_.init {
2084            write!(self.out, " = ")?;
2085            self.write_const_expression(module, init, &module.global_expressions)?;
2086        }
2087
2088        writeln!(self.out, ";")?;
2089
2090        Ok(())
2091    }
2092
2093    // See https://github.com/rust-lang/rust-clippy/issues/4979.
2094    pub fn finish(self) -> W {
2095        self.out
2096    }
2097}
2098
2099struct WriterTypeContext<'m> {
2100    module: &'m Module,
2101    names: &'m crate::FastHashMap<NameKey, String>,
2102}
2103
2104impl TypeContext for WriterTypeContext<'_> {
2105    fn lookup_type(&self, handle: Handle<crate::Type>) -> &crate::Type {
2106        &self.module.types[handle]
2107    }
2108
2109    fn type_name(&self, handle: Handle<crate::Type>) -> &str {
2110        self.names[&NameKey::Type(handle)].as_str()
2111    }
2112
2113    fn write_unnamed_struct<W: Write>(&self, _: &TypeInner, _: &mut W) -> core::fmt::Result {
2114        unreachable!("the WGSL back end should always provide type handles");
2115    }
2116
2117    fn write_override<W: Write>(
2118        &self,
2119        handle: Handle<crate::Override>,
2120        out: &mut W,
2121    ) -> core::fmt::Result {
2122        write!(out, "{}", self.names[&NameKey::Override(handle)])
2123    }
2124
2125    fn write_non_wgsl_inner<W: Write>(&self, _: &TypeInner, _: &mut W) -> core::fmt::Result {
2126        unreachable!("backends should only be passed validated modules");
2127    }
2128
2129    fn write_non_wgsl_scalar<W: Write>(&self, _: crate::Scalar, _: &mut W) -> core::fmt::Result {
2130        unreachable!("backends should only be passed validated modules");
2131    }
2132}
2133
2134fn map_binding_to_attribute(binding: &crate::Binding) -> Vec<Attribute> {
2135    match *binding {
2136        crate::Binding::BuiltIn(built_in) => {
2137            if let crate::BuiltIn::Position { invariant: true } = built_in {
2138                vec![Attribute::BuiltIn(built_in), Attribute::Invariant]
2139            } else {
2140                vec![Attribute::BuiltIn(built_in)]
2141            }
2142        }
2143        crate::Binding::Location {
2144            location,
2145            interpolation,
2146            sampling,
2147            blend_src: None,
2148            per_primitive,
2149        } => {
2150            let mut attrs = vec![
2151                Attribute::Location(location),
2152                Attribute::Interpolate(interpolation, sampling),
2153            ];
2154            if per_primitive {
2155                attrs.push(Attribute::PerPrimitive);
2156            }
2157            attrs
2158        }
2159        crate::Binding::Location {
2160            location,
2161            interpolation,
2162            sampling,
2163            blend_src: Some(blend_src),
2164            per_primitive,
2165        } => {
2166            let mut attrs = vec![
2167                Attribute::Location(location),
2168                Attribute::BlendSrc(blend_src),
2169                Attribute::Interpolate(interpolation, sampling),
2170            ];
2171            if per_primitive {
2172                attrs.push(Attribute::PerPrimitive);
2173            }
2174            attrs
2175        }
2176    }
2177}