naga/proc/
layouter.rs

1use core::{fmt::Display, num::NonZeroU32, ops};
2
3use crate::{
4    arena::{Handle, HandleVec},
5    valid::MAX_TYPE_SIZE,
6};
7
8/// A newtype struct where its only valid values are powers of 2
9#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
10#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
11#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
12pub struct Alignment(NonZeroU32);
13
14impl Alignment {
15    pub const ONE: Self = Self(NonZeroU32::new(1).unwrap());
16    pub const TWO: Self = Self(NonZeroU32::new(2).unwrap());
17    pub const FOUR: Self = Self(NonZeroU32::new(4).unwrap());
18    pub const EIGHT: Self = Self(NonZeroU32::new(8).unwrap());
19    pub const SIXTEEN: Self = Self(NonZeroU32::new(16).unwrap());
20
21    pub const MIN_UNIFORM: Self = Self::SIXTEEN;
22
23    pub const fn new(n: u32) -> Option<Self> {
24        if n.is_power_of_two() {
25            Some(Self(NonZeroU32::new(n).unwrap()))
26        } else {
27            None
28        }
29    }
30
31    /// # Panics
32    /// If `width` is not a power of 2
33    pub const fn from_width(width: u8) -> Self {
34        Self::new(width as u32).unwrap()
35    }
36
37    /// Returns whether or not `n` is a multiple of this alignment.
38    pub const fn is_aligned(&self, n: u32) -> bool {
39        // equivalent to: `n % self.0.get() == 0` but much faster
40        n & (self.0.get() - 1) == 0
41    }
42
43    /// Round `n` up to the nearest alignment boundary.
44    pub const fn round_up(&self, n: u32) -> u32 {
45        // equivalent to:
46        // match n % self.0.get() {
47        //     0 => n,
48        //     rem => n + (self.0.get() - rem),
49        // }
50        let mask = self.0.get() - 1;
51        (n + mask) & !mask
52    }
53}
54
55impl Display for Alignment {
56    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
57        self.0.get().fmt(f)
58    }
59}
60
61impl ops::Mul<u32> for Alignment {
62    type Output = u32;
63
64    fn mul(self, rhs: u32) -> Self::Output {
65        self.0.get() * rhs
66    }
67}
68
69impl ops::Mul for Alignment {
70    type Output = Alignment;
71
72    fn mul(self, rhs: Alignment) -> Self::Output {
73        Self(NonZeroU32::new(self.0.get() * rhs.0.get()).unwrap())
74    }
75}
76
77impl From<crate::VectorSize> for Alignment {
78    fn from(size: crate::VectorSize) -> Self {
79        match size {
80            crate::VectorSize::Bi => Alignment::TWO,
81            crate::VectorSize::Tri => Alignment::FOUR,
82            crate::VectorSize::Quad => Alignment::FOUR,
83        }
84    }
85}
86
87impl From<crate::CooperativeSize> for Alignment {
88    fn from(size: crate::CooperativeSize) -> Self {
89        Self(NonZeroU32::new(size as u32).unwrap())
90    }
91}
92
93/// Size and alignment information for a type.
94#[derive(Clone, Copy, Debug, Hash, PartialEq)]
95#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
96#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
97pub struct TypeLayout {
98    pub size: u32,
99    pub alignment: Alignment,
100}
101
102impl TypeLayout {
103    /// Produce the stride as if this type is a base of an array.
104    pub const fn to_stride(&self) -> u32 {
105        self.alignment.round_up(self.size)
106    }
107}
108
109/// Helper processor that derives the sizes of all types.
110///
111/// `Layouter` uses the default layout algorithm/table, described in
112/// [WGSL §4.3.7, "Memory Layout"]
113///
114/// A `Layouter` may be indexed by `Handle<Type>` values: `layouter[handle]` is the
115/// layout of the type whose handle is `handle`.
116///
117/// [WGSL §4.3.7, "Memory Layout"](https://gpuweb.github.io/gpuweb/wgsl/#memory-layouts)
118#[derive(Debug, Default)]
119pub struct Layouter {
120    /// Layouts for types in an arena.
121    layouts: HandleVec<crate::Type, TypeLayout>,
122}
123
124impl ops::Index<Handle<crate::Type>> for Layouter {
125    type Output = TypeLayout;
126    fn index(&self, handle: Handle<crate::Type>) -> &TypeLayout {
127        &self.layouts[handle]
128    }
129}
130
131/// Errors generated by the `Layouter`.
132///
133/// All of these errors can be produced when validating an arbitrary module.
134/// When processing WGSL source, only the `TooLarge` error should be
135/// produced by the `Layouter`, as the front-end should not produce IR
136/// that would result in the other errors.
137#[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)]
138pub enum LayoutErrorInner {
139    #[error("Array element type {0:?} doesn't exist")]
140    InvalidArrayElementType(Handle<crate::Type>),
141    #[error("Struct member[{0}] type {1:?} doesn't exist")]
142    InvalidStructMemberType(u32, Handle<crate::Type>),
143    #[error("Type width must be a power of two")]
144    NonPowerOfTwoWidth,
145    #[error("Size exceeds limit of {MAX_TYPE_SIZE} bytes")]
146    TooLarge,
147}
148
149#[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)]
150#[error("Error laying out type {ty:?}: {inner}")]
151pub struct LayoutError {
152    pub ty: Handle<crate::Type>,
153    pub inner: LayoutErrorInner,
154}
155
156impl LayoutErrorInner {
157    const fn with(self, ty: Handle<crate::Type>) -> LayoutError {
158        LayoutError { ty, inner: self }
159    }
160}
161
162impl Layouter {
163    /// Remove all entries from this `Layouter`, retaining storage.
164    pub fn clear(&mut self) {
165        self.layouts.clear();
166    }
167
168    #[expect(rustdoc::private_intra_doc_links)]
169    /// Extend this `Layouter` with layouts for any new entries in `gctx.types`.
170    ///
171    /// Ensure that every type in `gctx.types` has a corresponding [TypeLayout]
172    /// in [`Self::layouts`].
173    ///
174    /// Some front ends need to be able to compute layouts for existing types
175    /// while module construction is still in progress and new types are still
176    /// being added. This function assumes that the `TypeLayout` values already
177    /// present in `self.layouts` cover their corresponding entries in `types`,
178    /// and extends `self.layouts` as needed to cover the rest. Thus, a front
179    /// end can call this function at any time, passing its current type and
180    /// constant arenas, and then assume that layouts are available for all
181    /// types.
182    #[allow(clippy::or_fun_call)]
183    pub fn update(&mut self, gctx: super::GlobalCtx) -> Result<(), LayoutError> {
184        use crate::TypeInner as Ti;
185
186        for (ty_handle, ty) in gctx.types.iter().skip(self.layouts.len()) {
187            let size = ty
188                .inner
189                .try_size(gctx)
190                .ok_or_else(|| LayoutErrorInner::TooLarge.with(ty_handle))?;
191            let layout = match ty.inner {
192                Ti::Scalar(scalar) | Ti::Atomic(scalar) => {
193                    let alignment = Alignment::new(scalar.width as u32)
194                        .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
195                    TypeLayout { size, alignment }
196                }
197                Ti::Vector {
198                    size: vec_size,
199                    scalar,
200                } => {
201                    let alignment = Alignment::new(scalar.width as u32)
202                        .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
203                    TypeLayout {
204                        size,
205                        alignment: Alignment::from(vec_size) * alignment,
206                    }
207                }
208                Ti::Matrix {
209                    columns: _,
210                    rows,
211                    scalar,
212                } => {
213                    let alignment = Alignment::new(scalar.width as u32)
214                        .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
215                    TypeLayout {
216                        size,
217                        alignment: Alignment::from(rows) * alignment,
218                    }
219                }
220                Ti::CooperativeMatrix {
221                    columns: _,
222                    rows,
223                    scalar,
224                    role: _,
225                } => {
226                    let alignment = Alignment::new(scalar.width as u32)
227                        .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
228                    TypeLayout {
229                        size,
230                        alignment: Alignment::from(rows) * alignment,
231                    }
232                }
233                Ti::Pointer { .. } | Ti::ValuePointer { .. } => TypeLayout {
234                    size,
235                    alignment: Alignment::ONE,
236                },
237                Ti::Array {
238                    base,
239                    stride: _,
240                    size: _,
241                } => TypeLayout {
242                    size,
243                    alignment: if base < ty_handle {
244                        self[base].alignment
245                    } else {
246                        return Err(LayoutErrorInner::InvalidArrayElementType(base).with(ty_handle));
247                    },
248                },
249                Ti::Struct { span, ref members } => {
250                    let mut alignment = Alignment::ONE;
251                    for (index, member) in members.iter().enumerate() {
252                        alignment = if member.ty < ty_handle {
253                            alignment.max(self[member.ty].alignment)
254                        } else {
255                            return Err(LayoutErrorInner::InvalidStructMemberType(
256                                index as u32,
257                                member.ty,
258                            )
259                            .with(ty_handle));
260                        };
261                    }
262                    TypeLayout {
263                        size: span,
264                        alignment,
265                    }
266                }
267                Ti::Image { .. }
268                | Ti::Sampler { .. }
269                | Ti::AccelerationStructure { .. }
270                | Ti::RayQuery { .. }
271                | Ti::BindingArray { .. } => TypeLayout {
272                    size,
273                    alignment: Alignment::ONE,
274                },
275            };
276            debug_assert!(size <= layout.size);
277            self.layouts.insert(ty_handle, layout);
278        }
279
280        Ok(())
281    }
282}