naga/back/hlsl/
storage.rs

1/*!
2Generating accesses to [`ByteAddressBuffer`] contents.
3
4Naga IR globals in the [`Storage`] address space are rendered as
5[`ByteAddressBuffer`]s or [`RWByteAddressBuffer`]s in HLSL. These
6buffers don't have HLSL types (structs, arrays, etc.); instead, they
7are just raw blocks of bytes, with methods to load and store values of
8specific types at particular byte offsets. This means that Naga must
9translate chains of [`Access`] and [`AccessIndex`] expressions into
10HLSL expressions that compute byte offsets into the buffer.
11
12To generate code for a [`Storage`] access:
13
14- Call [`Writer::fill_access_chain`] on the expression referring to
15  the value. This populates [`Writer::temp_access_chain`] with the
16  appropriate byte offset calculations, as a vector of [`SubAccess`]
17  values.
18
19- Call [`Writer::write_storage_address`] to emit an HLSL expression
20  for a given slice of [`SubAccess`] values.
21
22Naga IR expressions can operate on composite values of any type, but
23[`ByteAddressBuffer`] and [`RWByteAddressBuffer`] have only a fixed
24set of `Load` and `Store` methods, to access one through four
25consecutive 32-bit values. To synthesize a Naga access, you can
26initialize [`temp_access_chain`] to refer to the composite, and then
27temporarily push and pop additional steps on
28[`Writer::temp_access_chain`] to generate accesses to the individual
29elements/members.
30
31The [`temp_access_chain`] field is a member of [`Writer`] solely to
32allow re-use of the `Vec`'s dynamic allocation. Its value is no longer
33needed once HLSL for the access has been generated.
34
35Note about DXC and Load/Store functions:
36
37DXC's HLSL has a generic [`Load` and `Store`] function for [`ByteAddressBuffer`] and
38[`RWByteAddressBuffer`]. This is not available in FXC's HLSL, so we use
39it only for types that are only available in DXC. Notably 64 and 16 bit types.
40
41FXC's HLSL has functions Load, Load2, Load3, and Load4 and Store, Store2, Store3, Store4.
42This loads/stores a vector of length 1, 2, 3, or 4. We use that for 32bit types, bitcasting to the
43correct type if necessary.
44
45[`Storage`]: crate::AddressSpace::Storage
46[`ByteAddressBuffer`]: https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-object-byteaddressbuffer
47[`RWByteAddressBuffer`]: https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-object-rwbyteaddressbuffer
48[`Access`]: crate::Expression::Access
49[`AccessIndex`]: crate::Expression::AccessIndex
50[`Writer::fill_access_chain`]: super::Writer::fill_access_chain
51[`Writer::write_storage_address`]: super::Writer::write_storage_address
52[`Writer::temp_access_chain`]: super::Writer::temp_access_chain
53[`temp_access_chain`]: super::Writer::temp_access_chain
54[`Writer`]: super::Writer
55[`Load` and `Store`]: https://github.com/microsoft/DirectXShaderCompiler/wiki/ByteAddressBuffer-Load-Store-Additions
56*/
57
58use alloc::format;
59use core::{fmt, mem};
60
61use super::{super::FunctionCtx, BackendResult, Error};
62use crate::{
63    proc::{Alignment, NameKey, TypeResolution},
64    Handle,
65};
66
67const STORE_TEMP_NAME: &str = "_value";
68
69/// One step in accessing a [`Storage`] global's component or element.
70///
71/// [`Writer::temp_access_chain`] holds a series of these structures,
72/// describing how to compute the byte offset of a particular element
73/// or member of some global variable in the [`Storage`] address
74/// space.
75///
76/// [`Writer::temp_access_chain`]: super::Writer::temp_access_chain
77/// [`Storage`]: crate::AddressSpace::Storage
78#[derive(Debug)]
79pub(super) enum SubAccess {
80    BufferOffset {
81        group: u32,
82        offset: u32,
83    },
84
85    /// Add the given byte offset. This is used for struct members, or
86    /// known components of a vector or matrix. In all those cases,
87    /// the byte offset is a compile-time constant.
88    Offset(u32),
89
90    /// Scale `value` by `stride`, and add that to the current byte
91    /// offset. This is used to compute the offset of an array element
92    /// whose index is computed at runtime.
93    Index {
94        value: Handle<crate::Expression>,
95        stride: u32,
96    },
97}
98
99pub(super) enum StoreValue {
100    Expression(Handle<crate::Expression>),
101    TempIndex {
102        depth: usize,
103        index: u32,
104        ty: TypeResolution,
105    },
106    TempAccess {
107        depth: usize,
108        base: Handle<crate::Type>,
109        member_index: u32,
110    },
111    // Access to a single column of a Cx2 matrix within a struct
112    TempColumnAccess {
113        depth: usize,
114        base: Handle<crate::Type>,
115        member_index: u32,
116        column: u32,
117    },
118}
119
120impl<W: fmt::Write> super::Writer<'_, W> {
121    pub(super) fn write_storage_address(
122        &mut self,
123        module: &crate::Module,
124        chain: &[SubAccess],
125        func_ctx: &FunctionCtx,
126    ) -> BackendResult {
127        if chain.is_empty() {
128            write!(self.out, "0")?;
129        }
130        for (i, access) in chain.iter().enumerate() {
131            if i != 0 {
132                write!(self.out, "+")?;
133            }
134            match *access {
135                SubAccess::BufferOffset { group, offset } => {
136                    write!(self.out, "__dynamic_buffer_offsets{group}._{offset}")?;
137                }
138                SubAccess::Offset(offset) => {
139                    write!(self.out, "{offset}")?;
140                }
141                SubAccess::Index { value, stride } => {
142                    self.write_expr(module, value, func_ctx)?;
143                    write!(self.out, "*{stride}")?;
144                }
145            }
146        }
147        Ok(())
148    }
149
150    fn write_storage_load_sequence<I: Iterator<Item = (TypeResolution, u32)>>(
151        &mut self,
152        module: &crate::Module,
153        var_handle: Handle<crate::GlobalVariable>,
154        sequence: I,
155        func_ctx: &FunctionCtx,
156    ) -> BackendResult {
157        for (i, (ty_resolution, offset)) in sequence.enumerate() {
158            // add the index temporarily
159            self.temp_access_chain.push(SubAccess::Offset(offset));
160            if i != 0 {
161                write!(self.out, ", ")?;
162            };
163            self.write_storage_load(module, var_handle, ty_resolution, func_ctx)?;
164            self.temp_access_chain.pop();
165        }
166        Ok(())
167    }
168
169    /// Emit code to access a [`Storage`] global's component.
170    ///
171    /// Emit HLSL to access the component of `var_handle`, a global
172    /// variable in the [`Storage`] address space, whose type is
173    /// `result_ty` and whose location within the global is given by
174    /// [`self.temp_access_chain`]. See the [`storage`] module's
175    /// documentation for background.
176    ///
177    /// [`Storage`]: crate::AddressSpace::Storage
178    /// [`self.temp_access_chain`]: super::Writer::temp_access_chain
179    pub(super) fn write_storage_load(
180        &mut self,
181        module: &crate::Module,
182        var_handle: Handle<crate::GlobalVariable>,
183        result_ty: TypeResolution,
184        func_ctx: &FunctionCtx,
185    ) -> BackendResult {
186        match *result_ty.inner_with(&module.types) {
187            crate::TypeInner::Scalar(scalar) => {
188                // working around the borrow checker in `self.write_expr`
189                let chain = mem::take(&mut self.temp_access_chain);
190                let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
191                // See note about DXC and Load/Store in the module's documentation.
192                if scalar.width == 4 {
193                    let cast = scalar.kind.to_hlsl_cast();
194                    write!(self.out, "{cast}({var_name}.Load(")?;
195                } else {
196                    let ty = scalar.to_hlsl_str()?;
197                    write!(self.out, "{var_name}.Load<{ty}>(")?;
198                };
199                self.write_storage_address(module, &chain, func_ctx)?;
200                write!(self.out, ")")?;
201                if scalar.width == 4 {
202                    write!(self.out, ")")?;
203                }
204                self.temp_access_chain = chain;
205            }
206            crate::TypeInner::Vector { size, scalar } => {
207                // working around the borrow checker in `self.write_expr`
208                let chain = mem::take(&mut self.temp_access_chain);
209                let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
210                let size = size as u8;
211                // See note about DXC and Load/Store in the module's documentation.
212                if scalar.width == 4 {
213                    let cast = scalar.kind.to_hlsl_cast();
214                    write!(self.out, "{cast}({var_name}.Load{size}(")?;
215                } else {
216                    let ty = scalar.to_hlsl_str()?;
217                    write!(self.out, "{var_name}.Load<{ty}{size}>(")?;
218                };
219                self.write_storage_address(module, &chain, func_ctx)?;
220                write!(self.out, ")")?;
221                if scalar.width == 4 {
222                    write!(self.out, ")")?;
223                }
224                self.temp_access_chain = chain;
225            }
226            crate::TypeInner::Matrix {
227                columns,
228                rows,
229                scalar,
230            } => {
231                write!(
232                    self.out,
233                    "{}{}x{}(",
234                    scalar.to_hlsl_str()?,
235                    columns as u8,
236                    rows as u8,
237                )?;
238
239                // Note: Matrices containing vec3s, due to padding, act like they contain vec4s.
240                let row_stride = Alignment::from(rows) * scalar.width as u32;
241                let iter = (0..columns as u32).map(|i| {
242                    let ty_inner = crate::TypeInner::Vector { size: rows, scalar };
243                    (TypeResolution::Value(ty_inner), i * row_stride)
244                });
245                self.write_storage_load_sequence(module, var_handle, iter, func_ctx)?;
246                write!(self.out, ")")?;
247            }
248            crate::TypeInner::Array {
249                base,
250                size: crate::ArraySize::Constant(size),
251                stride,
252            } => {
253                let constructor = super::help::WrappedConstructor {
254                    ty: result_ty.handle().unwrap(),
255                };
256                self.write_wrapped_constructor_function_name(module, constructor)?;
257                write!(self.out, "(")?;
258                let iter = (0..size.get()).map(|i| (TypeResolution::Handle(base), stride * i));
259                self.write_storage_load_sequence(module, var_handle, iter, func_ctx)?;
260                write!(self.out, ")")?;
261            }
262            crate::TypeInner::Struct { ref members, .. } => {
263                let constructor = super::help::WrappedConstructor {
264                    ty: result_ty.handle().unwrap(),
265                };
266                self.write_wrapped_constructor_function_name(module, constructor)?;
267                write!(self.out, "(")?;
268                let iter = members
269                    .iter()
270                    .map(|m| (TypeResolution::Handle(m.ty), m.offset));
271                self.write_storage_load_sequence(module, var_handle, iter, func_ctx)?;
272                write!(self.out, ")")?;
273            }
274            _ => unreachable!(),
275        }
276        Ok(())
277    }
278
279    fn write_store_value(
280        &mut self,
281        module: &crate::Module,
282        value: &StoreValue,
283        func_ctx: &FunctionCtx,
284    ) -> BackendResult {
285        match *value {
286            StoreValue::Expression(expr) => self.write_expr(module, expr, func_ctx)?,
287            StoreValue::TempIndex {
288                depth,
289                index,
290                ty: _,
291            } => write!(self.out, "{STORE_TEMP_NAME}{depth}[{index}]")?,
292            StoreValue::TempAccess {
293                depth,
294                base,
295                member_index,
296            } => {
297                let name = &self.names[&NameKey::StructMember(base, member_index)];
298                write!(self.out, "{STORE_TEMP_NAME}{depth}.{name}")?
299            }
300            StoreValue::TempColumnAccess {
301                depth,
302                base,
303                member_index,
304                column,
305            } => {
306                let name = &self.names[&NameKey::StructMember(base, member_index)];
307                write!(self.out, "{STORE_TEMP_NAME}{depth}.{name}_{column}")?
308            }
309        }
310        Ok(())
311    }
312
313    /// Helper function to write down the Store operation on a `ByteAddressBuffer`.
314    pub(super) fn write_storage_store(
315        &mut self,
316        module: &crate::Module,
317        var_handle: Handle<crate::GlobalVariable>,
318        value: StoreValue,
319        func_ctx: &FunctionCtx,
320        level: crate::back::Level,
321        within_struct: Option<Handle<crate::Type>>,
322    ) -> BackendResult {
323        let temp_resolution;
324        let ty_resolution = match value {
325            StoreValue::Expression(expr) => &func_ctx.info[expr].ty,
326            StoreValue::TempIndex {
327                depth: _,
328                index: _,
329                ref ty,
330            } => ty,
331            StoreValue::TempAccess {
332                depth: _,
333                base,
334                member_index,
335            } => {
336                let ty_handle = match module.types[base].inner {
337                    crate::TypeInner::Struct { ref members, .. } => {
338                        members[member_index as usize].ty
339                    }
340                    _ => unreachable!(),
341                };
342                temp_resolution = TypeResolution::Handle(ty_handle);
343                &temp_resolution
344            }
345            StoreValue::TempColumnAccess { .. } => {
346                unreachable!("attempting write_storage_store for TempColumnAccess");
347            }
348        };
349        match *ty_resolution.inner_with(&module.types) {
350            crate::TypeInner::Scalar(scalar) => {
351                // working around the borrow checker in `self.write_expr`
352                let chain = mem::take(&mut self.temp_access_chain);
353                let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
354                // See note about DXC and Load/Store in the module's documentation.
355                if scalar.width == 4 {
356                    write!(self.out, "{level}{var_name}.Store(")?;
357                    self.write_storage_address(module, &chain, func_ctx)?;
358                    write!(self.out, ", asuint(")?;
359                    self.write_store_value(module, &value, func_ctx)?;
360                    writeln!(self.out, "));")?;
361                } else {
362                    write!(self.out, "{level}{var_name}.Store(")?;
363                    self.write_storage_address(module, &chain, func_ctx)?;
364                    write!(self.out, ", ")?;
365                    self.write_store_value(module, &value, func_ctx)?;
366                    writeln!(self.out, ");")?;
367                }
368                self.temp_access_chain = chain;
369            }
370            crate::TypeInner::Vector { size, scalar } => {
371                // working around the borrow checker in `self.write_expr`
372                let chain = mem::take(&mut self.temp_access_chain);
373                let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
374                // See note about DXC and Load/Store in the module's documentation.
375                if scalar.width == 4 {
376                    write!(self.out, "{}{}.Store{}(", level, var_name, size as u8)?;
377                    self.write_storage_address(module, &chain, func_ctx)?;
378                    write!(self.out, ", asuint(")?;
379                    self.write_store_value(module, &value, func_ctx)?;
380                    writeln!(self.out, "));")?;
381                } else {
382                    write!(self.out, "{level}{var_name}.Store(")?;
383                    self.write_storage_address(module, &chain, func_ctx)?;
384                    write!(self.out, ", ")?;
385                    self.write_store_value(module, &value, func_ctx)?;
386                    writeln!(self.out, ");")?;
387                }
388                self.temp_access_chain = chain;
389            }
390            crate::TypeInner::Matrix {
391                columns,
392                rows,
393                scalar,
394            } => {
395                // Note: Matrices containing vec3s, due to padding, act like they contain vec4s.
396                let row_stride = Alignment::from(rows) * scalar.width as u32;
397
398                writeln!(self.out, "{level}{{")?;
399
400                match within_struct {
401                    Some(containing_struct) if rows == crate::VectorSize::Bi => {
402                        // If we are within a struct, then the struct was already assigned to
403                        // a temporary, we don't need to make another.
404                        let mut chain = mem::take(&mut self.temp_access_chain);
405                        for i in 0..columns as u32 {
406                            chain.push(SubAccess::Offset(i * row_stride));
407                            // working around the borrow checker in `self.write_expr`
408                            let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
409                            let StoreValue::TempAccess { member_index, .. } = value else {
410                                unreachable!(
411                                    "write_storage_store within_struct but not TempAccess"
412                                );
413                            };
414                            let column_value = StoreValue::TempColumnAccess {
415                                depth: level.0, // note not incrementing, b/c no temp
416                                base: containing_struct,
417                                member_index,
418                                column: i,
419                            };
420                            // See note about DXC and Load/Store in the module's documentation.
421                            if scalar.width == 4 {
422                                write!(
423                                    self.out,
424                                    "{}{}.Store{}(",
425                                    level.next(),
426                                    var_name,
427                                    rows as u8
428                                )?;
429                                self.write_storage_address(module, &chain, func_ctx)?;
430                                write!(self.out, ", asuint(")?;
431                                self.write_store_value(module, &column_value, func_ctx)?;
432                                writeln!(self.out, "));")?;
433                            } else {
434                                write!(self.out, "{}{var_name}.Store(", level.next())?;
435                                self.write_storage_address(module, &chain, func_ctx)?;
436                                write!(self.out, ", ")?;
437                                self.write_store_value(module, &column_value, func_ctx)?;
438                                writeln!(self.out, ");")?;
439                            }
440                            chain.pop();
441                        }
442                        self.temp_access_chain = chain;
443                    }
444                    _ => {
445                        // first, assign the value to a temporary
446                        let depth = level.0 + 1;
447                        write!(
448                            self.out,
449                            "{}{}{}x{} {}{} = ",
450                            level.next(),
451                            scalar.to_hlsl_str()?,
452                            columns as u8,
453                            rows as u8,
454                            STORE_TEMP_NAME,
455                            depth,
456                        )?;
457                        self.write_store_value(module, &value, func_ctx)?;
458                        writeln!(self.out, ";")?;
459
460                        // then iterate the stores
461                        for i in 0..columns as u32 {
462                            self.temp_access_chain
463                                .push(SubAccess::Offset(i * row_stride));
464                            let ty_inner = crate::TypeInner::Vector { size: rows, scalar };
465                            let sv = StoreValue::TempIndex {
466                                depth,
467                                index: i,
468                                ty: TypeResolution::Value(ty_inner),
469                            };
470                            self.write_storage_store(
471                                module,
472                                var_handle,
473                                sv,
474                                func_ctx,
475                                level.next(),
476                                None,
477                            )?;
478                            self.temp_access_chain.pop();
479                        }
480                    }
481                }
482                // done
483                writeln!(self.out, "{level}}}")?;
484            }
485            crate::TypeInner::Array {
486                base,
487                size: crate::ArraySize::Constant(size),
488                stride,
489            } => {
490                // first, assign the value to a temporary
491                writeln!(self.out, "{level}{{")?;
492                write!(self.out, "{}", level.next())?;
493                self.write_type(module, base)?;
494                let depth = level.next().0;
495                write!(self.out, " {STORE_TEMP_NAME}{depth}")?;
496                self.write_array_size(module, base, crate::ArraySize::Constant(size))?;
497                write!(self.out, " = ")?;
498                self.write_store_value(module, &value, func_ctx)?;
499                writeln!(self.out, ";")?;
500                // then iterate the stores
501                for i in 0..size.get() {
502                    self.temp_access_chain.push(SubAccess::Offset(i * stride));
503                    let sv = StoreValue::TempIndex {
504                        depth,
505                        index: i,
506                        ty: TypeResolution::Handle(base),
507                    };
508                    self.write_storage_store(module, var_handle, sv, func_ctx, level.next(), None)?;
509                    self.temp_access_chain.pop();
510                }
511                // done
512                writeln!(self.out, "{level}}}")?;
513            }
514            crate::TypeInner::Struct { ref members, .. } => {
515                // first, assign the value to a temporary
516                writeln!(self.out, "{level}{{")?;
517                let depth = level.next().0;
518                let struct_ty = ty_resolution.handle().unwrap();
519                let struct_name = &self.names[&NameKey::Type(struct_ty)];
520                write!(
521                    self.out,
522                    "{}{} {}{} = ",
523                    level.next(),
524                    struct_name,
525                    STORE_TEMP_NAME,
526                    depth
527                )?;
528                self.write_store_value(module, &value, func_ctx)?;
529                writeln!(self.out, ";")?;
530                // then iterate the stores
531                for (i, member) in members.iter().enumerate() {
532                    self.temp_access_chain
533                        .push(SubAccess::Offset(member.offset));
534                    let sv = StoreValue::TempAccess {
535                        depth,
536                        base: struct_ty,
537                        member_index: i as u32,
538                    };
539                    self.write_storage_store(
540                        module,
541                        var_handle,
542                        sv,
543                        func_ctx,
544                        level.next(),
545                        Some(struct_ty),
546                    )?;
547                    self.temp_access_chain.pop();
548                }
549                // done
550                writeln!(self.out, "{level}}}")?;
551            }
552            _ => unreachable!(),
553        }
554        Ok(())
555    }
556
557    /// Set [`temp_access_chain`] to compute the byte offset of `cur_expr`.
558    ///
559    /// The `cur_expr` expression must be a reference to a global
560    /// variable in the [`Storage`] address space, or a chain of
561    /// [`Access`] and [`AccessIndex`] expressions referring to some
562    /// component of such a global.
563    ///
564    /// [`temp_access_chain`]: super::Writer::temp_access_chain
565    /// [`Storage`]: crate::AddressSpace::Storage
566    /// [`Access`]: crate::Expression::Access
567    /// [`AccessIndex`]: crate::Expression::AccessIndex
568    pub(super) fn fill_access_chain(
569        &mut self,
570        module: &crate::Module,
571        mut cur_expr: Handle<crate::Expression>,
572        func_ctx: &FunctionCtx,
573    ) -> Result<Handle<crate::GlobalVariable>, Error> {
574        enum AccessIndex {
575            Expression(Handle<crate::Expression>),
576            Constant(u32),
577        }
578        enum Parent<'a> {
579            Array { stride: u32 },
580            Struct(&'a [crate::StructMember]),
581        }
582        self.temp_access_chain.clear();
583
584        loop {
585            let (next_expr, access_index) = match func_ctx.expressions[cur_expr] {
586                crate::Expression::GlobalVariable(handle) => {
587                    if let Some(ref binding) = module.global_variables[handle].binding {
588                        // this was already resolved earlier when we started evaluating an entry point.
589                        let bt = self.options.resolve_resource_binding(binding).unwrap();
590                        if let Some(dynamic_storage_buffer_offsets_index) =
591                            bt.dynamic_storage_buffer_offsets_index
592                        {
593                            self.temp_access_chain.push(SubAccess::BufferOffset {
594                                group: binding.group,
595                                offset: dynamic_storage_buffer_offsets_index,
596                            });
597                        }
598                    }
599                    return Ok(handle);
600                }
601                crate::Expression::Access { base, index } => (base, AccessIndex::Expression(index)),
602                crate::Expression::AccessIndex { base, index } => {
603                    (base, AccessIndex::Constant(index))
604                }
605                ref other => {
606                    return Err(Error::Unimplemented(format!("Pointer access of {other:?}")))
607                }
608            };
609
610            let parent = match *func_ctx.resolve_type(next_expr, &module.types) {
611                crate::TypeInner::Pointer { base, .. } => match module.types[base].inner {
612                    crate::TypeInner::Struct { ref members, .. } => Parent::Struct(members),
613                    crate::TypeInner::Array { stride, .. } => Parent::Array { stride },
614                    crate::TypeInner::Vector { scalar, .. } => Parent::Array {
615                        stride: scalar.width as u32,
616                    },
617                    crate::TypeInner::Matrix { rows, scalar, .. } => Parent::Array {
618                        // The stride between matrices is the count of rows as this is how
619                        // long each column is.
620                        stride: Alignment::from(rows) * scalar.width as u32,
621                    },
622                    _ => unreachable!(),
623                },
624                crate::TypeInner::ValuePointer { scalar, .. } => Parent::Array {
625                    stride: scalar.width as u32,
626                },
627                _ => unreachable!(),
628            };
629
630            let sub = match (parent, access_index) {
631                (Parent::Array { stride }, AccessIndex::Expression(value)) => {
632                    SubAccess::Index { value, stride }
633                }
634                (Parent::Array { stride }, AccessIndex::Constant(index)) => {
635                    SubAccess::Offset(stride * index)
636                }
637                (Parent::Struct(members), AccessIndex::Constant(index)) => {
638                    SubAccess::Offset(members[index as usize].offset)
639                }
640                (Parent::Struct(_), AccessIndex::Expression(_)) => unreachable!(),
641            };
642
643            self.temp_access_chain.push(sub);
644            cur_expr = next_expr;
645        }
646    }
647}