naga/back/hlsl/
help.rs

1/*!
2Helpers for the hlsl backend
3
4Important note about `Expression::ImageQuery`/`Expression::ArrayLength` and hlsl backend:
5
6Due to implementation of `GetDimensions` function in hlsl (<https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-getdimensions>)
7backend can't work with it as an expression.
8Instead, it generates a unique wrapped function per `Expression::ImageQuery`, based on texture info and query function.
9See `WrappedImageQuery` struct that represents a unique function and will be generated before writing all statements and expressions.
10This allowed to works with `Expression::ImageQuery` as expression and write wrapped function.
11
12For example:
13```wgsl
14let dim_1d = textureDimensions(image_1d);
15```
16
17```hlsl
18int NagaDimensions1D(Texture1D<float4>)
19{
20   uint4 ret;
21   image_1d.GetDimensions(ret.x);
22   return ret.x;
23}
24
25int dim_1d = NagaDimensions1D(image_1d);
26```
27*/
28
29use alloc::format;
30use core::fmt::Write;
31
32use super::{
33    super::FunctionCtx,
34    writer::{
35        ABS_FUNCTION, DIV_FUNCTION, EXTRACT_BITS_FUNCTION, F2I32_FUNCTION, F2I64_FUNCTION,
36        F2U32_FUNCTION, F2U64_FUNCTION, IMAGE_LOAD_EXTERNAL_FUNCTION,
37        IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION, INSERT_BITS_FUNCTION, MOD_FUNCTION, NEG_FUNCTION,
38    },
39    BackendResult, WrappedType,
40};
41use crate::{arena::Handle, proc::NameKey, ScalarKind};
42
43#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
44pub(super) struct WrappedArrayLength {
45    pub(super) writable: bool,
46}
47
48#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
49pub(super) struct WrappedImageLoad {
50    pub(super) class: crate::ImageClass,
51}
52
53#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
54pub(super) struct WrappedImageSample {
55    pub(super) class: crate::ImageClass,
56    pub(super) clamp_to_edge: bool,
57}
58
59#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
60pub(super) struct WrappedImageQuery {
61    pub(super) dim: crate::ImageDimension,
62    pub(super) arrayed: bool,
63    pub(super) class: crate::ImageClass,
64    pub(super) query: ImageQuery,
65}
66
67#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
68pub(super) struct WrappedConstructor {
69    pub(super) ty: Handle<crate::Type>,
70}
71
72#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
73pub(super) struct WrappedStructMatrixAccess {
74    pub(super) ty: Handle<crate::Type>,
75    pub(super) index: u32,
76}
77
78#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
79pub(super) struct WrappedMatCx2 {
80    pub(super) columns: crate::VectorSize,
81}
82
83#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
84pub(super) struct WrappedMath {
85    pub(super) fun: crate::MathFunction,
86    pub(super) scalar: crate::Scalar,
87    pub(super) components: Option<u32>,
88}
89
90#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
91pub(super) struct WrappedZeroValue {
92    pub(super) ty: Handle<crate::Type>,
93}
94
95#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
96pub(super) struct WrappedUnaryOp {
97    pub(super) op: crate::UnaryOperator,
98    // This can only represent scalar or vector types. If we ever need to wrap
99    // unary ops with other types, we'll need a better representation.
100    pub(super) ty: (Option<crate::VectorSize>, crate::Scalar),
101}
102
103#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
104pub(super) struct WrappedBinaryOp {
105    pub(super) op: crate::BinaryOperator,
106    // This can only represent scalar or vector types. If we ever need to wrap
107    // binary ops with other types, we'll need a better representation.
108    pub(super) left_ty: (Option<crate::VectorSize>, crate::Scalar),
109    pub(super) right_ty: (Option<crate::VectorSize>, crate::Scalar),
110}
111
112#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
113pub(super) struct WrappedCast {
114    // This can only represent scalar or vector types. If we ever need to wrap
115    // casts with other types, we'll need a better representation.
116    pub(super) vector_size: Option<crate::VectorSize>,
117    pub(super) src_scalar: crate::Scalar,
118    pub(super) dst_scalar: crate::Scalar,
119}
120
121/// HLSL backend requires its own `ImageQuery` enum.
122///
123/// It is used inside `WrappedImageQuery` and should be unique per ImageQuery function.
124/// IR version can't be unique per function, because it's store mipmap level as an expression.
125///
126/// For example:
127/// ```wgsl
128/// let dim_cube_array_lod = textureDimensions(image_cube_array, 1);
129/// let dim_cube_array_lod2 = textureDimensions(image_cube_array, 1);
130/// ```
131///
132/// ```ir
133/// ImageQuery {
134///  image: [1],
135///  query: Size {
136///      level: Some(
137///          [1],
138///      ),
139///  },
140/// },
141/// ImageQuery {
142///  image: [1],
143///  query: Size {
144///      level: Some(
145///          [2],
146///      ),
147///  },
148/// },
149/// ```
150///
151/// HLSL should generate only 1 function for this case.
152#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
153pub(super) enum ImageQuery {
154    Size,
155    SizeLevel,
156    NumLevels,
157    NumLayers,
158    NumSamples,
159}
160
161impl From<crate::ImageQuery> for ImageQuery {
162    fn from(q: crate::ImageQuery) -> Self {
163        use crate::ImageQuery as Iq;
164        match q {
165            Iq::Size { level: Some(_) } => ImageQuery::SizeLevel,
166            Iq::Size { level: None } => ImageQuery::Size,
167            Iq::NumLevels => ImageQuery::NumLevels,
168            Iq::NumLayers => ImageQuery::NumLayers,
169            Iq::NumSamples => ImageQuery::NumSamples,
170        }
171    }
172}
173
174pub(super) const IMAGE_STORAGE_LOAD_SCALAR_WRAPPER: &str = "LoadedStorageValueFrom";
175
176impl<W: Write> super::Writer<'_, W> {
177    pub(super) fn write_image_type(
178        &mut self,
179        dim: crate::ImageDimension,
180        arrayed: bool,
181        class: crate::ImageClass,
182    ) -> BackendResult {
183        let access_str = match class {
184            crate::ImageClass::Storage { .. } => "RW",
185            _ => "",
186        };
187        let dim_str = dim.to_hlsl_str();
188        let arrayed_str = if arrayed { "Array" } else { "" };
189        write!(self.out, "{access_str}Texture{dim_str}{arrayed_str}")?;
190        match class {
191            crate::ImageClass::Depth { multi } => {
192                let multi_str = if multi { "MS" } else { "" };
193                write!(self.out, "{multi_str}<float>")?
194            }
195            crate::ImageClass::Sampled { kind, multi } => {
196                let multi_str = if multi { "MS" } else { "" };
197                let scalar_kind_str = crate::Scalar { kind, width: 4 }.to_hlsl_str()?;
198                write!(self.out, "{multi_str}<{scalar_kind_str}4>")?
199            }
200            crate::ImageClass::Storage { format, .. } => {
201                let storage_format_str = format.to_hlsl_str();
202                write!(self.out, "<{storage_format_str}>")?
203            }
204            crate::ImageClass::External => {
205                unreachable!(
206                    "external images should be handled by `write_global_external_texture`"
207                );
208            }
209        }
210        Ok(())
211    }
212
213    pub(super) fn write_wrapped_array_length_function_name(
214        &mut self,
215        query: WrappedArrayLength,
216    ) -> BackendResult {
217        let access_str = if query.writable { "RW" } else { "" };
218        write!(self.out, "NagaBufferLength{access_str}",)?;
219
220        Ok(())
221    }
222
223    /// Helper function that write wrapped function for `Expression::ArrayLength`
224    ///
225    /// <https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-object-rwbyteaddressbuffer-getdimensions>
226    pub(super) fn write_wrapped_array_length_function(
227        &mut self,
228        wal: WrappedArrayLength,
229    ) -> BackendResult {
230        use crate::back::INDENT;
231
232        const ARGUMENT_VARIABLE_NAME: &str = "buffer";
233        const RETURN_VARIABLE_NAME: &str = "ret";
234
235        // Write function return type and name
236        write!(self.out, "uint ")?;
237        self.write_wrapped_array_length_function_name(wal)?;
238
239        // Write function parameters
240        write!(self.out, "(")?;
241        let access_str = if wal.writable { "RW" } else { "" };
242        writeln!(
243            self.out,
244            "{access_str}ByteAddressBuffer {ARGUMENT_VARIABLE_NAME})"
245        )?;
246        // Write function body
247        writeln!(self.out, "{{")?;
248
249        // Write `GetDimensions` function.
250        writeln!(self.out, "{INDENT}uint {RETURN_VARIABLE_NAME};")?;
251        writeln!(
252            self.out,
253            "{INDENT}{ARGUMENT_VARIABLE_NAME}.GetDimensions({RETURN_VARIABLE_NAME});"
254        )?;
255
256        // Write return value
257        writeln!(self.out, "{INDENT}return {RETURN_VARIABLE_NAME};")?;
258
259        // End of function body
260        writeln!(self.out, "}}")?;
261        // Write extra new line
262        writeln!(self.out)?;
263
264        Ok(())
265    }
266
267    /// Helper function used by [`Self::write_wrapped_image_load_function`] and
268    /// [`Self::write_wrapped_image_sample_function`] to write the shared YUV
269    /// to RGB conversion code for external textures. Expects the preceding
270    /// code to declare the Y component as a `float` variable of name `y`, the
271    /// UV components as a `float2` variable of name `uv`, and the external
272    /// texture params as a variable of name `params`. The emitted code will
273    /// return the result.
274    fn write_convert_yuv_to_rgb_and_return(
275        &mut self,
276        level: crate::back::Level,
277        y: &str,
278        uv: &str,
279        params: &str,
280    ) -> BackendResult {
281        let l1 = level;
282        let l2 = l1.next();
283
284        // Convert from YUV to non-linear RGB in the source color space. We
285        // declare our matrices as row_major in HLSL, therefore we must reverse
286        // the order of this multiplication
287        writeln!(
288            self.out,
289            "{l1}float3 srcGammaRgb = mul(float4({y}, {uv}, 1.0), {params}.yuv_conversion_matrix).rgb;"
290        )?;
291
292        // Apply the inverse of the source transfer function to convert to
293        // linear RGB in the source color space.
294        writeln!(
295            self.out,
296            "{l1}float3 srcLinearRgb = srcGammaRgb < {params}.src_tf.k * {params}.src_tf.b ?"
297        )?;
298        writeln!(self.out, "{l2}srcGammaRgb / {params}.src_tf.k :")?;
299        writeln!(self.out, "{l2}pow((srcGammaRgb + {params}.src_tf.a - 1.0) / {params}.src_tf.a, {params}.src_tf.g);")?;
300
301        // Multiply by the gamut conversion matrix to convert to linear RGB in
302        // the destination color space. We declare our matrices as row_major in
303        // HLSL, therefore we must reverse the order of this multiplication.
304        writeln!(
305            self.out,
306            "{l1}float3 dstLinearRgb = mul(srcLinearRgb, {params}.gamut_conversion_matrix);"
307        )?;
308
309        // Finally, apply the dest transfer function to convert to non-linear
310        // RGB in the destination color space, and return the result.
311        writeln!(
312            self.out,
313            "{l1}float3 dstGammaRgb = dstLinearRgb < {params}.dst_tf.b ?"
314        )?;
315        writeln!(self.out, "{l2}{params}.dst_tf.k * dstLinearRgb :")?;
316        writeln!(self.out, "{l2}{params}.dst_tf.a * pow(dstLinearRgb, 1.0 / {params}.dst_tf.g) - ({params}.dst_tf.a - 1);")?;
317
318        writeln!(self.out, "{l1}return float4(dstGammaRgb, 1.0);")?;
319        Ok(())
320    }
321
322    pub(super) fn write_wrapped_image_load_function(
323        &mut self,
324        module: &crate::Module,
325        load: WrappedImageLoad,
326    ) -> BackendResult {
327        match load {
328            WrappedImageLoad {
329                class: crate::ImageClass::External,
330            } => {
331                let l1 = crate::back::Level(1);
332                let l2 = l1.next();
333                let l3 = l2.next();
334                let params_ty_name = &self.names
335                    [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
336                writeln!(self.out, "float4 {IMAGE_LOAD_EXTERNAL_FUNCTION}(")?;
337                writeln!(self.out, "{l1}Texture2D<float4> plane0,")?;
338                writeln!(self.out, "{l1}Texture2D<float4> plane1,")?;
339                writeln!(self.out, "{l1}Texture2D<float4> plane2,")?;
340                writeln!(self.out, "{l1}{params_ty_name} params,")?;
341                writeln!(self.out, "{l1}uint2 coords)")?;
342                writeln!(self.out, "{{")?;
343                writeln!(self.out, "{l1}uint2 plane0_size;")?;
344                writeln!(
345                    self.out,
346                    "{l1}plane0.GetDimensions(plane0_size.x, plane0_size.y);"
347                )?;
348                // Clamp coords to provided size of external texture to prevent OOB read.
349                // If params.size is zero then clamp to the actual size of the texture.
350                writeln!(
351                    self.out,
352                    "{l1}uint2 cropped_size = any(params.size) ? params.size : plane0_size;"
353                )?;
354                writeln!(self.out, "{l1}coords = min(coords, cropped_size - 1);")?;
355
356                // Apply load transformation. We declare our matrices as row_major in
357                // HLSL, therefore we must reverse the order of this multiplication
358                writeln!(self.out, "{l1}float3x2 load_transform = float3x2(")?;
359                writeln!(self.out, "{l2}params.load_transform_0,")?;
360                writeln!(self.out, "{l2}params.load_transform_1,")?;
361                writeln!(self.out, "{l2}params.load_transform_2")?;
362                writeln!(self.out, "{l1});")?;
363                writeln!(self.out, "{l1}uint2 plane0_coords = uint2(round(mul(float3(coords, 1.0), load_transform)));")?;
364                writeln!(self.out, "{l1}if (params.num_planes == 1u) {{")?;
365                // For single plane, simply read from plane0
366                writeln!(
367                    self.out,
368                    "{l2}return plane0.Load(uint3(plane0_coords, 0u));"
369                )?;
370                writeln!(self.out, "{l1}}} else {{")?;
371
372                // Chroma planes may be subsampled so we must scale the coords accordingly.
373                writeln!(self.out, "{l2}uint2 plane1_size;")?;
374                writeln!(
375                    self.out,
376                    "{l2}plane1.GetDimensions(plane1_size.x, plane1_size.y);"
377                )?;
378                writeln!(self.out, "{l2}uint2 plane1_coords = uint2(floor(float2(plane0_coords) * float2(plane1_size) / float2(plane0_size)));")?;
379
380                // For multi-plane, read the Y value from plane 0
381                writeln!(
382                    self.out,
383                    "{l2}float y = plane0.Load(uint3(plane0_coords, 0u)).x;"
384                )?;
385
386                writeln!(self.out, "{l2}float2 uv;")?;
387                writeln!(self.out, "{l2}if (params.num_planes == 2u) {{")?;
388                // Read UV from interleaved plane 1
389                writeln!(
390                    self.out,
391                    "{l3}uv = plane1.Load(uint3(plane1_coords, 0u)).xy;"
392                )?;
393                writeln!(self.out, "{l2}}} else {{")?;
394                // Read U and V from planes 1 and 2 respectively
395                writeln!(self.out, "{l3}uint2 plane2_size;")?;
396                writeln!(
397                    self.out,
398                    "{l3}plane2.GetDimensions(plane2_size.x, plane2_size.y);"
399                )?;
400                writeln!(self.out, "{l3}uint2 plane2_coords = uint2(floor(float2(plane0_coords) * float2(plane2_size) / float2(plane0_size)));")?;
401                writeln!(self.out, "{l3}uv = float2(plane1.Load(uint3(plane1_coords, 0u)).x, plane2.Load(uint3(plane2_coords, 0u)).x);")?;
402                writeln!(self.out, "{l2}}}")?;
403
404                self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "params")?;
405
406                writeln!(self.out, "{l1}}}")?;
407                writeln!(self.out, "}}")?;
408                writeln!(self.out)?;
409            }
410            _ => {}
411        }
412
413        Ok(())
414    }
415
416    pub(super) fn write_wrapped_image_sample_function(
417        &mut self,
418        module: &crate::Module,
419        sample: WrappedImageSample,
420    ) -> BackendResult {
421        match sample {
422            WrappedImageSample {
423                class: crate::ImageClass::External,
424                clamp_to_edge: true,
425            } => {
426                let l1 = crate::back::Level(1);
427                let l2 = l1.next();
428                let l3 = l2.next();
429                let params_ty_name = &self.names
430                    [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
431                writeln!(
432                    self.out,
433                    "float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}("
434                )?;
435                writeln!(self.out, "{l1}Texture2D<float4> plane0,")?;
436                writeln!(self.out, "{l1}Texture2D<float4> plane1,")?;
437                writeln!(self.out, "{l1}Texture2D<float4> plane2,")?;
438                writeln!(self.out, "{l1}{params_ty_name} params,")?;
439                writeln!(self.out, "{l1}SamplerState samp,")?;
440                writeln!(self.out, "{l1}float2 coords)")?;
441                writeln!(self.out, "{{")?;
442                writeln!(self.out, "{l1}float2 plane0_size;")?;
443                writeln!(
444                    self.out,
445                    "{l1}plane0.GetDimensions(plane0_size.x, plane0_size.y);"
446                )?;
447                writeln!(self.out, "{l1}float3x2 sample_transform = float3x2(")?;
448                writeln!(self.out, "{l2}params.sample_transform_0,")?;
449                writeln!(self.out, "{l2}params.sample_transform_1,")?;
450                writeln!(self.out, "{l2}params.sample_transform_2")?;
451                writeln!(self.out, "{l1});")?;
452                // Apply sample transformation. We declare our matrices as row_major in
453                // HLSL, therefore we must reverse the order of this multiplication
454                writeln!(
455                    self.out,
456                    "{l1}coords = mul(float3(coords, 1.0), sample_transform);"
457                )?;
458                // Calculate the sample bounds. The purported size of the texture
459                // (params.size) is irrelevant here as we are dealing with normalized
460                // coordinates. Usually we would clamp to (0,0)..(1,1). However, we must
461                // apply the sample transformation to that, also bearing in mind that it
462                // may contain a flip on either axis. We calculate and adjust for the
463                // half-texel separately for each plane as it depends on the actual
464                // texture size which may vary between planes.
465                writeln!(
466                    self.out,
467                    "{l1}float2 bounds_min = mul(float3(0.0, 0.0, 1.0), sample_transform);"
468                )?;
469                writeln!(
470                    self.out,
471                    "{l1}float2 bounds_max = mul(float3(1.0, 1.0, 1.0), sample_transform);"
472                )?;
473                writeln!(self.out, "{l1}float4 bounds = float4(min(bounds_min, bounds_max), max(bounds_min, bounds_max));")?;
474                writeln!(
475                    self.out,
476                    "{l1}float2 plane0_half_texel = float2(0.5, 0.5) / plane0_size;"
477                )?;
478                writeln!(
479                    self.out,
480                    "{l1}float2 plane0_coords = clamp(coords, bounds.xy + plane0_half_texel, bounds.zw - plane0_half_texel);"
481                )?;
482                writeln!(self.out, "{l1}if (params.num_planes == 1u) {{")?;
483                // For single plane, simply sample from plane0
484                writeln!(
485                    self.out,
486                    "{l2}return plane0.SampleLevel(samp, plane0_coords, 0.0f);"
487                )?;
488                writeln!(self.out, "{l1}}} else {{")?;
489
490                writeln!(self.out, "{l2}float2 plane1_size;")?;
491                writeln!(
492                    self.out,
493                    "{l2}plane1.GetDimensions(plane1_size.x, plane1_size.y);"
494                )?;
495                writeln!(
496                    self.out,
497                    "{l2}float2 plane1_half_texel = float2(0.5, 0.5) / plane1_size;"
498                )?;
499                writeln!(
500                    self.out,
501                    "{l2}float2 plane1_coords = clamp(coords, bounds.xy + plane1_half_texel, bounds.zw - plane1_half_texel);"
502                )?;
503
504                // For multi-plane, sample the Y value from plane 0
505                writeln!(
506                    self.out,
507                    "{l2}float y = plane0.SampleLevel(samp, plane0_coords, 0.0f).x;"
508                )?;
509                writeln!(self.out, "{l2}float2 uv;")?;
510                writeln!(self.out, "{l2}if (params.num_planes == 2u) {{")?;
511                // Sample UV from interleaved plane 1
512                writeln!(
513                    self.out,
514                    "{l3}uv = plane1.SampleLevel(samp, plane1_coords, 0.0f).xy;"
515                )?;
516                writeln!(self.out, "{l2}}} else {{")?;
517                // Sample U and V from planes 1 and 2 respectively
518                writeln!(self.out, "{l3}float2 plane2_size;")?;
519                writeln!(
520                    self.out,
521                    "{l3}plane2.GetDimensions(plane2_size.x, plane2_size.y);"
522                )?;
523                writeln!(
524                    self.out,
525                    "{l3}float2 plane2_half_texel = float2(0.5, 0.5) / plane2_size;"
526                )?;
527                writeln!(self.out, "{l3}float2 plane2_coords = clamp(coords, bounds.xy + plane2_half_texel, bounds.zw - plane2_half_texel);")?;
528                writeln!(self.out, "{l3}uv = float2(plane1.SampleLevel(samp, plane1_coords, 0.0f).x, plane2.SampleLevel(samp, plane2_coords, 0.0f).x);")?;
529                writeln!(self.out, "{l2}}}")?;
530
531                self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "params")?;
532
533                writeln!(self.out, "{l1}}}")?;
534                writeln!(self.out, "}}")?;
535                writeln!(self.out)?;
536            }
537            WrappedImageSample {
538                class:
539                    crate::ImageClass::Sampled {
540                        kind: ScalarKind::Float,
541                        multi: false,
542                    },
543                clamp_to_edge: true,
544            } => {
545                writeln!(self.out, "float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(Texture2D<float4> tex, SamplerState samp, float2 coords) {{")?;
546                let l1 = crate::back::Level(1);
547                writeln!(self.out, "{l1}float2 size;")?;
548                writeln!(self.out, "{l1}tex.GetDimensions(size.x, size.y);")?;
549                writeln!(self.out, "{l1}float2 half_texel = float2(0.5, 0.5) / size;")?;
550                writeln!(
551                    self.out,
552                    "{l1}return tex.SampleLevel(samp, clamp(coords, half_texel, 1.0 - half_texel), 0.0);"
553                )?;
554                writeln!(self.out, "}}")?;
555                writeln!(self.out)?;
556            }
557            _ => {}
558        }
559
560        Ok(())
561    }
562
563    pub(super) fn write_wrapped_image_query_function_name(
564        &mut self,
565        query: WrappedImageQuery,
566    ) -> BackendResult {
567        let dim_str = query.dim.to_hlsl_str();
568        let class_str = match query.class {
569            crate::ImageClass::Sampled { multi: true, .. } => "MS",
570            crate::ImageClass::Depth { multi: true } => "DepthMS",
571            crate::ImageClass::Depth { multi: false } => "Depth",
572            crate::ImageClass::Sampled { multi: false, .. } => "",
573            crate::ImageClass::Storage { .. } => "RW",
574            crate::ImageClass::External => "External",
575        };
576        let arrayed_str = if query.arrayed { "Array" } else { "" };
577        let query_str = match query.query {
578            ImageQuery::Size => "Dimensions",
579            ImageQuery::SizeLevel => "MipDimensions",
580            ImageQuery::NumLevels => "NumLevels",
581            ImageQuery::NumLayers => "NumLayers",
582            ImageQuery::NumSamples => "NumSamples",
583        };
584
585        write!(self.out, "Naga{class_str}{query_str}{dim_str}{arrayed_str}")?;
586
587        Ok(())
588    }
589
590    /// Helper function that write wrapped function for `Expression::ImageQuery`
591    ///
592    /// <https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-getdimensions>
593    pub(super) fn write_wrapped_image_query_function(
594        &mut self,
595        module: &crate::Module,
596        wiq: WrappedImageQuery,
597        expr_handle: Handle<crate::Expression>,
598        func_ctx: &FunctionCtx,
599    ) -> BackendResult {
600        use crate::{
601            back::{COMPONENTS, INDENT},
602            ImageDimension as IDim,
603        };
604
605        match wiq.class {
606            crate::ImageClass::External => {
607                if wiq.query != ImageQuery::Size {
608                    return Err(super::Error::Custom(
609                        "External images only support `Size` queries".into(),
610                    ));
611                }
612
613                write!(self.out, "uint2 ")?;
614                self.write_wrapped_image_query_function_name(wiq)?;
615                let params_name = &self.names
616                    [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
617                // Only plane0 and params are used by this implementation, but it's easier to
618                // always take all of them as arguments so that we can unconditionally expand an
619                // external texture expression each of its parts.
620                writeln!(self.out, "(Texture2D<float4> plane0, Texture2D<float4> plane1, Texture2D<float4> plane2, {params_name} params) {{")?;
621                let l1 = crate::back::Level(1);
622                let l2 = l1.next();
623                writeln!(self.out, "{l1}if (any(params.size)) {{")?;
624                writeln!(self.out, "{l2}return params.size;")?;
625                writeln!(self.out, "{l1}}} else {{")?;
626                // params.size == (0, 0) indicates to query and return plane 0's actual size
627                writeln!(self.out, "{l2}uint2 ret;")?;
628                writeln!(self.out, "{l2}plane0.GetDimensions(ret.x, ret.y);")?;
629                writeln!(self.out, "{l2}return ret;")?;
630                writeln!(self.out, "{l1}}}")?;
631                writeln!(self.out, "}}")?;
632                writeln!(self.out)?;
633            }
634            _ => {
635                const ARGUMENT_VARIABLE_NAME: &str = "tex";
636                const RETURN_VARIABLE_NAME: &str = "ret";
637                const MIP_LEVEL_PARAM: &str = "mip_level";
638
639                // Write function return type and name
640                let ret_ty = func_ctx.resolve_type(expr_handle, &module.types);
641                self.write_value_type(module, ret_ty)?;
642                write!(self.out, " ")?;
643                self.write_wrapped_image_query_function_name(wiq)?;
644
645                // Write function parameters
646                write!(self.out, "(")?;
647                // Texture always first parameter
648                self.write_image_type(wiq.dim, wiq.arrayed, wiq.class)?;
649                write!(self.out, " {ARGUMENT_VARIABLE_NAME}")?;
650                // Mipmap is a second parameter if exists
651                if let ImageQuery::SizeLevel = wiq.query {
652                    write!(self.out, ", uint {MIP_LEVEL_PARAM}")?;
653                }
654                writeln!(self.out, ")")?;
655
656                // Write function body
657                writeln!(self.out, "{{")?;
658
659                let array_coords = usize::from(wiq.arrayed);
660                // extra parameter is the mip level count or the sample count
661                let extra_coords = match wiq.class {
662                    crate::ImageClass::Storage { .. } => 0,
663                    crate::ImageClass::Sampled { .. } | crate::ImageClass::Depth { .. } => 1,
664                    crate::ImageClass::External => unreachable!(),
665                };
666
667                // GetDimensions Overloaded Methods
668                // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-getdimensions#overloaded-methods
669                let (ret_swizzle, number_of_params) = match wiq.query {
670                    ImageQuery::Size | ImageQuery::SizeLevel => {
671                        let ret = match wiq.dim {
672                            IDim::D1 => "x",
673                            IDim::D2 => "xy",
674                            IDim::D3 => "xyz",
675                            IDim::Cube => "xy",
676                        };
677                        (ret, ret.len() + array_coords + extra_coords)
678                    }
679                    ImageQuery::NumLevels | ImageQuery::NumSamples | ImageQuery::NumLayers => {
680                        if wiq.arrayed || wiq.dim == IDim::D3 {
681                            ("w", 4)
682                        } else {
683                            ("z", 3)
684                        }
685                    }
686                };
687
688                // Write `GetDimensions` function.
689                writeln!(self.out, "{INDENT}uint4 {RETURN_VARIABLE_NAME};")?;
690                write!(self.out, "{INDENT}{ARGUMENT_VARIABLE_NAME}.GetDimensions(")?;
691                match wiq.query {
692                    ImageQuery::SizeLevel => {
693                        write!(self.out, "{MIP_LEVEL_PARAM}, ")?;
694                    }
695                    _ => match wiq.class {
696                        crate::ImageClass::Sampled { multi: true, .. }
697                        | crate::ImageClass::Depth { multi: true }
698                        | crate::ImageClass::Storage { .. } => {}
699                        _ => {
700                            // Write zero mipmap level for supported types
701                            write!(self.out, "0, ")?;
702                        }
703                    },
704                }
705
706                for component in COMPONENTS[..number_of_params - 1].iter() {
707                    write!(self.out, "{RETURN_VARIABLE_NAME}.{component}, ")?;
708                }
709
710                // write last parameter without comma and space for last parameter
711                write!(
712                    self.out,
713                    "{}.{}",
714                    RETURN_VARIABLE_NAME,
715                    COMPONENTS[number_of_params - 1]
716                )?;
717
718                writeln!(self.out, ");")?;
719
720                // Write return value
721                writeln!(
722                    self.out,
723                    "{INDENT}return {RETURN_VARIABLE_NAME}.{ret_swizzle};"
724                )?;
725
726                // End of function body
727                writeln!(self.out, "}}")?;
728                // Write extra new line
729                writeln!(self.out)?;
730            }
731        }
732        Ok(())
733    }
734
735    pub(super) fn write_wrapped_constructor_function_name(
736        &mut self,
737        module: &crate::Module,
738        constructor: WrappedConstructor,
739    ) -> BackendResult {
740        let name = crate::TypeInner::hlsl_type_id(constructor.ty, module.to_ctx(), &self.names)?;
741        write!(self.out, "Construct{name}")?;
742        Ok(())
743    }
744
745    /// Helper function that write wrapped function for `Expression::Compose` for structures.
746    fn write_wrapped_constructor_function(
747        &mut self,
748        module: &crate::Module,
749        constructor: WrappedConstructor,
750    ) -> BackendResult {
751        use crate::back::INDENT;
752
753        const ARGUMENT_VARIABLE_NAME: &str = "arg";
754        const RETURN_VARIABLE_NAME: &str = "ret";
755
756        // Write function return type and name
757        if let crate::TypeInner::Array { base, size, .. } = module.types[constructor.ty].inner {
758            write!(self.out, "typedef ")?;
759            self.write_type(module, constructor.ty)?;
760            write!(self.out, " ret_")?;
761            self.write_wrapped_constructor_function_name(module, constructor)?;
762            self.write_array_size(module, base, size)?;
763            writeln!(self.out, ";")?;
764
765            write!(self.out, "ret_")?;
766            self.write_wrapped_constructor_function_name(module, constructor)?;
767        } else {
768            self.write_type(module, constructor.ty)?;
769        }
770        write!(self.out, " ")?;
771        self.write_wrapped_constructor_function_name(module, constructor)?;
772
773        // Write function parameters
774        write!(self.out, "(")?;
775
776        let mut write_arg = |i, ty| -> BackendResult {
777            if i != 0 {
778                write!(self.out, ", ")?;
779            }
780            self.write_type(module, ty)?;
781            write!(self.out, " {ARGUMENT_VARIABLE_NAME}{i}")?;
782            if let crate::TypeInner::Array { base, size, .. } = module.types[ty].inner {
783                self.write_array_size(module, base, size)?;
784            }
785            Ok(())
786        };
787
788        match module.types[constructor.ty].inner {
789            crate::TypeInner::Struct { ref members, .. } => {
790                for (i, member) in members.iter().enumerate() {
791                    write_arg(i, member.ty)?;
792                }
793            }
794            crate::TypeInner::Array {
795                base,
796                size: crate::ArraySize::Constant(size),
797                ..
798            } => {
799                for i in 0..size.get() as usize {
800                    write_arg(i, base)?;
801                }
802            }
803            _ => unreachable!(),
804        };
805
806        write!(self.out, ")")?;
807
808        // Write function body
809        writeln!(self.out, " {{")?;
810
811        match module.types[constructor.ty].inner {
812            crate::TypeInner::Struct { ref members, .. } => {
813                let struct_name = &self.names[&NameKey::Type(constructor.ty)];
814                writeln!(
815                    self.out,
816                    "{INDENT}{struct_name} {RETURN_VARIABLE_NAME} = ({struct_name})0;"
817                )?;
818                for (i, member) in members.iter().enumerate() {
819                    let field_name = &self.names[&NameKey::StructMember(constructor.ty, i as u32)];
820
821                    match module.types[member.ty].inner {
822                        crate::TypeInner::Matrix {
823                            columns,
824                            rows: crate::VectorSize::Bi,
825                            ..
826                        } if member.binding.is_none() => {
827                            for j in 0..columns as u8 {
828                                writeln!(
829                                    self.out,
830                                    "{INDENT}{RETURN_VARIABLE_NAME}.{field_name}_{j} = {ARGUMENT_VARIABLE_NAME}{i}[{j}];"
831                                )?;
832                            }
833                        }
834                        ref other => {
835                            // We cast arrays of native HLSL `floatCx2`s to arrays of `matCx2`s
836                            // (where the inner matrix is represented by a struct with C `float2` members).
837                            // See the module-level block comment in mod.rs for details.
838                            if let Some(super::writer::MatrixType {
839                                columns,
840                                rows: crate::VectorSize::Bi,
841                                width: 4,
842                            }) = super::writer::get_inner_matrix_data(module, member.ty)
843                            {
844                                write!(
845                                    self.out,
846                                    "{}{}.{} = (__mat{}x2",
847                                    INDENT, RETURN_VARIABLE_NAME, field_name, columns as u8
848                                )?;
849                                if let crate::TypeInner::Array { base, size, .. } = *other {
850                                    self.write_array_size(module, base, size)?;
851                                }
852                                writeln!(self.out, "){ARGUMENT_VARIABLE_NAME}{i};",)?;
853                            } else {
854                                writeln!(
855                                    self.out,
856                                    "{INDENT}{RETURN_VARIABLE_NAME}.{field_name} = {ARGUMENT_VARIABLE_NAME}{i};",
857                                )?;
858                            }
859                        }
860                    }
861                }
862            }
863            crate::TypeInner::Array {
864                base,
865                size: crate::ArraySize::Constant(size),
866                ..
867            } => {
868                write!(self.out, "{INDENT}")?;
869                self.write_type(module, base)?;
870                write!(self.out, " {RETURN_VARIABLE_NAME}")?;
871                self.write_array_size(module, base, crate::ArraySize::Constant(size))?;
872                write!(self.out, " = {{ ")?;
873                for i in 0..size.get() {
874                    if i != 0 {
875                        write!(self.out, ", ")?;
876                    }
877                    write!(self.out, "{ARGUMENT_VARIABLE_NAME}{i}")?;
878                }
879                writeln!(self.out, " }};",)?;
880            }
881            _ => unreachable!(),
882        }
883
884        // Write return value
885        writeln!(self.out, "{INDENT}return {RETURN_VARIABLE_NAME};")?;
886
887        // End of function body
888        writeln!(self.out, "}}")?;
889        // Write extra new line
890        writeln!(self.out)?;
891
892        Ok(())
893    }
894
895    /// Writes the conversion from a single length storage texture load to a vec4 with the loaded
896    /// scalar in its `x` component, 1 in its `a` component and 0 everywhere else.
897    fn write_loaded_scalar_to_storage_loaded_value(
898        &mut self,
899        scalar_type: crate::Scalar,
900    ) -> BackendResult {
901        const ARGUMENT_VARIABLE_NAME: &str = "arg";
902        const RETURN_VARIABLE_NAME: &str = "ret";
903
904        let zero;
905        let one;
906        match scalar_type.kind {
907            ScalarKind::Sint => {
908                assert_eq!(
909                    scalar_type.width, 4,
910                    "Scalar {scalar_type:?} is not a result from any storage format"
911                );
912                zero = "0";
913                one = "1";
914            }
915            ScalarKind::Uint => match scalar_type.width {
916                4 => {
917                    zero = "0u";
918                    one = "1u";
919                }
920                8 => {
921                    zero = "0uL";
922                    one = "1uL"
923                }
924                _ => unreachable!("Scalar {scalar_type:?} is not a result from any storage format"),
925            },
926            ScalarKind::Float => {
927                assert_eq!(
928                    scalar_type.width, 4,
929                    "Scalar {scalar_type:?} is not a result from any storage format"
930                );
931                zero = "0.0";
932                one = "1.0";
933            }
934            _ => unreachable!("Scalar {scalar_type:?} is not a result from any storage format"),
935        }
936
937        let ty = scalar_type.to_hlsl_str()?;
938        writeln!(
939            self.out,
940            "{ty}4 {IMAGE_STORAGE_LOAD_SCALAR_WRAPPER}{ty}({ty} {ARGUMENT_VARIABLE_NAME}) {{\
941    {ty}4 {RETURN_VARIABLE_NAME} = {ty}4({ARGUMENT_VARIABLE_NAME}, {zero}, {zero}, {one});\
942    return {RETURN_VARIABLE_NAME};\
943}}"
944        )?;
945
946        Ok(())
947    }
948
949    pub(super) fn write_wrapped_struct_matrix_get_function_name(
950        &mut self,
951        access: WrappedStructMatrixAccess,
952    ) -> BackendResult {
953        let name = &self.names[&NameKey::Type(access.ty)];
954        let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
955        write!(self.out, "GetMat{field_name}On{name}")?;
956        Ok(())
957    }
958
959    /// Writes a function used to get a matCx2 from within a structure.
960    pub(super) fn write_wrapped_struct_matrix_get_function(
961        &mut self,
962        module: &crate::Module,
963        access: WrappedStructMatrixAccess,
964    ) -> BackendResult {
965        use crate::back::INDENT;
966
967        const STRUCT_ARGUMENT_VARIABLE_NAME: &str = "obj";
968
969        // Write function return type and name
970        let member = match module.types[access.ty].inner {
971            crate::TypeInner::Struct { ref members, .. } => &members[access.index as usize],
972            _ => unreachable!(),
973        };
974        let ret_ty = &module.types[member.ty].inner;
975        self.write_value_type(module, ret_ty)?;
976        write!(self.out, " ")?;
977        self.write_wrapped_struct_matrix_get_function_name(access)?;
978
979        // Write function parameters
980        write!(self.out, "(")?;
981        let struct_name = &self.names[&NameKey::Type(access.ty)];
982        write!(self.out, "{struct_name} {STRUCT_ARGUMENT_VARIABLE_NAME}")?;
983
984        // Write function body
985        writeln!(self.out, ") {{")?;
986
987        // Write return value
988        write!(self.out, "{INDENT}return ")?;
989        self.write_value_type(module, ret_ty)?;
990        write!(self.out, "(")?;
991        let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
992        match module.types[member.ty].inner {
993            crate::TypeInner::Matrix { columns, .. } => {
994                for i in 0..columns as u8 {
995                    if i != 0 {
996                        write!(self.out, ", ")?;
997                    }
998                    write!(self.out, "{STRUCT_ARGUMENT_VARIABLE_NAME}.{field_name}_{i}")?;
999                }
1000            }
1001            _ => unreachable!(),
1002        }
1003        writeln!(self.out, ");")?;
1004
1005        // End of function body
1006        writeln!(self.out, "}}")?;
1007        // Write extra new line
1008        writeln!(self.out)?;
1009
1010        Ok(())
1011    }
1012
1013    pub(super) fn write_wrapped_struct_matrix_set_function_name(
1014        &mut self,
1015        access: WrappedStructMatrixAccess,
1016    ) -> BackendResult {
1017        let name = &self.names[&NameKey::Type(access.ty)];
1018        let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
1019        write!(self.out, "SetMat{field_name}On{name}")?;
1020        Ok(())
1021    }
1022
1023    /// Writes a function used to set a matCx2 from within a structure.
1024    pub(super) fn write_wrapped_struct_matrix_set_function(
1025        &mut self,
1026        module: &crate::Module,
1027        access: WrappedStructMatrixAccess,
1028    ) -> BackendResult {
1029        use crate::back::INDENT;
1030
1031        const STRUCT_ARGUMENT_VARIABLE_NAME: &str = "obj";
1032        const MATRIX_ARGUMENT_VARIABLE_NAME: &str = "mat";
1033
1034        // Write function return type and name
1035        write!(self.out, "void ")?;
1036        self.write_wrapped_struct_matrix_set_function_name(access)?;
1037
1038        // Write function parameters
1039        write!(self.out, "(")?;
1040        let struct_name = &self.names[&NameKey::Type(access.ty)];
1041        write!(self.out, "{struct_name} {STRUCT_ARGUMENT_VARIABLE_NAME}, ")?;
1042        let member = match module.types[access.ty].inner {
1043            crate::TypeInner::Struct { ref members, .. } => &members[access.index as usize],
1044            _ => unreachable!(),
1045        };
1046        self.write_type(module, member.ty)?;
1047        write!(self.out, " {MATRIX_ARGUMENT_VARIABLE_NAME}")?;
1048        // Write function body
1049        writeln!(self.out, ") {{")?;
1050
1051        let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
1052
1053        match module.types[member.ty].inner {
1054            crate::TypeInner::Matrix { columns, .. } => {
1055                for i in 0..columns as u8 {
1056                    writeln!(
1057                        self.out,
1058                        "{INDENT}{STRUCT_ARGUMENT_VARIABLE_NAME}.{field_name}_{i} = {MATRIX_ARGUMENT_VARIABLE_NAME}[{i}];"
1059                    )?;
1060                }
1061            }
1062            _ => unreachable!(),
1063        }
1064
1065        // End of function body
1066        writeln!(self.out, "}}")?;
1067        // Write extra new line
1068        writeln!(self.out)?;
1069
1070        Ok(())
1071    }
1072
1073    pub(super) fn write_wrapped_struct_matrix_set_vec_function_name(
1074        &mut self,
1075        access: WrappedStructMatrixAccess,
1076    ) -> BackendResult {
1077        let name = &self.names[&NameKey::Type(access.ty)];
1078        let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
1079        write!(self.out, "SetMatVec{field_name}On{name}")?;
1080        Ok(())
1081    }
1082
1083    /// Writes a function used to set a vec2 on a matCx2 from within a structure.
1084    pub(super) fn write_wrapped_struct_matrix_set_vec_function(
1085        &mut self,
1086        module: &crate::Module,
1087        access: WrappedStructMatrixAccess,
1088    ) -> BackendResult {
1089        use crate::back::INDENT;
1090
1091        const STRUCT_ARGUMENT_VARIABLE_NAME: &str = "obj";
1092        const VECTOR_ARGUMENT_VARIABLE_NAME: &str = "vec";
1093        const MATRIX_INDEX_ARGUMENT_VARIABLE_NAME: &str = "mat_idx";
1094
1095        // Write function return type and name
1096        write!(self.out, "void ")?;
1097        self.write_wrapped_struct_matrix_set_vec_function_name(access)?;
1098
1099        // Write function parameters
1100        write!(self.out, "(")?;
1101        let struct_name = &self.names[&NameKey::Type(access.ty)];
1102        write!(self.out, "{struct_name} {STRUCT_ARGUMENT_VARIABLE_NAME}, ")?;
1103        let member = match module.types[access.ty].inner {
1104            crate::TypeInner::Struct { ref members, .. } => &members[access.index as usize],
1105            _ => unreachable!(),
1106        };
1107        let vec_ty = match module.types[member.ty].inner {
1108            crate::TypeInner::Matrix { rows, scalar, .. } => {
1109                crate::TypeInner::Vector { size: rows, scalar }
1110            }
1111            _ => unreachable!(),
1112        };
1113        self.write_value_type(module, &vec_ty)?;
1114        write!(
1115            self.out,
1116            " {VECTOR_ARGUMENT_VARIABLE_NAME}, uint {MATRIX_INDEX_ARGUMENT_VARIABLE_NAME}"
1117        )?;
1118
1119        // Write function body
1120        writeln!(self.out, ") {{")?;
1121
1122        writeln!(
1123            self.out,
1124            "{INDENT}switch({MATRIX_INDEX_ARGUMENT_VARIABLE_NAME}) {{"
1125        )?;
1126
1127        let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
1128
1129        match module.types[member.ty].inner {
1130            crate::TypeInner::Matrix { columns, .. } => {
1131                for i in 0..columns as u8 {
1132                    writeln!(
1133                        self.out,
1134                        "{INDENT}case {i}: {{ {STRUCT_ARGUMENT_VARIABLE_NAME}.{field_name}_{i} = {VECTOR_ARGUMENT_VARIABLE_NAME}; break; }}"
1135                    )?;
1136                }
1137            }
1138            _ => unreachable!(),
1139        }
1140
1141        writeln!(self.out, "{INDENT}}}")?;
1142
1143        // End of function body
1144        writeln!(self.out, "}}")?;
1145        // Write extra new line
1146        writeln!(self.out)?;
1147
1148        Ok(())
1149    }
1150
1151    pub(super) fn write_wrapped_struct_matrix_set_scalar_function_name(
1152        &mut self,
1153        access: WrappedStructMatrixAccess,
1154    ) -> BackendResult {
1155        let name = &self.names[&NameKey::Type(access.ty)];
1156        let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
1157        write!(self.out, "SetMatScalar{field_name}On{name}")?;
1158        Ok(())
1159    }
1160
1161    /// Writes a function used to set a float on a matCx2 from within a structure.
1162    pub(super) fn write_wrapped_struct_matrix_set_scalar_function(
1163        &mut self,
1164        module: &crate::Module,
1165        access: WrappedStructMatrixAccess,
1166    ) -> BackendResult {
1167        use crate::back::INDENT;
1168
1169        const STRUCT_ARGUMENT_VARIABLE_NAME: &str = "obj";
1170        const SCALAR_ARGUMENT_VARIABLE_NAME: &str = "scalar";
1171        const MATRIX_INDEX_ARGUMENT_VARIABLE_NAME: &str = "mat_idx";
1172        const VECTOR_INDEX_ARGUMENT_VARIABLE_NAME: &str = "vec_idx";
1173
1174        // Write function return type and name
1175        write!(self.out, "void ")?;
1176        self.write_wrapped_struct_matrix_set_scalar_function_name(access)?;
1177
1178        // Write function parameters
1179        write!(self.out, "(")?;
1180        let struct_name = &self.names[&NameKey::Type(access.ty)];
1181        write!(self.out, "{struct_name} {STRUCT_ARGUMENT_VARIABLE_NAME}, ")?;
1182        let member = match module.types[access.ty].inner {
1183            crate::TypeInner::Struct { ref members, .. } => &members[access.index as usize],
1184            _ => unreachable!(),
1185        };
1186        let scalar_ty = match module.types[member.ty].inner {
1187            crate::TypeInner::Matrix { scalar, .. } => crate::TypeInner::Scalar(scalar),
1188            _ => unreachable!(),
1189        };
1190        self.write_value_type(module, &scalar_ty)?;
1191        write!(
1192            self.out,
1193            " {SCALAR_ARGUMENT_VARIABLE_NAME}, uint {MATRIX_INDEX_ARGUMENT_VARIABLE_NAME}, uint {VECTOR_INDEX_ARGUMENT_VARIABLE_NAME}"
1194        )?;
1195
1196        // Write function body
1197        writeln!(self.out, ") {{")?;
1198
1199        writeln!(
1200            self.out,
1201            "{INDENT}switch({MATRIX_INDEX_ARGUMENT_VARIABLE_NAME}) {{"
1202        )?;
1203
1204        let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
1205
1206        match module.types[member.ty].inner {
1207            crate::TypeInner::Matrix { columns, .. } => {
1208                for i in 0..columns as u8 {
1209                    writeln!(
1210                        self.out,
1211                        "{INDENT}case {i}: {{ {STRUCT_ARGUMENT_VARIABLE_NAME}.{field_name}_{i}[{VECTOR_INDEX_ARGUMENT_VARIABLE_NAME}] = {SCALAR_ARGUMENT_VARIABLE_NAME}; break; }}"
1212                    )?;
1213                }
1214            }
1215            _ => unreachable!(),
1216        }
1217
1218        writeln!(self.out, "{INDENT}}}")?;
1219
1220        // End of function body
1221        writeln!(self.out, "}}")?;
1222        // Write extra new line
1223        writeln!(self.out)?;
1224
1225        Ok(())
1226    }
1227
1228    /// Write functions to create special types.
1229    pub(super) fn write_special_functions(&mut self, module: &crate::Module) -> BackendResult {
1230        for (type_key, struct_ty) in module.special_types.predeclared_types.iter() {
1231            match type_key {
1232                &crate::PredeclaredType::ModfResult { size, scalar }
1233                | &crate::PredeclaredType::FrexpResult { size, scalar } => {
1234                    let arg_type_name_owner;
1235                    let arg_type_name = if let Some(size) = size {
1236                        arg_type_name_owner = format!(
1237                            "{}{}",
1238                            if scalar.width == 8 { "double" } else { "float" },
1239                            size as u8
1240                        );
1241                        &arg_type_name_owner
1242                    } else if scalar.width == 8 {
1243                        "double"
1244                    } else {
1245                        "float"
1246                    };
1247
1248                    let (defined_func_name, called_func_name, second_field_name, sign_multiplier) =
1249                        if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) {
1250                            (super::writer::MODF_FUNCTION, "modf", "whole", "")
1251                        } else {
1252                            (
1253                                super::writer::FREXP_FUNCTION,
1254                                "frexp",
1255                                "exp_",
1256                                "sign(arg) * ",
1257                            )
1258                        };
1259
1260                    let struct_name = &self.names[&NameKey::Type(*struct_ty)];
1261
1262                    writeln!(
1263                        self.out,
1264                        "{struct_name} {defined_func_name}({arg_type_name} arg) {{
1265    {arg_type_name} other;
1266    {struct_name} result;
1267    result.fract = {sign_multiplier}{called_func_name}(arg, other);
1268    result.{second_field_name} = other;
1269    return result;
1270}}"
1271                    )?;
1272                    writeln!(self.out)?;
1273                }
1274                &crate::PredeclaredType::AtomicCompareExchangeWeakResult { .. } => {}
1275            }
1276        }
1277        if module.special_types.ray_desc.is_some() {
1278            self.write_ray_desc_from_ray_desc_constructor_function(module)?;
1279        }
1280
1281        Ok(())
1282    }
1283
1284    /// Helper function that writes wrapped functions for expressions in a function
1285    pub(super) fn write_wrapped_expression_functions(
1286        &mut self,
1287        module: &crate::Module,
1288        expressions: &crate::Arena<crate::Expression>,
1289        context: Option<&FunctionCtx>,
1290    ) -> BackendResult {
1291        for (handle, _) in expressions.iter() {
1292            match expressions[handle] {
1293                crate::Expression::Compose { ty, .. } => {
1294                    match module.types[ty].inner {
1295                        crate::TypeInner::Struct { .. } | crate::TypeInner::Array { .. } => {
1296                            let constructor = WrappedConstructor { ty };
1297                            if self.wrapped.insert(WrappedType::Constructor(constructor)) {
1298                                self.write_wrapped_constructor_function(module, constructor)?;
1299                            }
1300                        }
1301                        _ => {}
1302                    };
1303                }
1304                crate::Expression::ImageLoad { image, .. } => {
1305                    // This can only happen in a function as this is not a valid const expression
1306                    match *context.as_ref().unwrap().resolve_type(image, &module.types) {
1307                        crate::TypeInner::Image {
1308                            class: crate::ImageClass::Storage { format, .. },
1309                            ..
1310                        } => {
1311                            if format.single_component() {
1312                                let scalar: crate::Scalar = format.into();
1313                                if self.wrapped.insert(WrappedType::ImageLoadScalar(scalar)) {
1314                                    self.write_loaded_scalar_to_storage_loaded_value(scalar)?;
1315                                }
1316                            }
1317                        }
1318                        _ => {}
1319                    }
1320                }
1321                crate::Expression::RayQueryGetIntersection { committed, .. } => {
1322                    if committed {
1323                        if !self.written_committed_intersection {
1324                            self.write_committed_intersection_function(module)?;
1325                            self.written_committed_intersection = true;
1326                        }
1327                    } else if !self.written_candidate_intersection {
1328                        self.write_candidate_intersection_function(module)?;
1329                        self.written_candidate_intersection = true;
1330                    }
1331                }
1332                _ => {}
1333            }
1334        }
1335        Ok(())
1336    }
1337
1338    // TODO: we could merge this with iteration in write_wrapped_expression_functions...
1339    //
1340    /// Helper function that writes zero value wrapped functions
1341    pub(super) fn write_wrapped_zero_value_functions(
1342        &mut self,
1343        module: &crate::Module,
1344        expressions: &crate::Arena<crate::Expression>,
1345    ) -> BackendResult {
1346        for (handle, _) in expressions.iter() {
1347            if let crate::Expression::ZeroValue(ty) = expressions[handle] {
1348                let zero_value = WrappedZeroValue { ty };
1349                if self.wrapped.insert(WrappedType::ZeroValue(zero_value)) {
1350                    self.write_wrapped_zero_value_function(module, zero_value)?;
1351                }
1352            }
1353        }
1354        Ok(())
1355    }
1356
1357    pub(super) fn write_wrapped_math_functions(
1358        &mut self,
1359        module: &crate::Module,
1360        func_ctx: &FunctionCtx,
1361    ) -> BackendResult {
1362        for (_, expression) in func_ctx.expressions.iter() {
1363            if let crate::Expression::Math {
1364                fun,
1365                arg,
1366                arg1: _arg1,
1367                arg2: _arg2,
1368                arg3: _arg3,
1369            } = *expression
1370            {
1371                let arg_ty = func_ctx.resolve_type(arg, &module.types);
1372
1373                match fun {
1374                    crate::MathFunction::ExtractBits => {
1375                        // The behavior of our extractBits polyfill is undefined if offset + count > bit_width. We need
1376                        // to first sanitize the offset and count first. If we don't do this, we will get out-of-spec
1377                        // values if the extracted range is not within the bit width.
1378                        //
1379                        // This encodes the exact formula specified by the wgsl spec:
1380                        // https://gpuweb.github.io/gpuweb/wgsl/#extractBits-unsigned-builtin
1381                        //
1382                        // w = sizeof(x) * 8
1383                        // o = min(offset, w)
1384                        // c = min(count, w - o)
1385                        //
1386                        // bitfieldExtract(x, o, c)
1387                        let scalar = arg_ty.scalar().unwrap();
1388                        let components = arg_ty.components();
1389
1390                        let wrapped = WrappedMath {
1391                            fun,
1392                            scalar,
1393                            components,
1394                        };
1395
1396                        if !self.wrapped.insert(WrappedType::Math(wrapped)) {
1397                            continue;
1398                        }
1399
1400                        // Write return type
1401                        self.write_value_type(module, arg_ty)?;
1402
1403                        let scalar_width: u8 = scalar.width * 8;
1404
1405                        // Write function name and parameters
1406                        writeln!(self.out, " {EXTRACT_BITS_FUNCTION}(")?;
1407                        write!(self.out, "    ")?;
1408                        self.write_value_type(module, arg_ty)?;
1409                        writeln!(self.out, " e,")?;
1410                        writeln!(self.out, "    uint offset,")?;
1411                        writeln!(self.out, "    uint count")?;
1412                        writeln!(self.out, ") {{")?;
1413
1414                        // Write function body
1415                        writeln!(self.out, "    uint w = {scalar_width};")?;
1416                        writeln!(self.out, "    uint o = min(offset, w);")?;
1417                        writeln!(self.out, "    uint c = min(count, w - o);")?;
1418                        writeln!(
1419                            self.out,
1420                            "    return (c == 0 ? 0 : (e << (w - c - o)) >> (w - c));"
1421                        )?;
1422
1423                        // End of function body
1424                        writeln!(self.out, "}}")?;
1425                    }
1426                    crate::MathFunction::InsertBits => {
1427                        // The behavior of our insertBits polyfill has the same constraints as the extractBits polyfill.
1428
1429                        let scalar = arg_ty.scalar().unwrap();
1430                        let components = arg_ty.components();
1431
1432                        let wrapped = WrappedMath {
1433                            fun,
1434                            scalar,
1435                            components,
1436                        };
1437
1438                        if !self.wrapped.insert(WrappedType::Math(wrapped)) {
1439                            continue;
1440                        }
1441
1442                        // Write return type
1443                        self.write_value_type(module, arg_ty)?;
1444
1445                        let scalar_width: u8 = scalar.width * 8;
1446                        let scalar_max: u64 = match scalar.width {
1447                            1 => 0xFF,
1448                            2 => 0xFFFF,
1449                            4 => 0xFFFFFFFF,
1450                            8 => 0xFFFFFFFFFFFFFFFF,
1451                            _ => unreachable!(),
1452                        };
1453
1454                        // Write function name and parameters
1455                        writeln!(self.out, " {INSERT_BITS_FUNCTION}(")?;
1456                        write!(self.out, "    ")?;
1457                        self.write_value_type(module, arg_ty)?;
1458                        writeln!(self.out, " e,")?;
1459                        write!(self.out, "    ")?;
1460                        self.write_value_type(module, arg_ty)?;
1461                        writeln!(self.out, " newbits,")?;
1462                        writeln!(self.out, "    uint offset,")?;
1463                        writeln!(self.out, "    uint count")?;
1464                        writeln!(self.out, ") {{")?;
1465
1466                        // Write function body
1467                        writeln!(self.out, "    uint w = {scalar_width}u;")?;
1468                        writeln!(self.out, "    uint o = min(offset, w);")?;
1469                        writeln!(self.out, "    uint c = min(count, w - o);")?;
1470
1471                        // The `u` suffix on the literals is _extremely_ important. Otherwise it will use
1472                        // i32 shifting instead of the intended u32 shifting.
1473                        writeln!(
1474                            self.out,
1475                            "    uint mask = (({scalar_max}u >> ({scalar_width}u - c)) << o);"
1476                        )?;
1477                        writeln!(
1478                            self.out,
1479                            "    return (c == 0 ? e : ((e & ~mask) | ((newbits << o) & mask)));"
1480                        )?;
1481
1482                        // End of function body
1483                        writeln!(self.out, "}}")?;
1484                    }
1485                    // Taking the absolute value of the minimum value of a two's
1486                    // complement signed integer type causes overflow, which is
1487                    // undefined behaviour in HLSL. To avoid this, when the value is
1488                    // negative we bitcast the value to unsigned and negate it, then
1489                    // bitcast back to signed.
1490                    // This adheres to the WGSL spec in that the absolute of the type's
1491                    // minimum value should equal to the minimum value.
1492                    //
1493                    // TODO(#7109): asint()/asuint() only support 32-bit integers, so we
1494                    // must find another solution for different bit-widths.
1495                    crate::MathFunction::Abs
1496                        if matches!(arg_ty.scalar(), Some(crate::Scalar::I32)) =>
1497                    {
1498                        let scalar = arg_ty.scalar().unwrap();
1499                        let components = arg_ty.components();
1500
1501                        let wrapped = WrappedMath {
1502                            fun,
1503                            scalar,
1504                            components,
1505                        };
1506
1507                        if !self.wrapped.insert(WrappedType::Math(wrapped)) {
1508                            continue;
1509                        }
1510
1511                        self.write_value_type(module, arg_ty)?;
1512                        write!(self.out, " {ABS_FUNCTION}(")?;
1513                        self.write_value_type(module, arg_ty)?;
1514                        writeln!(self.out, " val) {{")?;
1515
1516                        let level = crate::back::Level(1);
1517                        writeln!(
1518                            self.out,
1519                            "{level}return val >= 0 ? val : asint(-asuint(val));"
1520                        )?;
1521                        writeln!(self.out, "}}")?;
1522                        writeln!(self.out)?;
1523                    }
1524                    _ => {}
1525                }
1526            }
1527        }
1528
1529        Ok(())
1530    }
1531
1532    pub(super) fn write_wrapped_unary_ops(
1533        &mut self,
1534        module: &crate::Module,
1535        func_ctx: &FunctionCtx,
1536    ) -> BackendResult {
1537        for (_, expression) in func_ctx.expressions.iter() {
1538            if let crate::Expression::Unary { op, expr } = *expression {
1539                let expr_ty = func_ctx.resolve_type(expr, &module.types);
1540                let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
1541                    continue;
1542                };
1543                let wrapped = WrappedUnaryOp {
1544                    op,
1545                    ty: (vector_size, scalar),
1546                };
1547
1548                // Negating the minimum value of a two's complement signed integer type
1549                // causes overflow, which is undefined behaviour in HLSL. To avoid this
1550                // we bitcast the value to unsigned and negate it, then bitcast back to
1551                // signed. This adheres to the WGSL spec in that the negative of the
1552                // type's minimum value should equal to the minimum value.
1553                //
1554                // TODO(#7109): asint()/asuint() only support 32-bit integers, so we must
1555                // find another solution for different bit-widths.
1556                match (op, scalar) {
1557                    (crate::UnaryOperator::Negate, crate::Scalar::I32) => {
1558                        if !self.wrapped.insert(WrappedType::UnaryOp(wrapped)) {
1559                            continue;
1560                        }
1561
1562                        self.write_value_type(module, expr_ty)?;
1563                        write!(self.out, " {NEG_FUNCTION}(")?;
1564                        self.write_value_type(module, expr_ty)?;
1565                        writeln!(self.out, " val) {{")?;
1566
1567                        let level = crate::back::Level(1);
1568                        writeln!(self.out, "{level}return asint(-asuint(val));",)?;
1569                        writeln!(self.out, "}}")?;
1570                        writeln!(self.out)?;
1571                    }
1572                    _ => {}
1573                }
1574            }
1575        }
1576
1577        Ok(())
1578    }
1579
1580    pub(super) fn write_wrapped_binary_ops(
1581        &mut self,
1582        module: &crate::Module,
1583        func_ctx: &FunctionCtx,
1584    ) -> BackendResult {
1585        for (expr_handle, expression) in func_ctx.expressions.iter() {
1586            if let crate::Expression::Binary { op, left, right } = *expression {
1587                let expr_ty = func_ctx.resolve_type(expr_handle, &module.types);
1588                let left_ty = func_ctx.resolve_type(left, &module.types);
1589                let right_ty = func_ctx.resolve_type(right, &module.types);
1590
1591                match (op, expr_ty.scalar()) {
1592                    // Signed integer division of the type's minimum representable value
1593                    // divided by -1, or signed or unsigned division by zero, is
1594                    // undefined behaviour in HLSL. We override the divisor to 1 in these
1595                    // cases.
1596                    // This adheres to the WGSL spec in that:
1597                    // * TYPE_MIN / -1 == TYPE_MIN
1598                    // * x / 0 == x
1599                    (
1600                        crate::BinaryOperator::Divide,
1601                        Some(
1602                            scalar @ crate::Scalar {
1603                                kind: ScalarKind::Sint | ScalarKind::Uint,
1604                                ..
1605                            },
1606                        ),
1607                    ) => {
1608                        let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
1609                            continue;
1610                        };
1611                        let Some(right_wrapped_ty) = right_ty.vector_size_and_scalar() else {
1612                            continue;
1613                        };
1614                        let wrapped = WrappedBinaryOp {
1615                            op,
1616                            left_ty: left_wrapped_ty,
1617                            right_ty: right_wrapped_ty,
1618                        };
1619                        if !self.wrapped.insert(WrappedType::BinaryOp(wrapped)) {
1620                            continue;
1621                        }
1622
1623                        self.write_value_type(module, expr_ty)?;
1624                        write!(self.out, " {DIV_FUNCTION}(")?;
1625                        self.write_value_type(module, left_ty)?;
1626                        write!(self.out, " lhs, ")?;
1627                        self.write_value_type(module, right_ty)?;
1628                        writeln!(self.out, " rhs) {{")?;
1629                        let level = crate::back::Level(1);
1630                        match scalar.kind {
1631                            ScalarKind::Sint => {
1632                                let min_val = match scalar.width {
1633                                    2 => crate::Literal::I16(i16::MIN),
1634                                    4 => crate::Literal::I32(i32::MIN),
1635                                    8 => crate::Literal::I64(i64::MIN),
1636                                    _ => {
1637                                        return Err(super::Error::UnsupportedScalar(scalar));
1638                                    }
1639                                };
1640                                write!(self.out, "{level}return lhs / (((lhs == ")?;
1641                                self.write_literal(min_val)?;
1642                                writeln!(self.out, " & rhs == -1) | (rhs == 0)) ? 1 : rhs);")?
1643                            }
1644                            ScalarKind::Uint => {
1645                                writeln!(self.out, "{level}return lhs / (rhs == 0u ? 1u : rhs);")?
1646                            }
1647                            _ => unreachable!(),
1648                        }
1649                        writeln!(self.out, "}}")?;
1650                        writeln!(self.out)?;
1651                    }
1652                    // The modulus operator is only defined for integers in HLSL when
1653                    // either both sides are positive or both sides are negative. To
1654                    // avoid this undefined behaviour we use the following equation:
1655                    //
1656                    // dividend - (dividend / divisor) * divisor
1657                    //
1658                    // overriding the divisor to 1 if either it is 0, or it is -1
1659                    // and the dividend is the minimum representable value.
1660                    //
1661                    // This adheres to the WGSL spec in that:
1662                    // * min_value % -1 == 0
1663                    // * x % 0 == 0
1664                    (
1665                        crate::BinaryOperator::Modulo,
1666                        Some(
1667                            scalar @ crate::Scalar {
1668                                kind: ScalarKind::Sint | ScalarKind::Uint | ScalarKind::Float,
1669                                ..
1670                            },
1671                        ),
1672                    ) => {
1673                        let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
1674                            continue;
1675                        };
1676                        let Some(right_wrapped_ty) = right_ty.vector_size_and_scalar() else {
1677                            continue;
1678                        };
1679                        let wrapped = WrappedBinaryOp {
1680                            op,
1681                            left_ty: left_wrapped_ty,
1682                            right_ty: right_wrapped_ty,
1683                        };
1684                        if !self.wrapped.insert(WrappedType::BinaryOp(wrapped)) {
1685                            continue;
1686                        }
1687
1688                        self.write_value_type(module, expr_ty)?;
1689                        write!(self.out, " {MOD_FUNCTION}(")?;
1690                        self.write_value_type(module, left_ty)?;
1691                        write!(self.out, " lhs, ")?;
1692                        self.write_value_type(module, right_ty)?;
1693                        writeln!(self.out, " rhs) {{")?;
1694                        let level = crate::back::Level(1);
1695                        match scalar.kind {
1696                            ScalarKind::Sint => {
1697                                let min_val = match scalar.width {
1698                                    2 => crate::Literal::I16(i16::MIN),
1699                                    4 => crate::Literal::I32(i32::MIN),
1700                                    8 => crate::Literal::I64(i64::MIN),
1701                                    _ => {
1702                                        return Err(super::Error::UnsupportedScalar(scalar));
1703                                    }
1704                                };
1705                                write!(self.out, "{level}")?;
1706                                self.write_value_type(module, right_ty)?;
1707                                write!(self.out, " divisor = ((lhs == ")?;
1708                                self.write_literal(min_val)?;
1709                                writeln!(self.out, " & rhs == -1) | (rhs == 0)) ? 1 : rhs;")?;
1710                                writeln!(
1711                                    self.out,
1712                                    "{level}return lhs - (lhs / divisor) * divisor;"
1713                                )?
1714                            }
1715                            ScalarKind::Uint => {
1716                                writeln!(self.out, "{level}return lhs % (rhs == 0u ? 1u : rhs);")?
1717                            }
1718                            // HLSL's fmod has the same definition as WGSL's % operator but due
1719                            // to its implementation in DXC it is not as accurate as the WGSL spec
1720                            // requires it to be. See:
1721                            // - https://shader-playground.timjones.io/0c8572816dbb6fc4435cc5d016a978a7
1722                            // - https://github.com/llvm/llvm-project/blob/50f9b8acafdca48e87e6b8e393c1f116a2d193ee/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h#L78-L81
1723                            ScalarKind::Float => {
1724                                writeln!(self.out, "{level}return lhs - rhs * trunc(lhs / rhs);")?
1725                            }
1726                            _ => unreachable!(),
1727                        }
1728                        writeln!(self.out, "}}")?;
1729                        writeln!(self.out)?;
1730                    }
1731                    _ => {}
1732                }
1733            }
1734        }
1735
1736        Ok(())
1737    }
1738
1739    fn write_wrapped_cast_functions(
1740        &mut self,
1741        module: &crate::Module,
1742        func_ctx: &FunctionCtx,
1743    ) -> BackendResult {
1744        for (_, expression) in func_ctx.expressions.iter() {
1745            if let crate::Expression::As {
1746                expr,
1747                kind,
1748                convert: Some(width),
1749            } = *expression
1750            {
1751                // Avoid undefined behaviour when casting from a float to integer
1752                // when the value is out of range for the target type. Additionally
1753                // ensure we clamp to the correct value as per the WGSL spec.
1754                //
1755                // https://www.w3.org/TR/WGSL/#floating-point-conversion:
1756                // * If X is exactly representable in the target type T, then the
1757                //   result is that value.
1758                // * Otherwise, the result is the value in T closest to
1759                //   truncate(X) and also exactly representable in the original
1760                //   floating point type.
1761                let src_ty = func_ctx.resolve_type(expr, &module.types);
1762                let Some((vector_size, src_scalar)) = src_ty.vector_size_and_scalar() else {
1763                    continue;
1764                };
1765                let dst_scalar = crate::Scalar { kind, width };
1766                if src_scalar.kind != ScalarKind::Float
1767                    || (dst_scalar.kind != ScalarKind::Sint && dst_scalar.kind != ScalarKind::Uint)
1768                {
1769                    continue;
1770                }
1771
1772                let wrapped = WrappedCast {
1773                    src_scalar,
1774                    vector_size,
1775                    dst_scalar,
1776                };
1777                if !self.wrapped.insert(WrappedType::Cast(wrapped)) {
1778                    continue;
1779                }
1780
1781                let (src_ty, dst_ty) = match vector_size {
1782                    None => (
1783                        crate::TypeInner::Scalar(src_scalar),
1784                        crate::TypeInner::Scalar(dst_scalar),
1785                    ),
1786                    Some(vector_size) => (
1787                        crate::TypeInner::Vector {
1788                            scalar: src_scalar,
1789                            size: vector_size,
1790                        },
1791                        crate::TypeInner::Vector {
1792                            scalar: dst_scalar,
1793                            size: vector_size,
1794                        },
1795                    ),
1796                };
1797                let (min, max) =
1798                    crate::proc::min_max_float_representable_by(src_scalar, dst_scalar);
1799                let cast_str = format!(
1800                    "{}{}",
1801                    dst_scalar.to_hlsl_str()?,
1802                    vector_size
1803                        .map(crate::common::vector_size_str)
1804                        .unwrap_or(""),
1805                );
1806                let fun_name = match dst_scalar {
1807                    crate::Scalar::I32 => F2I32_FUNCTION,
1808                    crate::Scalar::U32 => F2U32_FUNCTION,
1809                    crate::Scalar::I64 => F2I64_FUNCTION,
1810                    crate::Scalar::U64 => F2U64_FUNCTION,
1811                    _ => unreachable!(),
1812                };
1813                self.write_value_type(module, &dst_ty)?;
1814                write!(self.out, " {fun_name}(")?;
1815                self.write_value_type(module, &src_ty)?;
1816                writeln!(self.out, " value) {{")?;
1817                let level = crate::back::Level(1);
1818                write!(self.out, "{level}return {cast_str}(clamp(value, ")?;
1819                self.write_literal(min)?;
1820                write!(self.out, ", ")?;
1821                self.write_literal(max)?;
1822                writeln!(self.out, "));",)?;
1823                writeln!(self.out, "}}")?;
1824                writeln!(self.out)?;
1825            }
1826        }
1827        Ok(())
1828    }
1829
1830    /// Helper function that writes various wrapped functions
1831    pub(super) fn write_wrapped_functions(
1832        &mut self,
1833        module: &crate::Module,
1834        func_ctx: &FunctionCtx,
1835    ) -> BackendResult {
1836        self.write_wrapped_math_functions(module, func_ctx)?;
1837        self.write_wrapped_unary_ops(module, func_ctx)?;
1838        self.write_wrapped_binary_ops(module, func_ctx)?;
1839        self.write_wrapped_expression_functions(module, func_ctx.expressions, Some(func_ctx))?;
1840        self.write_wrapped_zero_value_functions(module, func_ctx.expressions)?;
1841        self.write_wrapped_cast_functions(module, func_ctx)?;
1842
1843        for (handle, _) in func_ctx.expressions.iter() {
1844            match func_ctx.expressions[handle] {
1845                crate::Expression::ArrayLength(expr) => {
1846                    let global_expr = match func_ctx.expressions[expr] {
1847                        crate::Expression::GlobalVariable(_) => expr,
1848                        crate::Expression::AccessIndex { base, index: _ } => base,
1849                        ref other => unreachable!("Array length of {:?}", other),
1850                    };
1851                    let global_var = match func_ctx.expressions[global_expr] {
1852                        crate::Expression::GlobalVariable(var_handle) => {
1853                            &module.global_variables[var_handle]
1854                        }
1855                        ref other => {
1856                            return Err(super::Error::Unimplemented(format!(
1857                                "Array length of base {other:?}"
1858                            )))
1859                        }
1860                    };
1861                    let storage_access = match global_var.space {
1862                        crate::AddressSpace::Storage { access } => access,
1863                        _ => crate::StorageAccess::default(),
1864                    };
1865                    let wal = WrappedArrayLength {
1866                        writable: storage_access.contains(crate::StorageAccess::STORE),
1867                    };
1868
1869                    if self.wrapped.insert(WrappedType::ArrayLength(wal)) {
1870                        self.write_wrapped_array_length_function(wal)?;
1871                    }
1872                }
1873                crate::Expression::ImageLoad { image, .. } => {
1874                    let class = match *func_ctx.resolve_type(image, &module.types) {
1875                        crate::TypeInner::Image { class, .. } => class,
1876                        _ => unreachable!(),
1877                    };
1878                    let wrapped = WrappedImageLoad { class };
1879                    if self.wrapped.insert(WrappedType::ImageLoad(wrapped)) {
1880                        self.write_wrapped_image_load_function(module, wrapped)?;
1881                    }
1882                }
1883                crate::Expression::ImageSample {
1884                    image,
1885                    clamp_to_edge,
1886                    ..
1887                } => {
1888                    let class = match *func_ctx.resolve_type(image, &module.types) {
1889                        crate::TypeInner::Image { class, .. } => class,
1890                        _ => unreachable!(),
1891                    };
1892                    let wrapped = WrappedImageSample {
1893                        class,
1894                        clamp_to_edge,
1895                    };
1896                    if self.wrapped.insert(WrappedType::ImageSample(wrapped)) {
1897                        self.write_wrapped_image_sample_function(module, wrapped)?;
1898                    }
1899                }
1900                crate::Expression::ImageQuery { image, query } => {
1901                    let wiq = match *func_ctx.resolve_type(image, &module.types) {
1902                        crate::TypeInner::Image {
1903                            dim,
1904                            arrayed,
1905                            class,
1906                        } => WrappedImageQuery {
1907                            dim,
1908                            arrayed,
1909                            class,
1910                            query: query.into(),
1911                        },
1912                        _ => unreachable!("we only query images"),
1913                    };
1914
1915                    if self.wrapped.insert(WrappedType::ImageQuery(wiq)) {
1916                        self.write_wrapped_image_query_function(module, wiq, handle, func_ctx)?;
1917                    }
1918                }
1919                // Write `WrappedConstructor` for structs that are loaded from `AddressSpace::Storage`
1920                // since they will later be used by the fn `write_storage_load`
1921                crate::Expression::Load { pointer } => {
1922                    let pointer_space = func_ctx
1923                        .resolve_type(pointer, &module.types)
1924                        .pointer_space();
1925
1926                    if let Some(crate::AddressSpace::Storage { .. }) = pointer_space {
1927                        if let Some(ty) = func_ctx.info[handle].ty.handle() {
1928                            write_wrapped_constructor(self, ty, module)?;
1929                        }
1930                    }
1931
1932                    fn write_wrapped_constructor<W: Write>(
1933                        writer: &mut super::Writer<'_, W>,
1934                        ty: Handle<crate::Type>,
1935                        module: &crate::Module,
1936                    ) -> BackendResult {
1937                        match module.types[ty].inner {
1938                            crate::TypeInner::Struct { ref members, .. } => {
1939                                for member in members {
1940                                    write_wrapped_constructor(writer, member.ty, module)?;
1941                                }
1942
1943                                let constructor = WrappedConstructor { ty };
1944                                if writer.wrapped.insert(WrappedType::Constructor(constructor)) {
1945                                    writer
1946                                        .write_wrapped_constructor_function(module, constructor)?;
1947                                }
1948                            }
1949                            crate::TypeInner::Array { base, .. } => {
1950                                write_wrapped_constructor(writer, base, module)?;
1951
1952                                let constructor = WrappedConstructor { ty };
1953                                if writer.wrapped.insert(WrappedType::Constructor(constructor)) {
1954                                    writer
1955                                        .write_wrapped_constructor_function(module, constructor)?;
1956                                }
1957                            }
1958                            _ => {}
1959                        };
1960
1961                        Ok(())
1962                    }
1963                }
1964                // We treat matrices of the form `matCx2` as a sequence of C `vec2`s
1965                // (see top level module docs for details).
1966                //
1967                // The functions injected here are required to get the matrix accesses working.
1968                crate::Expression::AccessIndex { base, index } => {
1969                    let base_ty_res = &func_ctx.info[base].ty;
1970                    let mut resolved = base_ty_res.inner_with(&module.types);
1971                    let base_ty_handle = match *resolved {
1972                        crate::TypeInner::Pointer { base, .. } => {
1973                            resolved = &module.types[base].inner;
1974                            Some(base)
1975                        }
1976                        _ => base_ty_res.handle(),
1977                    };
1978                    if let crate::TypeInner::Struct { ref members, .. } = *resolved {
1979                        let member = &members[index as usize];
1980
1981                        match module.types[member.ty].inner {
1982                            crate::TypeInner::Matrix {
1983                                rows: crate::VectorSize::Bi,
1984                                ..
1985                            } if member.binding.is_none() => {
1986                                let ty = base_ty_handle.unwrap();
1987                                let access = WrappedStructMatrixAccess { ty, index };
1988
1989                                if self.wrapped.insert(WrappedType::StructMatrixAccess(access)) {
1990                                    self.write_wrapped_struct_matrix_get_function(module, access)?;
1991                                    self.write_wrapped_struct_matrix_set_function(module, access)?;
1992                                    self.write_wrapped_struct_matrix_set_vec_function(
1993                                        module, access,
1994                                    )?;
1995                                    self.write_wrapped_struct_matrix_set_scalar_function(
1996                                        module, access,
1997                                    )?;
1998                                }
1999                            }
2000                            _ => {}
2001                        }
2002                    }
2003                }
2004                _ => {}
2005            };
2006        }
2007
2008        Ok(())
2009    }
2010
2011    /// Writes out the sampler heap declarations if they haven't been written yet.
2012    pub(super) fn write_sampler_heaps(&mut self) -> BackendResult {
2013        if self.wrapped.sampler_heaps {
2014            return Ok(());
2015        }
2016
2017        writeln!(
2018            self.out,
2019            "SamplerState {}[2048]: register(s{}, space{});",
2020            super::writer::SAMPLER_HEAP_VAR,
2021            self.options.sampler_heap_target.standard_samplers.register,
2022            self.options.sampler_heap_target.standard_samplers.space
2023        )?;
2024        writeln!(
2025            self.out,
2026            "SamplerComparisonState {}[2048]: register(s{}, space{});",
2027            super::writer::COMPARISON_SAMPLER_HEAP_VAR,
2028            self.options
2029                .sampler_heap_target
2030                .comparison_samplers
2031                .register,
2032            self.options.sampler_heap_target.comparison_samplers.space
2033        )?;
2034
2035        self.wrapped.sampler_heaps = true;
2036
2037        Ok(())
2038    }
2039
2040    /// Writes out the sampler index buffer declaration if it hasn't been written yet.
2041    pub(super) fn write_wrapped_sampler_buffer(
2042        &mut self,
2043        key: super::SamplerIndexBufferKey,
2044    ) -> BackendResult {
2045        // The astute will notice that we do a double hash lookup, but we do this to avoid
2046        // holding a mutable reference to `self` while trying to call `write_sampler_heaps`.
2047        //
2048        // We only pay this double lookup cost when we actually need to write out the sampler
2049        // buffer, which should be not be common.
2050
2051        if self.wrapped.sampler_index_buffers.contains_key(&key) {
2052            return Ok(());
2053        };
2054
2055        self.write_sampler_heaps()?;
2056
2057        // Because the group number can be arbitrary, we use the namer to generate a unique name
2058        // instead of adding it to the reserved name list.
2059        let sampler_array_name = self
2060            .namer
2061            .call(&format!("nagaGroup{}SamplerIndexArray", key.group));
2062
2063        let bind_target = match self.options.sampler_buffer_binding_map.get(&key) {
2064            Some(&bind_target) => bind_target,
2065            None if self.options.fake_missing_bindings => super::BindTarget {
2066                space: u8::MAX,
2067                register: key.group,
2068                binding_array_size: None,
2069                dynamic_storage_buffer_offsets_index: None,
2070                restrict_indexing: false,
2071            },
2072            None => {
2073                unreachable!("Sampler buffer of group {key:?} not bound to a register");
2074            }
2075        };
2076
2077        writeln!(
2078            self.out,
2079            "StructuredBuffer<uint> {sampler_array_name} : register(t{}, space{});",
2080            bind_target.register, bind_target.space
2081        )?;
2082
2083        self.wrapped
2084            .sampler_index_buffers
2085            .insert(key, sampler_array_name);
2086
2087        Ok(())
2088    }
2089
2090    pub(super) fn write_texture_coordinates(
2091        &mut self,
2092        kind: &str,
2093        coordinate: Handle<crate::Expression>,
2094        array_index: Option<Handle<crate::Expression>>,
2095        mip_level: Option<Handle<crate::Expression>>,
2096        module: &crate::Module,
2097        func_ctx: &FunctionCtx,
2098    ) -> BackendResult {
2099        // HLSL expects the array index to be merged with the coordinate
2100        let extra = array_index.is_some() as usize + (mip_level.is_some()) as usize;
2101        if extra == 0 {
2102            self.write_expr(module, coordinate, func_ctx)?;
2103        } else {
2104            let num_coords = match *func_ctx.resolve_type(coordinate, &module.types) {
2105                crate::TypeInner::Scalar { .. } => 1,
2106                crate::TypeInner::Vector { size, .. } => size as usize,
2107                _ => unreachable!(),
2108            };
2109            write!(self.out, "{}{}(", kind, num_coords + extra)?;
2110            self.write_expr(module, coordinate, func_ctx)?;
2111            if let Some(expr) = array_index {
2112                write!(self.out, ", ")?;
2113                self.write_expr(module, expr, func_ctx)?;
2114            }
2115            if let Some(expr) = mip_level {
2116                // Explicit cast if needed
2117                let cast_to_int = matches!(
2118                    *func_ctx.resolve_type(expr, &module.types),
2119                    crate::TypeInner::Scalar(crate::Scalar {
2120                        kind: ScalarKind::Uint,
2121                        ..
2122                    })
2123                );
2124
2125                write!(self.out, ", ")?;
2126
2127                if cast_to_int {
2128                    write!(self.out, "int(")?;
2129                }
2130
2131                self.write_expr(module, expr, func_ctx)?;
2132
2133                if cast_to_int {
2134                    write!(self.out, ")")?;
2135                }
2136            }
2137            write!(self.out, ")")?;
2138        }
2139        Ok(())
2140    }
2141
2142    pub(super) fn write_mat_cx2_typedef_and_functions(
2143        &mut self,
2144        WrappedMatCx2 { columns }: WrappedMatCx2,
2145    ) -> BackendResult {
2146        use crate::back::INDENT;
2147
2148        // typedef
2149        write!(self.out, "typedef struct {{ ")?;
2150        for i in 0..columns as u8 {
2151            write!(self.out, "float2 _{i}; ")?;
2152        }
2153        writeln!(self.out, "}} __mat{}x2;", columns as u8)?;
2154
2155        // __get_col_of_mat
2156        writeln!(
2157            self.out,
2158            "float2 __get_col_of_mat{}x2(__mat{}x2 mat, uint idx) {{",
2159            columns as u8, columns as u8
2160        )?;
2161        writeln!(self.out, "{INDENT}switch(idx) {{")?;
2162        for i in 0..columns as u8 {
2163            writeln!(self.out, "{INDENT}case {i}: {{ return mat._{i}; }}")?;
2164        }
2165        writeln!(self.out, "{INDENT}default: {{ return (float2)0; }}")?;
2166        writeln!(self.out, "{INDENT}}}")?;
2167        writeln!(self.out, "}}")?;
2168
2169        // __set_col_of_mat
2170        writeln!(
2171            self.out,
2172            "void __set_col_of_mat{}x2(__mat{}x2 mat, uint idx, float2 value) {{",
2173            columns as u8, columns as u8
2174        )?;
2175        writeln!(self.out, "{INDENT}switch(idx) {{")?;
2176        for i in 0..columns as u8 {
2177            writeln!(self.out, "{INDENT}case {i}: {{ mat._{i} = value; break; }}")?;
2178        }
2179        writeln!(self.out, "{INDENT}}}")?;
2180        writeln!(self.out, "}}")?;
2181
2182        // __set_el_of_mat
2183        writeln!(
2184            self.out,
2185            "void __set_el_of_mat{}x2(__mat{}x2 mat, uint idx, uint vec_idx, float value) {{",
2186            columns as u8, columns as u8
2187        )?;
2188        writeln!(self.out, "{INDENT}switch(idx) {{")?;
2189        for i in 0..columns as u8 {
2190            writeln!(
2191                self.out,
2192                "{INDENT}case {i}: {{ mat._{i}[vec_idx] = value; break; }}"
2193            )?;
2194        }
2195        writeln!(self.out, "{INDENT}}}")?;
2196        writeln!(self.out, "}}")?;
2197
2198        writeln!(self.out)?;
2199
2200        Ok(())
2201    }
2202
2203    pub(super) fn write_all_mat_cx2_typedefs_and_functions(
2204        &mut self,
2205        module: &crate::Module,
2206    ) -> BackendResult {
2207        for (handle, _) in module.global_variables.iter() {
2208            let global = &module.global_variables[handle];
2209
2210            if global.space == crate::AddressSpace::Uniform {
2211                if let Some(super::writer::MatrixType {
2212                    columns,
2213                    rows: crate::VectorSize::Bi,
2214                    width: 4,
2215                }) = super::writer::get_inner_matrix_data(module, global.ty)
2216                {
2217                    let entry = WrappedMatCx2 { columns };
2218                    if self.wrapped.insert(WrappedType::MatCx2(entry)) {
2219                        self.write_mat_cx2_typedef_and_functions(entry)?;
2220                    }
2221                }
2222            }
2223        }
2224
2225        for (_, ty) in module.types.iter() {
2226            if let crate::TypeInner::Struct { ref members, .. } = ty.inner {
2227                for member in members.iter() {
2228                    if let crate::TypeInner::Array { .. } = module.types[member.ty].inner {
2229                        if let Some(super::writer::MatrixType {
2230                            columns,
2231                            rows: crate::VectorSize::Bi,
2232                            width: 4,
2233                        }) = super::writer::get_inner_matrix_data(module, member.ty)
2234                        {
2235                            let entry = WrappedMatCx2 { columns };
2236                            if self.wrapped.insert(WrappedType::MatCx2(entry)) {
2237                                self.write_mat_cx2_typedef_and_functions(entry)?;
2238                            }
2239                        }
2240                    }
2241                }
2242            }
2243        }
2244
2245        Ok(())
2246    }
2247
2248    pub(super) fn write_wrapped_zero_value_function_name(
2249        &mut self,
2250        module: &crate::Module,
2251        zero_value: WrappedZeroValue,
2252    ) -> BackendResult {
2253        let name = crate::TypeInner::hlsl_type_id(zero_value.ty, module.to_ctx(), &self.names)?;
2254        write!(self.out, "ZeroValue{name}")?;
2255        Ok(())
2256    }
2257
2258    /// Helper function that write wrapped function for `Expression::ZeroValue`
2259    ///
2260    /// This is necessary since we might have a member access after the zero value expression, e.g.
2261    /// `.y` (in practice this can come up when consuming SPIRV that's been produced by glslc).
2262    ///
2263    /// So we can't just write `(float4)0` since `(float4)0.y` won't parse correctly.
2264    ///
2265    /// Parenthesizing the expression like `((float4)0).y` would work... except DXC can't handle
2266    /// cases like:
2267    ///
2268    /// ```text
2269    /// tests\out\hlsl\access.hlsl:183:41: error: cannot compile this l-value expression yet
2270    ///     t_1.am = (__mat4x2[2])((float4x2[2])0);
2271    ///                                         ^
2272    /// ```
2273    fn write_wrapped_zero_value_function(
2274        &mut self,
2275        module: &crate::Module,
2276        zero_value: WrappedZeroValue,
2277    ) -> BackendResult {
2278        use crate::back::INDENT;
2279
2280        // Write function return type and name
2281        if let crate::TypeInner::Array { base, size, .. } = module.types[zero_value.ty].inner {
2282            write!(self.out, "typedef ")?;
2283            self.write_type(module, zero_value.ty)?;
2284            write!(self.out, " ret_")?;
2285            self.write_wrapped_zero_value_function_name(module, zero_value)?;
2286            self.write_array_size(module, base, size)?;
2287            writeln!(self.out, ";")?;
2288
2289            write!(self.out, "ret_")?;
2290            self.write_wrapped_zero_value_function_name(module, zero_value)?;
2291        } else {
2292            self.write_type(module, zero_value.ty)?;
2293        }
2294        write!(self.out, " ")?;
2295        self.write_wrapped_zero_value_function_name(module, zero_value)?;
2296
2297        // Write function parameters (none) and start function body
2298        writeln!(self.out, "() {{")?;
2299
2300        // Write `ZeroValue` function.
2301        write!(self.out, "{INDENT}return ")?;
2302        self.write_default_init(module, zero_value.ty)?;
2303        writeln!(self.out, ";")?;
2304
2305        // End of function body
2306        writeln!(self.out, "}}")?;
2307        // Write extra new line
2308        writeln!(self.out)?;
2309
2310        Ok(())
2311    }
2312}
2313
2314impl crate::StorageFormat {
2315    /// Returns `true` if there is just one component, otherwise `false`
2316    pub(super) const fn single_component(&self) -> bool {
2317        match *self {
2318            crate::StorageFormat::R16Float
2319            | crate::StorageFormat::R32Float
2320            | crate::StorageFormat::R8Unorm
2321            | crate::StorageFormat::R16Unorm
2322            | crate::StorageFormat::R8Snorm
2323            | crate::StorageFormat::R16Snorm
2324            | crate::StorageFormat::R8Uint
2325            | crate::StorageFormat::R16Uint
2326            | crate::StorageFormat::R32Uint
2327            | crate::StorageFormat::R8Sint
2328            | crate::StorageFormat::R16Sint
2329            | crate::StorageFormat::R32Sint
2330            | crate::StorageFormat::R64Uint => true,
2331            _ => false,
2332        }
2333    }
2334}