tuwunel_core/utils/
two_phase_counter.rs1use std::{
4 collections::VecDeque,
5 ops::{Deref, Range},
6 sync::{Arc, RwLock},
7};
8
9use crate::{Result, checked, is_equal_to};
10
11pub struct Counter<F: Fn(u64) -> Result + Send + Sync> {
26 inner: RwLock<State<F>>,
28}
29
30pub struct State<F: Fn(u64) -> Result + Send + Sync> {
32 dispatched: u64,
35
36 commit: F,
39
40 pending: VecDeque<u64>,
44
45 release: F,
48}
49
50#[clippy::has_significant_drop]
51pub struct Permit<F: Fn(u64) -> Result + Send + Sync> {
52 state: Arc<Counter<F>>,
54
55 retired: u64,
58
59 id: u64,
61}
62
63impl<F: Fn(u64) -> Result + Send + Sync> Counter<F> {
64 pub fn new(init: u64, commit: F, release: F) -> Arc<Self> {
68 Arc::new(Self {
69 inner: State::new(init, commit, release).into(),
70 })
71 }
72
73 pub fn next(self: &Arc<Self>) -> Result<Permit<F>> {
75 let (retired, id) = self.inner.write()?.dispatch()?;
76
77 Ok(Permit::<F> { state: self.clone(), retired, id })
78 }
79
80 #[inline]
82 pub fn range(&self) -> Range<u64> {
83 let inner = self.inner.read().expect("locked for reading");
84
85 Range {
86 start: inner.retired(),
87 end: inner.dispatched,
88 }
89 }
90
91 #[inline]
94 pub fn current(&self) -> u64 {
95 self.inner
96 .read()
97 .expect("locked for reading")
98 .retired()
99 }
100
101 #[inline]
104 pub fn dispatched(&self) -> u64 {
105 self.inner
106 .read()
107 .expect("locked for reading")
108 .dispatched
109 }
110}
111
112impl<F: Fn(u64) -> Result + Send + Sync> State<F> {
113 fn new(dispatched: u64, commit: F, release: F) -> Self {
116 Self {
117 dispatched,
118 commit,
119 pending: VecDeque::new(),
120 release,
121 }
122 }
123
124 fn dispatch(&mut self) -> Result<(u64, u64)> {
127 let prev = self.dispatched;
128 let retired = self.retired();
129 let dispatched = checked!(prev + 1)?;
130 debug_assert!(
131 !self.check_pending(dispatched),
132 "sequence number cannot already be pending",
133 );
134
135 (self.commit)(dispatched)?;
136 self.dispatched = dispatched;
137 self.pending.push_back(self.dispatched);
138 Ok((retired, self.dispatched))
139 }
140
141 fn retire(&mut self, id: u64) {
143 debug_assert!(self.check_pending(id), "sequence number must be currently pending");
144
145 let index = self
146 .pending_index(id)
147 .expect("sequence number must be found as pending");
148
149 let removed = self
150 .pending
151 .remove(index)
152 .expect("sequence number at index must be removed");
153
154 debug_assert_eq!(removed, id, "sequence number removed must match id");
155
156 if index != 0 {
158 return;
159 }
160
161 let release = if self.pending.is_empty() { self.dispatched } else { id };
163
164 debug_assert!(release >= id, "sequence number released must not be less than id");
165
166 (self.release)(release).expect("release callback should not error");
167 }
168
169 fn retired(&self) -> u64 {
173 debug_assert!(
174 self.pending.iter().is_sorted(),
175 "Pending values should be naturally sorted"
176 );
177
178 self.pending
179 .front()
180 .map(|val| val.saturating_sub(1))
181 .unwrap_or(self.dispatched)
182 }
183
184 fn pending_index(&self, id: u64) -> Option<usize> {
186 debug_assert!(
187 self.pending.iter().is_sorted(),
188 "Pending values should be naturally sorted"
189 );
190
191 self.pending.binary_search(&id).ok()
192 }
193
194 fn check_pending(&self, id: u64) -> bool { self.pending.iter().any(is_equal_to!(&id)) }
197}
198
199impl<F: Fn(u64) -> Result + Send + Sync> Permit<F> {
200 #[inline]
203 #[must_use]
204 pub fn retired(&self) -> &u64 { &self.retired }
205
206 #[inline]
208 #[must_use]
209 pub fn id(&self) -> &u64 { &self.id }
210}
211
212impl<F: Fn(u64) -> Result + Send + Sync> Deref for Permit<F> {
213 type Target = u64;
214
215 #[inline]
216 fn deref(&self) -> &Self::Target { self.id() }
217}
218
219impl<F: Fn(u64) -> Result + Send + Sync> Drop for Permit<F> {
220 fn drop(&mut self) {
221 self.state
222 .inner
223 .write()
224 .expect("locked for writing")
225 .retire(self.id);
226 }
227}