Skip to main content

tuwunel_service/globals/
data.rs

1use std::{ops::Range, sync::Arc};
2
3use futures::TryFutureExt;
4use tokio::sync::watch::Sender;
5use tuwunel_core::{
6	Result, err, utils,
7	utils::two_phase_counter::{Counter as TwoPhaseCounter, Permit as TwoPhasePermit},
8};
9use tuwunel_database::{Database, Deserialized, Map};
10
11pub struct Data {
12	global: Arc<Map>,
13	retires: Sender<u64>,
14	counter: Arc<Counter>,
15	pub(super) db: Arc<Database>,
16}
17
18pub(super) type Permit = TwoPhasePermit<Callback>;
19type Counter = TwoPhaseCounter<Callback>;
20type Callback = Box<dyn Fn(u64) -> Result + Send + Sync>;
21
22const COUNTER: &[u8] = b"c";
23
24impl Data {
25	pub(super) fn new(args: &crate::Args<'_>) -> Self {
26		let db = args.db.clone();
27		let count = Self::stored_count(&args.db["global"]).expect("initialize global counter");
28		let retires = Sender::new(count);
29		Self {
30			db: args.db.clone(),
31			global: args.db["global"].clone(),
32			retires: retires.clone(),
33			counter: Counter::new(
34				count,
35				Box::new(move |count| Self::store_count(&db, &db["global"], count)),
36				Box::new(move |count| Self::handle_retire(&retires, count)),
37			),
38		}
39	}
40
41	#[inline]
42	pub(super) async fn wait_pending(&self) -> Result<u64> {
43		let count = self.counter.dispatched();
44		self.wait_count(&count).await.inspect(|retired| {
45			debug_assert!(
46				*retired >= count,
47				"Expecting retired sequence number >= snapshotted dispatch number"
48			);
49		})
50	}
51
52	#[inline]
53	pub(super) async fn wait_count(&self, count: &u64) -> Result<u64> {
54		self.retires
55			.subscribe()
56			.wait_for(|retired| retired.ge(count))
57			.map_ok(|retired| *retired)
58			.map_err(|e| err!(debug_error!("counter channel error {e:?}")))
59			.await
60	}
61
62	#[inline]
63	pub(super) fn next_count(&self) -> Permit {
64		self.counter
65			.next()
66			.expect("failed to obtain next sequence number")
67	}
68
69	#[inline]
70	pub(super) fn current_count(&self) -> u64 { self.counter.current() }
71
72	#[inline]
73	pub(super) fn pending_count(&self) -> Range<u64> { self.counter.range() }
74
75	#[tracing::instrument(name = "retire", level = "debug", skip(sender))]
76	fn handle_retire(sender: &Sender<u64>, count: u64) -> Result {
77		let _prev = sender.send_replace(count);
78
79		Ok(())
80	}
81
82	#[tracing::instrument(name = "dispatch", level = "debug", skip(db, global))]
83	fn store_count(db: &Arc<Database>, global: &Arc<Map>, count: u64) -> Result {
84		let _cork = db.cork();
85		global.insert(COUNTER, count.to_be_bytes());
86
87		Ok(())
88	}
89
90	fn stored_count(global: &Arc<Map>) -> Result<u64> {
91		global
92			.get_blocking(COUNTER)
93			.as_deref()
94			.map_or(Ok(0_u64), utils::u64_from_bytes)
95	}
96}
97
98impl Data {
99	pub fn bump_database_version(&self, new_version: u64) {
100		self.global.raw_put(b"version", new_version);
101	}
102
103	pub async fn database_version(&self) -> u64 {
104		self.global
105			.get(b"version")
106			.await
107			.deserialized()
108			.unwrap_or(0)
109	}
110}