1use core::{fmt, ops};
2
3#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
6#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
7#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
8pub struct ImmediateSlots(u64);
9
10impl ImmediateSlots {
11 pub const fn from_raw(raw: u64) -> Self {
12 Self(raw)
13 }
14
15 pub const fn from_range(offset: u32, size_bytes: u32) -> Self {
17 if size_bytes == 0 {
18 return Self(0);
19 }
20 let lo = offset / 4;
21 let hi = (offset + size_bytes).div_ceil(4);
22 Self(u64::MAX << lo & u64::MAX >> (64 - hi))
23 }
24
25 pub fn from_type(
28 ty: &crate::TypeInner,
29 offset: u32,
30 types: &crate::UniqueArena<crate::Type>,
31 gctx: crate::proc::GlobalCtx,
32 ) -> Self {
33 match *ty {
34 crate::TypeInner::Struct { ref members, .. } => {
35 let mut slots = Self::default();
36 for member in members {
37 let member_ty = &types[member.ty].inner;
38 slots |= Self::from_type(member_ty, offset + member.offset, types, gctx);
39 }
40 slots
41 }
42 _ => Self::from_range(offset, ty.size(gctx)),
43 }
44 }
45
46 pub const fn contains(self, other: Self) -> bool {
48 other.0 & !self.0 == 0
49 }
50
51 pub const fn difference(self, other: Self) -> Self {
53 Self(self.0 & !other.0)
54 }
55
56 pub fn size_for_module(module: &crate::Module) -> u32 {
59 module
60 .global_variables
61 .iter()
62 .find(|&(_, var)| var.space == crate::AddressSpace::Immediate)
63 .map(|(_, var)| module.types[var.ty].inner.size(module.to_ctx()))
64 .unwrap_or(0)
65 }
66
67 pub(crate) fn for_pointer(
73 pointer: crate::arena::Handle<crate::Expression>,
74 global: crate::arena::Handle<crate::GlobalVariable>,
75 expression_arena: &crate::Arena<crate::Expression>,
76 global_vars: &crate::Arena<crate::GlobalVariable>,
77 types: &crate::UniqueArena<crate::Type>,
78 ) -> Self {
79 use crate::Expression as E;
80 use crate::TypeInner;
81
82 let gctx = crate::proc::GlobalCtx {
83 types,
84 constants: &crate::Arena::new(),
85 overrides: &crate::Arena::new(),
86 global_expressions: &crate::Arena::new(),
87 };
88
89 let global_ty = &types[global_vars[global].ty].inner;
90
91 match expression_arena[pointer] {
92 E::GlobalVariable(_) => Self::from_type(global_ty, 0, types, gctx),
93 E::AccessIndex { base, index } => {
94 if let E::GlobalVariable(_) = expression_arena[base] {
95 if let TypeInner::Struct { ref members, .. } = *global_ty {
96 let member = &members[index as usize];
97 let member_ty = &types[member.ty].inner;
98 return Self::from_type(member_ty, member.offset, types, gctx);
99 }
100 }
101 Self::from_type(global_ty, 0, types, gctx)
102 }
103 _ => Self::from_type(global_ty, 0, types, gctx),
104 }
105 }
106}
107
108impl ops::BitOrAssign for ImmediateSlots {
109 fn bitor_assign(&mut self, rhs: Self) {
110 self.0 |= rhs.0;
111 }
112}
113
114impl ops::BitOr for ImmediateSlots {
115 type Output = Self;
116 fn bitor(self, rhs: Self) -> Self {
117 Self(self.0 | rhs.0)
118 }
119}
120
121impl fmt::Display for ImmediateSlots {
122 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123 if self.0 == 0 {
124 return write!(f, "(none)");
125 }
126 let mut first = true;
127 let mut bit = 0u32;
128 while bit < 64 {
129 if self.0 & (1u64 << bit) != 0 {
130 let start = bit * 4;
131 while bit < 64 && self.0 & (1u64 << bit) != 0 {
132 bit += 1;
133 }
134 let end = bit * 4;
135 if !first {
136 write!(f, ", ")?;
137 }
138 write!(f, "{start}..{end}")?;
139 first = false;
140 } else {
141 bit += 1;
142 }
143 }
144 Ok(())
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::ImmediateSlots;
151
152 #[test]
153 fn range_single() {
154 assert_eq!(
155 ImmediateSlots::from_range(0, 4),
156 ImmediateSlots::from_raw(0b1)
157 );
158 assert_eq!(
159 ImmediateSlots::from_range(4, 4),
160 ImmediateSlots::from_raw(0b10)
161 );
162 assert_eq!(
163 ImmediateSlots::from_range(8, 4),
164 ImmediateSlots::from_raw(0b100)
165 );
166 }
167
168 #[test]
169 fn range_vec4() {
170 assert_eq!(
171 ImmediateSlots::from_range(0, 16),
172 ImmediateSlots::from_raw(0b1111)
173 );
174 assert_eq!(
175 ImmediateSlots::from_range(16, 16),
176 ImmediateSlots::from_raw(0b1111_0000)
177 );
178 }
179
180 #[test]
181 fn range_full_256() {
182 assert_eq!(
183 ImmediateSlots::from_range(0, 256),
184 ImmediateSlots::from_raw(u64::MAX)
185 );
186 }
187
188 #[test]
189 fn from_type_excludes_struct_padding() {
190 let module = crate::front::wgsl::parse_str("struct S { a: f32, b: vec4<f32> }").unwrap();
191 let struct_ty = (module.types.iter().map(|ty| ty.1))
192 .find(|ty| ty.name.as_deref() == Some("S"))
193 .unwrap();
194 let slots = ImmediateSlots::from_type(&struct_ty.inner, 0, &module.types, module.to_ctx());
195 assert_eq!(slots, ImmediateSlots::from_raw(0b1111_0001));
196 }
197
198 #[test]
199 fn range_unaligned() {
200 assert_eq!(
201 ImmediateSlots::from_range(0, 3),
202 ImmediateSlots::from_raw(0b1)
203 );
204 assert_eq!(
205 ImmediateSlots::from_range(0, 5),
206 ImmediateSlots::from_raw(0b11)
207 );
208 }
209
210 #[test]
211 fn contains() {
212 let required = ImmediateSlots::from_raw(0b1111_0001);
213 let mut set = ImmediateSlots::default();
214 assert!(!set.contains(required));
215 set |= ImmediateSlots::from_range(0, 4);
216 assert!(!set.contains(required));
217 set |= ImmediateSlots::from_range(16, 16);
218 assert!(set.contains(required));
219 }
220
221 #[test]
222 fn difference() {
223 let required = ImmediateSlots::from_raw(0b1111_0001);
224 let set = ImmediateSlots::from_range(0, 4);
225 assert_eq!(
226 required.difference(set),
227 ImmediateSlots::from_raw(0b1111_0000)
228 );
229 }
230}