1use core::{fmt::Display, num::NonZeroU32, ops};
2
3use crate::{
4 arena::{Handle, HandleVec},
5 valid::MAX_TYPE_SIZE,
6};
7
8#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
10#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
11#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
12pub struct Alignment(NonZeroU32);
13
14impl Alignment {
15 pub const ONE: Self = Self(NonZeroU32::new(1).unwrap());
16 pub const TWO: Self = Self(NonZeroU32::new(2).unwrap());
17 pub const FOUR: Self = Self(NonZeroU32::new(4).unwrap());
18 pub const EIGHT: Self = Self(NonZeroU32::new(8).unwrap());
19 pub const SIXTEEN: Self = Self(NonZeroU32::new(16).unwrap());
20
21 pub const MIN_UNIFORM: Self = Self::SIXTEEN;
22
23 pub const fn new(n: u32) -> Option<Self> {
24 if n.is_power_of_two() {
25 Some(Self(NonZeroU32::new(n).unwrap()))
26 } else {
27 None
28 }
29 }
30
31 pub const fn from_width(width: u8) -> Self {
34 Self::new(width as u32).unwrap()
35 }
36
37 pub const fn is_aligned(&self, n: u32) -> bool {
39 n & (self.0.get() - 1) == 0
41 }
42
43 pub const fn round_up(&self, n: u32) -> u32 {
45 let mask = self.0.get() - 1;
51 (n + mask) & !mask
52 }
53}
54
55impl Display for Alignment {
56 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
57 self.0.get().fmt(f)
58 }
59}
60
61impl ops::Mul<u32> for Alignment {
62 type Output = u32;
63
64 fn mul(self, rhs: u32) -> Self::Output {
65 self.0.get() * rhs
66 }
67}
68
69impl ops::Mul for Alignment {
70 type Output = Alignment;
71
72 fn mul(self, rhs: Alignment) -> Self::Output {
73 Self(NonZeroU32::new(self.0.get() * rhs.0.get()).unwrap())
74 }
75}
76
77impl From<crate::VectorSize> for Alignment {
78 fn from(size: crate::VectorSize) -> Self {
79 match size {
80 crate::VectorSize::Bi => Alignment::TWO,
81 crate::VectorSize::Tri => Alignment::FOUR,
82 crate::VectorSize::Quad => Alignment::FOUR,
83 }
84 }
85}
86
87impl From<crate::CooperativeSize> for Alignment {
88 fn from(size: crate::CooperativeSize) -> Self {
89 Self(NonZeroU32::new(size as u32).unwrap())
90 }
91}
92
93#[derive(Clone, Copy, Debug, Hash, PartialEq)]
95#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
96#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
97pub struct TypeLayout {
98 pub size: u32,
99 pub alignment: Alignment,
100}
101
102impl TypeLayout {
103 pub const fn to_stride(&self) -> u32 {
105 self.alignment.round_up(self.size)
106 }
107}
108
109#[derive(Debug, Default)]
119pub struct Layouter {
120 layouts: HandleVec<crate::Type, TypeLayout>,
122}
123
124impl ops::Index<Handle<crate::Type>> for Layouter {
125 type Output = TypeLayout;
126 fn index(&self, handle: Handle<crate::Type>) -> &TypeLayout {
127 &self.layouts[handle]
128 }
129}
130
131#[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)]
138pub enum LayoutErrorInner {
139 #[error("Array element type {0:?} doesn't exist")]
140 InvalidArrayElementType(Handle<crate::Type>),
141 #[error("Struct member[{0}] type {1:?} doesn't exist")]
142 InvalidStructMemberType(u32, Handle<crate::Type>),
143 #[error("Type width must be a power of two")]
144 NonPowerOfTwoWidth,
145 #[error("Size exceeds limit of {MAX_TYPE_SIZE} bytes")]
146 TooLarge,
147}
148
149#[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)]
150#[error("Error laying out type {ty:?}: {inner}")]
151pub struct LayoutError {
152 pub ty: Handle<crate::Type>,
153 pub inner: LayoutErrorInner,
154}
155
156impl LayoutErrorInner {
157 const fn with(self, ty: Handle<crate::Type>) -> LayoutError {
158 LayoutError { ty, inner: self }
159 }
160}
161
162impl Layouter {
163 pub fn clear(&mut self) {
165 self.layouts.clear();
166 }
167
168 #[expect(rustdoc::private_intra_doc_links)]
169 #[allow(clippy::or_fun_call)]
183 pub fn update(&mut self, gctx: super::GlobalCtx) -> Result<(), LayoutError> {
184 use crate::TypeInner as Ti;
185
186 for (ty_handle, ty) in gctx.types.iter().skip(self.layouts.len()) {
187 let size = ty
188 .inner
189 .try_size(gctx)
190 .ok_or_else(|| LayoutErrorInner::TooLarge.with(ty_handle))?;
191 let layout = match ty.inner {
192 Ti::Scalar(scalar) | Ti::Atomic(scalar) => {
193 let alignment = Alignment::new(scalar.width as u32)
194 .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
195 TypeLayout { size, alignment }
196 }
197 Ti::Vector {
198 size: vec_size,
199 scalar,
200 } => {
201 let alignment = Alignment::new(scalar.width as u32)
202 .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
203 TypeLayout {
204 size,
205 alignment: Alignment::from(vec_size) * alignment,
206 }
207 }
208 Ti::Matrix {
209 columns: _,
210 rows,
211 scalar,
212 } => {
213 let alignment = Alignment::new(scalar.width as u32)
214 .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
215 TypeLayout {
216 size,
217 alignment: Alignment::from(rows) * alignment,
218 }
219 }
220 Ti::CooperativeMatrix {
221 columns: _,
222 rows,
223 scalar,
224 role: _,
225 } => {
226 let alignment = Alignment::new(scalar.width as u32)
227 .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
228 TypeLayout {
229 size,
230 alignment: Alignment::from(rows) * alignment,
231 }
232 }
233 Ti::Pointer { .. } | Ti::ValuePointer { .. } => TypeLayout {
234 size,
235 alignment: Alignment::ONE,
236 },
237 Ti::Array {
238 base,
239 stride: _,
240 size: _,
241 } => TypeLayout {
242 size,
243 alignment: if base < ty_handle {
244 self[base].alignment
245 } else {
246 return Err(LayoutErrorInner::InvalidArrayElementType(base).with(ty_handle));
247 },
248 },
249 Ti::Struct { span, ref members } => {
250 let mut alignment = Alignment::ONE;
251 for (index, member) in members.iter().enumerate() {
252 alignment = if member.ty < ty_handle {
253 alignment.max(self[member.ty].alignment)
254 } else {
255 return Err(LayoutErrorInner::InvalidStructMemberType(
256 index as u32,
257 member.ty,
258 )
259 .with(ty_handle));
260 };
261 }
262 TypeLayout {
263 size: span,
264 alignment,
265 }
266 }
267 Ti::Image { .. }
268 | Ti::Sampler { .. }
269 | Ti::AccelerationStructure { .. }
270 | Ti::RayQuery { .. }
271 | Ti::BindingArray { .. } => TypeLayout {
272 size,
273 alignment: Alignment::ONE,
274 },
275 };
276 debug_assert!(size <= layout.size);
277 self.layouts.insert(ty_handle, layout);
278 }
279
280 Ok(())
281 }
282}