Skip to main content

tuwunel_router/
layers.rs

1#[cfg(test)]
2mod tests;
3
4use std::{any::Any, sync::Arc, time::Duration};
5
6use axum::{
7	Extension, Router,
8	extract::{DefaultBodyLimit, MatchedPath},
9};
10use axum_client_ip::SecureClientIpSource;
11use http::{
12	HeaderValue, Method, StatusCode,
13	header::{self, HeaderName},
14	uri::PathAndQuery,
15};
16use tower::{
17	ServiceBuilder,
18	layer::util::Identity,
19	util::{Either, option_layer},
20};
21use tower_http::{
22	catch_panic::CatchPanicLayer,
23	cors::{AllowOrigin, CorsLayer},
24	sensitive_headers::SetSensitiveHeadersLayer,
25	set_header::SetResponseHeaderLayer,
26	timeout::{RequestBodyTimeoutLayer, ResponseBodyTimeoutLayer, TimeoutLayer},
27	trace::{DefaultOnFailure, DefaultOnRequest, DefaultOnResponse, TraceLayer},
28};
29use tracing::Level;
30use tuwunel_api::router::{ConfiguredIpSource, state::Guard};
31use tuwunel_core::{Result, Server, config::IpSource, debug, error};
32use tuwunel_service::Services;
33
34use crate::{request, router};
35
36const TUWUNEL_CSP: &[&str; 5] = &[
37	"default-src 'none'",
38	"frame-ancestors 'none'",
39	"form-action 'none'",
40	"base-uri 'none'",
41	"sandbox",
42];
43
44const TUWUNEL_HTML_CSP: &[&str; 7] = &[
45	"default-src 'none'",
46	"script-src 'unsafe-inline'",
47	"style-src 'unsafe-inline'",
48	"frame-ancestors 'none'",
49	"form-action 'none'",
50	"base-uri 'none'",
51	"sandbox",
52];
53
54const TUWUNEL_PERMISSIONS_POLICY: &[&str; 2] = &["interest-cohort=()", "browsing-topics=()"];
55
56pub(crate) fn build(services: &Arc<Services>) -> Result<(Router, Guard)> {
57	let server = &services.server;
58	let layers = ServiceBuilder::new();
59
60	#[cfg(feature = "sentry_telemetry")]
61	let layers = layers.layer(sentry_tower::NewSentryLayer::<http::Request<_>>::new_from_top());
62
63	#[cfg(any(
64		feature = "zstd_compression",
65		feature = "gzip_compression",
66		feature = "brotli_compression"
67	))]
68	let layers = layers.layer(compression_layer(server));
69
70	let services_ = services.clone();
71	let layers = layers
72		.layer(SetSensitiveHeadersLayer::new([header::AUTHORIZATION]))
73		.layer(
74			TraceLayer::new_for_http()
75				.make_span_with(tracing_span::<_>)
76				.on_failure(DefaultOnFailure::new().level(Level::ERROR))
77				.on_request(DefaultOnRequest::new().level(Level::TRACE))
78				.on_response(DefaultOnResponse::new().level(Level::DEBUG)),
79		)
80		.layer(axum::middleware::from_fn_with_state(Arc::clone(services), request::handle))
81		.layer(ip_source_layer(server.config.ip_source))
82		.layer(ResponseBodyTimeoutLayer::new(Duration::from_secs(
83			server.config.client_response_timeout,
84		)))
85		.layer(RequestBodyTimeoutLayer::new(Duration::from_secs(
86			server.config.client_receive_timeout,
87		)))
88		.layer(TimeoutLayer::with_status_code(
89			StatusCode::REQUEST_TIMEOUT,
90			Duration::from_secs(server.config.client_request_timeout),
91		))
92		.layer(SetResponseHeaderLayer::if_not_present(
93			// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin-Agent-Cluster
94			HeaderName::from_static("origin-agent-cluster"),
95			HeaderValue::from_static("?1"),
96		))
97		.layer(SetResponseHeaderLayer::if_not_present(
98			header::X_CONTENT_TYPE_OPTIONS,
99			HeaderValue::from_static("nosniff"),
100		))
101		.layer(SetResponseHeaderLayer::if_not_present(
102			header::X_XSS_PROTECTION,
103			HeaderValue::from_static("0"),
104		))
105		.layer(SetResponseHeaderLayer::if_not_present(
106			header::X_FRAME_OPTIONS,
107			HeaderValue::from_static("DENY"),
108		))
109		.layer(SetResponseHeaderLayer::if_not_present(
110			HeaderName::from_static("permissions-policy"),
111			HeaderValue::from_str(&TUWUNEL_PERMISSIONS_POLICY.join(","))?,
112		))
113		.layer(SetResponseHeaderLayer::if_not_present(
114			header::CONTENT_SECURITY_POLICY,
115			|res: &http::Response<_>| {
116				let csp = res
117					.headers()
118					.get(header::CONTENT_TYPE)
119					.map(HeaderValue::to_str)
120					.and_then(Result::ok)
121					.is_some_and(|val| val.contains("text/html"))
122					.then(|| TUWUNEL_HTML_CSP.join(";"))
123					.unwrap_or_else(|| TUWUNEL_CSP.join(";"));
124
125				HeaderValue::from_str(&csp).ok()
126			},
127		))
128		.layer(cors_layer(server))
129		.layer(body_limit_layer(server))
130		.layer(CatchPanicLayer::custom(move |panic| catch_panic(panic, services_.clone())));
131
132	let (router, guard) = router::build(services);
133	Ok((router.layer(layers), guard))
134}
135
136#[cfg(any(
137	feature = "zstd_compression",
138	feature = "gzip_compression",
139	feature = "brotli_compression"
140))]
141fn compression_layer(server: &Server) -> tower_http::compression::CompressionLayer {
142	let mut compression_layer = tower_http::compression::CompressionLayer::new();
143
144	#[cfg(feature = "zstd_compression")]
145	{
146		compression_layer = if server.config.zstd_compression {
147			compression_layer.zstd(true)
148		} else {
149			compression_layer.no_zstd()
150		};
151	};
152
153	#[cfg(feature = "gzip_compression")]
154	{
155		compression_layer = if server.config.gzip_compression {
156			compression_layer.gzip(true)
157		} else {
158			compression_layer.no_gzip()
159		};
160	};
161
162	#[cfg(feature = "brotli_compression")]
163	{
164		compression_layer = if server.config.brotli_compression {
165			compression_layer.br(true)
166		} else {
167			compression_layer.no_br()
168		};
169	};
170
171	compression_layer
172}
173
174fn cors_layer(server: &Server) -> CorsLayer {
175	const METHODS: [Method; 7] = [
176		Method::DELETE,
177		Method::GET,
178		Method::HEAD,
179		Method::OPTIONS,
180		Method::PATCH,
181		Method::POST,
182		Method::PUT,
183	];
184
185	let headers: [HeaderName; 5] = [
186		header::ACCEPT,
187		header::AUTHORIZATION,
188		header::CONTENT_TYPE,
189		header::ORIGIN,
190		HeaderName::from_lowercase(b"x-requested-with")
191			.expect("valid HTTP HeaderName from lowercase."),
192	];
193
194	let allow_origin_list = server
195		.config
196		.access_control_allow_origin
197		.iter()
198		.map(AsRef::as_ref)
199		.map(HeaderValue::from_str)
200		.filter_map(Result::ok);
201
202	let allow_origin = if !server
203		.config
204		.access_control_allow_origin
205		.is_empty()
206	{
207		AllowOrigin::list(allow_origin_list)
208	} else {
209		AllowOrigin::any()
210	};
211
212	CorsLayer::new()
213		.max_age(Duration::from_hours(24))
214		.allow_methods(METHODS)
215		.allow_headers(headers)
216		.allow_origin(allow_origin)
217}
218
219fn body_limit_layer(server: &Server) -> DefaultBodyLimit {
220	DefaultBodyLimit::max(server.config.max_request_size)
221}
222
223fn configured_ip_source(source: IpSource) -> SecureClientIpSource {
224	match source {
225		| IpSource::ConnectInfo => SecureClientIpSource::ConnectInfo,
226		| IpSource::RightmostXForwardedFor => SecureClientIpSource::RightmostXForwardedFor,
227		| IpSource::RightmostForwarded => SecureClientIpSource::RightmostForwarded,
228		| IpSource::XRealIp => SecureClientIpSource::XRealIp,
229		| IpSource::CfConnectingIp => SecureClientIpSource::CfConnectingIp,
230		| IpSource::TrueClientIp => SecureClientIpSource::TrueClientIp,
231		| IpSource::FlyClientIp => SecureClientIpSource::FlyClientIp,
232		| IpSource::CloudFrontViewerAddress => SecureClientIpSource::CloudFrontViewerAddress,
233	}
234}
235
236fn ip_source_layer(source: Option<IpSource>) -> Either<Extension<ConfiguredIpSource>, Identity> {
237	option_layer(source.map(|source| Extension(ConfiguredIpSource(configured_ip_source(source)))))
238}
239
240#[tracing::instrument(name = "panic", level = "error", skip_all)]
241#[expect(clippy::needless_pass_by_value)]
242fn catch_panic(
243	err: Box<dyn Any + Send + 'static>,
244	services: Arc<Services>,
245) -> http::Response<http_body_util::Full<bytes::Bytes>> {
246	services
247		.server
248		.metrics
249		.requests_panic
250		.fetch_add(1, std::sync::atomic::Ordering::Release);
251
252	let details = match err.downcast_ref::<String>() {
253		| Some(s) => s.clone(),
254		| _ => match err.downcast_ref::<&str>() {
255			| Some(s) => (*s).to_owned(),
256			| _ => "Unknown internal server error occurred.".to_owned(),
257		},
258	};
259
260	error!("{details:#}");
261	let body = serde_json::json!({
262		"errcode": "M_UNKNOWN",
263		"error": "M_UNKNOWN: Internal server error occurred",
264		"details": details,
265	});
266
267	http::Response::builder()
268		.status(StatusCode::INTERNAL_SERVER_ERROR)
269		.header(header::CONTENT_TYPE, "application/json")
270		.body(http_body_util::Full::from(body.to_string()))
271		.expect("Failed to create response for our panic catcher?")
272}
273
274fn tracing_span<T>(request: &http::Request<T>) -> tracing::Span {
275	let path = request
276		.extensions()
277		.get::<MatchedPath>()
278		.map_or_else(|| request_path_str(request), truncated_matched_path);
279
280	tracing::span! {
281		parent: None,
282		debug::INFO_SPAN_LEVEL,
283		"router",
284		method = %request.method(),
285		%path,
286	}
287}
288
289fn request_path_str<T>(request: &http::Request<T>) -> &str {
290	request
291		.uri()
292		.path_and_query()
293		.map(PathAndQuery::as_str)
294		.unwrap_or("/")
295}
296
297fn truncated_matched_path(path: &MatchedPath) -> &str {
298	path.as_str()
299		.rsplit_once('{')
300		.map_or(path.as_str(), |path| path.0.strip_suffix('/').unwrap_or(path.0))
301}