1use core::{cell::UnsafeCell, fmt, mem::ManuallyDrop};
2
3use crate::lock::{rank, RankData, RwLock, RwLockReadGuard, RwLockWriteGuard};
4
5pub struct SnatchGuard<'a>(RwLockReadGuard<'a, ()>);
7pub struct ExclusiveSnatchGuard<'a>(#[expect(dead_code)] RwLockWriteGuard<'a, ()>);
9
10pub struct SnatchableInner<T> {
17 value: UnsafeCell<T>,
18}
19
20pub type Snatchable<T> = SnatchableInner<Option<T>>;
21
22impl<T> Snatchable<T> {
23 pub fn new(val: T) -> Self {
24 SnatchableInner {
25 value: UnsafeCell::new(Some(val)),
26 }
27 }
28
29 #[allow(dead_code)]
30 pub fn empty() -> Self {
31 SnatchableInner {
32 value: UnsafeCell::new(None),
33 }
34 }
35
36 pub fn get<'a>(&'a self, _guard: &'a SnatchGuard) -> Option<&'a T> {
38 unsafe { (*self.value.get()).as_ref() }
39 }
40
41 pub fn snatch(&self, _guard: &mut ExclusiveSnatchGuard) -> Option<T> {
43 unsafe { (*self.value.get()).take() }
44 }
45
46 pub fn take(&mut self) -> Option<T> {
51 self.value.get_mut().take()
52 }
53}
54
55impl<T> fmt::Debug for SnatchableInner<T> {
58 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
59 write!(f, "<snatchable>")
60 }
61}
62
63unsafe impl<T> Sync for SnatchableInner<T> {}
64
65use trace::LockTrace;
66#[cfg(all(debug_assertions, feature = "std"))]
67mod trace {
68 use core::{cell::Cell, fmt, panic::Location};
69 use std::{backtrace::Backtrace, thread};
70
71 pub(super) struct LockTrace {
72 purpose: &'static str,
73 caller: &'static Location<'static>,
74 backtrace: Backtrace,
75 }
76
77 impl fmt::Display for LockTrace {
78 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79 write!(
80 f,
81 "a {} lock at {}\n{}",
82 self.purpose, self.caller, self.backtrace
83 )
84 }
85 }
86
87 impl LockTrace {
88 #[track_caller]
89 pub(super) fn enter(purpose: &'static str) {
90 let new = LockTrace {
91 purpose,
92 caller: Location::caller(),
93 backtrace: Backtrace::capture(),
94 };
95
96 if let Some(prev) = SNATCH_LOCK_TRACE.take() {
97 let current = thread::current();
98 let name = current.name().unwrap_or("<unnamed>");
99 panic!(
100 "thread '{name}' attempted to acquire a snatch lock recursively.\n\
101 - Currently trying to acquire {new}\n\
102 - Previously acquired {prev}",
103 );
104 } else {
105 SNATCH_LOCK_TRACE.set(Some(new));
106 }
107 }
108
109 pub(super) fn exit() {
110 SNATCH_LOCK_TRACE.take();
111 }
112 }
113
114 std::thread_local! {
115 static SNATCH_LOCK_TRACE: Cell<Option<LockTrace>> = const { Cell::new(None) };
116 }
117}
118#[cfg(not(all(debug_assertions, feature = "std")))]
119mod trace {
120 pub(super) struct LockTrace {
121 _private: (),
122 }
123
124 impl LockTrace {
125 pub(super) fn enter(_purpose: &'static str) {}
126 pub(super) fn exit() {}
127 }
128}
129
130pub struct SnatchLock {
132 lock: RwLock<()>,
133}
134
135impl SnatchLock {
136 pub unsafe fn new(rank: rank::LockRank) -> Self {
141 SnatchLock {
142 lock: RwLock::new(rank, ()),
143 }
144 }
145
146 #[track_caller]
148 pub fn read(&self) -> SnatchGuard<'_> {
149 LockTrace::enter("read");
150 SnatchGuard(self.lock.read())
151 }
152
153 #[track_caller]
159 pub fn write(&self) -> ExclusiveSnatchGuard<'_> {
160 LockTrace::enter("write");
161 ExclusiveSnatchGuard(self.lock.write())
162 }
163
164 #[track_caller]
165 pub unsafe fn force_unlock_read(&self, data: RankData) {
166 LockTrace::exit();
170 unsafe { self.lock.force_unlock_read(data) };
171 }
172}
173
174impl SnatchGuard<'_> {
175 pub fn forget(this: Self) -> RankData {
180 let manually_drop = ManuallyDrop::new(this);
182
183 let inner_guard = unsafe { core::ptr::read(&manually_drop.0) };
187
188 RwLockReadGuard::forget(inner_guard)
189 }
190}
191
192impl Drop for SnatchGuard<'_> {
193 fn drop(&mut self) {
194 LockTrace::exit();
195 }
196}
197
198impl Drop for ExclusiveSnatchGuard<'_> {
199 fn drop(&mut self) {
200 LockTrace::exit();
201 }
202}