tuwunel_router/serve/
tls.rs1use std::{
2 net::{SocketAddr, TcpListener},
3 path::Path,
4};
5
6use axum::{Router, extract::connect_info::IntoMakeServiceWithConnectInfo};
7use axum_server::Handle;
8use axum_server_dual_protocol::{ServerExt, axum_server::tls_rustls::RustlsConfig};
9use futures::{FutureExt, future::BoxFuture};
10use tuwunel_core::{Result, debug, err, info, itertools::Itertools};
11
12pub(super) async fn serve<'a>(
13 app: &Router,
14 handle: &Handle<SocketAddr>,
15 cert: &Path,
16 key: &Path,
17 dual_protocol: bool,
18 listeners: impl Iterator<Item = TcpListener>,
19 addrs: &[SocketAddr],
20) -> Result<Vec<BoxFuture<'a, Result<(), std::io::Error>>>> {
21 info!(
22 "Note: It is strongly recommended that you use a reverse proxy instead of running \
23 tuwunel directly with TLS."
24 );
25
26 debug!(
27 "Using direct TLS. Certificate path {cert:?} and certificate private key path {key:?}"
28 );
29
30 let conf = RustlsConfig::from_pem_file(cert, key)
31 .await
32 .map_err(|e| err!(Config("tls", "Failed to load certificates or key: {e}")))?;
33
34 let app = app
35 .clone()
36 .into_make_service_with_connect_info::<SocketAddr>();
37
38 if dual_protocol {
39 serve_dual_protocol(&app, &conf, handle, listeners, addrs)
40 } else {
41 serve_tls(&app, &conf, handle, listeners, addrs)
42 }
43}
44
45fn serve_dual_protocol<'a>(
46 app: &IntoMakeServiceWithConnectInfo<Router, SocketAddr>,
47 conf: &RustlsConfig,
48 handle: &Handle<SocketAddr>,
49 listeners: impl Iterator<Item = TcpListener>,
50 addrs: &[SocketAddr],
51) -> Result<Vec<BoxFuture<'a, Result<(), std::io::Error>>>> {
52 let bound_servers = addrs.iter().map(|addr| -> Result<_> {
53 Ok(axum_server_dual_protocol::bind_dual_protocol(*addr, conf.clone()))
54 });
55
56 let passed_servers = listeners.map(|listener| -> Result<_> {
57 Ok(axum_server_dual_protocol::from_tcp_dual_protocol(
58 listener.try_clone()?,
59 conf.clone(),
60 )?
61 .set_upgrade(false))
62 });
63
64 bound_servers
65 .chain(passed_servers)
66 .map_ok(|server| {
67 server
68 .handle(handle.clone())
69 .serve(app.clone())
70 .boxed()
71 })
72 .collect()
73}
74
75fn serve_tls<'a>(
76 app: &IntoMakeServiceWithConnectInfo<Router, SocketAddr>,
77 conf: &RustlsConfig,
78 handle: &Handle<SocketAddr>,
79 listeners: impl Iterator<Item = TcpListener>,
80 addrs: &[SocketAddr],
81) -> Result<Vec<BoxFuture<'a, Result<(), std::io::Error>>>> {
82 let bound_servers = addrs
83 .iter()
84 .map(|addr| -> Result<_> { Ok(axum_server::bind_rustls(*addr, conf.clone())) });
85
86 let passed_servers = listeners.map(|listener| -> Result<_> {
87 Ok(axum_server::from_tcp_rustls(listener.try_clone()?, conf.clone())?)
88 });
89
90 bound_servers
91 .chain(passed_servers)
92 .map_ok(|server| {
93 server
94 .handle(handle.clone())
95 .serve(app.clone())
96 .boxed()
97 })
98 .collect()
99}