1#[cfg(test)]
2mod tests;
3
4use std::{any::Any, sync::Arc, time::Duration};
5
6use axum::{
7 Extension, Router,
8 extract::{DefaultBodyLimit, MatchedPath},
9};
10use http::{
11 HeaderValue, Method, StatusCode,
12 header::{self, HeaderName},
13 uri::PathAndQuery,
14};
15use ipnet::IpNet;
16use tower::{
17 ServiceBuilder,
18 layer::util::Identity,
19 util::{Either, option_layer},
20};
21use tower_http::{
22 catch_panic::CatchPanicLayer,
23 cors::{AllowOrigin, CorsLayer},
24 sensitive_headers::SetSensitiveHeadersLayer,
25 set_header::SetResponseHeaderLayer,
26 timeout::{RequestBodyTimeoutLayer, ResponseBodyTimeoutLayer, TimeoutLayer},
27 trace::{DefaultOnFailure, DefaultOnRequest, DefaultOnResponse, TraceLayer},
28};
29use tracing::Level;
30use tuwunel_api::router::{ConfiguredIpSource, TrustedPeerSubnets, state::Guard};
31use tuwunel_core::{Result, Server, config::IpSource, debug, error};
32use tuwunel_service::Services;
33
34use crate::{request, router};
35
36const TUWUNEL_CSP: &[&str; 5] = &[
37 "default-src 'none'",
38 "frame-ancestors 'none'",
39 "form-action 'none'",
40 "base-uri 'none'",
41 "sandbox",
42];
43
44const TUWUNEL_HTML_CSP: &[&str; 7] = &[
45 "default-src 'none'",
46 "script-src 'unsafe-inline'",
47 "style-src 'unsafe-inline'",
48 "frame-ancestors 'none'",
49 "form-action 'none'",
50 "base-uri 'none'",
51 "sandbox",
52];
53
54const TUWUNEL_PERMISSIONS_POLICY: &[&str; 2] = &["interest-cohort=()", "browsing-topics=()"];
55
56pub(crate) fn build(services: &Arc<Services>) -> Result<(Router, Guard)> {
57 let server = &services.server;
58 let layers = ServiceBuilder::new();
59
60 #[cfg(feature = "sentry_telemetry")]
61 let layers = layers.layer(sentry_tower::NewSentryLayer::<http::Request<_>>::new_from_top());
62
63 #[cfg(any(
64 feature = "zstd_compression",
65 feature = "gzip_compression",
66 feature = "brotli_compression"
67 ))]
68 let layers = layers.layer(compression_layer(server));
69
70 let services_ = services.clone();
71 let layers = layers
72 .layer(SetSensitiveHeadersLayer::new([header::AUTHORIZATION]))
73 .layer(
74 TraceLayer::new_for_http()
75 .make_span_with(tracing_span::<_>)
76 .on_failure(DefaultOnFailure::new().level(Level::ERROR))
77 .on_request(DefaultOnRequest::new().level(Level::TRACE))
78 .on_response(DefaultOnResponse::new().level(Level::DEBUG)),
79 )
80 .layer(axum::middleware::from_fn_with_state(Arc::clone(services), request::handle))
81 .layer(trusted_peer_subnets_layer(&server.config.ip_source_trusted_subnets))
82 .layer(ip_source_layer(server.config.ip_source))
83 .layer(ResponseBodyTimeoutLayer::new(Duration::from_secs(
84 server.config.client_response_timeout,
85 )))
86 .layer(RequestBodyTimeoutLayer::new(Duration::from_secs(
87 server.config.client_receive_timeout,
88 )))
89 .layer(TimeoutLayer::with_status_code(
90 StatusCode::REQUEST_TIMEOUT,
91 Duration::from_secs(server.config.client_request_timeout),
92 ))
93 .layer(SetResponseHeaderLayer::if_not_present(
94 HeaderName::from_static("origin-agent-cluster"),
96 HeaderValue::from_static("?1"),
97 ))
98 .layer(SetResponseHeaderLayer::if_not_present(
99 header::X_CONTENT_TYPE_OPTIONS,
100 HeaderValue::from_static("nosniff"),
101 ))
102 .layer(SetResponseHeaderLayer::if_not_present(
103 header::X_XSS_PROTECTION,
104 HeaderValue::from_static("0"),
105 ))
106 .layer(SetResponseHeaderLayer::if_not_present(
107 header::X_FRAME_OPTIONS,
108 HeaderValue::from_static("DENY"),
109 ))
110 .layer(SetResponseHeaderLayer::if_not_present(
111 HeaderName::from_static("permissions-policy"),
112 HeaderValue::from_str(&TUWUNEL_PERMISSIONS_POLICY.join(","))?,
113 ))
114 .layer(SetResponseHeaderLayer::if_not_present(
115 header::CONTENT_SECURITY_POLICY,
116 |res: &http::Response<_>| {
117 let csp = res
118 .headers()
119 .get(header::CONTENT_TYPE)
120 .map(HeaderValue::to_str)
121 .and_then(Result::ok)
122 .is_some_and(|val| val.contains("text/html"))
123 .then(|| TUWUNEL_HTML_CSP.join(";"))
124 .unwrap_or_else(|| TUWUNEL_CSP.join(";"));
125
126 HeaderValue::from_str(&csp).ok()
127 },
128 ))
129 .layer(cors_layer(server))
130 .layer(body_limit_layer(server))
131 .layer(CatchPanicLayer::custom(move |panic| catch_panic(panic, services_.clone())));
132
133 let (router, guard) = router::build(services);
134 Ok((router.layer(layers), guard))
135}
136
137#[cfg(any(
138 feature = "zstd_compression",
139 feature = "gzip_compression",
140 feature = "brotli_compression"
141))]
142fn compression_layer(server: &Server) -> tower_http::compression::CompressionLayer {
143 let mut compression_layer = tower_http::compression::CompressionLayer::new();
144
145 #[cfg(feature = "zstd_compression")]
146 {
147 compression_layer = if server.config.zstd_compression {
148 compression_layer.zstd(true)
149 } else {
150 compression_layer.no_zstd()
151 };
152 };
153
154 #[cfg(feature = "gzip_compression")]
155 {
156 compression_layer = if server.config.gzip_compression {
157 compression_layer.gzip(true)
158 } else {
159 compression_layer.no_gzip()
160 };
161 };
162
163 #[cfg(feature = "brotli_compression")]
164 {
165 compression_layer = if server.config.brotli_compression {
166 compression_layer.br(true)
167 } else {
168 compression_layer.no_br()
169 };
170 };
171
172 compression_layer
173}
174
175fn cors_layer(server: &Server) -> CorsLayer {
176 const METHODS: [Method; 7] = [
177 Method::DELETE,
178 Method::GET,
179 Method::HEAD,
180 Method::OPTIONS,
181 Method::PATCH,
182 Method::POST,
183 Method::PUT,
184 ];
185
186 let headers: [HeaderName; 5] = [
187 header::ACCEPT,
188 header::AUTHORIZATION,
189 header::CONTENT_TYPE,
190 header::ORIGIN,
191 HeaderName::from_lowercase(b"x-requested-with")
192 .expect("valid HTTP HeaderName from lowercase."),
193 ];
194
195 let allow_origin_list = server
196 .config
197 .access_control_allow_origin
198 .iter()
199 .map(AsRef::as_ref)
200 .map(HeaderValue::from_str)
201 .filter_map(Result::ok);
202
203 let allow_origin = if !server
204 .config
205 .access_control_allow_origin
206 .is_empty()
207 {
208 AllowOrigin::list(allow_origin_list)
209 } else {
210 AllowOrigin::any()
211 };
212
213 CorsLayer::new()
214 .max_age(Duration::from_hours(24))
215 .allow_methods(METHODS)
216 .allow_headers(headers)
217 .allow_origin(allow_origin)
218}
219
220fn body_limit_layer(server: &Server) -> DefaultBodyLimit {
221 DefaultBodyLimit::max(server.config.max_request_size)
222}
223
224fn trusted_peer_subnets_layer(
225 subnets: &[IpNet],
226) -> Either<Extension<TrustedPeerSubnets>, Identity> {
227 option_layer((!subnets.is_empty()).then(|| Extension(TrustedPeerSubnets(Arc::from(subnets)))))
228}
229
230fn ip_source_layer(source: Option<IpSource>) -> Either<Extension<ConfiguredIpSource>, Identity> {
231 option_layer(source.map(|source| Extension(ConfiguredIpSource(source))))
232}
233
234#[tracing::instrument(name = "panic", level = "error", skip_all)]
235#[expect(clippy::needless_pass_by_value)]
236fn catch_panic(
237 err: Box<dyn Any + Send + 'static>,
238 services: Arc<Services>,
239) -> http::Response<http_body_util::Full<bytes::Bytes>> {
240 services
241 .server
242 .metrics
243 .requests_panic
244 .fetch_add(1, std::sync::atomic::Ordering::Release);
245
246 let details = match err.downcast_ref::<String>() {
247 | Some(s) => s.clone(),
248 | _ => match err.downcast_ref::<&str>() {
249 | Some(s) => (*s).to_owned(),
250 | _ => "Unknown internal server error occurred.".to_owned(),
251 },
252 };
253
254 error!("{details:#}");
255 let body = serde_json::json!({
256 "errcode": "M_UNKNOWN",
257 "error": "M_UNKNOWN: Internal server error occurred",
258 "details": details,
259 });
260
261 http::Response::builder()
262 .status(StatusCode::INTERNAL_SERVER_ERROR)
263 .header(header::CONTENT_TYPE, "application/json")
264 .body(http_body_util::Full::from(body.to_string()))
265 .expect("Failed to create response for our panic catcher?")
266}
267
268fn tracing_span<T>(request: &http::Request<T>) -> tracing::Span {
269 let path = request
270 .extensions()
271 .get::<MatchedPath>()
272 .map_or_else(|| request_path_str(request), truncated_matched_path);
273
274 tracing::span! {
275 parent: None,
276 debug::INFO_SPAN_LEVEL,
277 "router",
278 method = %request.method(),
279 %path,
280 }
281}
282
283fn request_path_str<T>(request: &http::Request<T>) -> &str {
284 request
285 .uri()
286 .path_and_query()
287 .map(PathAndQuery::as_str)
288 .unwrap_or("/")
289}
290
291fn truncated_matched_path(path: &MatchedPath) -> &str {
292 path.as_str()
293 .rsplit_once('{')
294 .map_or(path.as_str(), |path| path.0.strip_suffix('/').unwrap_or(path.0))
295}