naga/proc/
layouter.rs
1use core::{fmt::Display, num::NonZeroU32, ops};
2
3use crate::{
4 arena::{Handle, HandleVec},
5 valid::MAX_TYPE_SIZE,
6};
7
8#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
10#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
11#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
12pub struct Alignment(NonZeroU32);
13
14impl Alignment {
15 pub const ONE: Self = Self(unsafe { NonZeroU32::new_unchecked(1) });
16 pub const TWO: Self = Self(unsafe { NonZeroU32::new_unchecked(2) });
17 pub const FOUR: Self = Self(unsafe { NonZeroU32::new_unchecked(4) });
18 pub const EIGHT: Self = Self(unsafe { NonZeroU32::new_unchecked(8) });
19 pub const SIXTEEN: Self = Self(unsafe { NonZeroU32::new_unchecked(16) });
20
21 pub const MIN_UNIFORM: Self = Self::SIXTEEN;
22
23 pub const fn new(n: u32) -> Option<Self> {
24 if n.is_power_of_two() {
25 Some(Self(unsafe { NonZeroU32::new_unchecked(n) }))
27 } else {
28 None
29 }
30 }
31
32 pub fn from_width(width: u8) -> Self {
35 Self::new(width as u32).unwrap()
36 }
37
38 pub const fn is_aligned(&self, n: u32) -> bool {
40 n & (self.0.get() - 1) == 0
42 }
43
44 pub const fn round_up(&self, n: u32) -> u32 {
46 let mask = self.0.get() - 1;
52 (n + mask) & !mask
53 }
54}
55
56impl Display for Alignment {
57 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
58 self.0.get().fmt(f)
59 }
60}
61
62impl ops::Mul<u32> for Alignment {
63 type Output = u32;
64
65 fn mul(self, rhs: u32) -> Self::Output {
66 self.0.get() * rhs
67 }
68}
69
70impl ops::Mul for Alignment {
71 type Output = Alignment;
72
73 fn mul(self, rhs: Alignment) -> Self::Output {
74 Self(unsafe { NonZeroU32::new_unchecked(self.0.get() * rhs.0.get()) })
76 }
77}
78
79impl From<crate::VectorSize> for Alignment {
80 fn from(size: crate::VectorSize) -> Self {
81 match size {
82 crate::VectorSize::Bi => Alignment::TWO,
83 crate::VectorSize::Tri => Alignment::FOUR,
84 crate::VectorSize::Quad => Alignment::FOUR,
85 }
86 }
87}
88
89#[derive(Clone, Copy, Debug, Hash, PartialEq)]
91#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
92#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
93pub struct TypeLayout {
94 pub size: u32,
95 pub alignment: Alignment,
96}
97
98impl TypeLayout {
99 pub const fn to_stride(&self) -> u32 {
101 self.alignment.round_up(self.size)
102 }
103}
104
105#[derive(Debug, Default)]
115pub struct Layouter {
116 layouts: HandleVec<crate::Type, TypeLayout>,
118}
119
120impl ops::Index<Handle<crate::Type>> for Layouter {
121 type Output = TypeLayout;
122 fn index(&self, handle: Handle<crate::Type>) -> &TypeLayout {
123 &self.layouts[handle]
124 }
125}
126
127#[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)]
134pub enum LayoutErrorInner {
135 #[error("Array element type {0:?} doesn't exist")]
136 InvalidArrayElementType(Handle<crate::Type>),
137 #[error("Struct member[{0}] type {1:?} doesn't exist")]
138 InvalidStructMemberType(u32, Handle<crate::Type>),
139 #[error("Type width must be a power of two")]
140 NonPowerOfTwoWidth,
141 #[error("Size exceeds limit of {MAX_TYPE_SIZE} bytes")]
142 TooLarge,
143}
144
145#[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)]
146#[error("Error laying out type {ty:?}: {inner}")]
147pub struct LayoutError {
148 pub ty: Handle<crate::Type>,
149 pub inner: LayoutErrorInner,
150}
151
152impl LayoutErrorInner {
153 const fn with(self, ty: Handle<crate::Type>) -> LayoutError {
154 LayoutError { ty, inner: self }
155 }
156}
157
158impl Layouter {
159 pub fn clear(&mut self) {
161 self.layouts.clear();
162 }
163
164 #[allow(clippy::or_fun_call)]
178 pub fn update(&mut self, gctx: super::GlobalCtx) -> Result<(), LayoutError> {
179 use crate::TypeInner as Ti;
180
181 for (ty_handle, ty) in gctx.types.iter().skip(self.layouts.len()) {
182 let size = ty
183 .inner
184 .try_size(gctx)
185 .ok_or_else(|| LayoutErrorInner::TooLarge.with(ty_handle))?;
186 let layout = match ty.inner {
187 Ti::Scalar(scalar) | Ti::Atomic(scalar) => {
188 let alignment = Alignment::new(scalar.width as u32)
189 .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
190 TypeLayout { size, alignment }
191 }
192 Ti::Vector {
193 size: vec_size,
194 scalar,
195 } => {
196 let alignment = Alignment::new(scalar.width as u32)
197 .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
198 TypeLayout {
199 size,
200 alignment: Alignment::from(vec_size) * alignment,
201 }
202 }
203 Ti::Matrix {
204 columns: _,
205 rows,
206 scalar,
207 } => {
208 let alignment = Alignment::new(scalar.width as u32)
209 .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
210 TypeLayout {
211 size,
212 alignment: Alignment::from(rows) * alignment,
213 }
214 }
215 Ti::Pointer { .. } | Ti::ValuePointer { .. } => TypeLayout {
216 size,
217 alignment: Alignment::ONE,
218 },
219 Ti::Array {
220 base,
221 stride: _,
222 size: _,
223 } => TypeLayout {
224 size,
225 alignment: if base < ty_handle {
226 self[base].alignment
227 } else {
228 return Err(LayoutErrorInner::InvalidArrayElementType(base).with(ty_handle));
229 },
230 },
231 Ti::Struct { span, ref members } => {
232 let mut alignment = Alignment::ONE;
233 for (index, member) in members.iter().enumerate() {
234 alignment = if member.ty < ty_handle {
235 alignment.max(self[member.ty].alignment)
236 } else {
237 return Err(LayoutErrorInner::InvalidStructMemberType(
238 index as u32,
239 member.ty,
240 )
241 .with(ty_handle));
242 };
243 }
244 TypeLayout {
245 size: span,
246 alignment,
247 }
248 }
249 Ti::Image { .. }
250 | Ti::Sampler { .. }
251 | Ti::AccelerationStructure { .. }
252 | Ti::RayQuery { .. }
253 | Ti::BindingArray { .. } => TypeLayout {
254 size,
255 alignment: Alignment::ONE,
256 },
257 };
258 debug_assert!(size <= layout.size);
259 self.layouts.insert(ty_handle, layout);
260 }
261
262 Ok(())
263 }
264}