Skip to main content

tuwunel_router/serve/
tls.rs

1use std::{
2	net::{SocketAddr, TcpListener},
3	path::Path,
4};
5
6use axum::{Router, extract::connect_info::IntoMakeServiceWithConnectInfo};
7use axum_server::Handle;
8use axum_server_dual_protocol::{ServerExt, axum_server::tls_rustls::RustlsConfig};
9use futures::{FutureExt, future::BoxFuture};
10use tuwunel_core::{Result, debug, err, info, itertools::Itertools};
11
12pub(super) async fn serve<'a>(
13	app: &Router,
14	handle: &Handle<SocketAddr>,
15	cert: &Path,
16	key: &Path,
17	dual_protocol: bool,
18	listeners: impl Iterator<Item = TcpListener>,
19	addrs: &[SocketAddr],
20) -> Result<Vec<BoxFuture<'a, Result<(), std::io::Error>>>> {
21	info!(
22		"Note: It is strongly recommended that you use a reverse proxy instead of running \
23		 tuwunel directly with TLS."
24	);
25
26	debug!(
27		"Using direct TLS. Certificate path {cert:?} and certificate private key path {key:?}"
28	);
29
30	let conf = RustlsConfig::from_pem_file(cert, key)
31		.await
32		.map_err(|e| err!(Config("tls", "Failed to load certificates or key: {e}")))?;
33
34	let app = app
35		.clone()
36		.into_make_service_with_connect_info::<SocketAddr>();
37
38	if dual_protocol {
39		serve_dual_protocol(&app, &conf, handle, listeners, addrs)
40	} else {
41		serve_tls(&app, &conf, handle, listeners, addrs)
42	}
43}
44
45fn serve_dual_protocol<'a>(
46	app: &IntoMakeServiceWithConnectInfo<Router, SocketAddr>,
47	conf: &RustlsConfig,
48	handle: &Handle<SocketAddr>,
49	listeners: impl Iterator<Item = TcpListener>,
50	addrs: &[SocketAddr],
51) -> Result<Vec<BoxFuture<'a, Result<(), std::io::Error>>>> {
52	let bound_servers = addrs.iter().map(|addr| -> Result<_> {
53		Ok(axum_server_dual_protocol::bind_dual_protocol(*addr, conf.clone()))
54	});
55
56	let passed_servers = listeners.map(|listener| -> Result<_> {
57		Ok(axum_server_dual_protocol::from_tcp_dual_protocol(
58			listener.try_clone()?,
59			conf.clone(),
60		)?
61		.set_upgrade(false))
62	});
63
64	bound_servers
65		.chain(passed_servers)
66		.map_ok(|server| {
67			server
68				.handle(handle.clone())
69				.serve(app.clone())
70				.boxed()
71		})
72		.collect()
73}
74
75fn serve_tls<'a>(
76	app: &IntoMakeServiceWithConnectInfo<Router, SocketAddr>,
77	conf: &RustlsConfig,
78	handle: &Handle<SocketAddr>,
79	listeners: impl Iterator<Item = TcpListener>,
80	addrs: &[SocketAddr],
81) -> Result<Vec<BoxFuture<'a, Result<(), std::io::Error>>>> {
82	let bound_servers = addrs
83		.iter()
84		.map(|addr| -> Result<_> { Ok(axum_server::bind_rustls(*addr, conf.clone())) });
85
86	let passed_servers = listeners.map(|listener| -> Result<_> {
87		Ok(axum_server::from_tcp_rustls(listener.try_clone()?, conf.clone())?)
88	});
89
90	bound_servers
91		.chain(passed_servers)
92		.map_ok(|server| {
93			server
94				.handle(handle.clone())
95				.serve(app.clone())
96				.boxed()
97		})
98		.collect()
99}