Skip to main content

tuwunel_database/map/
watch.rs

1use std::{
2	collections::{BTreeMap, btree_map::Entry},
3	future::Future,
4	ops::RangeToInclusive,
5	sync::Mutex,
6};
7
8use futures::pin_mut;
9use serde::Serialize;
10use tokio::sync::watch::{Sender, channel};
11use tuwunel_core::{debug, implement, smallvec::SmallVec};
12
13use crate::keyval::{KeyBuf, serialize_key};
14
15type Watchers = Mutex<BTreeMap<KeyBuf, Sender<()>>>;
16type KeyVec = SmallVec<[KeyBuf; 1]>;
17
18#[derive(Default)]
19pub(super) struct Watch {
20	watchers: Watchers,
21}
22
23#[implement(super::Map)]
24pub fn watch_prefix<K>(&self, prefix: K) -> impl Future<Output = ()> + Send + '_
25where
26	K: Serialize,
27{
28	let prefix = serialize_key(prefix).expect("failed to serialize watch prefix key");
29	self.watch_raw_prefix(&prefix)
30}
31
32#[implement(super::Map)]
33pub fn watch_raw_prefix<'a, K>(&self, prefix: &'a K) -> impl Future<Output = ()> + Send + use<K>
34where
35	K: AsRef<[u8]> + ?Sized + 'a,
36{
37	let rx = match self
38		.watch
39		.watchers
40		.lock()
41		.expect("locked")
42		.entry(prefix.as_ref().into())
43	{
44		| Entry::Occupied(node) => node.get().subscribe(),
45		| Entry::Vacant(node) => {
46			let (tx, rx) = channel(());
47			node.insert(tx);
48			rx
49		},
50	};
51
52	async move {
53		pin_mut!(rx);
54		rx.changed()
55			.await
56			.expect("watcher sender dropped");
57	}
58}
59
60#[implement(super::Map)]
61#[tracing::instrument(
62	level = "trace",
63	skip_all,
64	fields(
65		map = self.name(),
66		key = str::from_utf8(key.as_ref()).unwrap_or("<binary>"),
67	)
68)]
69pub(crate) fn notify<K>(&self, key: &K)
70where
71	K: AsRef<[u8]> + Ord + ?Sized,
72{
73	let range = RangeToInclusive::<KeyBuf> { end: key.as_ref().into() };
74
75	let mut watchers = self.watch.watchers.lock().expect("locked");
76
77	let num_notified = watchers
78		.range(range)
79		.rev()
80		.take_while(|(k, _)| key.as_ref().starts_with(k))
81		.filter_map(|(k, tx)| tx.send(()).is_err().then_some(k))
82		.cloned()
83		.collect::<KeyVec>()
84		.into_iter()
85		.fold(0_usize, |num_notified, key| {
86			watchers.remove(&key);
87			num_notified.saturating_add(1)
88		});
89
90	if num_notified > 0 {
91		debug!(watchers = watchers.len(), num_notified, "notified");
92	}
93}