tuwunel_database/map/
watch.rs1use std::{
2 collections::{BTreeMap, btree_map::Entry},
3 future::Future,
4 ops::RangeToInclusive,
5 sync::Mutex,
6};
7
8use futures::pin_mut;
9use serde::Serialize;
10use tokio::sync::watch::{Receiver, Sender, channel};
11use tuwunel_core::{debug, defer, implement, smallvec::SmallVec};
12
13use crate::keyval::{KeyBuf, serialize_key};
14
15type Watchers = Mutex<BTreeMap<KeyBuf, Sender<()>>>;
16type KeyVec = SmallVec<[KeyBuf; 1]>;
17
18#[derive(Default)]
19pub(super) struct Watch {
20 watchers: Watchers,
21}
22
23#[implement(super::Map)]
24pub fn watch_prefix<K>(&self, prefix: K) -> impl Future<Output = ()> + Send + '_
25where
26 K: Serialize,
27{
28 let prefix = serialize_key(prefix).expect("failed to serialize watch prefix key");
29 self.watch_raw_prefix(&prefix)
30}
31
32#[implement(super::Map)]
33pub fn watch_raw_prefix_once<K>(&self, prefix: K) -> impl Future<Output = ()> + Send + '_
34where
35 K: AsRef<[u8]>,
36{
37 let key: KeyBuf = prefix.as_ref().into();
38 let rx = self.subscribe(key.clone());
39
40 async move {
41 pin_mut!(rx);
42
43 defer! {{
45 let mut watchers = self.watch.watchers.lock().expect("locked");
46 if watchers.get(&key).is_some_and(|tx| tx.receiver_count() == 1) {
47 watchers.remove(&key);
48 }
49 }}
50
51 rx.changed()
52 .await
53 .expect("watcher sender dropped");
54 }
55}
56
57#[implement(super::Map)]
58pub fn watch_raw_prefix<'a, K>(&self, prefix: &'a K) -> impl Future<Output = ()> + Send + use<K>
59where
60 K: AsRef<[u8]> + ?Sized + 'a,
61{
62 let rx = self.subscribe(prefix.as_ref().into());
63
64 async move {
65 pin_mut!(rx);
66 rx.changed()
67 .await
68 .expect("watcher sender dropped");
69 }
70}
71
72#[implement(super::Map)]
73fn subscribe(&self, key: KeyBuf) -> Receiver<()> {
74 match self
75 .watch
76 .watchers
77 .lock()
78 .expect("locked")
79 .entry(key)
80 {
81 | Entry::Occupied(node) => node.get().subscribe(),
82 | Entry::Vacant(node) => {
83 let (tx, rx) = channel(());
84 node.insert(tx);
85 rx
86 },
87 }
88}
89
90#[implement(super::Map)]
91#[tracing::instrument(
92 level = "trace",
93 skip_all,
94 fields(
95 map = self.name(),
96 key = str::from_utf8(key.as_ref()).unwrap_or("<binary>"),
97 )
98)]
99pub(crate) fn notify<K>(&self, key: &K)
100where
101 K: AsRef<[u8]> + Ord + ?Sized,
102{
103 let range = RangeToInclusive::<KeyBuf> { end: key.as_ref().into() };
104
105 let mut watchers = self.watch.watchers.lock().expect("locked");
106
107 let num_notified = watchers
108 .range(range)
109 .rev()
110 .take_while(|(k, _)| key.as_ref().starts_with(k))
111 .filter_map(|(k, tx)| tx.send(()).is_err().then_some(k))
112 .cloned()
113 .collect::<KeyVec>()
114 .into_iter()
115 .fold(0_usize, |num_notified, key| {
116 watchers.remove(&key);
117 num_notified.saturating_add(1)
118 });
119
120 if num_notified > 0 {
121 debug!(watchers = watchers.len(), num_notified, "notified");
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use tokio::sync::watch::channel;
128
129 #[test]
132 fn receiver_count_reaps_at_last_drop() {
133 let (tx, rx) = channel(());
134 assert_eq!(tx.receiver_count(), 1, "fresh channel has one receiver");
135
136 let rx2 = tx.subscribe();
137 assert_eq!(tx.receiver_count(), 2, "subscribe adds a receiver");
138
139 drop(rx2);
140 assert_eq!(tx.receiver_count(), 1, "drop is reflected synchronously");
141
142 drop(rx);
143 assert_eq!(tx.receiver_count(), 0, "last drop leaves no receiver");
144 assert!(tx.send(()).is_err(), "send fails with zero receivers");
145 }
146}