wgpu_core/
pool.rs

1use alloc::sync::{Arc, Weak};
2use core::hash::Hash;
3
4use hashbrown::{hash_map::Entry, HashMap};
5use once_cell::sync::OnceCell;
6
7use crate::lock::{rank, Mutex};
8use crate::FastHashMap;
9
10type SlotInner<V> = Weak<V>;
11type ResourcePoolSlot<V> = Arc<OnceCell<SlotInner<V>>>;
12
13pub struct ResourcePool<K, V> {
14    inner: Mutex<FastHashMap<K, ResourcePoolSlot<V>>>,
15}
16
17impl<K: Clone + Eq + Hash, V> ResourcePool<K, V> {
18    pub fn new() -> Self {
19        Self {
20            inner: Mutex::new(rank::RESOURCE_POOL_INNER, HashMap::default()),
21        }
22    }
23
24    /// Get a resource from the pool with the given entry map, or create a new
25    /// one if it doesn't exist using the given constructor.
26    ///
27    /// Behaves such that only one resource will be created for each unique
28    /// entry map at any one time.
29    pub fn get_or_init<F, E>(&self, key: K, constructor: F) -> Result<Arc<V>, E>
30    where
31        F: FnOnce(K) -> Result<Arc<V>, E>,
32    {
33        // We can't prove at compile time that these will only ever be consumed once,
34        // so we need to do the check at runtime.
35        let mut key = Some(key);
36        let mut constructor = Some(constructor);
37
38        'race: loop {
39            let mut map_guard = self.inner.lock();
40
41            let entry = match map_guard.entry(key.clone().unwrap()) {
42                // An entry exists for this resource.
43                //
44                // We know that either:
45                // - The resource is still alive, and Weak::upgrade will succeed.
46                // - The resource is in the process of being dropped, and Weak::upgrade will fail.
47                //
48                // The entry will never be empty while the BGL is still alive.
49                Entry::Occupied(entry) => Arc::clone(entry.get()),
50                // No entry exists for this resource.
51                //
52                // We know that the resource is not alive, so we can create a new entry.
53                Entry::Vacant(entry) => Arc::clone(entry.insert(Arc::new(OnceCell::new()))),
54            };
55
56            drop(map_guard);
57
58            // Some other thread may beat us to initializing the entry, but OnceCell guarantees that only one thread
59            // will actually initialize the entry.
60            //
61            // We pass the strong reference outside of the closure to keep it alive while we're the only one keeping a reference to it.
62            let mut strong = None;
63            let weak = entry.get_or_try_init(|| {
64                let strong_inner = constructor.take().unwrap()(key.take().unwrap())?;
65                let weak = Arc::downgrade(&strong_inner);
66                strong = Some(strong_inner);
67                Ok(weak)
68            })?;
69
70            // If strong is Some, that means we just initialized the entry, so we can just return it.
71            if let Some(strong) = strong {
72                return Ok(strong);
73            }
74
75            // The entry was already initialized by someone else, so we need to try to upgrade it.
76            if let Some(strong) = weak.upgrade() {
77                // We succeed, the resource is still alive, just return that.
78                return Ok(strong);
79            }
80
81            // The resource is in the process of being dropped, because upgrade failed.
82            // The entry still exists in the map, but it points to nothing.
83            //
84            // We're in a race with the drop implementation of the resource,
85            //  so lets just go around again. When we go around again:
86            // - If the entry exists, we might need to go around a few more times.
87            // - If the entry doesn't exist, we'll create a new one.
88            continue 'race;
89        }
90    }
91
92    /// Remove the given entry map from the pool.
93    ///
94    /// Must *only* be called in the Drop impl of [`BindGroupLayout`].
95    ///
96    /// [`BindGroupLayout`]: crate::binding_model::BindGroupLayout
97    pub fn remove(&self, key: &K) {
98        let mut map_guard = self.inner.lock();
99
100        // Weak::upgrade will be failing long before this code is called. All threads trying to access the resource will be spinning,
101        // waiting for the entry to be removed. It is safe to remove the entry from the map.
102        map_guard.remove(key);
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use core::{
109        sync::atomic::{AtomicU32, Ordering},
110        time::Duration,
111    };
112    use std::{eprintln, sync::Barrier, thread};
113
114    use super::*;
115
116    #[test]
117    fn deduplication() {
118        let pool = ResourcePool::<u32, u32>::new();
119
120        let mut counter = 0_u32;
121
122        let arc1 = pool
123            .get_or_init::<_, ()>(0, |key| {
124                counter += 1;
125                Ok(Arc::new(key))
126            })
127            .unwrap();
128
129        assert_eq!(*arc1, 0);
130        assert_eq!(counter, 1);
131
132        let arc2 = pool
133            .get_or_init::<_, ()>(0, |key| {
134                counter += 1;
135                Ok(Arc::new(key))
136            })
137            .unwrap();
138
139        assert!(Arc::ptr_eq(&arc1, &arc2));
140        assert_eq!(*arc2, 0);
141        assert_eq!(counter, 1);
142
143        drop(arc1);
144        drop(arc2);
145        pool.remove(&0);
146
147        let arc3 = pool
148            .get_or_init::<_, ()>(0, |key| {
149                counter += 1;
150                Ok(Arc::new(key))
151            })
152            .unwrap();
153
154        assert_eq!(*arc3, 0);
155        assert_eq!(counter, 2);
156    }
157
158    // Test name has "2_threads" in the name so nextest reserves two threads for it.
159    #[test]
160    fn concurrent_creation_2_threads() {
161        struct Resources {
162            pool: ResourcePool<u32, u32>,
163            counter: AtomicU32,
164            barrier: Barrier,
165        }
166
167        let resources = Arc::new(Resources {
168            pool: ResourcePool::<u32, u32>::new(),
169            counter: AtomicU32::new(0),
170            barrier: Barrier::new(2),
171        });
172
173        // Like all races, this is not inherently guaranteed to work, but in practice it should work fine.
174        //
175        // To validate the expected order of events, we've put print statements in the code, indicating when each thread is at a certain point.
176        // The output will look something like this if the test is working as expected:
177        //
178        // ```
179        // 0: prewait
180        // 1: prewait
181        // 1: postwait
182        // 0: postwait
183        // 1: init
184        // 1: postget
185        // 0: postget
186        // ```
187        fn thread_inner(idx: u8, resources: &Resources) -> Arc<u32> {
188            eprintln!("{idx}: prewait");
189
190            // Once this returns, both threads should hit get_or_init at about the same time,
191            // allowing us to actually test concurrent creation.
192            //
193            // Like all races, this is not inherently guaranteed to work, but in practice it should work fine.
194            resources.barrier.wait();
195
196            eprintln!("{idx}: postwait");
197
198            let ret = resources
199                .pool
200                .get_or_init::<_, ()>(0, |key| {
201                    eprintln!("{idx}: init");
202
203                    // Simulate long running constructor, ensuring that both threads will be in get_or_init.
204                    thread::sleep(Duration::from_millis(250));
205
206                    resources.counter.fetch_add(1, Ordering::SeqCst);
207
208                    Ok(Arc::new(key))
209                })
210                .unwrap();
211
212            eprintln!("{idx}: postget");
213
214            ret
215        }
216
217        let thread1 = thread::spawn({
218            let resource_clone = Arc::clone(&resources);
219            move || thread_inner(1, &resource_clone)
220        });
221
222        let arc0 = thread_inner(0, &resources);
223
224        assert_eq!(resources.counter.load(Ordering::Acquire), 1);
225
226        let arc1 = thread1.join().unwrap();
227
228        assert!(Arc::ptr_eq(&arc0, &arc1));
229    }
230
231    // Test name has "2_threads" in the name so nextest reserves two threads for it.
232    #[test]
233    fn create_while_drop_2_threads() {
234        struct Resources {
235            pool: ResourcePool<u32, u32>,
236            barrier: Barrier,
237        }
238
239        let resources = Arc::new(Resources {
240            pool: ResourcePool::<u32, u32>::new(),
241            barrier: Barrier::new(2),
242        });
243
244        // Like all races, this is not inherently guaranteed to work, but in practice it should work fine.
245        //
246        // To validate the expected order of events, we've put print statements in the code, indicating when each thread is at a certain point.
247        // The output will look something like this if the test is working as expected:
248        //
249        // ```
250        // 0: prewait
251        // 1: prewait
252        // 1: postwait
253        // 0: postwait
254        // 1: postsleep
255        // 1: removal
256        // 0: postget
257        // ```
258        //
259        // The last two _may_ be flipped.
260
261        let existing_entry = resources
262            .pool
263            .get_or_init::<_, ()>(0, |key| Ok(Arc::new(key)))
264            .unwrap();
265
266        // Drop the entry, but do _not_ remove it from the pool.
267        // This simulates the situation where the resource arc has been dropped, but the Drop implementation
268        // has not yet run, which calls remove.
269        drop(existing_entry);
270
271        fn thread0_inner(resources: &Resources) {
272            eprintln!("0: prewait");
273            resources.barrier.wait();
274
275            eprintln!("0: postwait");
276            // We try to create a new entry, but the entry already exists.
277            //
278            // As Arc::upgrade is failing, we will just keep spinning until remove is called.
279            resources
280                .pool
281                .get_or_init::<_, ()>(0, |key| Ok(Arc::new(key)))
282                .unwrap();
283            eprintln!("0: postget");
284        }
285
286        fn thread1_inner(resources: &Resources) {
287            eprintln!("1: prewait");
288            resources.barrier.wait();
289
290            eprintln!("1: postwait");
291            // We wait a little bit, making sure that thread0_inner has started spinning.
292            thread::sleep(Duration::from_millis(250));
293            eprintln!("1: postsleep");
294
295            // We remove the entry from the pool, allowing thread0_inner to re-create.
296            resources.pool.remove(&0);
297            eprintln!("1: removal");
298        }
299
300        let thread1 = thread::spawn({
301            let resource_clone = Arc::clone(&resources);
302            move || thread1_inner(&resource_clone)
303        });
304
305        thread0_inner(&resources);
306
307        thread1.join().unwrap();
308    }
309}