1use std::{
2 panic::AssertUnwindSafe,
3 sync::{
4 Arc,
5 atomic::{AtomicUsize, Ordering},
6 },
7 time::Duration,
8};
9
10use futures::{FutureExt, TryFutureExt};
11use tokio::{
12 sync::{Mutex, MutexGuard},
13 task::{JoinHandle, JoinSet, yield_now},
14 time::sleep,
15};
16use tuwunel_core::{
17 Err, Error, Result, Server, debug, debug::INFO_SPAN_LEVEL, debug_warn, defer, error,
18 utils::time, warn,
19};
20
21use crate::{Services, service::Service};
22
23pub(crate) struct Manager {
24 manager: Mutex<Option<JoinHandle<Result>>>,
25 workers: Mutex<Workers>,
26 active: AtomicUsize,
27 server: Arc<Server>,
28 services: Arc<Services>,
29}
30
31type Workers = JoinSet<WorkerResult>;
32type WorkerResult = (Arc<dyn Service>, Result);
33type WorkersLocked<'a> = MutexGuard<'a, Workers>;
34
35const RESTART_DELAY_MS: u64 = 2500;
36
37impl Manager {
38 pub(super) fn new(services: &Arc<Services>) -> Arc<Self> {
39 Arc::new(Self {
40 manager: Mutex::new(None),
41 workers: Mutex::new(JoinSet::new()),
42 active: 0.into(),
43 server: services.server.clone(),
44 services: services.clone(),
45 })
46 }
47
48 pub(super) async fn poll(&self) -> Result {
49 if let Some(manager) = &mut *self.manager.lock().await {
50 debug!("Polling service manager...");
51 return manager.await?;
52 }
53
54 Ok(())
55 }
56
57 #[tracing::instrument(
58 name = "manager",
59 level = INFO_SPAN_LEVEL,
60 skip_all,
61 fields(
62 active = self.active.load(Ordering::Relaxed),
63 ),
64 )]
65 pub(super) async fn stop(&self) {
66 if let Some(manager) = self.manager.lock().await.take() {
67 debug!("Waiting for service manager...");
68 if let Err(e) = manager.await {
69 error!("Manager shutdown error: {e:?}");
70 }
71 }
72 }
73
74 #[tracing::instrument(name = "manager", level = "trace", skip_all)]
75 pub(super) async fn start(self: Arc<Self>) -> Result {
76 let mut workers = self.workers.lock().await;
77
78 debug!("Starting service manager...");
79 let self_ = self.clone();
80 _ = self.manager.lock().await.insert(
81 self.server
82 .runtime()
83 .spawn(async move { self_.worker().await }),
84 );
85
86 debug!("Starting service workers...");
87 for service in self.services.services() {
88 self.start_worker(&mut workers, &service)?;
89 }
90
91 yield_now().await;
92 debug_assert!(
93 self.manager.lock().await.is_some(),
94 "Service manager's task must have been installed."
95 );
96
97 debug!(
98 workers = workers.len(),
99 active = self.active.load(Ordering::Relaxed),
100 "Spawned service workers...",
101 );
102
103 Ok(())
104 }
105
106 #[tracing::instrument(
107 name = "manager",
108 level = INFO_SPAN_LEVEL,
109 skip_all,
110 ret,
111 err,
112 )]
113 async fn worker(self: &Arc<Self>) -> Result {
114 loop {
115 let mut workers = self.workers.lock().await;
116 tokio::select! {
117 result = workers.join_next() => match result {
118 Some(Ok(result)) => self.handle_result(&mut workers, result).await?,
119 Some(Err(error)) => self.handle_abort(&mut workers, &Error::from(error))?,
120 None => break,
121 }
122 }
123 }
124
125 debug!("Worker manager finished");
126 Ok(())
127 }
128
129 #[allow(clippy::unused_self)]
130 fn handle_abort(&self, _workers: &mut WorkersLocked<'_>, error: &Error) -> Result {
131 unimplemented!("unexpected worker task abort {error:?}");
133 }
134
135 async fn handle_result(
136 self: &Arc<Self>,
137 workers: &mut WorkersLocked<'_>,
138 result: WorkerResult,
139 ) -> Result {
140 let (service, result) = result;
141 match result {
142 | Ok(()) => self.handle_finished(workers, &service),
143 | Err(error) => self.handle_error(workers, &service, error).await,
144 }
145 }
146
147 #[tracing::instrument(
148 name = "finished",
149 level = "trace",
150 skip_all,
151 fields(
152 service = ?service.name(),
153 active = self.active.load(Ordering::Acquire),
154 ),
155 )]
156 fn handle_finished(
157 self: &Arc<Self>,
158 _workers: &mut WorkersLocked<'_>,
159 service: &Arc<dyn Service>,
160 ) -> Result {
161 debug!(name = service.name(), "Service worker finished...");
162
163 Ok(())
164 }
165
166 #[tracing::instrument(
167 name = "error",
168 level = "error",
169 skip_all,
170 fields(
171 service = ?service.name(),
172 active = self.active.load(Ordering::Acquire),
173 ),
174 )]
175 async fn handle_error(
176 self: &Arc<Self>,
177 workers: &mut WorkersLocked<'_>,
178 service: &Arc<dyn Service>,
179 error: Error,
180 ) -> Result {
181 let name = service.name();
182 error!("service {name:?} aborted: {error}");
183
184 if !self.server.is_running() {
185 debug_warn!("service {name:?} error ignored on shutdown.");
186 return Ok(());
187 }
188
189 if !error.is_panic() {
190 return Err(error);
191 }
192
193 let delay = Duration::from_millis(RESTART_DELAY_MS);
194 warn!(
195 delay = ?time::pretty(delay),
196 "service {name:?} worker restarting after delay..."
197 );
198
199 sleep(delay).await;
200 self.start_worker(workers, service)
201 }
202
203 fn start_worker(
205 self: &Arc<Self>,
206 workers: &mut WorkersLocked<'_>,
207 service: &Arc<dyn Service>,
208 ) -> Result {
209 if !self.server.is_running() {
210 return Err!(
211 "Service {:?} worker not starting during server shutdown.",
212 service.name()
213 );
214 }
215
216 debug!(name = service.name(), "Service worker starting...");
217 workers.spawn_on(worker(service.clone(), self.clone()), self.server.runtime());
218
219 Ok(())
220 }
221}
222
223#[tracing::instrument(
229 parent = None,
230 level = "trace",
231 skip_all,
232 fields(
233 service = ?service.name(),
234 active = mgr.active.load(Ordering::Relaxed),
235 ),
236)]
237async fn worker(service: Arc<dyn Service>, mgr: Arc<Manager>) -> WorkerResult {
238 mgr.active.fetch_add(1, Ordering::Relaxed);
239 defer! {{
240 mgr.active.fetch_sub(1, Ordering::Release);
241 }};
242
243 let service_ = Arc::clone(&service);
244 let result = AssertUnwindSafe(service_.worker())
245 .catch_unwind()
246 .map_err(Error::from_panic);
247
248 let result = if service.unconstrained() {
249 tokio::task::unconstrained(result).await
250 } else {
251 result.await
252 };
253
254 (service, result.unwrap_or_else(Err))
256}