naga/back/hlsl/
help.rs

1/*!
2Helpers for the hlsl backend
3
4Important note about `Expression::ImageQuery`/`Expression::ArrayLength` and hlsl backend:
5
6Due to implementation of `GetDimensions` function in hlsl (<https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-getdimensions>)
7backend can't work with it as an expression.
8Instead, it generates a unique wrapped function per `Expression::ImageQuery`, based on texture info and query function.
9See `WrappedImageQuery` struct that represents a unique function and will be generated before writing all statements and expressions.
10This allowed to works with `Expression::ImageQuery` as expression and write wrapped function.
11
12For example:
13```wgsl
14let dim_1d = textureDimensions(image_1d);
15```
16
17```hlsl
18int NagaDimensions1D(Texture1D<float4>)
19{
20   uint4 ret;
21   image_1d.GetDimensions(ret.x);
22   return ret.x;
23}
24
25int dim_1d = NagaDimensions1D(image_1d);
26```
27*/
28
29use alloc::format;
30use core::fmt::Write;
31
32use super::{
33    super::FunctionCtx,
34    writer::{
35        ABS_FUNCTION, DIV_FUNCTION, EXTRACT_BITS_FUNCTION, F2I32_FUNCTION, F2I64_FUNCTION,
36        F2U32_FUNCTION, F2U64_FUNCTION, IMAGE_LOAD_EXTERNAL_FUNCTION,
37        IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION, INSERT_BITS_FUNCTION, MOD_FUNCTION, NEG_FUNCTION,
38    },
39    BackendResult, WrappedType,
40};
41use crate::{arena::Handle, proc::NameKey, ScalarKind};
42
43#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
44pub(super) struct WrappedArrayLength {
45    pub(super) writable: bool,
46}
47
48#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
49pub(super) struct WrappedImageLoad {
50    pub(super) class: crate::ImageClass,
51}
52
53#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
54pub(super) struct WrappedImageSample {
55    pub(super) class: crate::ImageClass,
56    pub(super) clamp_to_edge: bool,
57}
58
59#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
60pub(super) struct WrappedImageQuery {
61    pub(super) dim: crate::ImageDimension,
62    pub(super) arrayed: bool,
63    pub(super) class: crate::ImageClass,
64    pub(super) query: ImageQuery,
65}
66
67#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
68pub(super) struct WrappedConstructor {
69    pub(super) ty: Handle<crate::Type>,
70}
71
72#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
73pub(super) struct WrappedStructMatrixAccess {
74    pub(super) ty: Handle<crate::Type>,
75    pub(super) index: u32,
76}
77
78#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
79pub(super) struct WrappedMatCx2 {
80    pub(super) columns: crate::VectorSize,
81}
82
83#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
84pub(super) struct WrappedMath {
85    pub(super) fun: crate::MathFunction,
86    pub(super) scalar: crate::Scalar,
87    pub(super) components: Option<u32>,
88}
89
90#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
91pub(super) struct WrappedZeroValue {
92    pub(super) ty: Handle<crate::Type>,
93}
94
95#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
96pub(super) struct WrappedUnaryOp {
97    pub(super) op: crate::UnaryOperator,
98    // This can only represent scalar or vector types. If we ever need to wrap
99    // unary ops with other types, we'll need a better representation.
100    pub(super) ty: (Option<crate::VectorSize>, crate::Scalar),
101}
102
103#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
104pub(super) struct WrappedBinaryOp {
105    pub(super) op: crate::BinaryOperator,
106    // This can only represent scalar or vector types. If we ever need to wrap
107    // binary ops with other types, we'll need a better representation.
108    pub(super) left_ty: (Option<crate::VectorSize>, crate::Scalar),
109    pub(super) right_ty: (Option<crate::VectorSize>, crate::Scalar),
110}
111
112#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
113pub(super) struct WrappedCast {
114    // This can only represent scalar or vector types. If we ever need to wrap
115    // casts with other types, we'll need a better representation.
116    pub(super) vector_size: Option<crate::VectorSize>,
117    pub(super) src_scalar: crate::Scalar,
118    pub(super) dst_scalar: crate::Scalar,
119}
120
121/// HLSL backend requires its own `ImageQuery` enum.
122///
123/// It is used inside `WrappedImageQuery` and should be unique per ImageQuery function.
124/// IR version can't be unique per function, because it's store mipmap level as an expression.
125///
126/// For example:
127/// ```wgsl
128/// let dim_cube_array_lod = textureDimensions(image_cube_array, 1);
129/// let dim_cube_array_lod2 = textureDimensions(image_cube_array, 1);
130/// ```
131///
132/// ```ir
133/// ImageQuery {
134///  image: [1],
135///  query: Size {
136///      level: Some(
137///          [1],
138///      ),
139///  },
140/// },
141/// ImageQuery {
142///  image: [1],
143///  query: Size {
144///      level: Some(
145///          [2],
146///      ),
147///  },
148/// },
149/// ```
150///
151/// HLSL should generate only 1 function for this case.
152#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
153pub(super) enum ImageQuery {
154    Size,
155    SizeLevel,
156    NumLevels,
157    NumLayers,
158    NumSamples,
159}
160
161impl From<crate::ImageQuery> for ImageQuery {
162    fn from(q: crate::ImageQuery) -> Self {
163        use crate::ImageQuery as Iq;
164        match q {
165            Iq::Size { level: Some(_) } => ImageQuery::SizeLevel,
166            Iq::Size { level: None } => ImageQuery::Size,
167            Iq::NumLevels => ImageQuery::NumLevels,
168            Iq::NumLayers => ImageQuery::NumLayers,
169            Iq::NumSamples => ImageQuery::NumSamples,
170        }
171    }
172}
173
174pub(super) const IMAGE_STORAGE_LOAD_SCALAR_WRAPPER: &str = "LoadedStorageValueFrom";
175
176impl<W: Write> super::Writer<'_, W> {
177    pub(super) fn write_image_type(
178        &mut self,
179        dim: crate::ImageDimension,
180        arrayed: bool,
181        class: crate::ImageClass,
182    ) -> BackendResult {
183        let access_str = match class {
184            crate::ImageClass::Storage { .. } => "RW",
185            _ => "",
186        };
187        let dim_str = dim.to_hlsl_str();
188        let arrayed_str = if arrayed { "Array" } else { "" };
189        write!(self.out, "{access_str}Texture{dim_str}{arrayed_str}")?;
190        match class {
191            crate::ImageClass::Depth { multi } => {
192                let multi_str = if multi { "MS" } else { "" };
193                write!(self.out, "{multi_str}<float>")?
194            }
195            crate::ImageClass::Sampled { kind, multi } => {
196                let multi_str = if multi { "MS" } else { "" };
197                let scalar_kind_str = crate::Scalar { kind, width: 4 }.to_hlsl_str()?;
198                write!(self.out, "{multi_str}<{scalar_kind_str}4>")?
199            }
200            crate::ImageClass::Storage { format, .. } => {
201                let storage_format_str = format.to_hlsl_str();
202                write!(self.out, "<{storage_format_str}>")?
203            }
204            crate::ImageClass::External => {
205                unreachable!(
206                    "external images should be handled by `write_global_external_texture`"
207                );
208            }
209        }
210        Ok(())
211    }
212
213    pub(super) fn write_wrapped_array_length_function_name(
214        &mut self,
215        query: WrappedArrayLength,
216    ) -> BackendResult {
217        let access_str = if query.writable { "RW" } else { "" };
218        write!(self.out, "NagaBufferLength{access_str}",)?;
219
220        Ok(())
221    }
222
223    /// Helper function that write wrapped function for `Expression::ArrayLength`
224    ///
225    /// <https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-object-rwbyteaddressbuffer-getdimensions>
226    pub(super) fn write_wrapped_array_length_function(
227        &mut self,
228        wal: WrappedArrayLength,
229    ) -> BackendResult {
230        use crate::back::INDENT;
231
232        const ARGUMENT_VARIABLE_NAME: &str = "buffer";
233        const RETURN_VARIABLE_NAME: &str = "ret";
234
235        // Write function return type and name
236        write!(self.out, "uint ")?;
237        self.write_wrapped_array_length_function_name(wal)?;
238
239        // Write function parameters
240        write!(self.out, "(")?;
241        let access_str = if wal.writable { "RW" } else { "" };
242        writeln!(
243            self.out,
244            "{access_str}ByteAddressBuffer {ARGUMENT_VARIABLE_NAME})"
245        )?;
246        // Write function body
247        writeln!(self.out, "{{")?;
248
249        // Write `GetDimensions` function.
250        writeln!(self.out, "{INDENT}uint {RETURN_VARIABLE_NAME};")?;
251        writeln!(
252            self.out,
253            "{INDENT}{ARGUMENT_VARIABLE_NAME}.GetDimensions({RETURN_VARIABLE_NAME});"
254        )?;
255
256        // Write return value
257        writeln!(self.out, "{INDENT}return {RETURN_VARIABLE_NAME};")?;
258
259        // End of function body
260        writeln!(self.out, "}}")?;
261        // Write extra new line
262        writeln!(self.out)?;
263
264        Ok(())
265    }
266
267    /// Helper function used by [`Self::write_wrapped_image_load_function`] and
268    /// [`Self::write_wrapped_image_sample_function`] to write the shared YUV
269    /// to RGB conversion code for external textures. Expects the preceding
270    /// code to declare the Y component as a `float` variable of name `y`, the
271    /// UV components as a `float2` variable of name `uv`, and the external
272    /// texture params as a variable of name `params`. The emitted code will
273    /// return the result.
274    fn write_convert_yuv_to_rgb_and_return(
275        &mut self,
276        level: crate::back::Level,
277        y: &str,
278        uv: &str,
279        params: &str,
280    ) -> BackendResult {
281        let l1 = level;
282        let l2 = l1.next();
283
284        // Convert from YUV to non-linear RGB in the source color space. We
285        // declare our matrices as row_major in HLSL, therefore we must reverse
286        // the order of this multiplication
287        writeln!(
288            self.out,
289            "{l1}float3 srcGammaRgb = mul(float4({y}, {uv}, 1.0), {params}.yuv_conversion_matrix).rgb;"
290        )?;
291
292        // Apply the inverse of the source transfer function to convert to
293        // linear RGB in the source color space.
294        writeln!(
295            self.out,
296            "{l1}float3 srcLinearRgb = srcGammaRgb < {params}.src_tf.k * {params}.src_tf.b ?"
297        )?;
298        writeln!(self.out, "{l2}srcGammaRgb / {params}.src_tf.k :")?;
299        writeln!(self.out, "{l2}pow((srcGammaRgb + {params}.src_tf.a - 1.0) / {params}.src_tf.a, {params}.src_tf.g);")?;
300
301        // Multiply by the gamut conversion matrix to convert to linear RGB in
302        // the destination color space. We declare our matrices as row_major in
303        // HLSL, therefore we must reverse the order of this multiplication.
304        writeln!(
305            self.out,
306            "{l1}float3 dstLinearRgb = mul(srcLinearRgb, {params}.gamut_conversion_matrix);"
307        )?;
308
309        // Finally, apply the dest transfer function to convert to non-linear
310        // RGB in the destination color space, and return the result.
311        writeln!(
312            self.out,
313            "{l1}float3 dstGammaRgb = dstLinearRgb < {params}.dst_tf.b ?"
314        )?;
315        writeln!(self.out, "{l2}{params}.dst_tf.k * dstLinearRgb :")?;
316        writeln!(self.out, "{l2}{params}.dst_tf.a * pow(dstLinearRgb, 1.0 / {params}.dst_tf.g) - ({params}.dst_tf.a - 1);")?;
317
318        writeln!(self.out, "{l1}return float4(dstGammaRgb, 1.0);")?;
319        Ok(())
320    }
321
322    pub(super) fn write_wrapped_image_load_function(
323        &mut self,
324        module: &crate::Module,
325        load: WrappedImageLoad,
326    ) -> BackendResult {
327        match load {
328            WrappedImageLoad {
329                class: crate::ImageClass::External,
330            } => {
331                let l1 = crate::back::Level(1);
332                let l2 = l1.next();
333                let l3 = l2.next();
334                let params_ty_name = &self.names
335                    [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
336                writeln!(self.out, "float4 {IMAGE_LOAD_EXTERNAL_FUNCTION}(")?;
337                writeln!(self.out, "{l1}Texture2D<float4> plane0,")?;
338                writeln!(self.out, "{l1}Texture2D<float4> plane1,")?;
339                writeln!(self.out, "{l1}Texture2D<float4> plane2,")?;
340                writeln!(self.out, "{l1}{params_ty_name} params,")?;
341                writeln!(self.out, "{l1}uint2 coords)")?;
342                writeln!(self.out, "{{")?;
343                writeln!(self.out, "{l1}uint2 plane0_size;")?;
344                writeln!(
345                    self.out,
346                    "{l1}plane0.GetDimensions(plane0_size.x, plane0_size.y);"
347                )?;
348                // Clamp coords to provided size of external texture to prevent OOB read.
349                // If params.size is zero then clamp to the actual size of the texture.
350                writeln!(
351                    self.out,
352                    "{l1}uint2 cropped_size = any(params.size) ? params.size : plane0_size;"
353                )?;
354                writeln!(self.out, "{l1}coords = min(coords, cropped_size - 1);")?;
355
356                // Apply load transformation. We declare our matrices as row_major in
357                // HLSL, therefore we must reverse the order of this multiplication
358                writeln!(self.out, "{l1}float3x2 load_transform = float3x2(")?;
359                writeln!(self.out, "{l2}params.load_transform_0,")?;
360                writeln!(self.out, "{l2}params.load_transform_1,")?;
361                writeln!(self.out, "{l2}params.load_transform_2")?;
362                writeln!(self.out, "{l1});")?;
363                writeln!(self.out, "{l1}uint2 plane0_coords = uint2(round(mul(float3(coords, 1.0), load_transform)));")?;
364                writeln!(self.out, "{l1}if (params.num_planes == 1u) {{")?;
365                // For single plane, simply read from plane0
366                writeln!(
367                    self.out,
368                    "{l2}return plane0.Load(uint3(plane0_coords, 0u));"
369                )?;
370                writeln!(self.out, "{l1}}} else {{")?;
371
372                // Chroma planes may be subsampled so we must scale the coords accordingly.
373                writeln!(self.out, "{l2}uint2 plane1_size;")?;
374                writeln!(
375                    self.out,
376                    "{l2}plane1.GetDimensions(plane1_size.x, plane1_size.y);"
377                )?;
378                writeln!(self.out, "{l2}uint2 plane1_coords = uint2(floor(float2(plane0_coords) * float2(plane1_size) / float2(plane0_size)));")?;
379
380                // For multi-plane, read the Y value from plane 0
381                writeln!(
382                    self.out,
383                    "{l2}float y = plane0.Load(uint3(plane0_coords, 0u)).x;"
384                )?;
385
386                writeln!(self.out, "{l2}float2 uv;")?;
387                writeln!(self.out, "{l2}if (params.num_planes == 2u) {{")?;
388                // Read UV from interleaved plane 1
389                writeln!(
390                    self.out,
391                    "{l3}uv = plane1.Load(uint3(plane1_coords, 0u)).xy;"
392                )?;
393                writeln!(self.out, "{l2}}} else {{")?;
394                // Read U and V from planes 1 and 2 respectively
395                writeln!(self.out, "{l3}uint2 plane2_size;")?;
396                writeln!(
397                    self.out,
398                    "{l3}plane2.GetDimensions(plane2_size.x, plane2_size.y);"
399                )?;
400                writeln!(self.out, "{l3}uint2 plane2_coords = uint2(floor(float2(plane0_coords) * float2(plane2_size) / float2(plane0_size)));")?;
401                writeln!(self.out, "{l3}uv = float2(plane1.Load(uint3(plane1_coords, 0u)).x, plane2.Load(uint3(plane2_coords, 0u)).x);")?;
402                writeln!(self.out, "{l2}}}")?;
403
404                self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "params")?;
405
406                writeln!(self.out, "{l1}}}")?;
407                writeln!(self.out, "}}")?;
408                writeln!(self.out)?;
409            }
410            _ => {}
411        }
412
413        Ok(())
414    }
415
416    pub(super) fn write_wrapped_image_sample_function(
417        &mut self,
418        module: &crate::Module,
419        sample: WrappedImageSample,
420    ) -> BackendResult {
421        match sample {
422            WrappedImageSample {
423                class: crate::ImageClass::External,
424                clamp_to_edge: true,
425            } => {
426                let l1 = crate::back::Level(1);
427                let l2 = l1.next();
428                let l3 = l2.next();
429                let params_ty_name = &self.names
430                    [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
431                writeln!(
432                    self.out,
433                    "float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}("
434                )?;
435                writeln!(self.out, "{l1}Texture2D<float4> plane0,")?;
436                writeln!(self.out, "{l1}Texture2D<float4> plane1,")?;
437                writeln!(self.out, "{l1}Texture2D<float4> plane2,")?;
438                writeln!(self.out, "{l1}{params_ty_name} params,")?;
439                writeln!(self.out, "{l1}SamplerState samp,")?;
440                writeln!(self.out, "{l1}float2 coords)")?;
441                writeln!(self.out, "{{")?;
442                writeln!(self.out, "{l1}float2 plane0_size;")?;
443                writeln!(
444                    self.out,
445                    "{l1}plane0.GetDimensions(plane0_size.x, plane0_size.y);"
446                )?;
447                writeln!(self.out, "{l1}float3x2 sample_transform = float3x2(")?;
448                writeln!(self.out, "{l2}params.sample_transform_0,")?;
449                writeln!(self.out, "{l2}params.sample_transform_1,")?;
450                writeln!(self.out, "{l2}params.sample_transform_2")?;
451                writeln!(self.out, "{l1});")?;
452                // Apply sample transformation. We declare our matrices as row_major in
453                // HLSL, therefore we must reverse the order of this multiplication
454                writeln!(
455                    self.out,
456                    "{l1}coords = mul(float3(coords, 1.0), sample_transform);"
457                )?;
458                // Calculate the sample bounds. The purported size of the texture
459                // (params.size) is irrelevant here as we are dealing with normalized
460                // coordinates. Usually we would clamp to (0,0)..(1,1). However, we must
461                // apply the sample transformation to that, also bearing in mind that it
462                // may contain a flip on either axis. We calculate and adjust for the
463                // half-texel separately for each plane as it depends on the actual
464                // texture size which may vary between planes.
465                writeln!(
466                    self.out,
467                    "{l1}float2 bounds_min = mul(float3(0.0, 0.0, 1.0), sample_transform);"
468                )?;
469                writeln!(
470                    self.out,
471                    "{l1}float2 bounds_max = mul(float3(1.0, 1.0, 1.0), sample_transform);"
472                )?;
473                writeln!(self.out, "{l1}float4 bounds = float4(min(bounds_min, bounds_max), max(bounds_min, bounds_max));")?;
474                writeln!(
475                    self.out,
476                    "{l1}float2 plane0_half_texel = float2(0.5, 0.5) / plane0_size;"
477                )?;
478                writeln!(
479                    self.out,
480                    "{l1}float2 plane0_coords = clamp(coords, bounds.xy + plane0_half_texel, bounds.zw - plane0_half_texel);"
481                )?;
482                writeln!(self.out, "{l1}if (params.num_planes == 1u) {{")?;
483                // For single plane, simply sample from plane0
484                writeln!(
485                    self.out,
486                    "{l2}return plane0.SampleLevel(samp, plane0_coords, 0.0f);"
487                )?;
488                writeln!(self.out, "{l1}}} else {{")?;
489
490                writeln!(self.out, "{l2}float2 plane1_size;")?;
491                writeln!(
492                    self.out,
493                    "{l2}plane1.GetDimensions(plane1_size.x, plane1_size.y);"
494                )?;
495                writeln!(
496                    self.out,
497                    "{l2}float2 plane1_half_texel = float2(0.5, 0.5) / plane1_size;"
498                )?;
499                writeln!(
500                    self.out,
501                    "{l2}float2 plane1_coords = clamp(coords, bounds.xy + plane1_half_texel, bounds.zw - plane1_half_texel);"
502                )?;
503
504                // For multi-plane, sample the Y value from plane 0
505                writeln!(
506                    self.out,
507                    "{l2}float y = plane0.SampleLevel(samp, plane0_coords, 0.0f).x;"
508                )?;
509                writeln!(self.out, "{l2}float2 uv;")?;
510                writeln!(self.out, "{l2}if (params.num_planes == 2u) {{")?;
511                // Sample UV from interleaved plane 1
512                writeln!(
513                    self.out,
514                    "{l3}uv = plane1.SampleLevel(samp, plane1_coords, 0.0f).xy;"
515                )?;
516                writeln!(self.out, "{l2}}} else {{")?;
517                // Sample U and V from planes 1 and 2 respectively
518                writeln!(self.out, "{l3}float2 plane2_size;")?;
519                writeln!(
520                    self.out,
521                    "{l3}plane2.GetDimensions(plane2_size.x, plane2_size.y);"
522                )?;
523                writeln!(
524                    self.out,
525                    "{l3}float2 plane2_half_texel = float2(0.5, 0.5) / plane2_size;"
526                )?;
527                writeln!(self.out, "{l3}float2 plane2_coords = clamp(coords, bounds.xy + plane2_half_texel, bounds.zw - plane2_half_texel);")?;
528                writeln!(self.out, "{l3}uv = float2(plane1.SampleLevel(samp, plane1_coords, 0.0f).x, plane2.SampleLevel(samp, plane2_coords, 0.0f).x);")?;
529                writeln!(self.out, "{l2}}}")?;
530
531                self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "params")?;
532
533                writeln!(self.out, "{l1}}}")?;
534                writeln!(self.out, "}}")?;
535                writeln!(self.out)?;
536            }
537            WrappedImageSample {
538                class:
539                    crate::ImageClass::Sampled {
540                        kind: ScalarKind::Float,
541                        multi: false,
542                    },
543                clamp_to_edge: true,
544            } => {
545                writeln!(self.out, "float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(Texture2D<float4> tex, SamplerState samp, float2 coords) {{")?;
546                let l1 = crate::back::Level(1);
547                writeln!(self.out, "{l1}float2 size;")?;
548                writeln!(self.out, "{l1}tex.GetDimensions(size.x, size.y);")?;
549                writeln!(self.out, "{l1}float2 half_texel = float2(0.5, 0.5) / size;")?;
550                writeln!(
551                    self.out,
552                    "{l1}return tex.SampleLevel(samp, clamp(coords, half_texel, 1.0 - half_texel), 0.0);"
553                )?;
554                writeln!(self.out, "}}")?;
555                writeln!(self.out)?;
556            }
557            _ => {}
558        }
559
560        Ok(())
561    }
562
563    pub(super) fn write_wrapped_image_query_function_name(
564        &mut self,
565        query: WrappedImageQuery,
566    ) -> BackendResult {
567        let dim_str = query.dim.to_hlsl_str();
568        let class_str = match query.class {
569            crate::ImageClass::Sampled { multi: true, .. } => "MS",
570            crate::ImageClass::Depth { multi: true } => "DepthMS",
571            crate::ImageClass::Depth { multi: false } => "Depth",
572            crate::ImageClass::Sampled { multi: false, .. } => "",
573            crate::ImageClass::Storage { .. } => "RW",
574            crate::ImageClass::External => "External",
575        };
576        let arrayed_str = if query.arrayed { "Array" } else { "" };
577        let query_str = match query.query {
578            ImageQuery::Size => "Dimensions",
579            ImageQuery::SizeLevel => "MipDimensions",
580            ImageQuery::NumLevels => "NumLevels",
581            ImageQuery::NumLayers => "NumLayers",
582            ImageQuery::NumSamples => "NumSamples",
583        };
584
585        write!(self.out, "Naga{class_str}{query_str}{dim_str}{arrayed_str}")?;
586
587        Ok(())
588    }
589
590    /// Helper function that write wrapped function for `Expression::ImageQuery`
591    ///
592    /// <https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-getdimensions>
593    pub(super) fn write_wrapped_image_query_function(
594        &mut self,
595        module: &crate::Module,
596        wiq: WrappedImageQuery,
597        expr_handle: Handle<crate::Expression>,
598        func_ctx: &FunctionCtx,
599    ) -> BackendResult {
600        use crate::{
601            back::{COMPONENTS, INDENT},
602            ImageDimension as IDim,
603        };
604
605        match wiq.class {
606            crate::ImageClass::External => {
607                if wiq.query != ImageQuery::Size {
608                    return Err(super::Error::Custom(
609                        "External images only support `Size` queries".into(),
610                    ));
611                }
612
613                write!(self.out, "uint2 ")?;
614                self.write_wrapped_image_query_function_name(wiq)?;
615                let params_name = &self.names
616                    [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
617                // Only plane0 and params are used by this implementation, but it's easier to
618                // always take all of them as arguments so that we can unconditionally expand an
619                // external texture expression each of its parts.
620                writeln!(self.out, "(Texture2D<float4> plane0, Texture2D<float4> plane1, Texture2D<float4> plane2, {params_name} params) {{")?;
621                let l1 = crate::back::Level(1);
622                let l2 = l1.next();
623                writeln!(self.out, "{l1}if (any(params.size)) {{")?;
624                writeln!(self.out, "{l2}return params.size;")?;
625                writeln!(self.out, "{l1}}} else {{")?;
626                // params.size == (0, 0) indicates to query and return plane 0's actual size
627                writeln!(self.out, "{l2}uint2 ret;")?;
628                writeln!(self.out, "{l2}plane0.GetDimensions(ret.x, ret.y);")?;
629                writeln!(self.out, "{l2}return ret;")?;
630                writeln!(self.out, "{l1}}}")?;
631                writeln!(self.out, "}}")?;
632                writeln!(self.out)?;
633            }
634            _ => {
635                const ARGUMENT_VARIABLE_NAME: &str = "tex";
636                const RETURN_VARIABLE_NAME: &str = "ret";
637                const MIP_LEVEL_PARAM: &str = "mip_level";
638
639                // Write function return type and name
640                let ret_ty = func_ctx.resolve_type(expr_handle, &module.types);
641                self.write_value_type(module, ret_ty)?;
642                write!(self.out, " ")?;
643                self.write_wrapped_image_query_function_name(wiq)?;
644
645                // Write function parameters
646                write!(self.out, "(")?;
647                // Texture always first parameter
648                self.write_image_type(wiq.dim, wiq.arrayed, wiq.class)?;
649                write!(self.out, " {ARGUMENT_VARIABLE_NAME}")?;
650                // Mipmap is a second parameter if exists
651                if let ImageQuery::SizeLevel = wiq.query {
652                    write!(self.out, ", uint {MIP_LEVEL_PARAM}")?;
653                }
654                writeln!(self.out, ")")?;
655
656                // Write function body
657                writeln!(self.out, "{{")?;
658
659                let array_coords = usize::from(wiq.arrayed);
660                // extra parameter is the mip level count or the sample count
661                let extra_coords = match wiq.class {
662                    crate::ImageClass::Storage { .. } => 0,
663                    crate::ImageClass::Sampled { .. } | crate::ImageClass::Depth { .. } => 1,
664                    crate::ImageClass::External => unreachable!(),
665                };
666
667                // GetDimensions Overloaded Methods
668                // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-getdimensions#overloaded-methods
669                let (ret_swizzle, number_of_params) = match wiq.query {
670                    ImageQuery::Size | ImageQuery::SizeLevel => {
671                        let ret = match wiq.dim {
672                            IDim::D1 => "x",
673                            IDim::D2 => "xy",
674                            IDim::D3 => "xyz",
675                            IDim::Cube => "xy",
676                        };
677                        (ret, ret.len() + array_coords + extra_coords)
678                    }
679                    ImageQuery::NumLevels | ImageQuery::NumSamples | ImageQuery::NumLayers => {
680                        if wiq.arrayed || wiq.dim == IDim::D3 {
681                            ("w", 4)
682                        } else {
683                            ("z", 3)
684                        }
685                    }
686                };
687
688                // Write `GetDimensions` function.
689                writeln!(self.out, "{INDENT}uint4 {RETURN_VARIABLE_NAME};")?;
690                write!(self.out, "{INDENT}{ARGUMENT_VARIABLE_NAME}.GetDimensions(")?;
691                match wiq.query {
692                    ImageQuery::SizeLevel => {
693                        write!(self.out, "{MIP_LEVEL_PARAM}, ")?;
694                    }
695                    _ => match wiq.class {
696                        crate::ImageClass::Sampled { multi: true, .. }
697                        | crate::ImageClass::Depth { multi: true }
698                        | crate::ImageClass::Storage { .. } => {}
699                        _ => {
700                            // Write zero mipmap level for supported types
701                            write!(self.out, "0, ")?;
702                        }
703                    },
704                }
705
706                for component in COMPONENTS[..number_of_params - 1].iter() {
707                    write!(self.out, "{RETURN_VARIABLE_NAME}.{component}, ")?;
708                }
709
710                // write last parameter without comma and space for last parameter
711                write!(
712                    self.out,
713                    "{}.{}",
714                    RETURN_VARIABLE_NAME,
715                    COMPONENTS[number_of_params - 1]
716                )?;
717
718                writeln!(self.out, ");")?;
719
720                // Write return value
721                writeln!(
722                    self.out,
723                    "{INDENT}return {RETURN_VARIABLE_NAME}.{ret_swizzle};"
724                )?;
725
726                // End of function body
727                writeln!(self.out, "}}")?;
728                // Write extra new line
729                writeln!(self.out)?;
730            }
731        }
732        Ok(())
733    }
734
735    pub(super) fn write_wrapped_constructor_function_name(
736        &mut self,
737        module: &crate::Module,
738        constructor: WrappedConstructor,
739    ) -> BackendResult {
740        let name = crate::TypeInner::hlsl_type_id(constructor.ty, module.to_ctx(), &self.names)?;
741        write!(self.out, "Construct{name}")?;
742        Ok(())
743    }
744
745    /// Helper function that write wrapped function for `Expression::Compose` for structures.
746    fn write_wrapped_constructor_function(
747        &mut self,
748        module: &crate::Module,
749        constructor: WrappedConstructor,
750    ) -> BackendResult {
751        use crate::back::INDENT;
752
753        const ARGUMENT_VARIABLE_NAME: &str = "arg";
754        const RETURN_VARIABLE_NAME: &str = "ret";
755
756        // Write function return type and name
757        if let crate::TypeInner::Array { base, size, .. } = module.types[constructor.ty].inner {
758            write!(self.out, "typedef ")?;
759            self.write_type(module, constructor.ty)?;
760            write!(self.out, " ret_")?;
761            self.write_wrapped_constructor_function_name(module, constructor)?;
762            self.write_array_size(module, base, size)?;
763            writeln!(self.out, ";")?;
764
765            write!(self.out, "ret_")?;
766            self.write_wrapped_constructor_function_name(module, constructor)?;
767        } else {
768            self.write_type(module, constructor.ty)?;
769        }
770        write!(self.out, " ")?;
771        self.write_wrapped_constructor_function_name(module, constructor)?;
772
773        // Write function parameters
774        write!(self.out, "(")?;
775
776        let mut write_arg = |i, ty| -> BackendResult {
777            if i != 0 {
778                write!(self.out, ", ")?;
779            }
780            self.write_type(module, ty)?;
781            write!(self.out, " {ARGUMENT_VARIABLE_NAME}{i}")?;
782            if let crate::TypeInner::Array { base, size, .. } = module.types[ty].inner {
783                self.write_array_size(module, base, size)?;
784            }
785            Ok(())
786        };
787
788        match module.types[constructor.ty].inner {
789            crate::TypeInner::Struct { ref members, .. } => {
790                for (i, member) in members.iter().enumerate() {
791                    write_arg(i, member.ty)?;
792                }
793            }
794            crate::TypeInner::Array {
795                base,
796                size: crate::ArraySize::Constant(size),
797                ..
798            } => {
799                for i in 0..size.get() as usize {
800                    write_arg(i, base)?;
801                }
802            }
803            _ => unreachable!(),
804        };
805
806        write!(self.out, ")")?;
807
808        // Write function body
809        writeln!(self.out, " {{")?;
810
811        match module.types[constructor.ty].inner {
812            crate::TypeInner::Struct { ref members, .. } => {
813                let struct_name = &self.names[&NameKey::Type(constructor.ty)];
814                writeln!(
815                    self.out,
816                    "{INDENT}{struct_name} {RETURN_VARIABLE_NAME} = ({struct_name})0;"
817                )?;
818                for (i, member) in members.iter().enumerate() {
819                    let field_name = &self.names[&NameKey::StructMember(constructor.ty, i as u32)];
820
821                    match module.types[member.ty].inner {
822                        crate::TypeInner::Matrix {
823                            columns,
824                            rows: crate::VectorSize::Bi,
825                            ..
826                        } if member.binding.is_none() => {
827                            for j in 0..columns as u8 {
828                                writeln!(
829                                    self.out,
830                                    "{INDENT}{RETURN_VARIABLE_NAME}.{field_name}_{j} = {ARGUMENT_VARIABLE_NAME}{i}[{j}];"
831                                )?;
832                            }
833                        }
834                        ref other => {
835                            // We cast arrays of native HLSL `floatCx2`s to arrays of `matCx2`s
836                            // (where the inner matrix is represented by a struct with C `float2` members).
837                            // See the module-level block comment in mod.rs for details.
838                            if let Some(super::writer::MatrixType {
839                                columns,
840                                rows: crate::VectorSize::Bi,
841                                width: 4,
842                            }) = super::writer::get_inner_matrix_data(module, member.ty)
843                            {
844                                write!(
845                                    self.out,
846                                    "{}{}.{} = (__mat{}x2",
847                                    INDENT, RETURN_VARIABLE_NAME, field_name, columns as u8
848                                )?;
849                                if let crate::TypeInner::Array { base, size, .. } = *other {
850                                    self.write_array_size(module, base, size)?;
851                                }
852                                writeln!(self.out, "){ARGUMENT_VARIABLE_NAME}{i};",)?;
853                            } else {
854                                writeln!(
855                                    self.out,
856                                    "{INDENT}{RETURN_VARIABLE_NAME}.{field_name} = {ARGUMENT_VARIABLE_NAME}{i};",
857                                )?;
858                            }
859                        }
860                    }
861                }
862            }
863            crate::TypeInner::Array {
864                base,
865                size: crate::ArraySize::Constant(size),
866                ..
867            } => {
868                write!(self.out, "{INDENT}")?;
869                self.write_type(module, base)?;
870                write!(self.out, " {RETURN_VARIABLE_NAME}")?;
871                self.write_array_size(module, base, crate::ArraySize::Constant(size))?;
872                write!(self.out, " = {{ ")?;
873                for i in 0..size.get() {
874                    if i != 0 {
875                        write!(self.out, ", ")?;
876                    }
877                    write!(self.out, "{ARGUMENT_VARIABLE_NAME}{i}")?;
878                }
879                writeln!(self.out, " }};",)?;
880            }
881            _ => unreachable!(),
882        }
883
884        // Write return value
885        writeln!(self.out, "{INDENT}return {RETURN_VARIABLE_NAME};")?;
886
887        // End of function body
888        writeln!(self.out, "}}")?;
889        // Write extra new line
890        writeln!(self.out)?;
891
892        Ok(())
893    }
894
895    /// Writes the conversion from a single length storage texture load to a vec4 with the loaded
896    /// scalar in its `x` component, 1 in its `a` component and 0 everywhere else.
897    fn write_loaded_scalar_to_storage_loaded_value(
898        &mut self,
899        scalar_type: crate::Scalar,
900    ) -> BackendResult {
901        const ARGUMENT_VARIABLE_NAME: &str = "arg";
902        const RETURN_VARIABLE_NAME: &str = "ret";
903
904        let zero;
905        let one;
906        match scalar_type.kind {
907            ScalarKind::Sint => {
908                assert_eq!(
909                    scalar_type.width, 4,
910                    "Scalar {scalar_type:?} is not a result from any storage format"
911                );
912                zero = "0";
913                one = "1";
914            }
915            ScalarKind::Uint => match scalar_type.width {
916                4 => {
917                    zero = "0u";
918                    one = "1u";
919                }
920                8 => {
921                    zero = "0uL";
922                    one = "1uL"
923                }
924                _ => unreachable!("Scalar {scalar_type:?} is not a result from any storage format"),
925            },
926            ScalarKind::Float => {
927                assert_eq!(
928                    scalar_type.width, 4,
929                    "Scalar {scalar_type:?} is not a result from any storage format"
930                );
931                zero = "0.0";
932                one = "1.0";
933            }
934            _ => unreachable!("Scalar {scalar_type:?} is not a result from any storage format"),
935        }
936
937        let ty = scalar_type.to_hlsl_str()?;
938        writeln!(
939            self.out,
940            "{ty}4 {IMAGE_STORAGE_LOAD_SCALAR_WRAPPER}{ty}({ty} {ARGUMENT_VARIABLE_NAME}) {{\
941    {ty}4 {RETURN_VARIABLE_NAME} = {ty}4({ARGUMENT_VARIABLE_NAME}, {zero}, {zero}, {one});\
942    return {RETURN_VARIABLE_NAME};\
943}}"
944        )?;
945
946        Ok(())
947    }
948
949    pub(super) fn write_wrapped_struct_matrix_get_function_name(
950        &mut self,
951        access: WrappedStructMatrixAccess,
952    ) -> BackendResult {
953        let name = &self.names[&NameKey::Type(access.ty)];
954        let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
955        write!(self.out, "GetMat{field_name}On{name}")?;
956        Ok(())
957    }
958
959    /// Writes a function used to get a matCx2 from within a structure.
960    pub(super) fn write_wrapped_struct_matrix_get_function(
961        &mut self,
962        module: &crate::Module,
963        access: WrappedStructMatrixAccess,
964    ) -> BackendResult {
965        use crate::back::INDENT;
966
967        const STRUCT_ARGUMENT_VARIABLE_NAME: &str = "obj";
968
969        // Write function return type and name
970        let member = match module.types[access.ty].inner {
971            crate::TypeInner::Struct { ref members, .. } => &members[access.index as usize],
972            _ => unreachable!(),
973        };
974        let ret_ty = &module.types[member.ty].inner;
975        self.write_value_type(module, ret_ty)?;
976        write!(self.out, " ")?;
977        self.write_wrapped_struct_matrix_get_function_name(access)?;
978
979        // Write function parameters
980        write!(self.out, "(")?;
981        let struct_name = &self.names[&NameKey::Type(access.ty)];
982        write!(self.out, "{struct_name} {STRUCT_ARGUMENT_VARIABLE_NAME}")?;
983
984        // Write function body
985        writeln!(self.out, ") {{")?;
986
987        // Write return value
988        write!(self.out, "{INDENT}return ")?;
989        self.write_value_type(module, ret_ty)?;
990        write!(self.out, "(")?;
991        let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
992        match module.types[member.ty].inner {
993            crate::TypeInner::Matrix { columns, .. } => {
994                for i in 0..columns as u8 {
995                    if i != 0 {
996                        write!(self.out, ", ")?;
997                    }
998                    write!(self.out, "{STRUCT_ARGUMENT_VARIABLE_NAME}.{field_name}_{i}")?;
999                }
1000            }
1001            _ => unreachable!(),
1002        }
1003        writeln!(self.out, ");")?;
1004
1005        // End of function body
1006        writeln!(self.out, "}}")?;
1007        // Write extra new line
1008        writeln!(self.out)?;
1009
1010        Ok(())
1011    }
1012
1013    pub(super) fn write_wrapped_struct_matrix_set_function_name(
1014        &mut self,
1015        access: WrappedStructMatrixAccess,
1016    ) -> BackendResult {
1017        let name = &self.names[&NameKey::Type(access.ty)];
1018        let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
1019        write!(self.out, "SetMat{field_name}On{name}")?;
1020        Ok(())
1021    }
1022
1023    /// Writes a function used to set a matCx2 from within a structure.
1024    pub(super) fn write_wrapped_struct_matrix_set_function(
1025        &mut self,
1026        module: &crate::Module,
1027        access: WrappedStructMatrixAccess,
1028    ) -> BackendResult {
1029        use crate::back::INDENT;
1030
1031        const STRUCT_ARGUMENT_VARIABLE_NAME: &str = "obj";
1032        const MATRIX_ARGUMENT_VARIABLE_NAME: &str = "mat";
1033
1034        // Write function return type and name
1035        write!(self.out, "void ")?;
1036        self.write_wrapped_struct_matrix_set_function_name(access)?;
1037
1038        // Write function parameters
1039        write!(self.out, "(")?;
1040        let struct_name = &self.names[&NameKey::Type(access.ty)];
1041        write!(self.out, "{struct_name} {STRUCT_ARGUMENT_VARIABLE_NAME}, ")?;
1042        let member = match module.types[access.ty].inner {
1043            crate::TypeInner::Struct { ref members, .. } => &members[access.index as usize],
1044            _ => unreachable!(),
1045        };
1046        self.write_type(module, member.ty)?;
1047        write!(self.out, " {MATRIX_ARGUMENT_VARIABLE_NAME}")?;
1048        // Write function body
1049        writeln!(self.out, ") {{")?;
1050
1051        let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
1052
1053        match module.types[member.ty].inner {
1054            crate::TypeInner::Matrix { columns, .. } => {
1055                for i in 0..columns as u8 {
1056                    writeln!(
1057                        self.out,
1058                        "{INDENT}{STRUCT_ARGUMENT_VARIABLE_NAME}.{field_name}_{i} = {MATRIX_ARGUMENT_VARIABLE_NAME}[{i}];"
1059                    )?;
1060                }
1061            }
1062            _ => unreachable!(),
1063        }
1064
1065        // End of function body
1066        writeln!(self.out, "}}")?;
1067        // Write extra new line
1068        writeln!(self.out)?;
1069
1070        Ok(())
1071    }
1072
1073    pub(super) fn write_wrapped_struct_matrix_set_vec_function_name(
1074        &mut self,
1075        access: WrappedStructMatrixAccess,
1076    ) -> BackendResult {
1077        let name = &self.names[&NameKey::Type(access.ty)];
1078        let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
1079        write!(self.out, "SetMatVec{field_name}On{name}")?;
1080        Ok(())
1081    }
1082
1083    /// Writes a function used to set a vec2 on a matCx2 from within a structure.
1084    pub(super) fn write_wrapped_struct_matrix_set_vec_function(
1085        &mut self,
1086        module: &crate::Module,
1087        access: WrappedStructMatrixAccess,
1088    ) -> BackendResult {
1089        use crate::back::INDENT;
1090
1091        const STRUCT_ARGUMENT_VARIABLE_NAME: &str = "obj";
1092        const VECTOR_ARGUMENT_VARIABLE_NAME: &str = "vec";
1093        const MATRIX_INDEX_ARGUMENT_VARIABLE_NAME: &str = "mat_idx";
1094
1095        // Write function return type and name
1096        write!(self.out, "void ")?;
1097        self.write_wrapped_struct_matrix_set_vec_function_name(access)?;
1098
1099        // Write function parameters
1100        write!(self.out, "(")?;
1101        let struct_name = &self.names[&NameKey::Type(access.ty)];
1102        write!(self.out, "{struct_name} {STRUCT_ARGUMENT_VARIABLE_NAME}, ")?;
1103        let member = match module.types[access.ty].inner {
1104            crate::TypeInner::Struct { ref members, .. } => &members[access.index as usize],
1105            _ => unreachable!(),
1106        };
1107        let vec_ty = match module.types[member.ty].inner {
1108            crate::TypeInner::Matrix { rows, scalar, .. } => {
1109                crate::TypeInner::Vector { size: rows, scalar }
1110            }
1111            _ => unreachable!(),
1112        };
1113        self.write_value_type(module, &vec_ty)?;
1114        write!(
1115            self.out,
1116            " {VECTOR_ARGUMENT_VARIABLE_NAME}, uint {MATRIX_INDEX_ARGUMENT_VARIABLE_NAME}"
1117        )?;
1118
1119        // Write function body
1120        writeln!(self.out, ") {{")?;
1121
1122        writeln!(
1123            self.out,
1124            "{INDENT}switch({MATRIX_INDEX_ARGUMENT_VARIABLE_NAME}) {{"
1125        )?;
1126
1127        let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
1128
1129        match module.types[member.ty].inner {
1130            crate::TypeInner::Matrix { columns, .. } => {
1131                for i in 0..columns as u8 {
1132                    writeln!(
1133                        self.out,
1134                        "{INDENT}case {i}: {{ {STRUCT_ARGUMENT_VARIABLE_NAME}.{field_name}_{i} = {VECTOR_ARGUMENT_VARIABLE_NAME}; break; }}"
1135                    )?;
1136                }
1137            }
1138            _ => unreachable!(),
1139        }
1140
1141        writeln!(self.out, "{INDENT}}}")?;
1142
1143        // End of function body
1144        writeln!(self.out, "}}")?;
1145        // Write extra new line
1146        writeln!(self.out)?;
1147
1148        Ok(())
1149    }
1150
1151    pub(super) fn write_wrapped_struct_matrix_set_scalar_function_name(
1152        &mut self,
1153        access: WrappedStructMatrixAccess,
1154    ) -> BackendResult {
1155        let name = &self.names[&NameKey::Type(access.ty)];
1156        let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
1157        write!(self.out, "SetMatScalar{field_name}On{name}")?;
1158        Ok(())
1159    }
1160
1161    /// Writes a function used to set a float on a matCx2 from within a structure.
1162    pub(super) fn write_wrapped_struct_matrix_set_scalar_function(
1163        &mut self,
1164        module: &crate::Module,
1165        access: WrappedStructMatrixAccess,
1166    ) -> BackendResult {
1167        use crate::back::INDENT;
1168
1169        const STRUCT_ARGUMENT_VARIABLE_NAME: &str = "obj";
1170        const SCALAR_ARGUMENT_VARIABLE_NAME: &str = "scalar";
1171        const MATRIX_INDEX_ARGUMENT_VARIABLE_NAME: &str = "mat_idx";
1172        const VECTOR_INDEX_ARGUMENT_VARIABLE_NAME: &str = "vec_idx";
1173
1174        // Write function return type and name
1175        write!(self.out, "void ")?;
1176        self.write_wrapped_struct_matrix_set_scalar_function_name(access)?;
1177
1178        // Write function parameters
1179        write!(self.out, "(")?;
1180        let struct_name = &self.names[&NameKey::Type(access.ty)];
1181        write!(self.out, "{struct_name} {STRUCT_ARGUMENT_VARIABLE_NAME}, ")?;
1182        let member = match module.types[access.ty].inner {
1183            crate::TypeInner::Struct { ref members, .. } => &members[access.index as usize],
1184            _ => unreachable!(),
1185        };
1186        let scalar_ty = match module.types[member.ty].inner {
1187            crate::TypeInner::Matrix { scalar, .. } => crate::TypeInner::Scalar(scalar),
1188            _ => unreachable!(),
1189        };
1190        self.write_value_type(module, &scalar_ty)?;
1191        write!(
1192            self.out,
1193            " {SCALAR_ARGUMENT_VARIABLE_NAME}, uint {MATRIX_INDEX_ARGUMENT_VARIABLE_NAME}, uint {VECTOR_INDEX_ARGUMENT_VARIABLE_NAME}"
1194        )?;
1195
1196        // Write function body
1197        writeln!(self.out, ") {{")?;
1198
1199        writeln!(
1200            self.out,
1201            "{INDENT}switch({MATRIX_INDEX_ARGUMENT_VARIABLE_NAME}) {{"
1202        )?;
1203
1204        let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
1205
1206        match module.types[member.ty].inner {
1207            crate::TypeInner::Matrix { columns, .. } => {
1208                for i in 0..columns as u8 {
1209                    writeln!(
1210                        self.out,
1211                        "{INDENT}case {i}: {{ {STRUCT_ARGUMENT_VARIABLE_NAME}.{field_name}_{i}[{VECTOR_INDEX_ARGUMENT_VARIABLE_NAME}] = {SCALAR_ARGUMENT_VARIABLE_NAME}; break; }}"
1212                    )?;
1213                }
1214            }
1215            _ => unreachable!(),
1216        }
1217
1218        writeln!(self.out, "{INDENT}}}")?;
1219
1220        // End of function body
1221        writeln!(self.out, "}}")?;
1222        // Write extra new line
1223        writeln!(self.out)?;
1224
1225        Ok(())
1226    }
1227
1228    /// Write functions to create special types.
1229    pub(super) fn write_special_functions(&mut self, module: &crate::Module) -> BackendResult {
1230        for (type_key, struct_ty) in module.special_types.predeclared_types.iter() {
1231            match type_key {
1232                &crate::PredeclaredType::ModfResult { size, scalar }
1233                | &crate::PredeclaredType::FrexpResult { size, scalar } => {
1234                    let arg_type_name_owner;
1235                    let arg_type_name = if let Some(size) = size {
1236                        arg_type_name_owner = format!(
1237                            "{}{}",
1238                            if scalar.width == 8 { "double" } else { "float" },
1239                            size as u8
1240                        );
1241                        &arg_type_name_owner
1242                    } else if scalar.width == 8 {
1243                        "double"
1244                    } else {
1245                        "float"
1246                    };
1247
1248                    let (defined_func_name, called_func_name, second_field_name, sign_multiplier) =
1249                        if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) {
1250                            (super::writer::MODF_FUNCTION, "modf", "whole", "")
1251                        } else {
1252                            (
1253                                super::writer::FREXP_FUNCTION,
1254                                "frexp",
1255                                "exp_",
1256                                "sign(arg) * ",
1257                            )
1258                        };
1259
1260                    let struct_name = &self.names[&NameKey::Type(*struct_ty)];
1261
1262                    writeln!(
1263                        self.out,
1264                        "{struct_name} {defined_func_name}({arg_type_name} arg) {{
1265    {arg_type_name} other;
1266    {struct_name} result;
1267    result.fract = {sign_multiplier}{called_func_name}(arg, other);
1268    result.{second_field_name} = other;
1269    return result;
1270}}"
1271                    )?;
1272                    writeln!(self.out)?;
1273                }
1274                &crate::PredeclaredType::AtomicCompareExchangeWeakResult { .. } => {}
1275            }
1276        }
1277        if module.special_types.ray_desc.is_some() {
1278            self.write_ray_desc_from_ray_desc_constructor_function(module)?;
1279        }
1280
1281        Ok(())
1282    }
1283
1284    /// Helper function that writes wrapped functions for expressions in a function
1285    pub(super) fn write_wrapped_expression_functions(
1286        &mut self,
1287        module: &crate::Module,
1288        expressions: &crate::Arena<crate::Expression>,
1289        context: Option<&FunctionCtx>,
1290    ) -> BackendResult {
1291        for (handle, _) in expressions.iter() {
1292            match expressions[handle] {
1293                crate::Expression::Compose { ty, .. } => {
1294                    match module.types[ty].inner {
1295                        crate::TypeInner::Struct { .. } | crate::TypeInner::Array { .. } => {
1296                            let constructor = WrappedConstructor { ty };
1297                            if self.wrapped.insert(WrappedType::Constructor(constructor)) {
1298                                self.write_wrapped_constructor_function(module, constructor)?;
1299                            }
1300                        }
1301                        _ => {}
1302                    };
1303                }
1304                crate::Expression::ImageLoad { image, .. } => {
1305                    // This can only happen in a function as this is not a valid const expression
1306                    match *context.as_ref().unwrap().resolve_type(image, &module.types) {
1307                        crate::TypeInner::Image {
1308                            class: crate::ImageClass::Storage { format, .. },
1309                            ..
1310                        } => {
1311                            if format.single_component() {
1312                                let scalar: crate::Scalar = format.into();
1313                                if self.wrapped.insert(WrappedType::ImageLoadScalar(scalar)) {
1314                                    self.write_loaded_scalar_to_storage_loaded_value(scalar)?;
1315                                }
1316                            }
1317                        }
1318                        _ => {}
1319                    }
1320                }
1321                crate::Expression::RayQueryGetIntersection { committed, .. } => {
1322                    if committed {
1323                        if !self.written_committed_intersection {
1324                            self.write_committed_intersection_function(module)?;
1325                            self.written_committed_intersection = true;
1326                        }
1327                    } else if !self.written_candidate_intersection {
1328                        self.write_candidate_intersection_function(module)?;
1329                        self.written_candidate_intersection = true;
1330                    }
1331                }
1332                _ => {}
1333            }
1334        }
1335        Ok(())
1336    }
1337
1338    // TODO: we could merge this with iteration in write_wrapped_expression_functions...
1339    //
1340    /// Helper function that writes zero value wrapped functions
1341    pub(super) fn write_wrapped_zero_value_functions(
1342        &mut self,
1343        module: &crate::Module,
1344        expressions: &crate::Arena<crate::Expression>,
1345    ) -> BackendResult {
1346        for (handle, _) in expressions.iter() {
1347            if let crate::Expression::ZeroValue(ty) = expressions[handle] {
1348                let zero_value = WrappedZeroValue { ty };
1349                if self.wrapped.insert(WrappedType::ZeroValue(zero_value)) {
1350                    self.write_wrapped_zero_value_function(module, zero_value)?;
1351                }
1352            }
1353        }
1354        Ok(())
1355    }
1356
1357    pub(super) fn write_wrapped_math_functions(
1358        &mut self,
1359        module: &crate::Module,
1360        func_ctx: &FunctionCtx,
1361    ) -> BackendResult {
1362        for (_, expression) in func_ctx.expressions.iter() {
1363            if let crate::Expression::Math {
1364                fun,
1365                arg,
1366                arg1: _arg1,
1367                arg2: _arg2,
1368                arg3: _arg3,
1369            } = *expression
1370            {
1371                let arg_ty = func_ctx.resolve_type(arg, &module.types);
1372
1373                match fun {
1374                    crate::MathFunction::ExtractBits => {
1375                        // The behavior of our extractBits polyfill is undefined if offset + count > bit_width. We need
1376                        // to first sanitize the offset and count first. If we don't do this, we will get out-of-spec
1377                        // values if the extracted range is not within the bit width.
1378                        //
1379                        // This encodes the exact formula specified by the wgsl spec:
1380                        // https://gpuweb.github.io/gpuweb/wgsl/#extractBits-unsigned-builtin
1381                        //
1382                        // w = sizeof(x) * 8
1383                        // o = min(offset, w)
1384                        // c = min(count, w - o)
1385                        //
1386                        // bitfieldExtract(x, o, c)
1387                        let scalar = arg_ty.scalar().unwrap();
1388                        let components = arg_ty.components();
1389
1390                        let wrapped = WrappedMath {
1391                            fun,
1392                            scalar,
1393                            components,
1394                        };
1395
1396                        if !self.wrapped.insert(WrappedType::Math(wrapped)) {
1397                            continue;
1398                        }
1399
1400                        // Write return type
1401                        self.write_value_type(module, arg_ty)?;
1402
1403                        let scalar_width: u8 = scalar.width * 8;
1404
1405                        // Write function name and parameters
1406                        writeln!(self.out, " {EXTRACT_BITS_FUNCTION}(")?;
1407                        write!(self.out, "    ")?;
1408                        self.write_value_type(module, arg_ty)?;
1409                        writeln!(self.out, " e,")?;
1410                        writeln!(self.out, "    uint offset,")?;
1411                        writeln!(self.out, "    uint count")?;
1412                        writeln!(self.out, ") {{")?;
1413
1414                        // Write function body
1415                        writeln!(self.out, "    uint w = {scalar_width};")?;
1416                        writeln!(self.out, "    uint o = min(offset, w);")?;
1417                        writeln!(self.out, "    uint c = min(count, w - o);")?;
1418                        writeln!(
1419                            self.out,
1420                            "    return (c == 0 ? 0 : (e << (w - c - o)) >> (w - c));"
1421                        )?;
1422
1423                        // End of function body
1424                        writeln!(self.out, "}}")?;
1425                    }
1426                    crate::MathFunction::InsertBits => {
1427                        // The behavior of our insertBits polyfill has the same constraints as the extractBits polyfill.
1428
1429                        let scalar = arg_ty.scalar().unwrap();
1430                        let components = arg_ty.components();
1431
1432                        let wrapped = WrappedMath {
1433                            fun,
1434                            scalar,
1435                            components,
1436                        };
1437
1438                        if !self.wrapped.insert(WrappedType::Math(wrapped)) {
1439                            continue;
1440                        }
1441
1442                        // Write return type
1443                        self.write_value_type(module, arg_ty)?;
1444
1445                        let scalar_width: u8 = scalar.width * 8;
1446                        let scalar_max: u64 = match scalar.width {
1447                            1 => 0xFF,
1448                            2 => 0xFFFF,
1449                            4 => 0xFFFFFFFF,
1450                            8 => 0xFFFFFFFFFFFFFFFF,
1451                            _ => unreachable!(),
1452                        };
1453
1454                        // Write function name and parameters
1455                        writeln!(self.out, " {INSERT_BITS_FUNCTION}(")?;
1456                        write!(self.out, "    ")?;
1457                        self.write_value_type(module, arg_ty)?;
1458                        writeln!(self.out, " e,")?;
1459                        write!(self.out, "    ")?;
1460                        self.write_value_type(module, arg_ty)?;
1461                        writeln!(self.out, " newbits,")?;
1462                        writeln!(self.out, "    uint offset,")?;
1463                        writeln!(self.out, "    uint count")?;
1464                        writeln!(self.out, ") {{")?;
1465
1466                        // Write function body
1467                        writeln!(self.out, "    uint w = {scalar_width}u;")?;
1468                        writeln!(self.out, "    uint o = min(offset, w);")?;
1469                        writeln!(self.out, "    uint c = min(count, w - o);")?;
1470
1471                        // The `u` suffix on the literals is _extremely_ important. Otherwise it will use
1472                        // i32 shifting instead of the intended u32 shifting.
1473                        writeln!(
1474                            self.out,
1475                            "    uint mask = (({scalar_max}u >> ({scalar_width}u - c)) << o);"
1476                        )?;
1477                        writeln!(
1478                            self.out,
1479                            "    return (c == 0 ? e : ((e & ~mask) | ((newbits << o) & mask)));"
1480                        )?;
1481
1482                        // End of function body
1483                        writeln!(self.out, "}}")?;
1484                    }
1485                    // Taking the absolute value of the minimum value of a two's
1486                    // complement signed integer type causes overflow, which is
1487                    // undefined behaviour in HLSL. To avoid this, when the value is
1488                    // negative we bitcast the value to unsigned and negate it, then
1489                    // bitcast back to signed.
1490                    // This adheres to the WGSL spec in that the absolute of the type's
1491                    // minimum value should equal to the minimum value.
1492                    //
1493                    // TODO(#7109): asint()/asuint() only support 32-bit integers, so we
1494                    // must find another solution for different bit-widths.
1495                    crate::MathFunction::Abs
1496                        if matches!(arg_ty.scalar(), Some(crate::Scalar::I32)) =>
1497                    {
1498                        let scalar = arg_ty.scalar().unwrap();
1499                        let components = arg_ty.components();
1500
1501                        let wrapped = WrappedMath {
1502                            fun,
1503                            scalar,
1504                            components,
1505                        };
1506
1507                        if !self.wrapped.insert(WrappedType::Math(wrapped)) {
1508                            continue;
1509                        }
1510
1511                        self.write_value_type(module, arg_ty)?;
1512                        write!(self.out, " {ABS_FUNCTION}(")?;
1513                        self.write_value_type(module, arg_ty)?;
1514                        writeln!(self.out, " val) {{")?;
1515
1516                        let level = crate::back::Level(1);
1517                        writeln!(
1518                            self.out,
1519                            "{level}return val >= 0 ? val : asint(-asuint(val));"
1520                        )?;
1521                        writeln!(self.out, "}}")?;
1522                        writeln!(self.out)?;
1523                    }
1524                    _ => {}
1525                }
1526            }
1527        }
1528
1529        Ok(())
1530    }
1531
1532    pub(super) fn write_wrapped_unary_ops(
1533        &mut self,
1534        module: &crate::Module,
1535        func_ctx: &FunctionCtx,
1536    ) -> BackendResult {
1537        for (_, expression) in func_ctx.expressions.iter() {
1538            if let crate::Expression::Unary { op, expr } = *expression {
1539                let expr_ty = func_ctx.resolve_type(expr, &module.types);
1540                let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
1541                    continue;
1542                };
1543                let wrapped = WrappedUnaryOp {
1544                    op,
1545                    ty: (vector_size, scalar),
1546                };
1547
1548                // Negating the minimum value of a two's complement signed integer type
1549                // causes overflow, which is undefined behaviour in HLSL. To avoid this
1550                // we bitcast the value to unsigned and negate it, then bitcast back to
1551                // signed. This adheres to the WGSL spec in that the negative of the
1552                // type's minimum value should equal to the minimum value.
1553                //
1554                // TODO(#7109): asint()/asuint() only support 32-bit integers, so we must
1555                // find another solution for different bit-widths.
1556                match (op, scalar) {
1557                    (crate::UnaryOperator::Negate, crate::Scalar::I32) => {
1558                        if !self.wrapped.insert(WrappedType::UnaryOp(wrapped)) {
1559                            continue;
1560                        }
1561
1562                        self.write_value_type(module, expr_ty)?;
1563                        write!(self.out, " {NEG_FUNCTION}(")?;
1564                        self.write_value_type(module, expr_ty)?;
1565                        writeln!(self.out, " val) {{")?;
1566
1567                        let level = crate::back::Level(1);
1568                        writeln!(self.out, "{level}return asint(-asuint(val));",)?;
1569                        writeln!(self.out, "}}")?;
1570                        writeln!(self.out)?;
1571                    }
1572                    _ => {}
1573                }
1574            }
1575        }
1576
1577        Ok(())
1578    }
1579
1580    pub(super) fn write_wrapped_binary_ops(
1581        &mut self,
1582        module: &crate::Module,
1583        func_ctx: &FunctionCtx,
1584    ) -> BackendResult {
1585        for (expr_handle, expression) in func_ctx.expressions.iter() {
1586            if let crate::Expression::Binary { op, left, right } = *expression {
1587                let expr_ty = func_ctx.resolve_type(expr_handle, &module.types);
1588                let left_ty = func_ctx.resolve_type(left, &module.types);
1589                let right_ty = func_ctx.resolve_type(right, &module.types);
1590
1591                match (op, expr_ty.scalar()) {
1592                    // Signed integer division of the type's minimum representable value
1593                    // divided by -1, or signed or unsigned division by zero, is
1594                    // undefined behaviour in HLSL. We override the divisor to 1 in these
1595                    // cases.
1596                    // This adheres to the WGSL spec in that:
1597                    // * TYPE_MIN / -1 == TYPE_MIN
1598                    // * x / 0 == x
1599                    (
1600                        crate::BinaryOperator::Divide,
1601                        Some(
1602                            scalar @ crate::Scalar {
1603                                kind: ScalarKind::Sint | ScalarKind::Uint,
1604                                ..
1605                            },
1606                        ),
1607                    ) => {
1608                        let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
1609                            continue;
1610                        };
1611                        let Some(right_wrapped_ty) = right_ty.vector_size_and_scalar() else {
1612                            continue;
1613                        };
1614                        let wrapped = WrappedBinaryOp {
1615                            op,
1616                            left_ty: left_wrapped_ty,
1617                            right_ty: right_wrapped_ty,
1618                        };
1619                        if !self.wrapped.insert(WrappedType::BinaryOp(wrapped)) {
1620                            continue;
1621                        }
1622
1623                        self.write_value_type(module, expr_ty)?;
1624                        write!(self.out, " {DIV_FUNCTION}(")?;
1625                        self.write_value_type(module, left_ty)?;
1626                        write!(self.out, " lhs, ")?;
1627                        self.write_value_type(module, right_ty)?;
1628                        writeln!(self.out, " rhs) {{")?;
1629                        let level = crate::back::Level(1);
1630                        match scalar.kind {
1631                            ScalarKind::Sint => {
1632                                let min_val = match scalar.width {
1633                                    4 => crate::Literal::I32(i32::MIN),
1634                                    8 => crate::Literal::I64(i64::MIN),
1635                                    _ => {
1636                                        return Err(super::Error::UnsupportedScalar(scalar));
1637                                    }
1638                                };
1639                                write!(self.out, "{level}return lhs / (((lhs == ")?;
1640                                self.write_literal(min_val)?;
1641                                writeln!(self.out, " & rhs == -1) | (rhs == 0)) ? 1 : rhs);")?
1642                            }
1643                            ScalarKind::Uint => {
1644                                writeln!(self.out, "{level}return lhs / (rhs == 0u ? 1u : rhs);")?
1645                            }
1646                            _ => unreachable!(),
1647                        }
1648                        writeln!(self.out, "}}")?;
1649                        writeln!(self.out)?;
1650                    }
1651                    // The modulus operator is only defined for integers in HLSL when
1652                    // either both sides are positive or both sides are negative. To
1653                    // avoid this undefined behaviour we use the following equation:
1654                    //
1655                    // dividend - (dividend / divisor) * divisor
1656                    //
1657                    // overriding the divisor to 1 if either it is 0, or it is -1
1658                    // and the dividend is the minimum representable value.
1659                    //
1660                    // This adheres to the WGSL spec in that:
1661                    // * min_value % -1 == 0
1662                    // * x % 0 == 0
1663                    (
1664                        crate::BinaryOperator::Modulo,
1665                        Some(
1666                            scalar @ crate::Scalar {
1667                                kind: ScalarKind::Sint | ScalarKind::Uint | ScalarKind::Float,
1668                                ..
1669                            },
1670                        ),
1671                    ) => {
1672                        let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
1673                            continue;
1674                        };
1675                        let Some(right_wrapped_ty) = right_ty.vector_size_and_scalar() else {
1676                            continue;
1677                        };
1678                        let wrapped = WrappedBinaryOp {
1679                            op,
1680                            left_ty: left_wrapped_ty,
1681                            right_ty: right_wrapped_ty,
1682                        };
1683                        if !self.wrapped.insert(WrappedType::BinaryOp(wrapped)) {
1684                            continue;
1685                        }
1686
1687                        self.write_value_type(module, expr_ty)?;
1688                        write!(self.out, " {MOD_FUNCTION}(")?;
1689                        self.write_value_type(module, left_ty)?;
1690                        write!(self.out, " lhs, ")?;
1691                        self.write_value_type(module, right_ty)?;
1692                        writeln!(self.out, " rhs) {{")?;
1693                        let level = crate::back::Level(1);
1694                        match scalar.kind {
1695                            ScalarKind::Sint => {
1696                                let min_val = match scalar.width {
1697                                    4 => crate::Literal::I32(i32::MIN),
1698                                    8 => crate::Literal::I64(i64::MIN),
1699                                    _ => {
1700                                        return Err(super::Error::UnsupportedScalar(scalar));
1701                                    }
1702                                };
1703                                write!(self.out, "{level}")?;
1704                                self.write_value_type(module, right_ty)?;
1705                                write!(self.out, " divisor = ((lhs == ")?;
1706                                self.write_literal(min_val)?;
1707                                writeln!(self.out, " & rhs == -1) | (rhs == 0)) ? 1 : rhs;")?;
1708                                writeln!(
1709                                    self.out,
1710                                    "{level}return lhs - (lhs / divisor) * divisor;"
1711                                )?
1712                            }
1713                            ScalarKind::Uint => {
1714                                writeln!(self.out, "{level}return lhs % (rhs == 0u ? 1u : rhs);")?
1715                            }
1716                            // HLSL's fmod has the same definition as WGSL's % operator but due
1717                            // to its implementation in DXC it is not as accurate as the WGSL spec
1718                            // requires it to be. See:
1719                            // - https://shader-playground.timjones.io/0c8572816dbb6fc4435cc5d016a978a7
1720                            // - https://github.com/llvm/llvm-project/blob/50f9b8acafdca48e87e6b8e393c1f116a2d193ee/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h#L78-L81
1721                            ScalarKind::Float => {
1722                                writeln!(self.out, "{level}return lhs - rhs * trunc(lhs / rhs);")?
1723                            }
1724                            _ => unreachable!(),
1725                        }
1726                        writeln!(self.out, "}}")?;
1727                        writeln!(self.out)?;
1728                    }
1729                    _ => {}
1730                }
1731            }
1732        }
1733
1734        Ok(())
1735    }
1736
1737    fn write_wrapped_cast_functions(
1738        &mut self,
1739        module: &crate::Module,
1740        func_ctx: &FunctionCtx,
1741    ) -> BackendResult {
1742        for (_, expression) in func_ctx.expressions.iter() {
1743            if let crate::Expression::As {
1744                expr,
1745                kind,
1746                convert: Some(width),
1747            } = *expression
1748            {
1749                // Avoid undefined behaviour when casting from a float to integer
1750                // when the value is out of range for the target type. Additionally
1751                // ensure we clamp to the correct value as per the WGSL spec.
1752                //
1753                // https://www.w3.org/TR/WGSL/#floating-point-conversion:
1754                // * If X is exactly representable in the target type T, then the
1755                //   result is that value.
1756                // * Otherwise, the result is the value in T closest to
1757                //   truncate(X) and also exactly representable in the original
1758                //   floating point type.
1759                let src_ty = func_ctx.resolve_type(expr, &module.types);
1760                let Some((vector_size, src_scalar)) = src_ty.vector_size_and_scalar() else {
1761                    continue;
1762                };
1763                let dst_scalar = crate::Scalar { kind, width };
1764                if src_scalar.kind != ScalarKind::Float
1765                    || (dst_scalar.kind != ScalarKind::Sint && dst_scalar.kind != ScalarKind::Uint)
1766                {
1767                    continue;
1768                }
1769
1770                let wrapped = WrappedCast {
1771                    src_scalar,
1772                    vector_size,
1773                    dst_scalar,
1774                };
1775                if !self.wrapped.insert(WrappedType::Cast(wrapped)) {
1776                    continue;
1777                }
1778
1779                let (src_ty, dst_ty) = match vector_size {
1780                    None => (
1781                        crate::TypeInner::Scalar(src_scalar),
1782                        crate::TypeInner::Scalar(dst_scalar),
1783                    ),
1784                    Some(vector_size) => (
1785                        crate::TypeInner::Vector {
1786                            scalar: src_scalar,
1787                            size: vector_size,
1788                        },
1789                        crate::TypeInner::Vector {
1790                            scalar: dst_scalar,
1791                            size: vector_size,
1792                        },
1793                    ),
1794                };
1795                let (min, max) =
1796                    crate::proc::min_max_float_representable_by(src_scalar, dst_scalar);
1797                let cast_str = format!(
1798                    "{}{}",
1799                    dst_scalar.to_hlsl_str()?,
1800                    vector_size
1801                        .map(crate::common::vector_size_str)
1802                        .unwrap_or(""),
1803                );
1804                let fun_name = match dst_scalar {
1805                    crate::Scalar::I32 => F2I32_FUNCTION,
1806                    crate::Scalar::U32 => F2U32_FUNCTION,
1807                    crate::Scalar::I64 => F2I64_FUNCTION,
1808                    crate::Scalar::U64 => F2U64_FUNCTION,
1809                    _ => unreachable!(),
1810                };
1811                self.write_value_type(module, &dst_ty)?;
1812                write!(self.out, " {fun_name}(")?;
1813                self.write_value_type(module, &src_ty)?;
1814                writeln!(self.out, " value) {{")?;
1815                let level = crate::back::Level(1);
1816                write!(self.out, "{level}return {cast_str}(clamp(value, ")?;
1817                self.write_literal(min)?;
1818                write!(self.out, ", ")?;
1819                self.write_literal(max)?;
1820                writeln!(self.out, "));",)?;
1821                writeln!(self.out, "}}")?;
1822                writeln!(self.out)?;
1823            }
1824        }
1825        Ok(())
1826    }
1827
1828    /// Helper function that writes various wrapped functions
1829    pub(super) fn write_wrapped_functions(
1830        &mut self,
1831        module: &crate::Module,
1832        func_ctx: &FunctionCtx,
1833    ) -> BackendResult {
1834        self.write_wrapped_math_functions(module, func_ctx)?;
1835        self.write_wrapped_unary_ops(module, func_ctx)?;
1836        self.write_wrapped_binary_ops(module, func_ctx)?;
1837        self.write_wrapped_expression_functions(module, func_ctx.expressions, Some(func_ctx))?;
1838        self.write_wrapped_zero_value_functions(module, func_ctx.expressions)?;
1839        self.write_wrapped_cast_functions(module, func_ctx)?;
1840
1841        for (handle, _) in func_ctx.expressions.iter() {
1842            match func_ctx.expressions[handle] {
1843                crate::Expression::ArrayLength(expr) => {
1844                    let global_expr = match func_ctx.expressions[expr] {
1845                        crate::Expression::GlobalVariable(_) => expr,
1846                        crate::Expression::AccessIndex { base, index: _ } => base,
1847                        ref other => unreachable!("Array length of {:?}", other),
1848                    };
1849                    let global_var = match func_ctx.expressions[global_expr] {
1850                        crate::Expression::GlobalVariable(var_handle) => {
1851                            &module.global_variables[var_handle]
1852                        }
1853                        ref other => {
1854                            return Err(super::Error::Unimplemented(format!(
1855                                "Array length of base {other:?}"
1856                            )))
1857                        }
1858                    };
1859                    let storage_access = match global_var.space {
1860                        crate::AddressSpace::Storage { access } => access,
1861                        _ => crate::StorageAccess::default(),
1862                    };
1863                    let wal = WrappedArrayLength {
1864                        writable: storage_access.contains(crate::StorageAccess::STORE),
1865                    };
1866
1867                    if self.wrapped.insert(WrappedType::ArrayLength(wal)) {
1868                        self.write_wrapped_array_length_function(wal)?;
1869                    }
1870                }
1871                crate::Expression::ImageLoad { image, .. } => {
1872                    let class = match *func_ctx.resolve_type(image, &module.types) {
1873                        crate::TypeInner::Image { class, .. } => class,
1874                        _ => unreachable!(),
1875                    };
1876                    let wrapped = WrappedImageLoad { class };
1877                    if self.wrapped.insert(WrappedType::ImageLoad(wrapped)) {
1878                        self.write_wrapped_image_load_function(module, wrapped)?;
1879                    }
1880                }
1881                crate::Expression::ImageSample {
1882                    image,
1883                    clamp_to_edge,
1884                    ..
1885                } => {
1886                    let class = match *func_ctx.resolve_type(image, &module.types) {
1887                        crate::TypeInner::Image { class, .. } => class,
1888                        _ => unreachable!(),
1889                    };
1890                    let wrapped = WrappedImageSample {
1891                        class,
1892                        clamp_to_edge,
1893                    };
1894                    if self.wrapped.insert(WrappedType::ImageSample(wrapped)) {
1895                        self.write_wrapped_image_sample_function(module, wrapped)?;
1896                    }
1897                }
1898                crate::Expression::ImageQuery { image, query } => {
1899                    let wiq = match *func_ctx.resolve_type(image, &module.types) {
1900                        crate::TypeInner::Image {
1901                            dim,
1902                            arrayed,
1903                            class,
1904                        } => WrappedImageQuery {
1905                            dim,
1906                            arrayed,
1907                            class,
1908                            query: query.into(),
1909                        },
1910                        _ => unreachable!("we only query images"),
1911                    };
1912
1913                    if self.wrapped.insert(WrappedType::ImageQuery(wiq)) {
1914                        self.write_wrapped_image_query_function(module, wiq, handle, func_ctx)?;
1915                    }
1916                }
1917                // Write `WrappedConstructor` for structs that are loaded from `AddressSpace::Storage`
1918                // since they will later be used by the fn `write_storage_load`
1919                crate::Expression::Load { pointer } => {
1920                    let pointer_space = func_ctx
1921                        .resolve_type(pointer, &module.types)
1922                        .pointer_space();
1923
1924                    if let Some(crate::AddressSpace::Storage { .. }) = pointer_space {
1925                        if let Some(ty) = func_ctx.info[handle].ty.handle() {
1926                            write_wrapped_constructor(self, ty, module)?;
1927                        }
1928                    }
1929
1930                    fn write_wrapped_constructor<W: Write>(
1931                        writer: &mut super::Writer<'_, W>,
1932                        ty: Handle<crate::Type>,
1933                        module: &crate::Module,
1934                    ) -> BackendResult {
1935                        match module.types[ty].inner {
1936                            crate::TypeInner::Struct { ref members, .. } => {
1937                                for member in members {
1938                                    write_wrapped_constructor(writer, member.ty, module)?;
1939                                }
1940
1941                                let constructor = WrappedConstructor { ty };
1942                                if writer.wrapped.insert(WrappedType::Constructor(constructor)) {
1943                                    writer
1944                                        .write_wrapped_constructor_function(module, constructor)?;
1945                                }
1946                            }
1947                            crate::TypeInner::Array { base, .. } => {
1948                                write_wrapped_constructor(writer, base, module)?;
1949
1950                                let constructor = WrappedConstructor { ty };
1951                                if writer.wrapped.insert(WrappedType::Constructor(constructor)) {
1952                                    writer
1953                                        .write_wrapped_constructor_function(module, constructor)?;
1954                                }
1955                            }
1956                            _ => {}
1957                        };
1958
1959                        Ok(())
1960                    }
1961                }
1962                // We treat matrices of the form `matCx2` as a sequence of C `vec2`s
1963                // (see top level module docs for details).
1964                //
1965                // The functions injected here are required to get the matrix accesses working.
1966                crate::Expression::AccessIndex { base, index } => {
1967                    let base_ty_res = &func_ctx.info[base].ty;
1968                    let mut resolved = base_ty_res.inner_with(&module.types);
1969                    let base_ty_handle = match *resolved {
1970                        crate::TypeInner::Pointer { base, .. } => {
1971                            resolved = &module.types[base].inner;
1972                            Some(base)
1973                        }
1974                        _ => base_ty_res.handle(),
1975                    };
1976                    if let crate::TypeInner::Struct { ref members, .. } = *resolved {
1977                        let member = &members[index as usize];
1978
1979                        match module.types[member.ty].inner {
1980                            crate::TypeInner::Matrix {
1981                                rows: crate::VectorSize::Bi,
1982                                ..
1983                            } if member.binding.is_none() => {
1984                                let ty = base_ty_handle.unwrap();
1985                                let access = WrappedStructMatrixAccess { ty, index };
1986
1987                                if self.wrapped.insert(WrappedType::StructMatrixAccess(access)) {
1988                                    self.write_wrapped_struct_matrix_get_function(module, access)?;
1989                                    self.write_wrapped_struct_matrix_set_function(module, access)?;
1990                                    self.write_wrapped_struct_matrix_set_vec_function(
1991                                        module, access,
1992                                    )?;
1993                                    self.write_wrapped_struct_matrix_set_scalar_function(
1994                                        module, access,
1995                                    )?;
1996                                }
1997                            }
1998                            _ => {}
1999                        }
2000                    }
2001                }
2002                _ => {}
2003            };
2004        }
2005
2006        Ok(())
2007    }
2008
2009    /// Writes out the sampler heap declarations if they haven't been written yet.
2010    pub(super) fn write_sampler_heaps(&mut self) -> BackendResult {
2011        if self.wrapped.sampler_heaps {
2012            return Ok(());
2013        }
2014
2015        writeln!(
2016            self.out,
2017            "SamplerState {}[2048]: register(s{}, space{});",
2018            super::writer::SAMPLER_HEAP_VAR,
2019            self.options.sampler_heap_target.standard_samplers.register,
2020            self.options.sampler_heap_target.standard_samplers.space
2021        )?;
2022        writeln!(
2023            self.out,
2024            "SamplerComparisonState {}[2048]: register(s{}, space{});",
2025            super::writer::COMPARISON_SAMPLER_HEAP_VAR,
2026            self.options
2027                .sampler_heap_target
2028                .comparison_samplers
2029                .register,
2030            self.options.sampler_heap_target.comparison_samplers.space
2031        )?;
2032
2033        self.wrapped.sampler_heaps = true;
2034
2035        Ok(())
2036    }
2037
2038    /// Writes out the sampler index buffer declaration if it hasn't been written yet.
2039    pub(super) fn write_wrapped_sampler_buffer(
2040        &mut self,
2041        key: super::SamplerIndexBufferKey,
2042    ) -> BackendResult {
2043        // The astute will notice that we do a double hash lookup, but we do this to avoid
2044        // holding a mutable reference to `self` while trying to call `write_sampler_heaps`.
2045        //
2046        // We only pay this double lookup cost when we actually need to write out the sampler
2047        // buffer, which should be not be common.
2048
2049        if self.wrapped.sampler_index_buffers.contains_key(&key) {
2050            return Ok(());
2051        };
2052
2053        self.write_sampler_heaps()?;
2054
2055        // Because the group number can be arbitrary, we use the namer to generate a unique name
2056        // instead of adding it to the reserved name list.
2057        let sampler_array_name = self
2058            .namer
2059            .call(&format!("nagaGroup{}SamplerIndexArray", key.group));
2060
2061        let bind_target = match self.options.sampler_buffer_binding_map.get(&key) {
2062            Some(&bind_target) => bind_target,
2063            None if self.options.fake_missing_bindings => super::BindTarget {
2064                space: u8::MAX,
2065                register: key.group,
2066                binding_array_size: None,
2067                dynamic_storage_buffer_offsets_index: None,
2068                restrict_indexing: false,
2069            },
2070            None => {
2071                unreachable!("Sampler buffer of group {key:?} not bound to a register");
2072            }
2073        };
2074
2075        writeln!(
2076            self.out,
2077            "StructuredBuffer<uint> {sampler_array_name} : register(t{}, space{});",
2078            bind_target.register, bind_target.space
2079        )?;
2080
2081        self.wrapped
2082            .sampler_index_buffers
2083            .insert(key, sampler_array_name);
2084
2085        Ok(())
2086    }
2087
2088    pub(super) fn write_texture_coordinates(
2089        &mut self,
2090        kind: &str,
2091        coordinate: Handle<crate::Expression>,
2092        array_index: Option<Handle<crate::Expression>>,
2093        mip_level: Option<Handle<crate::Expression>>,
2094        module: &crate::Module,
2095        func_ctx: &FunctionCtx,
2096    ) -> BackendResult {
2097        // HLSL expects the array index to be merged with the coordinate
2098        let extra = array_index.is_some() as usize + (mip_level.is_some()) as usize;
2099        if extra == 0 {
2100            self.write_expr(module, coordinate, func_ctx)?;
2101        } else {
2102            let num_coords = match *func_ctx.resolve_type(coordinate, &module.types) {
2103                crate::TypeInner::Scalar { .. } => 1,
2104                crate::TypeInner::Vector { size, .. } => size as usize,
2105                _ => unreachable!(),
2106            };
2107            write!(self.out, "{}{}(", kind, num_coords + extra)?;
2108            self.write_expr(module, coordinate, func_ctx)?;
2109            if let Some(expr) = array_index {
2110                write!(self.out, ", ")?;
2111                self.write_expr(module, expr, func_ctx)?;
2112            }
2113            if let Some(expr) = mip_level {
2114                // Explicit cast if needed
2115                let cast_to_int = matches!(
2116                    *func_ctx.resolve_type(expr, &module.types),
2117                    crate::TypeInner::Scalar(crate::Scalar {
2118                        kind: ScalarKind::Uint,
2119                        ..
2120                    })
2121                );
2122
2123                write!(self.out, ", ")?;
2124
2125                if cast_to_int {
2126                    write!(self.out, "int(")?;
2127                }
2128
2129                self.write_expr(module, expr, func_ctx)?;
2130
2131                if cast_to_int {
2132                    write!(self.out, ")")?;
2133                }
2134            }
2135            write!(self.out, ")")?;
2136        }
2137        Ok(())
2138    }
2139
2140    pub(super) fn write_mat_cx2_typedef_and_functions(
2141        &mut self,
2142        WrappedMatCx2 { columns }: WrappedMatCx2,
2143    ) -> BackendResult {
2144        use crate::back::INDENT;
2145
2146        // typedef
2147        write!(self.out, "typedef struct {{ ")?;
2148        for i in 0..columns as u8 {
2149            write!(self.out, "float2 _{i}; ")?;
2150        }
2151        writeln!(self.out, "}} __mat{}x2;", columns as u8)?;
2152
2153        // __get_col_of_mat
2154        writeln!(
2155            self.out,
2156            "float2 __get_col_of_mat{}x2(__mat{}x2 mat, uint idx) {{",
2157            columns as u8, columns as u8
2158        )?;
2159        writeln!(self.out, "{INDENT}switch(idx) {{")?;
2160        for i in 0..columns as u8 {
2161            writeln!(self.out, "{INDENT}case {i}: {{ return mat._{i}; }}")?;
2162        }
2163        writeln!(self.out, "{INDENT}default: {{ return (float2)0; }}")?;
2164        writeln!(self.out, "{INDENT}}}")?;
2165        writeln!(self.out, "}}")?;
2166
2167        // __set_col_of_mat
2168        writeln!(
2169            self.out,
2170            "void __set_col_of_mat{}x2(__mat{}x2 mat, uint idx, float2 value) {{",
2171            columns as u8, columns as u8
2172        )?;
2173        writeln!(self.out, "{INDENT}switch(idx) {{")?;
2174        for i in 0..columns as u8 {
2175            writeln!(self.out, "{INDENT}case {i}: {{ mat._{i} = value; break; }}")?;
2176        }
2177        writeln!(self.out, "{INDENT}}}")?;
2178        writeln!(self.out, "}}")?;
2179
2180        // __set_el_of_mat
2181        writeln!(
2182            self.out,
2183            "void __set_el_of_mat{}x2(__mat{}x2 mat, uint idx, uint vec_idx, float value) {{",
2184            columns as u8, columns as u8
2185        )?;
2186        writeln!(self.out, "{INDENT}switch(idx) {{")?;
2187        for i in 0..columns as u8 {
2188            writeln!(
2189                self.out,
2190                "{INDENT}case {i}: {{ mat._{i}[vec_idx] = value; break; }}"
2191            )?;
2192        }
2193        writeln!(self.out, "{INDENT}}}")?;
2194        writeln!(self.out, "}}")?;
2195
2196        writeln!(self.out)?;
2197
2198        Ok(())
2199    }
2200
2201    pub(super) fn write_all_mat_cx2_typedefs_and_functions(
2202        &mut self,
2203        module: &crate::Module,
2204    ) -> BackendResult {
2205        for (handle, _) in module.global_variables.iter() {
2206            let global = &module.global_variables[handle];
2207
2208            if global.space == crate::AddressSpace::Uniform {
2209                if let Some(super::writer::MatrixType {
2210                    columns,
2211                    rows: crate::VectorSize::Bi,
2212                    width: 4,
2213                }) = super::writer::get_inner_matrix_data(module, global.ty)
2214                {
2215                    let entry = WrappedMatCx2 { columns };
2216                    if self.wrapped.insert(WrappedType::MatCx2(entry)) {
2217                        self.write_mat_cx2_typedef_and_functions(entry)?;
2218                    }
2219                }
2220            }
2221        }
2222
2223        for (_, ty) in module.types.iter() {
2224            if let crate::TypeInner::Struct { ref members, .. } = ty.inner {
2225                for member in members.iter() {
2226                    if let crate::TypeInner::Array { .. } = module.types[member.ty].inner {
2227                        if let Some(super::writer::MatrixType {
2228                            columns,
2229                            rows: crate::VectorSize::Bi,
2230                            width: 4,
2231                        }) = super::writer::get_inner_matrix_data(module, member.ty)
2232                        {
2233                            let entry = WrappedMatCx2 { columns };
2234                            if self.wrapped.insert(WrappedType::MatCx2(entry)) {
2235                                self.write_mat_cx2_typedef_and_functions(entry)?;
2236                            }
2237                        }
2238                    }
2239                }
2240            }
2241        }
2242
2243        Ok(())
2244    }
2245
2246    pub(super) fn write_wrapped_zero_value_function_name(
2247        &mut self,
2248        module: &crate::Module,
2249        zero_value: WrappedZeroValue,
2250    ) -> BackendResult {
2251        let name = crate::TypeInner::hlsl_type_id(zero_value.ty, module.to_ctx(), &self.names)?;
2252        write!(self.out, "ZeroValue{name}")?;
2253        Ok(())
2254    }
2255
2256    /// Helper function that write wrapped function for `Expression::ZeroValue`
2257    ///
2258    /// This is necessary since we might have a member access after the zero value expression, e.g.
2259    /// `.y` (in practice this can come up when consuming SPIRV that's been produced by glslc).
2260    ///
2261    /// So we can't just write `(float4)0` since `(float4)0.y` won't parse correctly.
2262    ///
2263    /// Parenthesizing the expression like `((float4)0).y` would work... except DXC can't handle
2264    /// cases like:
2265    ///
2266    /// ```text
2267    /// tests\out\hlsl\access.hlsl:183:41: error: cannot compile this l-value expression yet
2268    ///     t_1.am = (__mat4x2[2])((float4x2[2])0);
2269    ///                                         ^
2270    /// ```
2271    fn write_wrapped_zero_value_function(
2272        &mut self,
2273        module: &crate::Module,
2274        zero_value: WrappedZeroValue,
2275    ) -> BackendResult {
2276        use crate::back::INDENT;
2277
2278        // Write function return type and name
2279        if let crate::TypeInner::Array { base, size, .. } = module.types[zero_value.ty].inner {
2280            write!(self.out, "typedef ")?;
2281            self.write_type(module, zero_value.ty)?;
2282            write!(self.out, " ret_")?;
2283            self.write_wrapped_zero_value_function_name(module, zero_value)?;
2284            self.write_array_size(module, base, size)?;
2285            writeln!(self.out, ";")?;
2286
2287            write!(self.out, "ret_")?;
2288            self.write_wrapped_zero_value_function_name(module, zero_value)?;
2289        } else {
2290            self.write_type(module, zero_value.ty)?;
2291        }
2292        write!(self.out, " ")?;
2293        self.write_wrapped_zero_value_function_name(module, zero_value)?;
2294
2295        // Write function parameters (none) and start function body
2296        writeln!(self.out, "() {{")?;
2297
2298        // Write `ZeroValue` function.
2299        write!(self.out, "{INDENT}return ")?;
2300        self.write_default_init(module, zero_value.ty)?;
2301        writeln!(self.out, ";")?;
2302
2303        // End of function body
2304        writeln!(self.out, "}}")?;
2305        // Write extra new line
2306        writeln!(self.out)?;
2307
2308        Ok(())
2309    }
2310}
2311
2312impl crate::StorageFormat {
2313    /// Returns `true` if there is just one component, otherwise `false`
2314    pub(super) const fn single_component(&self) -> bool {
2315        match *self {
2316            crate::StorageFormat::R16Float
2317            | crate::StorageFormat::R32Float
2318            | crate::StorageFormat::R8Unorm
2319            | crate::StorageFormat::R16Unorm
2320            | crate::StorageFormat::R8Snorm
2321            | crate::StorageFormat::R16Snorm
2322            | crate::StorageFormat::R8Uint
2323            | crate::StorageFormat::R16Uint
2324            | crate::StorageFormat::R32Uint
2325            | crate::StorageFormat::R8Sint
2326            | crate::StorageFormat::R16Sint
2327            | crate::StorageFormat::R32Sint
2328            | crate::StorageFormat::R64Uint => true,
2329            _ => false,
2330        }
2331    }
2332}