Skip to main content

tuwunel_router/
layers.rs

1#[cfg(test)]
2mod tests;
3
4use std::{any::Any, sync::Arc, time::Duration};
5
6use axum::{
7	Extension, Router,
8	extract::{DefaultBodyLimit, MatchedPath},
9};
10use http::{
11	HeaderValue, Method, StatusCode,
12	header::{self, HeaderName},
13	uri::PathAndQuery,
14};
15use ipnet::IpNet;
16use tower::{
17	ServiceBuilder,
18	layer::util::Identity,
19	util::{Either, option_layer},
20};
21use tower_http::{
22	catch_panic::CatchPanicLayer,
23	cors::{AllowOrigin, CorsLayer},
24	sensitive_headers::SetSensitiveHeadersLayer,
25	set_header::SetResponseHeaderLayer,
26	timeout::{RequestBodyTimeoutLayer, ResponseBodyTimeoutLayer, TimeoutLayer},
27	trace::{DefaultOnFailure, DefaultOnRequest, DefaultOnResponse, TraceLayer},
28};
29use tracing::Level;
30use tuwunel_api::router::{ConfiguredIpSource, TrustedPeerSubnets, state::Guard};
31use tuwunel_core::{Result, Server, config::IpSource, debug, error};
32use tuwunel_service::Services;
33
34use crate::{request, router};
35
36const TUWUNEL_CSP: &[&str; 5] = &[
37	"default-src 'none'",
38	"frame-ancestors 'none'",
39	"form-action 'none'",
40	"base-uri 'none'",
41	"sandbox",
42];
43
44const TUWUNEL_HTML_CSP: &[&str; 7] = &[
45	"default-src 'none'",
46	"script-src 'unsafe-inline'",
47	"style-src 'unsafe-inline'",
48	"frame-ancestors 'none'",
49	"form-action 'none'",
50	"base-uri 'none'",
51	"sandbox",
52];
53
54const TUWUNEL_PERMISSIONS_POLICY: &[&str; 2] = &["interest-cohort=()", "browsing-topics=()"];
55
56pub(crate) fn build(services: &Arc<Services>) -> Result<(Router, Guard)> {
57	let server = &services.server;
58	let layers = ServiceBuilder::new();
59
60	#[cfg(feature = "sentry_telemetry")]
61	let layers = layers.layer(sentry_tower::NewSentryLayer::<http::Request<_>>::new_from_top());
62
63	#[cfg(any(
64		feature = "zstd_compression",
65		feature = "gzip_compression",
66		feature = "brotli_compression"
67	))]
68	let layers = layers.layer(compression_layer(server));
69
70	let services_ = services.clone();
71	let layers = layers
72		.layer(SetSensitiveHeadersLayer::new([header::AUTHORIZATION]))
73		.layer(
74			TraceLayer::new_for_http()
75				.make_span_with(tracing_span::<_>)
76				.on_failure(DefaultOnFailure::new().level(Level::ERROR))
77				.on_request(DefaultOnRequest::new().level(Level::TRACE))
78				.on_response(DefaultOnResponse::new().level(Level::DEBUG)),
79		)
80		.layer(axum::middleware::from_fn_with_state(Arc::clone(services), request::handle))
81		.layer(trusted_peer_subnets_layer(&server.config.ip_source_trusted_subnets))
82		.layer(ip_source_layer(server.config.ip_source))
83		.layer(ResponseBodyTimeoutLayer::new(Duration::from_secs(
84			server.config.client_response_timeout,
85		)))
86		.layer(RequestBodyTimeoutLayer::new(Duration::from_secs(
87			server.config.client_receive_timeout,
88		)))
89		.layer(TimeoutLayer::with_status_code(
90			StatusCode::REQUEST_TIMEOUT,
91			Duration::from_secs(server.config.client_request_timeout),
92		))
93		.layer(SetResponseHeaderLayer::if_not_present(
94			// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin-Agent-Cluster
95			HeaderName::from_static("origin-agent-cluster"),
96			HeaderValue::from_static("?1"),
97		))
98		.layer(SetResponseHeaderLayer::if_not_present(
99			header::X_CONTENT_TYPE_OPTIONS,
100			HeaderValue::from_static("nosniff"),
101		))
102		.layer(SetResponseHeaderLayer::if_not_present(
103			header::X_XSS_PROTECTION,
104			HeaderValue::from_static("0"),
105		))
106		.layer(SetResponseHeaderLayer::if_not_present(
107			header::X_FRAME_OPTIONS,
108			HeaderValue::from_static("DENY"),
109		))
110		.layer(SetResponseHeaderLayer::if_not_present(
111			HeaderName::from_static("permissions-policy"),
112			HeaderValue::from_str(&TUWUNEL_PERMISSIONS_POLICY.join(","))?,
113		))
114		.layer(SetResponseHeaderLayer::if_not_present(
115			header::CONTENT_SECURITY_POLICY,
116			|res: &http::Response<_>| {
117				let csp = res
118					.headers()
119					.get(header::CONTENT_TYPE)
120					.map(HeaderValue::to_str)
121					.and_then(Result::ok)
122					.is_some_and(|val| val.contains("text/html"))
123					.then(|| TUWUNEL_HTML_CSP.join(";"))
124					.unwrap_or_else(|| TUWUNEL_CSP.join(";"));
125
126				HeaderValue::from_str(&csp).ok()
127			},
128		))
129		.layer(cors_layer(server))
130		.layer(body_limit_layer(server))
131		.layer(CatchPanicLayer::custom(move |panic| catch_panic(panic, services_.clone())));
132
133	let (router, guard) = router::build(services);
134	Ok((router.layer(layers), guard))
135}
136
137#[cfg(any(
138	feature = "zstd_compression",
139	feature = "gzip_compression",
140	feature = "brotli_compression"
141))]
142fn compression_layer(server: &Server) -> tower_http::compression::CompressionLayer {
143	let mut compression_layer = tower_http::compression::CompressionLayer::new();
144
145	#[cfg(feature = "zstd_compression")]
146	{
147		compression_layer = if server.config.zstd_compression {
148			compression_layer.zstd(true)
149		} else {
150			compression_layer.no_zstd()
151		};
152	};
153
154	#[cfg(feature = "gzip_compression")]
155	{
156		compression_layer = if server.config.gzip_compression {
157			compression_layer.gzip(true)
158		} else {
159			compression_layer.no_gzip()
160		};
161	};
162
163	#[cfg(feature = "brotli_compression")]
164	{
165		compression_layer = if server.config.brotli_compression {
166			compression_layer.br(true)
167		} else {
168			compression_layer.no_br()
169		};
170	};
171
172	compression_layer
173}
174
175fn cors_layer(server: &Server) -> CorsLayer {
176	const METHODS: [Method; 7] = [
177		Method::DELETE,
178		Method::GET,
179		Method::HEAD,
180		Method::OPTIONS,
181		Method::PATCH,
182		Method::POST,
183		Method::PUT,
184	];
185
186	let headers: [HeaderName; 5] = [
187		header::ACCEPT,
188		header::AUTHORIZATION,
189		header::CONTENT_TYPE,
190		header::ORIGIN,
191		HeaderName::from_lowercase(b"x-requested-with")
192			.expect("valid HTTP HeaderName from lowercase."),
193	];
194
195	let allow_origin_list = server
196		.config
197		.access_control_allow_origin
198		.iter()
199		.map(AsRef::as_ref)
200		.map(HeaderValue::from_str)
201		.filter_map(Result::ok);
202
203	let allow_origin = if !server
204		.config
205		.access_control_allow_origin
206		.is_empty()
207	{
208		AllowOrigin::list(allow_origin_list)
209	} else {
210		AllowOrigin::any()
211	};
212
213	CorsLayer::new()
214		.max_age(Duration::from_hours(24))
215		.allow_methods(METHODS)
216		.allow_headers(headers)
217		.allow_origin(allow_origin)
218}
219
220fn body_limit_layer(server: &Server) -> DefaultBodyLimit {
221	DefaultBodyLimit::max(server.config.max_request_size)
222}
223
224fn trusted_peer_subnets_layer(
225	subnets: &[IpNet],
226) -> Either<Extension<TrustedPeerSubnets>, Identity> {
227	option_layer((!subnets.is_empty()).then(|| Extension(TrustedPeerSubnets(Arc::from(subnets)))))
228}
229
230fn ip_source_layer(source: Option<IpSource>) -> Either<Extension<ConfiguredIpSource>, Identity> {
231	option_layer(source.map(|source| Extension(ConfiguredIpSource(source))))
232}
233
234#[tracing::instrument(name = "panic", level = "error", skip_all)]
235#[expect(clippy::needless_pass_by_value)]
236fn catch_panic(
237	err: Box<dyn Any + Send + 'static>,
238	services: Arc<Services>,
239) -> http::Response<http_body_util::Full<bytes::Bytes>> {
240	services
241		.server
242		.metrics
243		.requests_panic
244		.fetch_add(1, std::sync::atomic::Ordering::Release);
245
246	let details = match err.downcast_ref::<String>() {
247		| Some(s) => s.clone(),
248		| _ => match err.downcast_ref::<&str>() {
249			| Some(s) => (*s).to_owned(),
250			| _ => "Unknown internal server error occurred.".to_owned(),
251		},
252	};
253
254	error!("{details:#}");
255	let body = serde_json::json!({
256		"errcode": "M_UNKNOWN",
257		"error": "M_UNKNOWN: Internal server error occurred",
258		"details": details,
259	});
260
261	http::Response::builder()
262		.status(StatusCode::INTERNAL_SERVER_ERROR)
263		.header(header::CONTENT_TYPE, "application/json")
264		.body(http_body_util::Full::from(body.to_string()))
265		.expect("Failed to create response for our panic catcher?")
266}
267
268fn tracing_span<T>(request: &http::Request<T>) -> tracing::Span {
269	let path = request
270		.extensions()
271		.get::<MatchedPath>()
272		.map_or_else(|| request_path_str(request), truncated_matched_path);
273
274	tracing::span! {
275		parent: None,
276		debug::INFO_SPAN_LEVEL,
277		"router",
278		method = %request.method(),
279		%path,
280	}
281}
282
283fn request_path_str<T>(request: &http::Request<T>) -> &str {
284	request
285		.uri()
286		.path_and_query()
287		.map(PathAndQuery::as_str)
288		.unwrap_or("/")
289}
290
291fn truncated_matched_path(path: &MatchedPath) -> &str {
292	path.as_str()
293		.rsplit_once('{')
294		.map_or(path.as_str(), |path| path.0.strip_suffix('/').unwrap_or(path.0))
295}