naga/proc/
layouter.rs

1use core::{fmt::Display, num::NonZeroU32, ops};
2
3use crate::{
4    arena::{Handle, HandleVec},
5    valid::MAX_TYPE_SIZE,
6};
7
8/// A newtype struct where its only valid values are powers of 2
9#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
10#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
11#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
12pub struct Alignment(NonZeroU32);
13
14impl Alignment {
15    pub const ONE: Self = Self(unsafe { NonZeroU32::new_unchecked(1) });
16    pub const TWO: Self = Self(unsafe { NonZeroU32::new_unchecked(2) });
17    pub const FOUR: Self = Self(unsafe { NonZeroU32::new_unchecked(4) });
18    pub const EIGHT: Self = Self(unsafe { NonZeroU32::new_unchecked(8) });
19    pub const SIXTEEN: Self = Self(unsafe { NonZeroU32::new_unchecked(16) });
20
21    pub const MIN_UNIFORM: Self = Self::SIXTEEN;
22
23    pub const fn new(n: u32) -> Option<Self> {
24        if n.is_power_of_two() {
25            // SAFETY: value can't be 0 since we just checked if it's a power of 2
26            Some(Self(unsafe { NonZeroU32::new_unchecked(n) }))
27        } else {
28            None
29        }
30    }
31
32    /// # Panics
33    /// If `width` is not a power of 2
34    pub fn from_width(width: u8) -> Self {
35        Self::new(width as u32).unwrap()
36    }
37
38    /// Returns whether or not `n` is a multiple of this alignment.
39    pub const fn is_aligned(&self, n: u32) -> bool {
40        // equivalent to: `n % self.0.get() == 0` but much faster
41        n & (self.0.get() - 1) == 0
42    }
43
44    /// Round `n` up to the nearest alignment boundary.
45    pub const fn round_up(&self, n: u32) -> u32 {
46        // equivalent to:
47        // match n % self.0.get() {
48        //     0 => n,
49        //     rem => n + (self.0.get() - rem),
50        // }
51        let mask = self.0.get() - 1;
52        (n + mask) & !mask
53    }
54}
55
56impl Display for Alignment {
57    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
58        self.0.get().fmt(f)
59    }
60}
61
62impl ops::Mul<u32> for Alignment {
63    type Output = u32;
64
65    fn mul(self, rhs: u32) -> Self::Output {
66        self.0.get() * rhs
67    }
68}
69
70impl ops::Mul for Alignment {
71    type Output = Alignment;
72
73    fn mul(self, rhs: Alignment) -> Self::Output {
74        // SAFETY: both lhs and rhs are powers of 2, the result will be a power of 2
75        Self(unsafe { NonZeroU32::new_unchecked(self.0.get() * rhs.0.get()) })
76    }
77}
78
79impl From<crate::VectorSize> for Alignment {
80    fn from(size: crate::VectorSize) -> Self {
81        match size {
82            crate::VectorSize::Bi => Alignment::TWO,
83            crate::VectorSize::Tri => Alignment::FOUR,
84            crate::VectorSize::Quad => Alignment::FOUR,
85        }
86    }
87}
88
89/// Size and alignment information for a type.
90#[derive(Clone, Copy, Debug, Hash, PartialEq)]
91#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
92#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
93pub struct TypeLayout {
94    pub size: u32,
95    pub alignment: Alignment,
96}
97
98impl TypeLayout {
99    /// Produce the stride as if this type is a base of an array.
100    pub const fn to_stride(&self) -> u32 {
101        self.alignment.round_up(self.size)
102    }
103}
104
105/// Helper processor that derives the sizes of all types.
106///
107/// `Layouter` uses the default layout algorithm/table, described in
108/// [WGSL §4.3.7, "Memory Layout"]
109///
110/// A `Layouter` may be indexed by `Handle<Type>` values: `layouter[handle]` is the
111/// layout of the type whose handle is `handle`.
112///
113/// [WGSL §4.3.7, "Memory Layout"](https://gpuweb.github.io/gpuweb/wgsl/#memory-layouts)
114#[derive(Debug, Default)]
115pub struct Layouter {
116    /// Layouts for types in an arena.
117    layouts: HandleVec<crate::Type, TypeLayout>,
118}
119
120impl ops::Index<Handle<crate::Type>> for Layouter {
121    type Output = TypeLayout;
122    fn index(&self, handle: Handle<crate::Type>) -> &TypeLayout {
123        &self.layouts[handle]
124    }
125}
126
127/// Errors generated by the `Layouter`.
128///
129/// All of these errors can be produced when validating an arbitrary module.
130/// When processing WGSL source, only the `TooLarge` error should be
131/// produced by the `Layouter`, as the front-end should not produce IR
132/// that would result in the other errors.
133#[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)]
134pub enum LayoutErrorInner {
135    #[error("Array element type {0:?} doesn't exist")]
136    InvalidArrayElementType(Handle<crate::Type>),
137    #[error("Struct member[{0}] type {1:?} doesn't exist")]
138    InvalidStructMemberType(u32, Handle<crate::Type>),
139    #[error("Type width must be a power of two")]
140    NonPowerOfTwoWidth,
141    #[error("Size exceeds limit of {MAX_TYPE_SIZE} bytes")]
142    TooLarge,
143}
144
145#[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)]
146#[error("Error laying out type {ty:?}: {inner}")]
147pub struct LayoutError {
148    pub ty: Handle<crate::Type>,
149    pub inner: LayoutErrorInner,
150}
151
152impl LayoutErrorInner {
153    const fn with(self, ty: Handle<crate::Type>) -> LayoutError {
154        LayoutError { ty, inner: self }
155    }
156}
157
158impl Layouter {
159    /// Remove all entries from this `Layouter`, retaining storage.
160    pub fn clear(&mut self) {
161        self.layouts.clear();
162    }
163
164    /// Extend this `Layouter` with layouts for any new entries in `gctx.types`.
165    ///
166    /// Ensure that every type in `gctx.types` has a corresponding [TypeLayout]
167    /// in [`self.layouts`].
168    ///
169    /// Some front ends need to be able to compute layouts for existing types
170    /// while module construction is still in progress and new types are still
171    /// being added. This function assumes that the `TypeLayout` values already
172    /// present in `self.layouts` cover their corresponding entries in `types`,
173    /// and extends `self.layouts` as needed to cover the rest. Thus, a front
174    /// end can call this function at any time, passing its current type and
175    /// constant arenas, and then assume that layouts are available for all
176    /// types.
177    #[allow(clippy::or_fun_call)]
178    pub fn update(&mut self, gctx: super::GlobalCtx) -> Result<(), LayoutError> {
179        use crate::TypeInner as Ti;
180
181        for (ty_handle, ty) in gctx.types.iter().skip(self.layouts.len()) {
182            let size = ty
183                .inner
184                .try_size(gctx)
185                .ok_or_else(|| LayoutErrorInner::TooLarge.with(ty_handle))?;
186            let layout = match ty.inner {
187                Ti::Scalar(scalar) | Ti::Atomic(scalar) => {
188                    let alignment = Alignment::new(scalar.width as u32)
189                        .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
190                    TypeLayout { size, alignment }
191                }
192                Ti::Vector {
193                    size: vec_size,
194                    scalar,
195                } => {
196                    let alignment = Alignment::new(scalar.width as u32)
197                        .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
198                    TypeLayout {
199                        size,
200                        alignment: Alignment::from(vec_size) * alignment,
201                    }
202                }
203                Ti::Matrix {
204                    columns: _,
205                    rows,
206                    scalar,
207                } => {
208                    let alignment = Alignment::new(scalar.width as u32)
209                        .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
210                    TypeLayout {
211                        size,
212                        alignment: Alignment::from(rows) * alignment,
213                    }
214                }
215                Ti::Pointer { .. } | Ti::ValuePointer { .. } => TypeLayout {
216                    size,
217                    alignment: Alignment::ONE,
218                },
219                Ti::Array {
220                    base,
221                    stride: _,
222                    size: _,
223                } => TypeLayout {
224                    size,
225                    alignment: if base < ty_handle {
226                        self[base].alignment
227                    } else {
228                        return Err(LayoutErrorInner::InvalidArrayElementType(base).with(ty_handle));
229                    },
230                },
231                Ti::Struct { span, ref members } => {
232                    let mut alignment = Alignment::ONE;
233                    for (index, member) in members.iter().enumerate() {
234                        alignment = if member.ty < ty_handle {
235                            alignment.max(self[member.ty].alignment)
236                        } else {
237                            return Err(LayoutErrorInner::InvalidStructMemberType(
238                                index as u32,
239                                member.ty,
240                            )
241                            .with(ty_handle));
242                        };
243                    }
244                    TypeLayout {
245                        size: span,
246                        alignment,
247                    }
248                }
249                Ti::Image { .. }
250                | Ti::Sampler { .. }
251                | Ti::AccelerationStructure { .. }
252                | Ti::RayQuery { .. }
253                | Ti::BindingArray { .. } => TypeLayout {
254                    size,
255                    alignment: Alignment::ONE,
256                },
257            };
258            debug_assert!(size <= layout.size);
259            self.layouts.insert(ty_handle, layout);
260        }
261
262        Ok(())
263    }
264}