tuwunel_core/utils/
mutex_map.rs1use std::{
2 borrow::ToOwned,
3 fmt::Debug,
4 hash::Hash,
5 sync::{Arc, TryLockError::WouldBlock},
6};
7
8use tokio::sync::OwnedMutexGuard as Omg;
9
10use crate::{Result, err};
11
12#[derive(Debug)]
14pub struct MutexMap<Key, Val> {
15 map: Map<Key, Val>,
16}
17
18#[derive(Debug)]
19#[clippy::has_significant_drop]
20pub struct Guard<Key, Val> {
21 map: Map<Key, Val>,
22 val: Omg<Val>,
23}
24
25type Map<Key, Val> = Arc<MapMutex<Key, Val>>;
26type MapMutex<Key, Val> = std::sync::Mutex<HashMap<Key, Val>>;
27type HashMap<Key, Val> = std::collections::HashMap<Key, Value<Val>>;
28type Value<Val> = Arc<tokio::sync::Mutex<Val>>;
29
30impl<Key, Val> MutexMap<Key, Val>
31where
32 Key: Clone + Eq + Hash + Send,
33 Val: Default + Send,
34{
35 #[must_use]
36 pub fn new() -> Self {
37 Self {
38 map: Map::new(MapMutex::new(HashMap::new())),
39 }
40 }
41
42 #[tracing::instrument(level = "trace", skip(self))]
43 pub async fn lock<K>(&self, k: &K) -> Guard<Key, Val>
44 where
45 K: Debug + Send + ?Sized + Sync + ToOwned<Owned = Key>,
46 {
47 let val = self
48 .map
49 .lock()
50 .expect("locked")
51 .entry(k.to_owned())
52 .or_default()
53 .clone();
54
55 Guard::<Key, Val> {
56 map: Arc::clone(&self.map),
57 val: val.lock_owned().await,
58 }
59 }
60
61 #[tracing::instrument(level = "trace", skip(self))]
62 pub fn try_lock<K>(&self, k: &K) -> Result<Guard<Key, Val>>
63 where
64 K: Debug + Send + ?Sized + Sync + ToOwned<Owned = Key>,
65 {
66 let val = self
67 .map
68 .lock()
69 .expect("locked")
70 .entry(k.to_owned())
71 .or_default()
72 .clone();
73
74 Ok(Guard::<Key, Val> {
75 map: Arc::clone(&self.map),
76 val: val
77 .try_lock_owned()
78 .map_err(|_| err!("would yield"))?,
79 })
80 }
81
82 #[tracing::instrument(level = "trace", skip(self))]
83 pub fn try_try_lock<K>(&self, k: &K) -> Result<Guard<Key, Val>>
84 where
85 K: Debug + Send + ?Sized + Sync + ToOwned<Owned = Key>,
86 {
87 let val = self
88 .map
89 .try_lock()
90 .map_err(|e| match e {
91 | WouldBlock => err!("would block"),
92 | _ => panic!("{e:?}"),
93 })?
94 .entry(k.to_owned())
95 .or_default()
96 .clone();
97
98 Ok(Guard::<Key, Val> {
99 map: Arc::clone(&self.map),
100 val: val
101 .try_lock_owned()
102 .map_err(|_| err!("would yield"))?,
103 })
104 }
105
106 #[must_use]
107 pub fn contains(&self, k: &Key) -> bool { self.map.lock().expect("locked").contains_key(k) }
108
109 #[must_use]
110 pub fn is_empty(&self) -> bool { self.map.lock().expect("locked").is_empty() }
111
112 #[must_use]
113 pub fn len(&self) -> usize { self.map.lock().expect("locked").len() }
114}
115
116impl<Key, Val> Default for MutexMap<Key, Val>
117where
118 Key: Clone + Eq + Hash + Send,
119 Val: Default + Send,
120{
121 fn default() -> Self { Self::new() }
122}
123
124impl<Key, Val> Drop for Guard<Key, Val> {
125 #[tracing::instrument(name = "unlock", level = "trace", skip_all)]
126 fn drop(&mut self) {
127 if Arc::strong_count(Omg::mutex(&self.val)) <= 2 {
128 self.map.lock().expect("locked").retain(|_, val| {
129 !Arc::ptr_eq(val, Omg::mutex(&self.val)) || Arc::strong_count(val) > 2
130 });
131 }
132 }
133}