Skip to main content

tuwunel_service/
manager.rs

1use std::{
2	panic::AssertUnwindSafe,
3	sync::{
4		Arc,
5		atomic::{AtomicUsize, Ordering},
6	},
7	time::Duration,
8};
9
10use futures::{FutureExt, TryFutureExt};
11use tokio::{
12	sync::{Mutex, MutexGuard},
13	task::{JoinHandle, JoinSet, yield_now},
14	time::sleep,
15};
16use tuwunel_core::{
17	Err, Error, Result, Server, debug, debug::INFO_SPAN_LEVEL, debug_warn, defer, error,
18	utils::time, warn,
19};
20
21use crate::{Services, service::Service};
22
23pub(crate) struct Manager {
24	manager: Mutex<Option<JoinHandle<Result>>>,
25	workers: Mutex<Workers>,
26	active: AtomicUsize,
27	server: Arc<Server>,
28	services: Arc<Services>,
29}
30
31type Workers = JoinSet<WorkerResult>;
32type WorkerResult = (Arc<dyn Service>, Result);
33type WorkersLocked<'a> = MutexGuard<'a, Workers>;
34
35const RESTART_DELAY_MS: u64 = 2500;
36
37impl Manager {
38	pub(super) fn new(services: &Arc<Services>) -> Arc<Self> {
39		Arc::new(Self {
40			manager: Mutex::new(None),
41			workers: Mutex::new(JoinSet::new()),
42			active: 0.into(),
43			server: services.server.clone(),
44			services: services.clone(),
45		})
46	}
47
48	pub(super) async fn poll(&self) -> Result {
49		if let Some(manager) = &mut *self.manager.lock().await {
50			debug!("Polling service manager...");
51			return manager.await?;
52		}
53
54		Ok(())
55	}
56
57	#[tracing::instrument(
58		name = "manager",
59		level = INFO_SPAN_LEVEL,
60		skip_all,
61		fields(
62			active = self.active.load(Ordering::Relaxed),
63		),
64	)]
65	pub(super) async fn stop(&self) {
66		if let Some(manager) = self.manager.lock().await.take() {
67			debug!("Waiting for service manager...");
68			if let Err(e) = manager.await {
69				error!("Manager shutdown error: {e:?}");
70			}
71		}
72	}
73
74	#[tracing::instrument(name = "manager", level = "trace", skip_all)]
75	pub(super) async fn start(self: Arc<Self>) -> Result {
76		let mut workers = self.workers.lock().await;
77
78		debug!("Starting service manager...");
79		let self_ = self.clone();
80		_ = self.manager.lock().await.insert(
81			self.server
82				.runtime()
83				.spawn(async move { self_.worker().await }),
84		);
85
86		debug!("Starting service workers...");
87		for service in self.services.services() {
88			self.start_worker(&mut workers, &service)?;
89		}
90
91		yield_now().await;
92		debug_assert!(
93			self.manager.lock().await.is_some(),
94			"Service manager's task must have been installed."
95		);
96
97		debug!(
98			workers = workers.len(),
99			active = self.active.load(Ordering::Relaxed),
100			"Spawned service workers...",
101		);
102
103		Ok(())
104	}
105
106	#[tracing::instrument(
107		name = "manager",
108		level = INFO_SPAN_LEVEL,
109		skip_all,
110		ret,
111		err,
112	)]
113	async fn worker(self: &Arc<Self>) -> Result {
114		loop {
115			let mut workers = self.workers.lock().await;
116			tokio::select! {
117				result = workers.join_next() => match result {
118					Some(Ok(result)) => self.handle_result(&mut workers, result).await?,
119					Some(Err(error)) => self.handle_abort(&mut workers, &Error::from(error))?,
120					None => break,
121				}
122			}
123		}
124
125		debug!("Worker manager finished");
126		Ok(())
127	}
128
129	#[allow(clippy::unused_self)]
130	fn handle_abort(&self, _workers: &mut WorkersLocked<'_>, error: &Error) -> Result {
131		// not supported until service can be associated with abort
132		unimplemented!("unexpected worker task abort {error:?}");
133	}
134
135	async fn handle_result(
136		self: &Arc<Self>,
137		workers: &mut WorkersLocked<'_>,
138		result: WorkerResult,
139	) -> Result {
140		let (service, result) = result;
141		match result {
142			| Ok(()) => self.handle_finished(workers, &service),
143			| Err(error) => self.handle_error(workers, &service, error).await,
144		}
145	}
146
147	#[tracing::instrument(
148		name = "finished",
149		level = "trace",
150		skip_all,
151		fields(
152			service = ?service.name(),
153			active = self.active.load(Ordering::Acquire),
154		),
155	)]
156	fn handle_finished(
157		self: &Arc<Self>,
158		_workers: &mut WorkersLocked<'_>,
159		service: &Arc<dyn Service>,
160	) -> Result {
161		debug!(name = service.name(), "Service worker finished...");
162
163		Ok(())
164	}
165
166	#[tracing::instrument(
167		name = "error",
168		level = "error",
169		skip_all,
170		fields(
171			service = ?service.name(),
172			active = self.active.load(Ordering::Acquire),
173		),
174	)]
175	async fn handle_error(
176		self: &Arc<Self>,
177		workers: &mut WorkersLocked<'_>,
178		service: &Arc<dyn Service>,
179		error: Error,
180	) -> Result {
181		let name = service.name();
182		error!("service {name:?} aborted: {error}");
183
184		if !self.server.is_running() {
185			debug_warn!("service {name:?} error ignored on shutdown.");
186			return Ok(());
187		}
188
189		if !error.is_panic() {
190			return Err(error);
191		}
192
193		let delay = Duration::from_millis(RESTART_DELAY_MS);
194		warn!(
195			delay = ?time::pretty(delay),
196			"service {name:?} worker restarting after delay..."
197		);
198
199		sleep(delay).await;
200		self.start_worker(workers, service)
201	}
202
203	/// Start the worker in a task for the service.
204	fn start_worker(
205		self: &Arc<Self>,
206		workers: &mut WorkersLocked<'_>,
207		service: &Arc<dyn Service>,
208	) -> Result {
209		if !self.server.is_running() {
210			return Err!(
211				"Service {:?} worker not starting during server shutdown.",
212				service.name()
213			);
214		}
215
216		debug!(name = service.name(), "Service worker starting...");
217		workers.spawn_on(worker(service.clone(), self.clone()), self.server.runtime());
218
219		Ok(())
220	}
221}
222
223/// Base frame for service worker. This runs in a tokio::task. All errors and
224/// panics from the worker are caught and returned cleanly. The JoinHandle
225/// should never error with a panic, and if so it should propagate, but it may
226/// error with an Abort which the manager should handle along with results to
227/// determine if the worker should be restarted.
228#[tracing::instrument(
229	parent = None,
230	level = "trace",
231	skip_all,
232	fields(
233		service = ?service.name(),
234		active = mgr.active.load(Ordering::Relaxed),
235	),
236)]
237async fn worker(service: Arc<dyn Service>, mgr: Arc<Manager>) -> WorkerResult {
238	mgr.active.fetch_add(1, Ordering::Relaxed);
239	defer! {{
240		mgr.active.fetch_sub(1, Ordering::Release);
241	}};
242
243	let service_ = Arc::clone(&service);
244	let result = AssertUnwindSafe(service_.worker())
245		.catch_unwind()
246		.map_err(Error::from_panic);
247
248	let result = if service.unconstrained() {
249		tokio::task::unconstrained(result).await
250	} else {
251		result.await
252	};
253
254	// flattens JoinError for panic into worker's Error
255	(service, result.unwrap_or_else(Err))
256}