Skip to main content

tuwunel_database/
pool.rs

1mod configure;
2
3use std::{
4	mem::take,
5	sync::{
6		Arc, Mutex,
7		atomic::{AtomicUsize, Ordering},
8	},
9	thread,
10	thread::JoinHandle,
11};
12
13use async_channel::{QueueStrategy, Receiver, RecvError, Sender};
14use futures::{TryFutureExt, channel::oneshot};
15use oneshot::Sender as ResultSender;
16use rocksdb::Direction;
17use tuwunel_core::{
18	Error, Result, Server, debug, err, error, implement,
19	result::DebugInspect,
20	smallvec::SmallVec,
21	trace,
22	utils::sys::compute::{get_affinity, set_affinity},
23};
24
25use self::configure::configure;
26use crate::{Handle, Map, keyval::KeyBuf, stream};
27
28/// Frontend thread-pool. Operating system threads are used to make database
29/// requests which are not cached. These thread-blocking requests are offloaded
30/// from the tokio async workers and executed on this threadpool.
31pub(crate) struct Pool {
32	server: Arc<Server>,
33	queues: Vec<Sender<Cmd>>,
34	workers: Mutex<Vec<JoinHandle<()>>>,
35	topology: Vec<usize>,
36	busy: AtomicUsize,
37	queued_max: AtomicUsize,
38}
39
40/// Operations which can be submitted to the pool.
41pub(crate) enum Cmd {
42	Get(Get),
43	Iter(Seek),
44}
45
46/// Multi-point-query
47pub(crate) struct Get {
48	pub(crate) map: Arc<Map>,
49	pub(crate) key: BatchQuery<'static>,
50	pub(crate) res: Option<ResultSender<BatchResult<'static>>>,
51}
52
53/// Iterator-seek.
54/// Note: only initial seek is supported at this time on the assumption rocksdb
55/// prefetching prevents mid-iteration polls from blocking on I/O.
56pub(crate) struct Seek {
57	pub(crate) map: Arc<Map>,
58	pub(crate) state: stream::State<'static>,
59	pub(crate) dir: Direction,
60	pub(crate) key: Option<KeyBuf>,
61	pub(crate) res: Option<ResultSender<stream::State<'static>>>,
62}
63
64pub(crate) type BatchQuery<'a> = SmallVec<[KeyBuf; BATCH_INLINE]>;
65pub(crate) type BatchResult<'a> = SmallVec<[ResultHandle<'a>; BATCH_INLINE]>;
66pub(crate) type ResultHandle<'a> = Result<Handle<'a>>;
67
68const WORKER_LIMIT: (usize, usize) = (1, 4096);
69const QUEUE_LIMIT: (usize, usize) = (1, 1024);
70const BATCH_INLINE: usize = 1;
71
72const WORKER_STACK_SIZE: usize = 1_048_576;
73const WORKER_NAME: &str = "tuwunel:db";
74
75#[implement(Pool)]
76pub(crate) fn new(server: &Arc<Server>) -> Result<Arc<Self>> {
77	const CHAN_SCHED: (QueueStrategy, QueueStrategy) = (QueueStrategy::Fifo, QueueStrategy::Lifo);
78
79	let (topology, workers, queues) = configure(server);
80
81	let (senders, receivers): (Vec<_>, Vec<_>) = queues
82		.into_iter()
83		.map(|cap| cap.max(QUEUE_LIMIT.0))
84		.map(|cap| async_channel::bounded_with_queue_strategy(cap, CHAN_SCHED))
85		.unzip();
86
87	let pool = Arc::new(Self {
88		server: server.clone(),
89		queues: senders,
90		workers: Vec::new().into(),
91		topology,
92		busy: AtomicUsize::default(),
93		queued_max: AtomicUsize::default(),
94	});
95
96	for (chan_id, &count) in workers.iter().enumerate() {
97		pool.spawn_group(&receivers, chan_id, count)?;
98	}
99
100	Ok(pool)
101}
102
103impl Drop for Pool {
104	fn drop(&mut self) {
105		self.close();
106
107		debug_assert!(
108			self.queues.iter().all(Sender::is_empty),
109			"channel must should not have requests queued on drop"
110		);
111		debug_assert!(
112			self.queues.iter().all(Sender::is_closed),
113			"channel should be closed on drop"
114		);
115	}
116}
117
118#[implement(Pool)]
119#[tracing::instrument(skip_all)]
120pub(crate) fn close(&self) {
121	let workers = take(&mut *self.workers.lock().expect("locked"));
122
123	let senders = self
124		.queues
125		.iter()
126		.map(Sender::sender_count)
127		.sum::<usize>();
128
129	let receivers = self
130		.queues
131		.iter()
132		.map(Sender::receiver_count)
133		.sum::<usize>();
134
135	for queue in &self.queues {
136		queue.close();
137	}
138
139	if workers.is_empty() {
140		return;
141	}
142
143	debug!(
144		senders,
145		receivers,
146		queues = self.queues.len(),
147		workers = workers.len(),
148		"Closing pool. Waiting for workers to join..."
149	);
150
151	workers
152		.into_iter()
153		.map(JoinHandle::join)
154		.map(|result| result.map_err(Error::from_panic))
155		.enumerate()
156		.for_each(|(id, result)| match result {
157			| Ok(()) => trace!(?id, "worker joined"),
158			| Err(error) => error!(?id, "worker joined with error: {error}"),
159		});
160}
161
162#[implement(Pool)]
163fn spawn_group(self: &Arc<Self>, recv: &[Receiver<Cmd>], chan_id: usize, count: usize) -> Result {
164	let mut workers = self.workers.lock().expect("locked");
165	for _ in 0..count {
166		self.clone()
167			.spawn_one(&mut workers, recv, chan_id)?;
168	}
169
170	Ok(())
171}
172
173#[implement(Pool)]
174#[tracing::instrument(
175	name = "spawn",
176	level = "trace",
177	skip_all,
178	fields(id = %workers.len())
179)]
180fn spawn_one(
181	self: Arc<Self>,
182	workers: &mut Vec<JoinHandle<()>>,
183	recv: &[Receiver<Cmd>],
184	chan_id: usize,
185) -> Result {
186	debug_assert!(!self.queues.is_empty(), "Must have at least one queue");
187	debug_assert!(!recv.is_empty(), "Must have at least one receiver");
188
189	let id = workers.len();
190	let recv = recv[chan_id].clone();
191
192	let handle = thread::Builder::new()
193		.name(WORKER_NAME.into())
194		.stack_size(WORKER_STACK_SIZE)
195		.spawn(move || self.worker(id, chan_id, &recv))?;
196
197	workers.push(handle);
198
199	Ok(())
200}
201
202#[implement(Pool)]
203#[tracing::instrument(level = "trace", name = "get", skip(self, cmd))]
204pub(crate) async fn execute_get(self: &Arc<Self>, mut cmd: Get) -> Result<BatchResult<'_>> {
205	let (send, recv) = oneshot::channel();
206	_ = cmd.res.insert(send);
207
208	let queue = self.select_queue();
209	self.execute(queue, Cmd::Get(cmd))
210		.and_then(move |()| {
211			recv.map_ok(into_recv_get)
212				.map_err(|e| err!(error!("recv failed {e:?}")))
213		})
214		.await
215}
216
217#[implement(Pool)]
218#[tracing::instrument(level = "trace", name = "iter", skip(self, cmd))]
219pub(crate) async fn execute_iter(self: &Arc<Self>, mut cmd: Seek) -> Result<stream::State<'_>> {
220	let (send, recv) = oneshot::channel();
221	_ = cmd.res.insert(send);
222
223	let queue = self.select_queue();
224	self.execute(queue, Cmd::Iter(cmd))
225		.and_then(|()| {
226			recv.map_ok(into_recv_seek)
227				.map_err(|e| err!(error!("recv failed {e:?}")))
228		})
229		.await
230}
231
232#[implement(Pool)]
233fn select_queue(&self) -> &Sender<Cmd> {
234	let core_id = get_affinity()
235		.next()
236		.expect("Affinity mask should be available.");
237
238	let chan_id = self.topology[core_id];
239
240	self.queues
241		.get(chan_id)
242		.unwrap_or_else(|| &self.queues[0])
243}
244
245#[implement(Pool)]
246#[tracing::instrument(
247	level = "trace",
248	name = "execute",
249	skip(self, cmd),
250	fields(
251		task = ?tokio::task::try_id(),
252		receivers = queue.receiver_count(),
253		queued = queue.len(),
254		queued_max = self.queued_max.load(Ordering::Relaxed),
255	),
256)]
257async fn execute(&self, queue: &Sender<Cmd>, cmd: Cmd) -> Result {
258	if cfg!(debug_assertions) {
259		self.queued_max
260			.fetch_max(queue.len(), Ordering::Relaxed);
261	}
262
263	queue
264		.send(cmd)
265		.await
266		.map_err(|e| err!(error!("send failed {e:?}")))
267}
268
269#[implement(Pool)]
270#[tracing::instrument(
271	parent = None,
272	level = "debug",
273	skip_all,
274	fields(
275		id,
276		chan_id,
277		thread_id = ?thread::current().id(),
278	),
279)]
280fn worker(self: Arc<Self>, id: usize, chan_id: usize, recv: &Receiver<Cmd>) {
281	self.worker_init(id, chan_id);
282	self.worker_loop(recv);
283}
284
285#[implement(Pool)]
286fn worker_init(&self, id: usize, chan_id: usize) {
287	let affinity = self
288		.topology
289		.iter()
290		.enumerate()
291		.filter(|_| self.server.config.db_pool_affinity)
292		.filter_map(|(core_id, &queue_id)| (chan_id == queue_id).then_some(core_id));
293
294	// affinity is empty (no-op) if there's only one queue
295	set_affinity(affinity.clone());
296
297	trace!(
298		?id,
299		?chan_id,
300		affinity = ?affinity.collect::<Vec<_>>(),
301		"worker ready"
302	);
303}
304
305#[implement(Pool)]
306fn worker_loop(self: &Arc<Self>, recv: &Receiver<Cmd>) {
307	// initial +1 needed prior to entering wait
308	self.busy.fetch_add(1, Ordering::Relaxed);
309
310	while let Ok(cmd) = self.worker_wait(recv) {
311		worker_handle(cmd);
312	}
313}
314
315#[implement(Pool)]
316#[tracing::instrument(
317	name = "wait",
318	level = "trace",
319	skip_all,
320	fields(
321		receivers = recv.receiver_count(),
322		queued = recv.len(),
323		busy = self.busy.fetch_sub(1, Ordering::AcqRel) - 1,
324	),
325)]
326fn worker_wait(self: &Arc<Self>, recv: &Receiver<Cmd>) -> Result<Cmd, RecvError> {
327	recv.recv_blocking().debug_inspect(|_| {
328		self.busy.fetch_add(1, Ordering::Relaxed);
329	})
330}
331
332fn worker_handle(cmd: Cmd) {
333	match cmd {
334		| Cmd::Get(cmd) if cmd.key.len() == 1 => handle_get(cmd),
335		| Cmd::Get(cmd) => handle_batch(cmd),
336		| Cmd::Iter(cmd) => handle_iter(cmd),
337	}
338}
339
340#[tracing::instrument(
341	name = "iter",
342	level = "trace",
343	skip_all,
344	fields(%cmd.map),
345)]
346fn handle_iter(mut cmd: Seek) {
347	let chan = cmd.res.take().expect("missing result channel");
348
349	if chan.is_canceled() {
350		return;
351	}
352
353	let from = cmd.key.as_deref();
354
355	let result = match cmd.dir {
356		| Direction::Forward => cmd.state.init_fwd(from),
357		| Direction::Reverse => cmd.state.init_rev(from),
358	};
359
360	let chan_result = chan.send(into_send_seek(result));
361
362	let _chan_sent = chan_result.is_ok();
363}
364
365#[tracing::instrument(
366	name = "batch",
367	level = "trace",
368	skip_all,
369	fields(
370		%cmd.map,
371		keys = %cmd.key.len(),
372	),
373)]
374fn handle_batch(mut cmd: Get) {
375	debug_assert!(cmd.key.len() > 1, "should have more than one key");
376	debug_assert!(!cmd.key.iter().any(SmallVec::is_empty), "querying for empty key");
377
378	let chan = cmd.res.take().expect("missing result channel");
379
380	if chan.is_canceled() {
381		return;
382	}
383
384	let keys = cmd.key.iter();
385
386	let result: SmallVec<_> = cmd.map.get_batch_blocking(keys).collect();
387
388	let chan_result = chan.send(into_send_get(result));
389
390	let _chan_sent = chan_result.is_ok();
391}
392
393#[tracing::instrument(
394	name = "get",
395	level = "trace",
396	skip_all,
397	fields(%cmd.map),
398)]
399fn handle_get(mut cmd: Get) {
400	debug_assert!(!cmd.key[0].is_empty(), "querying for empty key");
401
402	// Obtain the result channel.
403	let chan = cmd.res.take().expect("missing result channel");
404
405	// It is worth checking if the future was dropped while the command was queued
406	// so we can bail without paying for any query.
407	if chan.is_canceled() {
408		return;
409	}
410
411	// Perform the actual database query. We reuse our database::Map interface but
412	// limited to the blocking calls, rather than creating another surface directly
413	// with rocksdb here.
414	let result = cmd.map.get_blocking(&cmd.key[0]);
415
416	// Send the result back to the submitter.
417	let chan_result = chan.send(into_send_get([result].into()));
418
419	// If the future was dropped during the query this will fail acceptably.
420	let _chan_sent = chan_result.is_ok();
421}
422
423fn into_send_get(result: BatchResult<'_>) -> BatchResult<'static> {
424	// SAFETY: Necessary to send the Handle (rust_rocksdb::PinnableSlice) through
425	// the channel. The lifetime on the handle is a device by rust-rocksdb to
426	// associate a database lifetime with its assets. The Handle must be dropped
427	// before the database is dropped.
428	unsafe { std::mem::transmute(result) }
429}
430
431fn into_recv_get<'a>(result: BatchResult<'static>) -> BatchResult<'a> {
432	// SAFETY: This is to receive the Handle from the channel.
433	unsafe { std::mem::transmute(result) }
434}
435
436pub(crate) fn into_send_seek(result: stream::State<'_>) -> stream::State<'static> {
437	// SAFETY: Necessary to send the State through the channel; see above.
438	unsafe { std::mem::transmute(result) }
439}
440
441fn into_recv_seek<'a>(result: stream::State<'static>) -> stream::State<'a> {
442	// SAFETY: This is to receive the State from the channel; see above.
443	unsafe { std::mem::transmute(result) }
444}