Skip to main content

tuwunel_router/
serve.rs

1mod plain;
2#[cfg(feature = "direct_tls")]
3mod tls;
4mod unix;
5
6use std::{
7	net::{SocketAddr, TcpListener},
8	os::unix::net::UnixListener,
9	path::Path,
10	sync::{Arc, atomic::Ordering},
11};
12
13use tokio::task::JoinSet;
14use tuwunel_core::{Err, Result, debug_info, info};
15use tuwunel_service::Services;
16
17use super::layers;
18use crate::handle::ServerHandle;
19
20/// Serve clients
21pub(super) async fn serve(services: Arc<Services>, handle: ServerHandle) -> Result {
22	let server = &services.server;
23	let config = &server.config;
24
25	let (app, _guard) = layers::build(&services)?;
26
27	let mut join_set = JoinSet::new();
28
29	let socket_path = &config.unix_socket_path;
30
31	let (passed_tcp_listeners, passed_unix_listeners) = systemd_listeners()?;
32
33	let addrs = config.get_bind_addrs();
34
35	let log_addrs = make_log_addrs(
36		&addrs,
37		socket_path.as_deref(),
38		&passed_tcp_listeners,
39		&passed_unix_listeners,
40	)?;
41
42	let mut futures = vec![];
43
44	#[cfg(unix)]
45	{
46		let socket_perms = config.get_unix_socket_perms()?;
47
48		let unix_futures = unix::serve(
49			&app,
50			&handle.handle_unix,
51			passed_unix_listeners.into_iter(),
52			socket_path.as_deref(),
53			socket_perms,
54		)
55		.await?;
56
57		futures.extend(unix_futures);
58	};
59
60	#[cfg_attr(
61		not(feature = "direct_tls"),
62		expect(clippy::redundant_else, unused_variables)
63	)]
64	if let Some((cert, key)) = config.tls.get_tls_cert_key() {
65		#[cfg(feature = "direct_tls")]
66		{
67			services.globals.init_rustls_provider()?;
68
69			let tls_futures = tls::serve(
70				&app,
71				&handle.handle_ip,
72				cert,
73				key,
74				config.tls.dual_protocol,
75				passed_tcp_listeners.into_iter(),
76				&addrs,
77			)
78			.await?;
79
80			futures.extend(tls_futures);
81		}
82
83		#[cfg(not(feature = "direct_tls"))]
84		return tuwunel_core::Err!(Config(
85			"tls",
86			"tuwunel was not built with direct TLS support (\"direct_tls\")"
87		));
88	} else {
89		let plain_futures =
90			plain::serve(&app, &handle.handle_ip, passed_tcp_listeners.into_iter(), &addrs)?;
91
92		futures.extend(plain_futures);
93	}
94
95	for future in futures {
96		join_set.spawn_on(future, server.runtime());
97	}
98
99	if join_set.is_empty() {
100		return Err!("at least one listener should be installed");
101	}
102
103	info!("Listening on {log_addrs:?}");
104
105	join_set.join_all().await;
106
107	let handle_active = server
108		.metrics
109		.requests_handle_active
110		.load(Ordering::Acquire);
111
112	debug_info!(
113		handle_finished = server
114			.metrics
115			.requests_handle_finished
116			.load(Ordering::Acquire),
117		panics = server
118			.metrics
119			.requests_panic
120			.load(Ordering::Acquire),
121		handle_active,
122		"Stopped listening on {addrs:?}",
123	);
124
125	debug_assert_eq!(0, handle_active, "active request handles still pending");
126
127	Ok(())
128}
129
130fn make_log_addrs(
131	tcp_addrs: &[SocketAddr],
132	unix_path: Option<&Path>,
133	tcp_listeners: &[TcpListener],
134	unix_listeners: &[UnixListener],
135) -> Result<Vec<String>> {
136	let tcp_log_addrs = tcp_addrs.iter().map(|addr| format!("tcp:{addr}"));
137
138	let unix_log_addr = unix_path.as_ref().map(|socket_path| {
139		let path = socket_path.to_string_lossy();
140		format!("unix:{path}")
141	});
142
143	let passed_tcp_log_addrs = tcp_listeners.iter().map(|listener| {
144		let addr = listener.local_addr()?;
145		Ok(format!("passed:tcp:{addr}"))
146	});
147
148	let passed_unix_log_addrs = unix_listeners.iter().map(|listener| {
149		let addr = listener.local_addr()?;
150		let path = addr.as_pathname();
151		let log_path = if let Some(path) = path {
152			&path.to_string_lossy()
153		} else {
154			"?"
155		};
156		Ok(format!("passed:unix:{log_path}"))
157	});
158
159	tcp_log_addrs
160		.map(Ok)
161		.chain(unix_log_addr.into_iter().map(Ok))
162		.chain(passed_tcp_log_addrs)
163		.chain(passed_unix_log_addrs)
164		.collect()
165}
166
167#[cfg(all(feature = "systemd", target_os = "linux"))]
168fn systemd_listeners() -> Result<(Vec<TcpListener>, Vec<UnixListener>)> {
169	use std::os::fd::FromRawFd;
170
171	use tuwunel_core::utils::sys::{SocketFamily, get_socket_family};
172
173	let mut tcp = vec![];
174	let mut unix = vec![];
175
176	for fd in sd_notify::listen_fds()? {
177		debug_assert!(fd >= 3, "fdno probably not a listener socket");
178
179		let family = get_socket_family(fd)?;
180
181		match family {
182			| SocketFamily::Inet => {
183				// SAFETY: systemd should already take care of providing
184				// the correct TCP socket, so we just use it via raw fd
185				let listener = unsafe { TcpListener::from_raw_fd(fd) };
186
187				listener.set_nonblocking(true)?;
188
189				tcp.push(listener);
190			},
191			| SocketFamily::Unix => {
192				// SAFETY: systemd should already take care of providing
193				// the correct UNIX socket, so we just use it via raw fd
194				let listener = unsafe { UnixListener::from_raw_fd(fd) };
195
196				listener.set_nonblocking(true)?;
197
198				unix.push(listener);
199			},
200		}
201	}
202
203	Ok((tcp, unix))
204}
205
206#[cfg(any(not(feature = "systemd"), not(target_os = "linux")))]
207fn systemd_listeners() -> Result<(Vec<TcpListener>, Vec<UnixListener>)> { Ok((vec![], vec![])) }