Skip to main content

tuwunel_api/router/
client_ip.rs

1//! Tuwunel's client-IP extractor.
2//!
3//! Wraps `axum_client_ip` with a two-mode fallback:
4//!
5//! * If the operator configured `ip_source`, a [`ConfiguredIpSource`] marker is
6//!   installed in request extensions and we delegate to
7//!   [`axum_client_ip::SecureClientIp`] with that source.
8//! * Otherwise we fall back to [`axum_client_ip::InsecureClientIp`], preserving
9//!   existing behavior exactly -- including the header scan chain and the
10//!   socket-address fallback that matters for Unix-socket deployments (see
11//!   matrix-construct/tuwunel#310).
12//!
13//! The plain `SecureClientIpSource::ConnectInfo` extension already
14//! installed by `src/router/layers.rs` is intentionally ignored here;
15//! only the [`ConfiguredIpSource`] marker participates in the secure
16//! path. This avoids flipping behavior for deployments that never opted
17//! in.
18
19use std::{fmt, marker::Sync, net::IpAddr};
20
21use axum::extract::FromRequestParts;
22use axum_client_ip::{InsecureClientIp, SecureClientIp, SecureClientIpSource};
23use http::{StatusCode, request::Parts};
24
25/// Tuwunel client-IP extractor. See module docs.
26#[derive(Clone, Copy, Debug)]
27pub(crate) struct ClientIp(pub(crate) IpAddr);
28
29/// Marker wrapper around [`SecureClientIpSource`] placed into request
30/// extensions only when an operator has explicitly configured
31/// `ip_source`.
32#[derive(Clone, Debug)]
33pub struct ConfiguredIpSource(pub SecureClientIpSource);
34
35impl<S> FromRequestParts<S> for ClientIp
36where
37	S: Sync,
38{
39	type Rejection = (StatusCode, &'static str);
40
41	async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
42		const ERROR: StatusCode = StatusCode::INTERNAL_SERVER_ERROR;
43
44		if let Some(ConfiguredIpSource(source)) = parts.extensions.get::<ConfiguredIpSource>() {
45			SecureClientIp::from(source, &parts.headers, &parts.extensions)
46				.map(|SecureClientIp(ip)| Self(ip))
47				.map_err(|_| (ERROR, "Can't extract client IP from configured ip_source"))
48		} else {
49			InsecureClientIp::from(&parts.headers, &parts.extensions)
50				.map(|InsecureClientIp(ip)| Self(ip))
51		}
52	}
53}
54
55impl fmt::Display for ClientIp {
56	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Display::fmt(&self.0, f) }
57}
58
59#[cfg(test)]
60mod tests {
61	use std::net::SocketAddr;
62
63	use axum::{
64		extract::{ConnectInfo, FromRequestParts},
65		http::{Request, StatusCode, request::Parts},
66	};
67	use axum_client_ip::SecureClientIpSource;
68
69	use super::{ClientIp, ConfiguredIpSource};
70
71	fn parts(headers: impl IntoIterator<Item = (&'static str, &'static str)>) -> Parts {
72		let mut request = Request::builder().uri("/");
73		for (name, value) in headers {
74			request = request.header(name, value);
75		}
76		let (parts, ()) = request.body(()).unwrap().into_parts();
77		parts
78	}
79
80	async fn extract_client_ip(
81		parts: &mut Parts,
82	) -> Result<ClientIp, (StatusCode, &'static str)> {
83		ClientIp::from_request_parts(parts, &()).await
84	}
85
86	#[tokio::test]
87	async fn x_forwarded_for_uses_leftmost_ip() {
88		let mut parts = parts([("X-Forwarded-For", "1.1.1.1, 2.2.2.2")]);
89		let ClientIp(ip) = extract_client_ip(&mut parts).await.unwrap();
90		assert_eq!(ip.to_string(), "1.1.1.1");
91	}
92
93	#[tokio::test]
94	async fn x_forwarded_for_takes_priority_over_x_real_ip() {
95		let mut parts =
96			parts([("X-Forwarded-For", "1.1.1.1, 2.2.2.2"), ("X-Real-Ip", "3.3.3.3")]);
97		let ClientIp(ip) = extract_client_ip(&mut parts).await.unwrap();
98		assert_eq!(ip.to_string(), "1.1.1.1");
99	}
100
101	#[tokio::test]
102	async fn x_forwarded_for_accepts_ipv6() {
103		let mut parts = parts([("X-Forwarded-For", "2001:db8::1, 2001:db8::2")]);
104		let ClientIp(ip) = extract_client_ip(&mut parts).await.unwrap();
105		assert_eq!(ip.to_string(), "2001:db8::1");
106	}
107
108	#[tokio::test]
109	async fn x_real_ip_works() {
110		let mut parts = parts([("X-Real-Ip", "1.2.3.4")]);
111		let ClientIp(ip) = extract_client_ip(&mut parts).await.unwrap();
112		assert_eq!(ip.to_string(), "1.2.3.4");
113	}
114
115	#[tokio::test]
116	async fn malformed_headers_fall_through_to_next_valid_source() {
117		let mut parts = parts([
118			("X-Forwarded-For", "foo"),
119			("X-Real-Ip", "foo"),
120			("Forwarded", "foo"),
121			("Forwarded", "for=1.1.1.1;proto=https;by=2.2.2.2"),
122		]);
123		let ClientIp(ip) = extract_client_ip(&mut parts).await.unwrap();
124		assert_eq!(ip.to_string(), "1.1.1.1");
125	}
126
127	#[tokio::test]
128	async fn no_headers_or_connect_info_rejects() {
129		let mut parts = parts(std::iter::empty());
130		let err = extract_client_ip(&mut parts).await.unwrap_err();
131		assert_eq!(err.0, StatusCode::INTERNAL_SERVER_ERROR);
132		assert!(err.1.contains("ConnectInfo"), "{err:?}");
133	}
134
135	#[tokio::test]
136	async fn configured_source_uses_secure_extraction() {
137		let mut parts = parts([("X-Forwarded-For", "1.1.1.1, 2.2.2.2")]);
138		parts
139			.extensions
140			.insert(ConfiguredIpSource(SecureClientIpSource::RightmostXForwardedFor));
141		let ClientIp(ip) = extract_client_ip(&mut parts).await.unwrap();
142		assert_eq!(ip.to_string(), "2.2.2.2");
143	}
144
145	#[tokio::test]
146	async fn configured_source_without_matching_header_rejects() {
147		let mut parts = parts(std::iter::empty());
148		parts
149			.extensions
150			.insert(ConfiguredIpSource(SecureClientIpSource::RightmostXForwardedFor));
151		let err = extract_client_ip(&mut parts).await.unwrap_err();
152		assert_eq!(err.0, StatusCode::INTERNAL_SERVER_ERROR);
153		assert_eq!(err.1, "Can't extract client IP from configured ip_source");
154	}
155
156	#[tokio::test]
157	async fn secure_client_ip_source_extension_does_not_hijack() {
158		let mut parts = parts([("X-Forwarded-For", "1.1.1.1, 2.2.2.2")]);
159		parts
160			.extensions
161			.insert(SecureClientIpSource::ConnectInfo);
162		let ClientIp(ip) = extract_client_ip(&mut parts).await.unwrap();
163		assert_eq!(ip.to_string(), "1.1.1.1");
164	}
165
166	#[tokio::test]
167	async fn connect_info_fallback_uses_real_socket_addr_without_config() {
168		let socket_addr = SocketAddr::from(([203, 0, 113, 9], 4567));
169		let mut parts = parts(std::iter::empty());
170		parts.extensions.insert(ConnectInfo(socket_addr));
171
172		let ClientIp(ip) = extract_client_ip(&mut parts).await.unwrap();
173		assert_eq!(ip, socket_addr.ip());
174	}
175
176	#[tokio::test]
177	async fn bare_secure_client_ip_source_connect_info_does_not_hijack() {
178		let socket_addr = SocketAddr::from(([203, 0, 113, 10], 4567));
179		let mut parts = parts([("X-Forwarded-For", "1.1.1.1, 2.2.2.2")]);
180		parts.extensions.insert(ConnectInfo(socket_addr));
181		parts
182			.extensions
183			.insert(SecureClientIpSource::ConnectInfo);
184
185		let ClientIp(ip) = extract_client_ip(&mut parts).await.unwrap();
186		assert_eq!(ip.to_string(), "1.1.1.1");
187	}
188}