1use crate::{Epoch, Index};
2use core::{
3 cmp::Ordering,
4 fmt::{self, Debug},
5 hash::Hash,
6 marker::PhantomData,
7 num::NonZeroU64,
8};
9use wgt::WasmNotSendSync;
10
11const _: () = {
12 if size_of::<Index>() != 4 {
13 panic!()
14 }
15};
16const _: () = {
17 if size_of::<Epoch>() != 4 {
18 panic!()
19 }
20};
21const _: () = {
22 if size_of::<RawId>() != 8 {
23 panic!()
24 }
25};
26
27#[repr(transparent)]
29#[cfg_attr(
30 any(feature = "serde", feature = "trace"),
31 derive(serde::Serialize),
32 serde(into = "SerialId")
33)]
34#[cfg_attr(
35 any(feature = "serde", feature = "replay"),
36 derive(serde::Deserialize),
37 serde(try_from = "SerialId")
38)]
39#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
40pub struct RawId(NonZeroU64);
41
42impl RawId {
43 pub fn zip(index: Index, epoch: Epoch) -> RawId {
49 let v = (index as u64) | ((epoch as u64) << 32);
50 Self(NonZeroU64::new(v).expect("IDs may not be zero"))
51 }
52
53 pub fn unzip(self) -> (Index, Epoch) {
55 (self.0.get() as Index, (self.0.get() >> 32) as Epoch)
56 }
57}
58
59#[repr(transparent)]
85#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
86#[cfg_attr(feature = "serde", serde(transparent))]
87pub struct Id<T: Marker>(RawId, PhantomData<T>);
88
89#[cfg(feature = "serde")]
91#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
92#[derive(Clone, Debug)]
93pub enum SerialId {
94 Id(Index, Epoch),
96}
97
98#[cfg(feature = "serde")]
99impl From<RawId> for SerialId {
100 fn from(id: RawId) -> Self {
101 let (index, epoch) = id.unzip();
102 Self::Id(index, epoch)
103 }
104}
105
106#[cfg(feature = "serde")]
107pub struct ZeroIdError;
108
109#[cfg(feature = "serde")]
110impl fmt::Display for ZeroIdError {
111 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
112 write!(f, "IDs may not be zero")
113 }
114}
115
116#[cfg(feature = "serde")]
117impl TryFrom<SerialId> for RawId {
118 type Error = ZeroIdError;
119 fn try_from(id: SerialId) -> Result<Self, ZeroIdError> {
120 let SerialId::Id(index, epoch) = id;
121 if index == 0 && epoch == 0 {
122 Err(ZeroIdError)
123 } else {
124 Ok(RawId::zip(index, epoch))
125 }
126 }
127}
128
129#[allow(dead_code)]
133#[cfg(feature = "serde")]
134#[derive(Debug, serde::Serialize, serde::Deserialize)]
135pub enum PointerId<T: Marker> {
136 PointerId(usize, #[serde(skip)] PhantomData<T>),
138}
139
140#[cfg(feature = "serde")]
141impl<T: Marker> Copy for PointerId<T> {}
142
143#[cfg(feature = "serde")]
144impl<T: Marker> Clone for PointerId<T> {
145 fn clone(&self) -> Self {
146 *self
147 }
148}
149
150#[cfg(feature = "serde")]
151impl<T: Marker> PartialEq for PointerId<T> {
152 fn eq(&self, other: &Self) -> bool {
153 let PointerId::PointerId(this, _) = self;
154 let PointerId::PointerId(other, _) = other;
155 this == other
156 }
157}
158
159#[cfg(feature = "serde")]
160impl<T: Marker> Eq for PointerId<T> {}
161
162#[cfg(feature = "serde")]
163impl<T: Marker> Hash for PointerId<T> {
164 fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
165 let PointerId::PointerId(this, _) = self;
166 this.hash(state);
167 }
168}
169
170#[cfg(feature = "serde")]
171impl<T: crate::storage::StorageItem> From<&alloc::sync::Arc<T>> for PointerId<T::Marker> {
172 fn from(arc: &alloc::sync::Arc<T>) -> Self {
173 PointerId::PointerId(alloc::sync::Arc::as_ptr(arc) as usize, PhantomData)
181 }
182}
183
184impl<T> Id<T>
185where
186 T: Marker,
187{
188 pub unsafe fn from_raw(raw: RawId) -> Self {
192 Self(raw, PhantomData)
193 }
194
195 pub fn into_raw(self) -> RawId {
197 self.0
198 }
199
200 #[inline]
201 pub fn zip(index: Index, epoch: Epoch) -> Self {
202 Id(RawId::zip(index, epoch), PhantomData)
203 }
204
205 #[inline]
206 pub fn unzip(self) -> (Index, Epoch) {
207 self.0.unzip()
208 }
209}
210
211impl<T> Copy for Id<T> where T: Marker {}
212
213impl<T> Clone for Id<T>
214where
215 T: Marker,
216{
217 #[inline]
218 fn clone(&self) -> Self {
219 *self
220 }
221}
222
223impl<T> Debug for Id<T>
224where
225 T: Marker,
226{
227 fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
228 let (index, epoch) = self.unzip();
229 write!(formatter, "Id({index},{epoch})")?;
230 Ok(())
231 }
232}
233
234impl<T> Hash for Id<T>
235where
236 T: Marker,
237{
238 #[inline]
239 fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
240 self.0.hash(state);
241 }
242}
243
244impl<T> PartialEq for Id<T>
245where
246 T: Marker,
247{
248 #[inline]
249 fn eq(&self, other: &Self) -> bool {
250 self.0 == other.0
251 }
252}
253
254impl<T> Eq for Id<T> where T: Marker {}
255
256impl<T> PartialOrd for Id<T>
257where
258 T: Marker,
259{
260 #[inline]
261 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
262 Some(self.cmp(other))
263 }
264}
265
266impl<T> Ord for Id<T>
267where
268 T: Marker,
269{
270 #[inline]
271 fn cmp(&self, other: &Self) -> Ordering {
272 self.0.cmp(&other.0)
273 }
274}
275
276pub trait Marker: 'static + WasmNotSendSync {}
281
282#[cfg(test)]
287impl Marker for () {}
288
289macro_rules! ids {
291 ($(
292 $(#[$($meta:meta)*])*
293 pub type $name:ident $marker:ident;
294 )*) => {
295 pub mod markers {
297 $(
298 #[derive(Debug)]
299 pub enum $marker {}
300 impl super::Marker for $marker {}
301 )*
302 }
303
304 $(
305 $(#[$($meta)*])*
306 pub type $name = Id<self::markers::$marker>;
307 )*
308 }
309}
310
311ids! {
312 pub type AdapterId Adapter;
313 pub type SurfaceId Surface;
314 pub type DeviceId Device;
315 pub type QueueId Queue;
316 pub type BufferId Buffer;
317 pub type StagingBufferId StagingBuffer;
318 pub type TextureViewId TextureView;
319 pub type TextureId Texture;
320 pub type ExternalTextureId ExternalTexture;
321 pub type SamplerId Sampler;
322 pub type BindGroupLayoutId BindGroupLayout;
323 pub type PipelineLayoutId PipelineLayout;
324 pub type BindGroupId BindGroup;
325 pub type ShaderModuleId ShaderModule;
326 pub type RenderPipelineId RenderPipeline;
327 pub type ComputePipelineId ComputePipeline;
328 pub type PipelineCacheId PipelineCache;
329 pub type CommandEncoderId CommandEncoder;
330 pub type CommandBufferId CommandBuffer;
331 pub type RenderPassEncoderId RenderPassEncoder;
332 pub type ComputePassEncoderId ComputePassEncoder;
333 pub type RenderBundleEncoderId RenderBundleEncoder;
334 pub type RenderBundleId RenderBundle;
335 pub type QuerySetId QuerySet;
336 pub type BlasId Blas;
337 pub type TlasId Tlas;
338}
339
340#[test]
341fn test_id() {
342 let indexes = [0, Index::MAX / 2 - 1, Index::MAX / 2 + 1, Index::MAX];
343 let epochs = [1, Epoch::MAX / 2 - 1, Epoch::MAX / 2 + 1, Epoch::MAX];
344 for &i in &indexes {
345 for &e in &epochs {
346 let id = Id::<()>::zip(i, e);
347 let (index, epoch) = id.unzip();
348 assert_eq!(index, i);
349 assert_eq!(epoch, e);
350 }
351 }
352}