tuwunel_api/router/
client_ip.rs1use std::{fmt, marker::Sync, net::IpAddr};
20
21use axum::extract::FromRequestParts;
22use axum_client_ip::{InsecureClientIp, SecureClientIp, SecureClientIpSource};
23use http::{StatusCode, request::Parts};
24
25#[derive(Clone, Copy, Debug)]
27pub(crate) struct ClientIp(pub(crate) IpAddr);
28
29#[derive(Clone, Debug)]
33pub struct ConfiguredIpSource(pub SecureClientIpSource);
34
35impl<S> FromRequestParts<S> for ClientIp
36where
37 S: Sync,
38{
39 type Rejection = (StatusCode, &'static str);
40
41 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
42 const ERROR: StatusCode = StatusCode::INTERNAL_SERVER_ERROR;
43
44 if let Some(ConfiguredIpSource(source)) = parts.extensions.get::<ConfiguredIpSource>() {
45 SecureClientIp::from(source, &parts.headers, &parts.extensions)
46 .map(|SecureClientIp(ip)| Self(ip))
47 .map_err(|_| (ERROR, "Can't extract client IP from configured ip_source"))
48 } else {
49 InsecureClientIp::from(&parts.headers, &parts.extensions)
50 .map(|InsecureClientIp(ip)| Self(ip))
51 }
52 }
53}
54
55impl fmt::Display for ClientIp {
56 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Display::fmt(&self.0, f) }
57}
58
59#[cfg(test)]
60mod tests {
61 use std::net::SocketAddr;
62
63 use axum::{
64 extract::{ConnectInfo, FromRequestParts},
65 http::{Request, StatusCode, request::Parts},
66 };
67 use axum_client_ip::SecureClientIpSource;
68
69 use super::{ClientIp, ConfiguredIpSource};
70
71 fn parts(headers: impl IntoIterator<Item = (&'static str, &'static str)>) -> Parts {
72 let mut request = Request::builder().uri("/");
73 for (name, value) in headers {
74 request = request.header(name, value);
75 }
76 let (parts, ()) = request.body(()).unwrap().into_parts();
77 parts
78 }
79
80 async fn extract_client_ip(
81 parts: &mut Parts,
82 ) -> Result<ClientIp, (StatusCode, &'static str)> {
83 ClientIp::from_request_parts(parts, &()).await
84 }
85
86 #[tokio::test]
87 async fn x_forwarded_for_uses_leftmost_ip() {
88 let mut parts = parts([("X-Forwarded-For", "1.1.1.1, 2.2.2.2")]);
89 let ClientIp(ip) = extract_client_ip(&mut parts).await.unwrap();
90 assert_eq!(ip.to_string(), "1.1.1.1");
91 }
92
93 #[tokio::test]
94 async fn x_forwarded_for_takes_priority_over_x_real_ip() {
95 let mut parts =
96 parts([("X-Forwarded-For", "1.1.1.1, 2.2.2.2"), ("X-Real-Ip", "3.3.3.3")]);
97 let ClientIp(ip) = extract_client_ip(&mut parts).await.unwrap();
98 assert_eq!(ip.to_string(), "1.1.1.1");
99 }
100
101 #[tokio::test]
102 async fn x_forwarded_for_accepts_ipv6() {
103 let mut parts = parts([("X-Forwarded-For", "2001:db8::1, 2001:db8::2")]);
104 let ClientIp(ip) = extract_client_ip(&mut parts).await.unwrap();
105 assert_eq!(ip.to_string(), "2001:db8::1");
106 }
107
108 #[tokio::test]
109 async fn x_real_ip_works() {
110 let mut parts = parts([("X-Real-Ip", "1.2.3.4")]);
111 let ClientIp(ip) = extract_client_ip(&mut parts).await.unwrap();
112 assert_eq!(ip.to_string(), "1.2.3.4");
113 }
114
115 #[tokio::test]
116 async fn malformed_headers_fall_through_to_next_valid_source() {
117 let mut parts = parts([
118 ("X-Forwarded-For", "foo"),
119 ("X-Real-Ip", "foo"),
120 ("Forwarded", "foo"),
121 ("Forwarded", "for=1.1.1.1;proto=https;by=2.2.2.2"),
122 ]);
123 let ClientIp(ip) = extract_client_ip(&mut parts).await.unwrap();
124 assert_eq!(ip.to_string(), "1.1.1.1");
125 }
126
127 #[tokio::test]
128 async fn no_headers_or_connect_info_rejects() {
129 let mut parts = parts(std::iter::empty());
130 let err = extract_client_ip(&mut parts).await.unwrap_err();
131 assert_eq!(err.0, StatusCode::INTERNAL_SERVER_ERROR);
132 assert!(err.1.contains("ConnectInfo"), "{err:?}");
133 }
134
135 #[tokio::test]
136 async fn configured_source_uses_secure_extraction() {
137 let mut parts = parts([("X-Forwarded-For", "1.1.1.1, 2.2.2.2")]);
138 parts
139 .extensions
140 .insert(ConfiguredIpSource(SecureClientIpSource::RightmostXForwardedFor));
141 let ClientIp(ip) = extract_client_ip(&mut parts).await.unwrap();
142 assert_eq!(ip.to_string(), "2.2.2.2");
143 }
144
145 #[tokio::test]
146 async fn configured_source_without_matching_header_rejects() {
147 let mut parts = parts(std::iter::empty());
148 parts
149 .extensions
150 .insert(ConfiguredIpSource(SecureClientIpSource::RightmostXForwardedFor));
151 let err = extract_client_ip(&mut parts).await.unwrap_err();
152 assert_eq!(err.0, StatusCode::INTERNAL_SERVER_ERROR);
153 assert_eq!(err.1, "Can't extract client IP from configured ip_source");
154 }
155
156 #[tokio::test]
157 async fn secure_client_ip_source_extension_does_not_hijack() {
158 let mut parts = parts([("X-Forwarded-For", "1.1.1.1, 2.2.2.2")]);
159 parts
160 .extensions
161 .insert(SecureClientIpSource::ConnectInfo);
162 let ClientIp(ip) = extract_client_ip(&mut parts).await.unwrap();
163 assert_eq!(ip.to_string(), "1.1.1.1");
164 }
165
166 #[tokio::test]
167 async fn connect_info_fallback_uses_real_socket_addr_without_config() {
168 let socket_addr = SocketAddr::from(([203, 0, 113, 9], 4567));
169 let mut parts = parts(std::iter::empty());
170 parts.extensions.insert(ConnectInfo(socket_addr));
171
172 let ClientIp(ip) = extract_client_ip(&mut parts).await.unwrap();
173 assert_eq!(ip, socket_addr.ip());
174 }
175
176 #[tokio::test]
177 async fn bare_secure_client_ip_source_connect_info_does_not_hijack() {
178 let socket_addr = SocketAddr::from(([203, 0, 113, 10], 4567));
179 let mut parts = parts([("X-Forwarded-For", "1.1.1.1, 2.2.2.2")]);
180 parts.extensions.insert(ConnectInfo(socket_addr));
181 parts
182 .extensions
183 .insert(SecureClientIpSource::ConnectInfo);
184
185 let ClientIp(ip) = extract_client_ip(&mut parts).await.unwrap();
186 assert_eq!(ip.to_string(), "1.1.1.1");
187 }
188}