Skip to main content

tuwunel_database/map/
watch.rs

1use std::{
2	collections::{BTreeMap, btree_map::Entry},
3	future::Future,
4	ops::RangeToInclusive,
5	sync::Mutex,
6};
7
8use futures::pin_mut;
9use serde::Serialize;
10use tokio::sync::watch::{Receiver, Sender, channel};
11use tuwunel_core::{debug, defer, implement, smallvec::SmallVec};
12
13use crate::keyval::{KeyBuf, serialize_key};
14
15type Watchers = Mutex<BTreeMap<KeyBuf, Sender<()>>>;
16type KeyVec = SmallVec<[KeyBuf; 1]>;
17
18#[derive(Default)]
19pub(super) struct Watch {
20	watchers: Watchers,
21}
22
23#[implement(super::Map)]
24pub fn watch_prefix<K>(&self, prefix: K) -> impl Future<Output = ()> + Send + '_
25where
26	K: Serialize,
27{
28	let prefix = serialize_key(prefix).expect("failed to serialize watch prefix key");
29	self.watch_raw_prefix(&prefix)
30}
31
32#[implement(super::Map)]
33pub fn watch_raw_prefix_once<K>(&self, prefix: K) -> impl Future<Output = ()> + Send + '_
34where
35	K: AsRef<[u8]>,
36{
37	let key: KeyBuf = prefix.as_ref().into();
38	let rx = self.subscribe(key.clone());
39
40	async move {
41		pin_mut!(rx);
42
43		// We are still subscribed, so a receiver count of one means we are the last.
44		defer! {{
45			let mut watchers = self.watch.watchers.lock().expect("locked");
46			if watchers.get(&key).is_some_and(|tx| tx.receiver_count() == 1) {
47				watchers.remove(&key);
48			}
49		}}
50
51		rx.changed()
52			.await
53			.expect("watcher sender dropped");
54	}
55}
56
57#[implement(super::Map)]
58pub fn watch_raw_prefix<'a, K>(&self, prefix: &'a K) -> impl Future<Output = ()> + Send + use<K>
59where
60	K: AsRef<[u8]> + ?Sized + 'a,
61{
62	let rx = self.subscribe(prefix.as_ref().into());
63
64	async move {
65		pin_mut!(rx);
66		rx.changed()
67			.await
68			.expect("watcher sender dropped");
69	}
70}
71
72#[implement(super::Map)]
73fn subscribe(&self, key: KeyBuf) -> Receiver<()> {
74	match self
75		.watch
76		.watchers
77		.lock()
78		.expect("locked")
79		.entry(key)
80	{
81		| Entry::Occupied(node) => node.get().subscribe(),
82		| Entry::Vacant(node) => {
83			let (tx, rx) = channel(());
84			node.insert(tx);
85			rx
86		},
87	}
88}
89
90#[implement(super::Map)]
91#[tracing::instrument(
92	level = "trace",
93	skip_all,
94	fields(
95		map = self.name(),
96		key = str::from_utf8(key.as_ref()).unwrap_or("<binary>"),
97	)
98)]
99pub(crate) fn notify<K>(&self, key: &K)
100where
101	K: AsRef<[u8]> + Ord + ?Sized,
102{
103	let range = RangeToInclusive::<KeyBuf> { end: key.as_ref().into() };
104
105	let mut watchers = self.watch.watchers.lock().expect("locked");
106
107	let num_notified = watchers
108		.range(range)
109		.rev()
110		.take_while(|(k, _)| key.as_ref().starts_with(k))
111		.filter_map(|(k, tx)| tx.send(()).is_err().then_some(k))
112		.cloned()
113		.collect::<KeyVec>()
114		.into_iter()
115		.fold(0_usize, |num_notified, key| {
116			watchers.remove(&key);
117			num_notified.saturating_add(1)
118		});
119
120	if num_notified > 0 {
121		debug!(watchers = watchers.len(), num_notified, "notified");
122	}
123}
124
125#[cfg(test)]
126mod tests {
127	use tokio::sync::watch::channel;
128
129	// Pins the tokio contract the reaper relies on: receiver_count() reflects a
130	// just-dropped Receiver and send() fails once no receiver remains.
131	#[test]
132	fn receiver_count_reaps_at_last_drop() {
133		let (tx, rx) = channel(());
134		assert_eq!(tx.receiver_count(), 1, "fresh channel has one receiver");
135
136		let rx2 = tx.subscribe();
137		assert_eq!(tx.receiver_count(), 2, "subscribe adds a receiver");
138
139		drop(rx2);
140		assert_eq!(tx.receiver_count(), 1, "drop is reflected synchronously");
141
142		drop(rx);
143		assert_eq!(tx.receiver_count(), 0, "last drop leaves no receiver");
144		assert!(tx.send(()).is_err(), "send fails with zero receivers");
145	}
146}