1use std::{
18 fmt,
19 marker::Sync,
20 net::{IpAddr, SocketAddr},
21 sync::Arc,
22};
23
24use axum::extract::{ConnectInfo, FromRequestParts};
25use http::{Extensions, HeaderMap, StatusCode, request::Parts};
26use ipnet::IpNet;
27use tuwunel_core::config::IpSource;
28
29#[derive(Clone, Copy, Debug)]
31pub(crate) struct ClientIp(pub(crate) IpAddr);
32
33#[derive(Clone, Debug)]
36pub struct ConfiguredIpSource(pub IpSource);
37
38#[derive(Clone, Debug)]
42pub struct TrustedPeerSubnets(pub Arc<[IpNet]>);
43
44impl<S> FromRequestParts<S> for ClientIp
45where
46 S: Sync,
47{
48 type Rejection = (StatusCode, &'static str);
49
50 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
51 const ERROR: StatusCode = StatusCode::INTERNAL_SERVER_ERROR;
52
53 if let Some(&ConfiguredIpSource(source)) = parts.extensions.get::<ConfiguredIpSource>()
54 && !peer_is_trusted(&parts.extensions)
55 {
56 return secure_extract(source, &parts.headers, &parts.extensions)
57 .map(Self)
58 .ok_or((ERROR, "Can't extract client IP from configured ip_source"));
59 }
60
61 insecure_fallback(&parts.headers, &parts.extensions)
62 .map(Self)
63 .ok_or((ERROR, "Can't extract `ClientIp`, provide `axum::extract::ConnectInfo`"))
64 }
65}
66
67impl fmt::Display for ClientIp {
68 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Display::fmt(&self.0, f) }
69}
70
71fn peer_is_trusted(extensions: &Extensions) -> bool {
72 let Some(ConnectInfo(addr)) = extensions.get::<ConnectInfo<SocketAddr>>() else {
73 return false;
74 };
75
76 let peer = addr.ip().to_canonical();
77
78 peer.is_loopback()
79 || extensions
80 .get::<TrustedPeerSubnets>()
81 .is_some_and(|TrustedPeerSubnets(nets)| nets.iter().any(|net| net.contains(&peer)))
82}
83
84fn secure_extract(
85 source: IpSource,
86 headers: &HeaderMap,
87 extensions: &Extensions,
88) -> Option<IpAddr> {
89 match source {
90 | IpSource::ConnectInfo => extensions
91 .get::<ConnectInfo<SocketAddr>>()
92 .map(|ConnectInfo(addr)| addr.ip()),
93 | IpSource::RightmostXForwardedFor => rightmost_x_forwarded_for(headers),
94 | IpSource::RightmostForwarded => rightmost_forwarded(headers),
95 | IpSource::XRealIp => single_ip_header(headers, "x-real-ip"),
96 | IpSource::CfConnectingIp => single_ip_header(headers, "cf-connecting-ip"),
97 | IpSource::TrueClientIp => single_ip_header(headers, "true-client-ip"),
98 | IpSource::FlyClientIp => single_ip_header(headers, "fly-client-ip"),
99 | IpSource::CloudFrontViewerAddress => cloudfront_viewer_address(headers),
100 }
101}
102
103fn rightmost_x_forwarded_for(headers: &HeaderMap) -> Option<IpAddr> {
104 headers
105 .get_all("x-forwarded-for")
106 .iter()
107 .filter_map(|v| v.to_str().ok())
108 .flat_map(|s| s.split(','))
109 .filter_map(|s| s.trim().parse::<IpAddr>().ok())
110 .next_back()
111}
112
113fn rightmost_forwarded(headers: &HeaderMap) -> Option<IpAddr> {
114 headers
115 .get_all("forwarded")
116 .iter()
117 .filter_map(|v| v.to_str().ok())
118 .flat_map(|s| s.split(','))
119 .filter_map(parse_forwarded_for)
120 .next_back()
121}
122
123fn insecure_fallback(headers: &HeaderMap, extensions: &Extensions) -> Option<IpAddr> {
125 leftmost_x_forwarded_for(headers)
126 .or_else(|| leftmost_forwarded(headers))
127 .or_else(|| single_ip_header(headers, "x-real-ip"))
128 .or_else(|| single_ip_header(headers, "fly-client-ip"))
129 .or_else(|| single_ip_header(headers, "true-client-ip"))
130 .or_else(|| single_ip_header(headers, "cf-connecting-ip"))
131 .or_else(|| cloudfront_viewer_address(headers))
132 .or_else(|| {
133 extensions
134 .get::<ConnectInfo<SocketAddr>>()
135 .map(|ConnectInfo(addr)| addr.ip())
136 })
137}
138
139fn leftmost_x_forwarded_for(headers: &HeaderMap) -> Option<IpAddr> {
140 headers
141 .get_all("x-forwarded-for")
142 .iter()
143 .filter_map(|v| v.to_str().ok())
144 .flat_map(|s| s.split(','))
145 .find_map(|s| s.trim().parse::<IpAddr>().ok())
146}
147
148fn leftmost_forwarded(headers: &HeaderMap) -> Option<IpAddr> {
151 headers
152 .get_all("forwarded")
153 .iter()
154 .filter_map(|v| v.to_str().ok())
155 .flat_map(|s| s.split(','))
156 .find_map(parse_forwarded_for)
157}
158
159fn parse_forwarded_for(stanza: &str) -> Option<IpAddr> {
160 let for_value = stanza
161 .split(';')
162 .find_map(|part| {
163 let (k, v) = part.split_once('=')?;
164 k.trim()
165 .eq_ignore_ascii_case("for")
166 .then_some(v.trim())
167 })?
168 .trim_matches('"');
169
170 let bracketed = for_value
171 .strip_prefix('[')
172 .and_then(|s| s.split_once(']'))
173 .map(|(ip, _rest)| ip);
174
175 let candidate = bracketed
176 .or_else(|| for_value.rsplit_once(':').map(|(ip, _port)| ip))
177 .unwrap_or(for_value);
178
179 candidate.trim().parse::<IpAddr>().ok()
180}
181
182fn single_ip_header(headers: &HeaderMap, name: &'static str) -> Option<IpAddr> {
183 headers
184 .get(name)
185 .and_then(|v| v.to_str().ok())
186 .and_then(|s| s.trim().parse::<IpAddr>().ok())
187}
188
189fn cloudfront_viewer_address(headers: &HeaderMap) -> Option<IpAddr> {
190 headers
191 .get("cloudfront-viewer-address")
192 .and_then(|v| v.to_str().ok())
193 .and_then(|s| s.rsplit_once(':').map(|(ip, _port)| ip))
194 .and_then(|s| s.trim().parse::<IpAddr>().ok())
195}
196
197#[cfg(test)]
198mod tests {
199 use std::{iter, net::SocketAddr, sync::Arc};
200
201 use axum::{
202 extract::{ConnectInfo, FromRequestParts},
203 http::{Request, StatusCode, request::Parts},
204 };
205 use ipnet::IpNet;
206 use tuwunel_core::config::IpSource;
207
208 use super::{ClientIp, ConfiguredIpSource, TrustedPeerSubnets};
209
210 fn trusted(nets: &[&str]) -> TrustedPeerSubnets {
211 let nets: Arc<[IpNet]> = nets
212 .iter()
213 .map(|s| s.parse().expect("test CIDR"))
214 .collect();
215
216 TrustedPeerSubnets(nets)
217 }
218
219 fn parts(headers: impl IntoIterator<Item = (&'static str, &'static str)>) -> Parts {
220 let mut request = Request::builder().uri("/");
221 for (name, value) in headers {
222 request = request.header(name, value);
223 }
224 let (parts, ()) = request.body(()).unwrap().into_parts();
225 parts
226 }
227
228 async fn extract_client_ip(
229 parts: &mut Parts,
230 ) -> Result<ClientIp, (StatusCode, &'static str)> {
231 ClientIp::from_request_parts(parts, &()).await
232 }
233
234 #[tokio::test]
235 async fn x_forwarded_for_uses_leftmost_ip() {
236 let mut parts = parts([("X-Forwarded-For", "1.1.1.1, 2.2.2.2")]);
237 let ClientIp(ip) = extract_client_ip(&mut parts).await.unwrap();
238 assert_eq!(ip.to_string(), "1.1.1.1");
239 }
240
241 #[tokio::test]
242 async fn x_forwarded_for_takes_priority_over_x_real_ip() {
243 let mut parts =
244 parts([("X-Forwarded-For", "1.1.1.1, 2.2.2.2"), ("X-Real-Ip", "3.3.3.3")]);
245 let ClientIp(ip) = extract_client_ip(&mut parts).await.unwrap();
246 assert_eq!(ip.to_string(), "1.1.1.1");
247 }
248
249 #[tokio::test]
250 async fn x_forwarded_for_accepts_ipv6() {
251 let mut parts = parts([("X-Forwarded-For", "2001:db8::1, 2001:db8::2")]);
252 let ClientIp(ip) = extract_client_ip(&mut parts).await.unwrap();
253 assert_eq!(ip.to_string(), "2001:db8::1");
254 }
255
256 #[tokio::test]
257 async fn x_real_ip_works() {
258 let mut parts = parts([("X-Real-Ip", "1.2.3.4")]);
259 let ClientIp(ip) = extract_client_ip(&mut parts).await.unwrap();
260 assert_eq!(ip.to_string(), "1.2.3.4");
261 }
262
263 #[tokio::test]
264 async fn malformed_headers_fall_through_to_next_valid_source() {
265 let mut parts = parts([
266 ("X-Forwarded-For", "foo"),
267 ("X-Real-Ip", "foo"),
268 ("Forwarded", "foo"),
269 ("Forwarded", "for=1.1.1.1;proto=https;by=2.2.2.2"),
270 ]);
271 let ClientIp(ip) = extract_client_ip(&mut parts).await.unwrap();
272 assert_eq!(ip.to_string(), "1.1.1.1");
273 }
274
275 #[tokio::test]
276 async fn no_headers_or_connect_info_rejects() {
277 let mut parts = parts(iter::empty());
278 let err = extract_client_ip(&mut parts).await.unwrap_err();
279 assert_eq!(err.0, StatusCode::INTERNAL_SERVER_ERROR);
280 assert!(err.1.contains("ConnectInfo"), "{err:?}");
281 }
282
283 #[tokio::test]
284 async fn configured_source_uses_secure_extraction() {
285 let mut parts = parts([("X-Forwarded-For", "1.1.1.1, 2.2.2.2")]);
286 parts
287 .extensions
288 .insert(ConfiguredIpSource(IpSource::RightmostXForwardedFor));
289 let ClientIp(ip) = extract_client_ip(&mut parts).await.unwrap();
290 assert_eq!(ip.to_string(), "2.2.2.2");
291 }
292
293 #[tokio::test]
294 async fn configured_source_without_matching_header_rejects() {
295 let mut parts = parts(iter::empty());
296 parts
297 .extensions
298 .insert(ConfiguredIpSource(IpSource::RightmostXForwardedFor));
299 let err = extract_client_ip(&mut parts).await.unwrap_err();
300 assert_eq!(err.0, StatusCode::INTERNAL_SERVER_ERROR);
301 assert_eq!(err.1, "Can't extract client IP from configured ip_source");
302 }
303
304 #[tokio::test]
305 async fn connect_info_fallback_uses_real_socket_addr_without_config() {
306 let socket_addr = SocketAddr::from(([203, 0, 113, 9], 4567));
307 let mut parts = parts(iter::empty());
308 parts.extensions.insert(ConnectInfo(socket_addr));
309
310 let ClientIp(ip) = extract_client_ip(&mut parts).await.unwrap();
311 assert_eq!(ip, socket_addr.ip());
312 }
313
314 #[tokio::test]
315 async fn loopback_peer_bypasses_configured_source_for_locally_connected_bridges() {
316 let socket_addr = SocketAddr::from(([127, 0, 0, 1], 38000));
317 let mut parts = parts(iter::empty());
318 parts.extensions.insert(ConnectInfo(socket_addr));
319 parts
320 .extensions
321 .insert(ConfiguredIpSource(IpSource::RightmostXForwardedFor));
322
323 let ClientIp(ip) = extract_client_ip(&mut parts).await.unwrap();
324 assert_eq!(ip, socket_addr.ip());
325 }
326
327 #[tokio::test]
328 async fn loopback_peer_with_proxy_header_still_uses_insecure_fallback() {
329 let socket_addr = SocketAddr::from(([127, 0, 0, 1], 38000));
334 let mut parts = parts([("X-Forwarded-For", "9.9.9.9")]);
335 parts.extensions.insert(ConnectInfo(socket_addr));
336 parts
337 .extensions
338 .insert(ConfiguredIpSource(IpSource::RightmostXForwardedFor));
339
340 let ClientIp(ip) = extract_client_ip(&mut parts).await.unwrap();
341 assert_eq!(ip.to_string(), "9.9.9.9");
342 }
343
344 #[tokio::test]
345 async fn ipv6_loopback_peer_also_bypasses_configured_source() {
346 let socket_addr = SocketAddr::from(([0_u16, 0, 0, 0, 0, 0, 0, 1], 38000));
347 let mut parts = parts(iter::empty());
348 parts.extensions.insert(ConnectInfo(socket_addr));
349 parts
350 .extensions
351 .insert(ConfiguredIpSource(IpSource::RightmostXForwardedFor));
352
353 let ClientIp(ip) = extract_client_ip(&mut parts).await.unwrap();
354 assert_eq!(ip, socket_addr.ip());
355 }
356
357 #[tokio::test]
358 async fn non_loopback_peer_with_configured_source_still_rejects() {
359 let socket_addr = SocketAddr::from(([203, 0, 113, 9], 38000));
360 let mut parts = parts(iter::empty());
361 parts.extensions.insert(ConnectInfo(socket_addr));
362 parts
363 .extensions
364 .insert(ConfiguredIpSource(IpSource::RightmostXForwardedFor));
365
366 let err = extract_client_ip(&mut parts).await.unwrap_err();
367 assert_eq!(err.0, StatusCode::INTERNAL_SERVER_ERROR);
368 assert_eq!(err.1, "Can't extract client IP from configured ip_source");
369 }
370
371 #[tokio::test]
372 async fn trusted_subnet_peer_bypasses_configured_source() {
373 let socket_addr = SocketAddr::from(([172, 18, 0, 5], 38000));
374 let mut parts = parts(iter::empty());
375 parts.extensions.insert(ConnectInfo(socket_addr));
376 parts
377 .extensions
378 .insert(ConfiguredIpSource(IpSource::RightmostXForwardedFor));
379 parts
380 .extensions
381 .insert(trusted(&["172.18.0.0/16"]));
382
383 let ClientIp(ip) = extract_client_ip(&mut parts).await.unwrap();
384 assert_eq!(ip, socket_addr.ip());
385 }
386
387 #[tokio::test]
388 async fn trusted_subnet_peer_with_proxy_header_uses_insecure_fallback() {
389 let socket_addr = SocketAddr::from(([172, 18, 0, 5], 38000));
390 let mut parts = parts([("X-Forwarded-For", "9.9.9.9")]);
391 parts.extensions.insert(ConnectInfo(socket_addr));
392 parts
393 .extensions
394 .insert(ConfiguredIpSource(IpSource::RightmostXForwardedFor));
395 parts
396 .extensions
397 .insert(trusted(&["172.18.0.0/16"]));
398
399 let ClientIp(ip) = extract_client_ip(&mut parts).await.unwrap();
400 assert_eq!(ip.to_string(), "9.9.9.9");
401 }
402
403 #[tokio::test]
404 async fn non_trusted_peer_with_subnets_configured_still_rejects() {
405 let socket_addr = SocketAddr::from(([203, 0, 113, 9], 38000));
406 let mut parts = parts(iter::empty());
407 parts.extensions.insert(ConnectInfo(socket_addr));
408 parts
409 .extensions
410 .insert(ConfiguredIpSource(IpSource::RightmostXForwardedFor));
411 parts
412 .extensions
413 .insert(trusted(&["172.18.0.0/16"]));
414
415 let err = extract_client_ip(&mut parts).await.unwrap_err();
416 assert_eq!(err.0, StatusCode::INTERNAL_SERVER_ERROR);
417 assert_eq!(err.1, "Can't extract client IP from configured ip_source");
418 }
419
420 #[tokio::test]
421 async fn ipv6_trusted_subnet_peer_bypasses_configured_source() {
422 let socket_addr = SocketAddr::from(([0xFD00_u16, 0, 0, 0, 0, 0, 0, 1], 38000));
423 let mut parts = parts(iter::empty());
424 parts.extensions.insert(ConnectInfo(socket_addr));
425 parts
426 .extensions
427 .insert(ConfiguredIpSource(IpSource::RightmostXForwardedFor));
428 parts.extensions.insert(trusted(&["fd00::/8"]));
429
430 let ClientIp(ip) = extract_client_ip(&mut parts).await.unwrap();
431 assert_eq!(ip, socket_addr.ip());
432 }
433
434 #[tokio::test]
435 async fn trusted_single_host_cidr_matches_only_that_address() {
436 let configured = ConfiguredIpSource(IpSource::RightmostXForwardedFor);
437
438 let mut listed = parts(iter::empty());
439 listed
440 .extensions
441 .insert(ConnectInfo(SocketAddr::from(([10, 0, 0, 5], 38000))));
442 listed.extensions.insert(configured.clone());
443 listed
444 .extensions
445 .insert(trusted(&["10.0.0.5/32"]));
446
447 let ClientIp(ip) = extract_client_ip(&mut listed).await.unwrap();
448 assert_eq!(ip.to_string(), "10.0.0.5");
449
450 let mut neighbour = parts(iter::empty());
451 neighbour
452 .extensions
453 .insert(ConnectInfo(SocketAddr::from(([10, 0, 0, 6], 38000))));
454 neighbour.extensions.insert(configured);
455 neighbour
456 .extensions
457 .insert(trusted(&["10.0.0.5/32"]));
458
459 let err = extract_client_ip(&mut neighbour)
460 .await
461 .unwrap_err();
462 assert_eq!(err.0, StatusCode::INTERNAL_SERVER_ERROR);
463 }
464
465 #[tokio::test]
466 async fn loopback_still_bypasses_when_trusted_subnets_extension_absent() {
467 let socket_addr = SocketAddr::from(([127, 0, 0, 1], 38000));
468 let mut parts = parts(iter::empty());
469 parts.extensions.insert(ConnectInfo(socket_addr));
470 parts
471 .extensions
472 .insert(ConfiguredIpSource(IpSource::RightmostXForwardedFor));
473
474 let ClientIp(ip) = extract_client_ip(&mut parts).await.unwrap();
475 assert_eq!(ip, socket_addr.ip());
476 }
477}