Skip to main content

tuwunel_core/utils/
two_phase_counter.rs

1//! Two-Phase Counter.
2
3use std::{
4	collections::VecDeque,
5	ops::{Deref, Range},
6	sync::{Arc, RwLock},
7};
8
9use crate::{Result, checked, is_equal_to};
10
11/// Two-Phase Counter.
12///
13/// This device solves the problem of a One-Phase Counter (or just a counter)
14/// which is incremented to provide unique sequence numbers (or index numbers)
15/// fundamental to server operation. For example, let's say a new Matrix Pdu
16/// is received: the counter is incremented and its value becomes the PduId
17/// used as a key for the Pdu value when writing to the database.
18///
19/// Problem: With a single counter shared by both writers and readers, pending
20/// writes might still be in-flight and not visible to readers after the writer
21/// incremented it. For example, client-sync sees the counter at a certain
22/// value, but that value has no Pdu found because its write has not been
23/// completed with global visibility. Client-sync will then move on to the next
24/// counter value having missed the data from the current one.
25pub struct Counter<F: Fn(u64) -> Result + Send + Sync> {
26	/// Self is intended to be `Arc<Counter>` with inner state mutable via Lock.
27	inner: RwLock<State<F>>,
28}
29
30/// Inner protected state for Two-Phase Counter.
31pub struct State<F: Fn(u64) -> Result + Send + Sync> {
32	/// Monotonic counter. The next sequence number is drawn by adding one to
33	/// this value. That number will be persisted and added to `pending`.
34	dispatched: u64,
35
36	/// Callback to persist the next sequence number drawn from `dispatched`.
37	/// This prevents pending numbers from being reused after server restart.
38	commit: F,
39
40	/// List of pending sequence numbers. One less than the minimum value in
41	/// this list is the "retirement" sequence number where all writes have
42	/// completed and all reads are globally visible.
43	pending: VecDeque<u64>,
44
45	/// Callback to notify updates of the retirement value. This is likely
46	/// called from the destructor of a permit/guard; try not to panic.
47	release: F,
48}
49
50#[clippy::has_significant_drop]
51pub struct Permit<F: Fn(u64) -> Result + Send + Sync> {
52	/// Link back to the shared-state.
53	state: Arc<Counter<F>>,
54
55	/// The retirement value computed as a courtesy when this permit was
56	/// created.
57	retired: u64,
58
59	/// Sequence number of this permit.
60	id: u64,
61}
62
63impl<F: Fn(u64) -> Result + Send + Sync> Counter<F> {
64	/// Construct a new Two-Phase counter state. The value of `init` is
65	/// considered retired, and the next sequence number dispatched will be one
66	/// greater.
67	pub fn new(init: u64, commit: F, release: F) -> Arc<Self> {
68		Arc::new(Self {
69			inner: State::new(init, commit, release).into(),
70		})
71	}
72
73	/// Obtain a sequence number to conduct write operations for the scope.
74	pub fn next(self: &Arc<Self>) -> Result<Permit<F>> {
75		let (retired, id) = self.inner.write()?.dispatch()?;
76
77		Ok(Permit::<F> { state: self.clone(), retired, id })
78	}
79
80	/// Load the current and dispatched values simultaneously
81	#[inline]
82	pub fn range(&self) -> Range<u64> {
83		let inner = self.inner.read().expect("locked for reading");
84
85		Range {
86			start: inner.retired(),
87			end: inner.dispatched,
88		}
89	}
90
91	/// Load the highest sequence number safe for reading, also known as the
92	/// retirement value with writes "globally visible."
93	#[inline]
94	pub fn current(&self) -> u64 {
95		self.inner
96			.read()
97			.expect("locked for reading")
98			.retired()
99	}
100
101	/// Load the highest sequence number (dispatched); may still be pending or
102	/// may be retired.
103	#[inline]
104	pub fn dispatched(&self) -> u64 {
105		self.inner
106			.read()
107			.expect("locked for reading")
108			.dispatched
109	}
110}
111
112impl<F: Fn(u64) -> Result + Send + Sync> State<F> {
113	/// Create new state, starting from `init`. The next sequence number
114	/// dispatched will be one greater than `init`.
115	fn new(dispatched: u64, commit: F, release: F) -> Self {
116		Self {
117			dispatched,
118			commit,
119			pending: VecDeque::new(),
120			release,
121		}
122	}
123
124	/// Dispatch the next sequence number as pending. The retired value is
125	/// calculated as a courtesy while the state is under lock.
126	fn dispatch(&mut self) -> Result<(u64, u64)> {
127		let prev = self.dispatched;
128		let retired = self.retired();
129		let dispatched = checked!(prev + 1)?;
130		debug_assert!(
131			!self.check_pending(dispatched),
132			"sequence number cannot already be pending",
133		);
134
135		(self.commit)(dispatched)?;
136		self.dispatched = dispatched;
137		self.pending.push_back(self.dispatched);
138		Ok((retired, self.dispatched))
139	}
140
141	/// Retire the sequence number `id`.
142	fn retire(&mut self, id: u64) {
143		debug_assert!(self.check_pending(id), "sequence number must be currently pending");
144
145		let index = self
146			.pending_index(id)
147			.expect("sequence number must be found as pending");
148
149		let removed = self
150			.pending
151			.remove(index)
152			.expect("sequence number at index must be removed");
153
154		debug_assert_eq!(removed, id, "sequence number removed must match id");
155
156		// release only occurs when the oldest value retires
157		if index != 0 {
158			return;
159		}
160
161		// release occurs for the maximum retired value
162		let release = if self.pending.is_empty() { self.dispatched } else { id };
163
164		debug_assert!(release >= id, "sequence number released must not be less than id");
165
166		(self.release)(release).expect("release callback should not error");
167	}
168
169	/// Calculate the retired sequence number, one less than the lowest pending
170	/// sequence number. If nothing is pending the value of `dispatched` has
171	/// been previously retired and is returned.
172	fn retired(&self) -> u64 {
173		debug_assert!(
174			self.pending.iter().is_sorted(),
175			"Pending values should be naturally sorted"
176		);
177
178		self.pending
179			.front()
180			.map(|val| val.saturating_sub(1))
181			.unwrap_or(self.dispatched)
182	}
183
184	/// Get the position of `id` in the pending list.
185	fn pending_index(&self, id: u64) -> Option<usize> {
186		debug_assert!(
187			self.pending.iter().is_sorted(),
188			"Pending values should be naturally sorted"
189		);
190
191		self.pending.binary_search(&id).ok()
192	}
193
194	/// Check for `id` in the pending list sequentially (for debug and assertion
195	/// purposes only)
196	fn check_pending(&self, id: u64) -> bool { self.pending.iter().any(is_equal_to!(&id)) }
197}
198
199impl<F: Fn(u64) -> Result + Send + Sync> Permit<F> {
200	/// Access the retired sequence number sampled at this permit's creation.
201	/// This may be outdated prior to access. Obtained as a courtesy under lock.
202	#[inline]
203	#[must_use]
204	pub fn retired(&self) -> &u64 { &self.retired }
205
206	/// Access the sequence number obtained by this permit; a unique value
207	#[inline]
208	#[must_use]
209	pub fn id(&self) -> &u64 { &self.id }
210}
211
212impl<F: Fn(u64) -> Result + Send + Sync> Deref for Permit<F> {
213	type Target = u64;
214
215	#[inline]
216	fn deref(&self) -> &Self::Target { self.id() }
217}
218
219impl<F: Fn(u64) -> Result + Send + Sync> Drop for Permit<F> {
220	fn drop(&mut self) {
221		self.state
222			.inner
223			.write()
224			.expect("locked for writing")
225			.retire(self.id);
226	}
227}