Skip to main content

tuwunel_core/utils/
mutex_map.rs

1use std::{
2	borrow::ToOwned,
3	fmt::Debug,
4	hash::Hash,
5	sync::{Arc, TryLockError::WouldBlock},
6};
7
8use tokio::sync::OwnedMutexGuard as Omg;
9
10use crate::{Result, err};
11
12/// Map of Mutexes
13#[derive(Debug)]
14pub struct MutexMap<Key, Val> {
15	map: Map<Key, Val>,
16}
17
18#[derive(Debug)]
19#[clippy::has_significant_drop]
20pub struct Guard<Key, Val> {
21	map: Map<Key, Val>,
22	val: Omg<Val>,
23}
24
25type Map<Key, Val> = Arc<MapMutex<Key, Val>>;
26type MapMutex<Key, Val> = std::sync::Mutex<HashMap<Key, Val>>;
27type HashMap<Key, Val> = std::collections::HashMap<Key, Value<Val>>;
28type Value<Val> = Arc<tokio::sync::Mutex<Val>>;
29
30impl<Key, Val> MutexMap<Key, Val>
31where
32	Key: Clone + Eq + Hash + Send,
33	Val: Default + Send,
34{
35	#[must_use]
36	pub fn new() -> Self {
37		Self {
38			map: Map::new(MapMutex::new(HashMap::new())),
39		}
40	}
41
42	#[tracing::instrument(level = "trace", skip(self))]
43	pub async fn lock<K>(&self, k: &K) -> Guard<Key, Val>
44	where
45		K: Debug + Send + ?Sized + Sync + ToOwned<Owned = Key>,
46	{
47		let val = self
48			.map
49			.lock()
50			.expect("locked")
51			.entry(k.to_owned())
52			.or_default()
53			.clone();
54
55		Guard::<Key, Val> {
56			map: Arc::clone(&self.map),
57			val: val.lock_owned().await,
58		}
59	}
60
61	#[tracing::instrument(level = "trace", skip(self))]
62	pub fn try_lock<K>(&self, k: &K) -> Result<Guard<Key, Val>>
63	where
64		K: Debug + Send + ?Sized + Sync + ToOwned<Owned = Key>,
65	{
66		let val = self
67			.map
68			.lock()
69			.expect("locked")
70			.entry(k.to_owned())
71			.or_default()
72			.clone();
73
74		Ok(Guard::<Key, Val> {
75			map: Arc::clone(&self.map),
76			val: val
77				.try_lock_owned()
78				.map_err(|_| err!("would yield"))?,
79		})
80	}
81
82	#[tracing::instrument(level = "trace", skip(self))]
83	pub fn try_try_lock<K>(&self, k: &K) -> Result<Guard<Key, Val>>
84	where
85		K: Debug + Send + ?Sized + Sync + ToOwned<Owned = Key>,
86	{
87		let val = self
88			.map
89			.try_lock()
90			.map_err(|e| match e {
91				| WouldBlock => err!("would block"),
92				| _ => panic!("{e:?}"),
93			})?
94			.entry(k.to_owned())
95			.or_default()
96			.clone();
97
98		Ok(Guard::<Key, Val> {
99			map: Arc::clone(&self.map),
100			val: val
101				.try_lock_owned()
102				.map_err(|_| err!("would yield"))?,
103		})
104	}
105
106	#[must_use]
107	pub fn contains(&self, k: &Key) -> bool { self.map.lock().expect("locked").contains_key(k) }
108
109	#[must_use]
110	pub fn is_empty(&self) -> bool { self.map.lock().expect("locked").is_empty() }
111
112	#[must_use]
113	pub fn len(&self) -> usize { self.map.lock().expect("locked").len() }
114}
115
116impl<Key, Val> Default for MutexMap<Key, Val>
117where
118	Key: Clone + Eq + Hash + Send,
119	Val: Default + Send,
120{
121	fn default() -> Self { Self::new() }
122}
123
124impl<Key, Val> Drop for Guard<Key, Val> {
125	#[tracing::instrument(name = "unlock", level = "trace", skip_all)]
126	fn drop(&mut self) {
127		if Arc::strong_count(Omg::mutex(&self.val)) <= 2 {
128			self.map.lock().expect("locked").retain(|_, val| {
129				!Arc::ptr_eq(val, Omg::mutex(&self.val)) || Arc::strong_count(val) > 2
130			});
131		}
132	}
133}