1use core::{fmt::Display, num::NonZeroU32, ops};
2
3use crate::{
4 arena::{Handle, HandleVec},
5 valid::MAX_TYPE_SIZE,
6};
7
8#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
10#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
11#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
12pub struct Alignment(NonZeroU32);
13
14impl Alignment {
15 pub const ONE: Self = Self(unsafe { NonZeroU32::new_unchecked(1) });
16 pub const TWO: Self = Self(unsafe { NonZeroU32::new_unchecked(2) });
17 pub const FOUR: Self = Self(unsafe { NonZeroU32::new_unchecked(4) });
18 pub const EIGHT: Self = Self(unsafe { NonZeroU32::new_unchecked(8) });
19 pub const SIXTEEN: Self = Self(unsafe { NonZeroU32::new_unchecked(16) });
20
21 pub const MIN_UNIFORM: Self = Self::SIXTEEN;
22
23 pub const fn new(n: u32) -> Option<Self> {
24 if n.is_power_of_two() {
25 Some(Self(unsafe { NonZeroU32::new_unchecked(n) }))
27 } else {
28 None
29 }
30 }
31
32 pub fn from_width(width: u8) -> Self {
35 Self::new(width as u32).unwrap()
36 }
37
38 pub const fn is_aligned(&self, n: u32) -> bool {
40 n & (self.0.get() - 1) == 0
42 }
43
44 pub const fn round_up(&self, n: u32) -> u32 {
46 let mask = self.0.get() - 1;
52 (n + mask) & !mask
53 }
54}
55
56impl Display for Alignment {
57 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
58 self.0.get().fmt(f)
59 }
60}
61
62impl ops::Mul<u32> for Alignment {
63 type Output = u32;
64
65 fn mul(self, rhs: u32) -> Self::Output {
66 self.0.get() * rhs
67 }
68}
69
70impl ops::Mul for Alignment {
71 type Output = Alignment;
72
73 fn mul(self, rhs: Alignment) -> Self::Output {
74 Self(unsafe { NonZeroU32::new_unchecked(self.0.get() * rhs.0.get()) })
76 }
77}
78
79impl From<crate::VectorSize> for Alignment {
80 fn from(size: crate::VectorSize) -> Self {
81 match size {
82 crate::VectorSize::Bi => Alignment::TWO,
83 crate::VectorSize::Tri => Alignment::FOUR,
84 crate::VectorSize::Quad => Alignment::FOUR,
85 }
86 }
87}
88
89impl From<crate::CooperativeSize> for Alignment {
90 fn from(size: crate::CooperativeSize) -> Self {
91 Self(unsafe { NonZeroU32::new_unchecked(size as u32) })
92 }
93}
94
95#[derive(Clone, Copy, Debug, Hash, PartialEq)]
97#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
98#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
99pub struct TypeLayout {
100 pub size: u32,
101 pub alignment: Alignment,
102}
103
104impl TypeLayout {
105 pub const fn to_stride(&self) -> u32 {
107 self.alignment.round_up(self.size)
108 }
109}
110
111#[derive(Debug, Default)]
121pub struct Layouter {
122 layouts: HandleVec<crate::Type, TypeLayout>,
124}
125
126impl ops::Index<Handle<crate::Type>> for Layouter {
127 type Output = TypeLayout;
128 fn index(&self, handle: Handle<crate::Type>) -> &TypeLayout {
129 &self.layouts[handle]
130 }
131}
132
133#[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)]
140pub enum LayoutErrorInner {
141 #[error("Array element type {0:?} doesn't exist")]
142 InvalidArrayElementType(Handle<crate::Type>),
143 #[error("Struct member[{0}] type {1:?} doesn't exist")]
144 InvalidStructMemberType(u32, Handle<crate::Type>),
145 #[error("Type width must be a power of two")]
146 NonPowerOfTwoWidth,
147 #[error("Size exceeds limit of {MAX_TYPE_SIZE} bytes")]
148 TooLarge,
149}
150
151#[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)]
152#[error("Error laying out type {ty:?}: {inner}")]
153pub struct LayoutError {
154 pub ty: Handle<crate::Type>,
155 pub inner: LayoutErrorInner,
156}
157
158impl LayoutErrorInner {
159 const fn with(self, ty: Handle<crate::Type>) -> LayoutError {
160 LayoutError { ty, inner: self }
161 }
162}
163
164impl Layouter {
165 pub fn clear(&mut self) {
167 self.layouts.clear();
168 }
169
170 #[expect(rustdoc::private_intra_doc_links)]
171 #[allow(clippy::or_fun_call)]
185 pub fn update(&mut self, gctx: super::GlobalCtx) -> Result<(), LayoutError> {
186 use crate::TypeInner as Ti;
187
188 for (ty_handle, ty) in gctx.types.iter().skip(self.layouts.len()) {
189 let size = ty
190 .inner
191 .try_size(gctx)
192 .ok_or_else(|| LayoutErrorInner::TooLarge.with(ty_handle))?;
193 let layout = match ty.inner {
194 Ti::Scalar(scalar) | Ti::Atomic(scalar) => {
195 let alignment = Alignment::new(scalar.width as u32)
196 .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
197 TypeLayout { size, alignment }
198 }
199 Ti::Vector {
200 size: vec_size,
201 scalar,
202 } => {
203 let alignment = Alignment::new(scalar.width as u32)
204 .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
205 TypeLayout {
206 size,
207 alignment: Alignment::from(vec_size) * alignment,
208 }
209 }
210 Ti::Matrix {
211 columns: _,
212 rows,
213 scalar,
214 } => {
215 let alignment = Alignment::new(scalar.width as u32)
216 .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
217 TypeLayout {
218 size,
219 alignment: Alignment::from(rows) * alignment,
220 }
221 }
222 Ti::CooperativeMatrix {
223 columns: _,
224 rows,
225 scalar,
226 role: _,
227 } => {
228 let alignment = Alignment::new(scalar.width as u32)
229 .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
230 TypeLayout {
231 size,
232 alignment: Alignment::from(rows) * alignment,
233 }
234 }
235 Ti::Pointer { .. } | Ti::ValuePointer { .. } => TypeLayout {
236 size,
237 alignment: Alignment::ONE,
238 },
239 Ti::Array {
240 base,
241 stride: _,
242 size: _,
243 } => TypeLayout {
244 size,
245 alignment: if base < ty_handle {
246 self[base].alignment
247 } else {
248 return Err(LayoutErrorInner::InvalidArrayElementType(base).with(ty_handle));
249 },
250 },
251 Ti::Struct { span, ref members } => {
252 let mut alignment = Alignment::ONE;
253 for (index, member) in members.iter().enumerate() {
254 alignment = if member.ty < ty_handle {
255 alignment.max(self[member.ty].alignment)
256 } else {
257 return Err(LayoutErrorInner::InvalidStructMemberType(
258 index as u32,
259 member.ty,
260 )
261 .with(ty_handle));
262 };
263 }
264 TypeLayout {
265 size: span,
266 alignment,
267 }
268 }
269 Ti::Image { .. }
270 | Ti::Sampler { .. }
271 | Ti::AccelerationStructure { .. }
272 | Ti::RayQuery { .. }
273 | Ti::BindingArray { .. } => TypeLayout {
274 size,
275 alignment: Alignment::ONE,
276 },
277 };
278 debug_assert!(size <= layout.size);
279 self.layouts.insert(ty_handle, layout);
280 }
281
282 Ok(())
283 }
284}