1#[cfg(test)]
2mod tests;
3
4use std::{any::Any, sync::Arc, time::Duration};
5
6use axum::{
7 Extension, Router,
8 extract::{DefaultBodyLimit, MatchedPath},
9};
10use axum_client_ip::SecureClientIpSource;
11use http::{
12 HeaderValue, Method, StatusCode,
13 header::{self, HeaderName},
14 uri::PathAndQuery,
15};
16use tower::{
17 ServiceBuilder,
18 layer::util::Identity,
19 util::{Either, option_layer},
20};
21use tower_http::{
22 catch_panic::CatchPanicLayer,
23 cors::{AllowOrigin, CorsLayer},
24 sensitive_headers::SetSensitiveHeadersLayer,
25 set_header::SetResponseHeaderLayer,
26 timeout::{RequestBodyTimeoutLayer, ResponseBodyTimeoutLayer, TimeoutLayer},
27 trace::{DefaultOnFailure, DefaultOnRequest, DefaultOnResponse, TraceLayer},
28};
29use tracing::Level;
30use tuwunel_api::router::{ConfiguredIpSource, state::Guard};
31use tuwunel_core::{Result, Server, config::IpSource, debug, error};
32use tuwunel_service::Services;
33
34use crate::{request, router};
35
36const TUWUNEL_CSP: &[&str; 5] = &[
37 "default-src 'none'",
38 "frame-ancestors 'none'",
39 "form-action 'none'",
40 "base-uri 'none'",
41 "sandbox",
42];
43
44const TUWUNEL_HTML_CSP: &[&str; 7] = &[
45 "default-src 'none'",
46 "script-src 'unsafe-inline'",
47 "style-src 'unsafe-inline'",
48 "frame-ancestors 'none'",
49 "form-action 'none'",
50 "base-uri 'none'",
51 "sandbox",
52];
53
54const TUWUNEL_PERMISSIONS_POLICY: &[&str; 2] = &["interest-cohort=()", "browsing-topics=()"];
55
56pub(crate) fn build(services: &Arc<Services>) -> Result<(Router, Guard)> {
57 let server = &services.server;
58 let layers = ServiceBuilder::new();
59
60 #[cfg(feature = "sentry_telemetry")]
61 let layers = layers.layer(sentry_tower::NewSentryLayer::<http::Request<_>>::new_from_top());
62
63 #[cfg(any(
64 feature = "zstd_compression",
65 feature = "gzip_compression",
66 feature = "brotli_compression"
67 ))]
68 let layers = layers.layer(compression_layer(server));
69
70 let services_ = services.clone();
71 let layers = layers
72 .layer(SetSensitiveHeadersLayer::new([header::AUTHORIZATION]))
73 .layer(
74 TraceLayer::new_for_http()
75 .make_span_with(tracing_span::<_>)
76 .on_failure(DefaultOnFailure::new().level(Level::ERROR))
77 .on_request(DefaultOnRequest::new().level(Level::TRACE))
78 .on_response(DefaultOnResponse::new().level(Level::DEBUG)),
79 )
80 .layer(axum::middleware::from_fn_with_state(Arc::clone(services), request::handle))
81 .layer(ip_source_layer(server.config.ip_source))
82 .layer(ResponseBodyTimeoutLayer::new(Duration::from_secs(
83 server.config.client_response_timeout,
84 )))
85 .layer(RequestBodyTimeoutLayer::new(Duration::from_secs(
86 server.config.client_receive_timeout,
87 )))
88 .layer(TimeoutLayer::with_status_code(
89 StatusCode::REQUEST_TIMEOUT,
90 Duration::from_secs(server.config.client_request_timeout),
91 ))
92 .layer(SetResponseHeaderLayer::if_not_present(
93 HeaderName::from_static("origin-agent-cluster"),
95 HeaderValue::from_static("?1"),
96 ))
97 .layer(SetResponseHeaderLayer::if_not_present(
98 header::X_CONTENT_TYPE_OPTIONS,
99 HeaderValue::from_static("nosniff"),
100 ))
101 .layer(SetResponseHeaderLayer::if_not_present(
102 header::X_XSS_PROTECTION,
103 HeaderValue::from_static("0"),
104 ))
105 .layer(SetResponseHeaderLayer::if_not_present(
106 header::X_FRAME_OPTIONS,
107 HeaderValue::from_static("DENY"),
108 ))
109 .layer(SetResponseHeaderLayer::if_not_present(
110 HeaderName::from_static("permissions-policy"),
111 HeaderValue::from_str(&TUWUNEL_PERMISSIONS_POLICY.join(","))?,
112 ))
113 .layer(SetResponseHeaderLayer::if_not_present(
114 header::CONTENT_SECURITY_POLICY,
115 |res: &http::Response<_>| {
116 let csp = res
117 .headers()
118 .get(header::CONTENT_TYPE)
119 .map(HeaderValue::to_str)
120 .and_then(Result::ok)
121 .is_some_and(|val| val.contains("text/html"))
122 .then(|| TUWUNEL_HTML_CSP.join(";"))
123 .unwrap_or_else(|| TUWUNEL_CSP.join(";"));
124
125 HeaderValue::from_str(&csp).ok()
126 },
127 ))
128 .layer(cors_layer(server))
129 .layer(body_limit_layer(server))
130 .layer(CatchPanicLayer::custom(move |panic| catch_panic(panic, services_.clone())));
131
132 let (router, guard) = router::build(services);
133 Ok((router.layer(layers), guard))
134}
135
136#[cfg(any(
137 feature = "zstd_compression",
138 feature = "gzip_compression",
139 feature = "brotli_compression"
140))]
141fn compression_layer(server: &Server) -> tower_http::compression::CompressionLayer {
142 let mut compression_layer = tower_http::compression::CompressionLayer::new();
143
144 #[cfg(feature = "zstd_compression")]
145 {
146 compression_layer = if server.config.zstd_compression {
147 compression_layer.zstd(true)
148 } else {
149 compression_layer.no_zstd()
150 };
151 };
152
153 #[cfg(feature = "gzip_compression")]
154 {
155 compression_layer = if server.config.gzip_compression {
156 compression_layer.gzip(true)
157 } else {
158 compression_layer.no_gzip()
159 };
160 };
161
162 #[cfg(feature = "brotli_compression")]
163 {
164 compression_layer = if server.config.brotli_compression {
165 compression_layer.br(true)
166 } else {
167 compression_layer.no_br()
168 };
169 };
170
171 compression_layer
172}
173
174fn cors_layer(server: &Server) -> CorsLayer {
175 const METHODS: [Method; 7] = [
176 Method::DELETE,
177 Method::GET,
178 Method::HEAD,
179 Method::OPTIONS,
180 Method::PATCH,
181 Method::POST,
182 Method::PUT,
183 ];
184
185 let headers: [HeaderName; 5] = [
186 header::ACCEPT,
187 header::AUTHORIZATION,
188 header::CONTENT_TYPE,
189 header::ORIGIN,
190 HeaderName::from_lowercase(b"x-requested-with")
191 .expect("valid HTTP HeaderName from lowercase."),
192 ];
193
194 let allow_origin_list = server
195 .config
196 .access_control_allow_origin
197 .iter()
198 .map(AsRef::as_ref)
199 .map(HeaderValue::from_str)
200 .filter_map(Result::ok);
201
202 let allow_origin = if !server
203 .config
204 .access_control_allow_origin
205 .is_empty()
206 {
207 AllowOrigin::list(allow_origin_list)
208 } else {
209 AllowOrigin::any()
210 };
211
212 CorsLayer::new()
213 .max_age(Duration::from_hours(24))
214 .allow_methods(METHODS)
215 .allow_headers(headers)
216 .allow_origin(allow_origin)
217}
218
219fn body_limit_layer(server: &Server) -> DefaultBodyLimit {
220 DefaultBodyLimit::max(server.config.max_request_size)
221}
222
223fn configured_ip_source(source: IpSource) -> SecureClientIpSource {
224 match source {
225 | IpSource::ConnectInfo => SecureClientIpSource::ConnectInfo,
226 | IpSource::RightmostXForwardedFor => SecureClientIpSource::RightmostXForwardedFor,
227 | IpSource::RightmostForwarded => SecureClientIpSource::RightmostForwarded,
228 | IpSource::XRealIp => SecureClientIpSource::XRealIp,
229 | IpSource::CfConnectingIp => SecureClientIpSource::CfConnectingIp,
230 | IpSource::TrueClientIp => SecureClientIpSource::TrueClientIp,
231 | IpSource::FlyClientIp => SecureClientIpSource::FlyClientIp,
232 | IpSource::CloudFrontViewerAddress => SecureClientIpSource::CloudFrontViewerAddress,
233 }
234}
235
236fn ip_source_layer(source: Option<IpSource>) -> Either<Extension<ConfiguredIpSource>, Identity> {
237 option_layer(source.map(|source| Extension(ConfiguredIpSource(configured_ip_source(source)))))
238}
239
240#[tracing::instrument(name = "panic", level = "error", skip_all)]
241#[expect(clippy::needless_pass_by_value)]
242fn catch_panic(
243 err: Box<dyn Any + Send + 'static>,
244 services: Arc<Services>,
245) -> http::Response<http_body_util::Full<bytes::Bytes>> {
246 services
247 .server
248 .metrics
249 .requests_panic
250 .fetch_add(1, std::sync::atomic::Ordering::Release);
251
252 let details = match err.downcast_ref::<String>() {
253 | Some(s) => s.clone(),
254 | _ => match err.downcast_ref::<&str>() {
255 | Some(s) => (*s).to_owned(),
256 | _ => "Unknown internal server error occurred.".to_owned(),
257 },
258 };
259
260 error!("{details:#}");
261 let body = serde_json::json!({
262 "errcode": "M_UNKNOWN",
263 "error": "M_UNKNOWN: Internal server error occurred",
264 "details": details,
265 });
266
267 http::Response::builder()
268 .status(StatusCode::INTERNAL_SERVER_ERROR)
269 .header(header::CONTENT_TYPE, "application/json")
270 .body(http_body_util::Full::from(body.to_string()))
271 .expect("Failed to create response for our panic catcher?")
272}
273
274fn tracing_span<T>(request: &http::Request<T>) -> tracing::Span {
275 let path = request
276 .extensions()
277 .get::<MatchedPath>()
278 .map_or_else(|| request_path_str(request), truncated_matched_path);
279
280 tracing::span! {
281 parent: None,
282 debug::INFO_SPAN_LEVEL,
283 "router",
284 method = %request.method(),
285 %path,
286 }
287}
288
289fn request_path_str<T>(request: &http::Request<T>) -> &str {
290 request
291 .uri()
292 .path_and_query()
293 .map(PathAndQuery::as_str)
294 .unwrap_or("/")
295}
296
297fn truncated_matched_path(path: &MatchedPath) -> &str {
298 path.as_str()
299 .rsplit_once('{')
300 .map_or(path.as_str(), |path| path.0.strip_suffix('/').unwrap_or(path.0))
301}