1mod configure;
2
3use std::{
4 mem::take,
5 sync::{
6 Arc, Mutex,
7 atomic::{AtomicUsize, Ordering},
8 },
9 thread,
10 thread::JoinHandle,
11};
12
13use async_channel::{QueueStrategy, Receiver, RecvError, Sender};
14use futures::{TryFutureExt, channel::oneshot};
15use oneshot::Sender as ResultSender;
16use rocksdb::Direction;
17use tuwunel_core::{
18 Error, Result, Server, debug, err, error, implement,
19 result::DebugInspect,
20 smallvec::SmallVec,
21 trace,
22 utils::sys::compute::{get_affinity, set_affinity},
23};
24
25use self::configure::configure;
26use crate::{Handle, Map, keyval::KeyBuf, stream};
27
28pub(crate) struct Pool {
32 server: Arc<Server>,
33 queues: Vec<Sender<Cmd>>,
34 workers: Mutex<Vec<JoinHandle<()>>>,
35 topology: Vec<usize>,
36 busy: AtomicUsize,
37 queued_max: AtomicUsize,
38}
39
40pub(crate) enum Cmd {
42 Get(Get),
43 Iter(Seek),
44}
45
46pub(crate) struct Get {
48 pub(crate) map: Arc<Map>,
49 pub(crate) key: BatchQuery<'static>,
50 pub(crate) res: Option<ResultSender<BatchResult<'static>>>,
51}
52
53pub(crate) struct Seek {
57 pub(crate) map: Arc<Map>,
58 pub(crate) state: stream::State<'static>,
59 pub(crate) dir: Direction,
60 pub(crate) key: Option<KeyBuf>,
61 pub(crate) res: Option<ResultSender<stream::State<'static>>>,
62}
63
64pub(crate) type BatchQuery<'a> = SmallVec<[KeyBuf; BATCH_INLINE]>;
65pub(crate) type BatchResult<'a> = SmallVec<[ResultHandle<'a>; BATCH_INLINE]>;
66pub(crate) type ResultHandle<'a> = Result<Handle<'a>>;
67
68const WORKER_LIMIT: (usize, usize) = (1, 4096);
69const QUEUE_LIMIT: (usize, usize) = (1, 1024);
70const BATCH_INLINE: usize = 1;
71
72const WORKER_STACK_SIZE: usize = 1_048_576;
73const WORKER_NAME: &str = "tuwunel:db";
74
75#[implement(Pool)]
76pub(crate) fn new(server: &Arc<Server>) -> Result<Arc<Self>> {
77 const CHAN_SCHED: (QueueStrategy, QueueStrategy) = (QueueStrategy::Fifo, QueueStrategy::Lifo);
78
79 let (topology, workers, queues) = configure(server);
80
81 let (senders, receivers): (Vec<_>, Vec<_>) = queues
82 .into_iter()
83 .map(|cap| cap.max(QUEUE_LIMIT.0))
84 .map(|cap| async_channel::bounded_with_queue_strategy(cap, CHAN_SCHED))
85 .unzip();
86
87 let pool = Arc::new(Self {
88 server: server.clone(),
89 queues: senders,
90 workers: Vec::new().into(),
91 topology,
92 busy: AtomicUsize::default(),
93 queued_max: AtomicUsize::default(),
94 });
95
96 for (chan_id, &count) in workers.iter().enumerate() {
97 pool.spawn_group(&receivers, chan_id, count)?;
98 }
99
100 Ok(pool)
101}
102
103impl Drop for Pool {
104 fn drop(&mut self) {
105 self.close();
106
107 debug_assert!(
108 self.queues.iter().all(Sender::is_empty),
109 "channel must should not have requests queued on drop"
110 );
111 debug_assert!(
112 self.queues.iter().all(Sender::is_closed),
113 "channel should be closed on drop"
114 );
115 }
116}
117
118#[implement(Pool)]
119#[tracing::instrument(skip_all)]
120pub(crate) fn close(&self) {
121 let workers = take(&mut *self.workers.lock().expect("locked"));
122
123 let senders = self
124 .queues
125 .iter()
126 .map(Sender::sender_count)
127 .sum::<usize>();
128
129 let receivers = self
130 .queues
131 .iter()
132 .map(Sender::receiver_count)
133 .sum::<usize>();
134
135 for queue in &self.queues {
136 queue.close();
137 }
138
139 if workers.is_empty() {
140 return;
141 }
142
143 debug!(
144 senders,
145 receivers,
146 queues = self.queues.len(),
147 workers = workers.len(),
148 "Closing pool. Waiting for workers to join..."
149 );
150
151 workers
152 .into_iter()
153 .map(JoinHandle::join)
154 .map(|result| result.map_err(Error::from_panic))
155 .enumerate()
156 .for_each(|(id, result)| match result {
157 | Ok(()) => trace!(?id, "worker joined"),
158 | Err(error) => error!(?id, "worker joined with error: {error}"),
159 });
160}
161
162#[implement(Pool)]
163fn spawn_group(self: &Arc<Self>, recv: &[Receiver<Cmd>], chan_id: usize, count: usize) -> Result {
164 let mut workers = self.workers.lock().expect("locked");
165 for _ in 0..count {
166 self.clone()
167 .spawn_one(&mut workers, recv, chan_id)?;
168 }
169
170 Ok(())
171}
172
173#[implement(Pool)]
174#[tracing::instrument(
175 name = "spawn",
176 level = "trace",
177 skip_all,
178 fields(id = %workers.len())
179)]
180fn spawn_one(
181 self: Arc<Self>,
182 workers: &mut Vec<JoinHandle<()>>,
183 recv: &[Receiver<Cmd>],
184 chan_id: usize,
185) -> Result {
186 debug_assert!(!self.queues.is_empty(), "Must have at least one queue");
187 debug_assert!(!recv.is_empty(), "Must have at least one receiver");
188
189 let id = workers.len();
190 let recv = recv[chan_id].clone();
191
192 let handle = thread::Builder::new()
193 .name(WORKER_NAME.into())
194 .stack_size(WORKER_STACK_SIZE)
195 .spawn(move || self.worker(id, chan_id, &recv))?;
196
197 workers.push(handle);
198
199 Ok(())
200}
201
202#[implement(Pool)]
203#[tracing::instrument(level = "trace", name = "get", skip(self, cmd))]
204pub(crate) async fn execute_get(self: &Arc<Self>, mut cmd: Get) -> Result<BatchResult<'_>> {
205 let (send, recv) = oneshot::channel();
206 _ = cmd.res.insert(send);
207
208 let queue = self.select_queue();
209 self.execute(queue, Cmd::Get(cmd))
210 .and_then(move |()| {
211 recv.map_ok(into_recv_get)
212 .map_err(|e| err!(error!("recv failed {e:?}")))
213 })
214 .await
215}
216
217#[implement(Pool)]
218#[tracing::instrument(level = "trace", name = "iter", skip(self, cmd))]
219pub(crate) async fn execute_iter(self: &Arc<Self>, mut cmd: Seek) -> Result<stream::State<'_>> {
220 let (send, recv) = oneshot::channel();
221 _ = cmd.res.insert(send);
222
223 let queue = self.select_queue();
224 self.execute(queue, Cmd::Iter(cmd))
225 .and_then(|()| {
226 recv.map_ok(into_recv_seek)
227 .map_err(|e| err!(error!("recv failed {e:?}")))
228 })
229 .await
230}
231
232#[implement(Pool)]
233fn select_queue(&self) -> &Sender<Cmd> {
234 let core_id = get_affinity()
235 .next()
236 .expect("Affinity mask should be available.");
237
238 let chan_id = self.topology[core_id];
239
240 self.queues
241 .get(chan_id)
242 .unwrap_or_else(|| &self.queues[0])
243}
244
245#[implement(Pool)]
246#[tracing::instrument(
247 level = "trace",
248 name = "execute",
249 skip(self, cmd),
250 fields(
251 task = ?tokio::task::try_id(),
252 receivers = queue.receiver_count(),
253 queued = queue.len(),
254 queued_max = self.queued_max.load(Ordering::Relaxed),
255 ),
256)]
257async fn execute(&self, queue: &Sender<Cmd>, cmd: Cmd) -> Result {
258 if cfg!(debug_assertions) {
259 self.queued_max
260 .fetch_max(queue.len(), Ordering::Relaxed);
261 }
262
263 queue
264 .send(cmd)
265 .await
266 .map_err(|e| err!(error!("send failed {e:?}")))
267}
268
269#[implement(Pool)]
270#[tracing::instrument(
271 parent = None,
272 level = "debug",
273 skip_all,
274 fields(
275 id,
276 chan_id,
277 thread_id = ?thread::current().id(),
278 ),
279)]
280fn worker(self: Arc<Self>, id: usize, chan_id: usize, recv: &Receiver<Cmd>) {
281 self.worker_init(id, chan_id);
282 self.worker_loop(recv);
283}
284
285#[implement(Pool)]
286fn worker_init(&self, id: usize, chan_id: usize) {
287 let affinity = self
288 .topology
289 .iter()
290 .enumerate()
291 .filter(|_| self.server.config.db_pool_affinity)
292 .filter_map(|(core_id, &queue_id)| (chan_id == queue_id).then_some(core_id));
293
294 set_affinity(affinity.clone());
296
297 trace!(
298 ?id,
299 ?chan_id,
300 affinity = ?affinity.collect::<Vec<_>>(),
301 "worker ready"
302 );
303}
304
305#[implement(Pool)]
306fn worker_loop(self: &Arc<Self>, recv: &Receiver<Cmd>) {
307 self.busy.fetch_add(1, Ordering::Relaxed);
309
310 while let Ok(cmd) = self.worker_wait(recv) {
311 worker_handle(cmd);
312 }
313}
314
315#[implement(Pool)]
316#[tracing::instrument(
317 name = "wait",
318 level = "trace",
319 skip_all,
320 fields(
321 receivers = recv.receiver_count(),
322 queued = recv.len(),
323 busy = self.busy.fetch_sub(1, Ordering::AcqRel) - 1,
324 ),
325)]
326fn worker_wait(self: &Arc<Self>, recv: &Receiver<Cmd>) -> Result<Cmd, RecvError> {
327 recv.recv_blocking().debug_inspect(|_| {
328 self.busy.fetch_add(1, Ordering::Relaxed);
329 })
330}
331
332fn worker_handle(cmd: Cmd) {
333 match cmd {
334 | Cmd::Get(cmd) if cmd.key.len() == 1 => handle_get(cmd),
335 | Cmd::Get(cmd) => handle_batch(cmd),
336 | Cmd::Iter(cmd) => handle_iter(cmd),
337 }
338}
339
340#[tracing::instrument(
341 name = "iter",
342 level = "trace",
343 skip_all,
344 fields(%cmd.map),
345)]
346fn handle_iter(mut cmd: Seek) {
347 let chan = cmd.res.take().expect("missing result channel");
348
349 if chan.is_canceled() {
350 return;
351 }
352
353 let from = cmd.key.as_deref();
354
355 let result = match cmd.dir {
356 | Direction::Forward => cmd.state.init_fwd(from),
357 | Direction::Reverse => cmd.state.init_rev(from),
358 };
359
360 let chan_result = chan.send(into_send_seek(result));
361
362 let _chan_sent = chan_result.is_ok();
363}
364
365#[tracing::instrument(
366 name = "batch",
367 level = "trace",
368 skip_all,
369 fields(
370 %cmd.map,
371 keys = %cmd.key.len(),
372 ),
373)]
374fn handle_batch(mut cmd: Get) {
375 debug_assert!(cmd.key.len() > 1, "should have more than one key");
376 debug_assert!(!cmd.key.iter().any(SmallVec::is_empty), "querying for empty key");
377
378 let chan = cmd.res.take().expect("missing result channel");
379
380 if chan.is_canceled() {
381 return;
382 }
383
384 let keys = cmd.key.iter();
385
386 let result: SmallVec<_> = cmd.map.get_batch_blocking(keys).collect();
387
388 let chan_result = chan.send(into_send_get(result));
389
390 let _chan_sent = chan_result.is_ok();
391}
392
393#[tracing::instrument(
394 name = "get",
395 level = "trace",
396 skip_all,
397 fields(%cmd.map),
398)]
399fn handle_get(mut cmd: Get) {
400 debug_assert!(!cmd.key[0].is_empty(), "querying for empty key");
401
402 let chan = cmd.res.take().expect("missing result channel");
404
405 if chan.is_canceled() {
408 return;
409 }
410
411 let result = cmd.map.get_blocking(&cmd.key[0]);
415
416 let chan_result = chan.send(into_send_get([result].into()));
418
419 let _chan_sent = chan_result.is_ok();
421}
422
423fn into_send_get(result: BatchResult<'_>) -> BatchResult<'static> {
424 unsafe { std::mem::transmute(result) }
429}
430
431fn into_recv_get<'a>(result: BatchResult<'static>) -> BatchResult<'a> {
432 unsafe { std::mem::transmute(result) }
434}
435
436pub(crate) fn into_send_seek(result: stream::State<'_>) -> stream::State<'static> {
437 unsafe { std::mem::transmute(result) }
439}
440
441fn into_recv_seek<'a>(result: stream::State<'static>) -> stream::State<'a> {
442 unsafe { std::mem::transmute(result) }
444}