1use std::{fmt::Debug, mem, time::Duration};
2
3use bytes::Bytes;
4use ipaddress::IPAddress;
5use reqwest::{Client, Method, Request, Response, Url};
6use ruma::{
7 ServerName,
8 api::{
9 EndpointError, IncomingResponse, MatrixVersion, OutgoingRequest, SupportedVersions,
10 error::Error as RumaError,
11 },
12};
13use tokio::time::timeout;
14use tuwunel_core::{
15 Err, Error, Result, debug, debug::INFO_SPAN_LEVEL, debug_error, debug_warn, err, implement,
16 trace,
17};
18
19use super::{
20 ShouldAttempt,
21 peer::classify_error,
22 scheme::{FedAuth, FedPath},
23};
24use crate::{client::read_response_capped, resolver::actual::ActualDest};
25
26#[implement(super::Service)]
28#[tracing::instrument(skip_all, name = "request", level = "debug")]
29pub async fn execute<T>(&self, dest: &ServerName, request: T) -> Result<T::IncomingResponse>
30where
31 T: OutgoingRequest + Debug + Send,
32 T::Authentication: FedAuth,
33 T::PathBuilder: FedPath,
34{
35 let client = &self.services.client.federation;
36 self.execute_on(client, dest, request).await
37}
38
39#[implement(super::Service)]
45#[tracing::instrument(skip_all, name = "keys", level = "debug")]
46pub async fn execute_keys<T>(&self, dest: &ServerName, request: T) -> Result<T::IncomingResponse>
47where
48 T: OutgoingRequest + Debug + Send,
49 T::Authentication: FedAuth,
50 T::PathBuilder: FedPath,
51{
52 if matches!(self.should_attempt(dest).await, ShouldAttempt::No { .. }) {
53 return Err!("{dest} is in federation backoff; skipping key lookup");
54 }
55
56 let timeout_dur = Duration::from_secs(
57 self.services
58 .server
59 .config
60 .federation_keys_timeout,
61 );
62
63 let client = &self.services.client.federation;
64
65 match timeout(timeout_dur, self.execute_uncounted(client, dest, request)).await {
66 | Ok(result) => result,
67 | Err(_elapsed) => Err!("{dest} key lookup exceeded {}s", timeout_dur.as_secs()),
68 }
69}
70
71#[implement(super::Service)]
73#[tracing::instrument(skip_all, name = "synapse", level = "debug")]
74pub async fn execute_synapse<T>(
75 &self,
76 dest: &ServerName,
77 request: T,
78) -> Result<T::IncomingResponse>
79where
80 T: OutgoingRequest + Debug + Send,
81 T::Authentication: FedAuth,
82 T::PathBuilder: FedPath,
83{
84 let client = &self.services.client.synapse;
85 self.execute_on(client, dest, request).await
86}
87
88#[implement(super::Service)]
89pub async fn execute_on<T>(
90 &self,
91 client: &Client,
92 dest: &ServerName,
93 request: T,
94) -> Result<T::IncomingResponse>
95where
96 T: OutgoingRequest + Send,
97 T::Authentication: FedAuth,
98 T::PathBuilder: FedPath,
99{
100 let result = self
101 .execute_uncounted(client, dest, request)
102 .await;
103
104 match &result {
105 | Ok(_) => self.record_success(dest),
106 | Err(error) =>
107 if let Some(class) = classify_error(error) {
108 self.record_failure(dest, class);
109 },
110 }
111
112 result
113}
114
115#[implement(super::Service)]
118#[tracing::instrument(
119 name = "fed",
120 level = INFO_SPAN_LEVEL,
121 skip(self, client, request),
122)]
123async fn execute_uncounted<T>(
124 &self,
125 client: &Client,
126 dest: &ServerName,
127 request: T,
128) -> Result<T::IncomingResponse>
129where
130 T: OutgoingRequest + Send,
131 T::Authentication: FedAuth,
132 T::PathBuilder: FedPath,
133{
134 if !self.services.server.config.allow_federation {
135 return Err!(Config("allow_federation", "Federation is disabled."));
136 }
137
138 if self
139 .services
140 .server
141 .config
142 .is_forbidden_remote_server_name(dest)
143 {
144 return Err!(Request(Forbidden(debug_warn!("Federation with {dest} is not allowed."))));
145 }
146
147 let actual = self
148 .services
149 .resolver
150 .get_actual_dest(dest)
151 .await?;
152
153 let request = self.prepare(&actual, dest, request)?;
154
155 self.perform::<T>(&actual, dest, request, client)
156 .await
157}
158
159#[implement(super::Service)]
160async fn perform<T>(
161 &self,
162 actual: &ActualDest,
163 dest: &ServerName,
164 request: Request,
165 client: &Client,
166) -> Result<T::IncomingResponse>
167where
168 T: OutgoingRequest + Send,
169 T::Authentication: FedAuth,
170 T::PathBuilder: FedPath,
171{
172 let url = request.url().clone();
173 let method = request.method().clone();
174
175 debug!(?method, ?url, "Sending request");
176 let limit = self.services.server.config.max_response_size;
177
178 match client.execute(request).await {
179 | Ok(response) =>
180 handle_response::<T>(actual, dest, &method, &url, response, limit).await,
181 | Err(error) => Err(self
182 .handle_error(dest, actual, &method, &url, error)
183 .expect_err("always returns error")),
184 }
185}
186
187#[implement(super::Service)]
188fn prepare<T>(&self, actual: &ActualDest, dest: &ServerName, request: T) -> Result<Request>
189where
190 T: OutgoingRequest + Send,
191 T::Authentication: FedAuth,
192 T::PathBuilder: FedPath,
193{
194 let request = self.to_http_request::<T>(actual, dest, request)?;
195 let request = Request::try_from(request)?;
196 self.validate_url(request.url())?;
197 self.services.server.check_running()?;
198
199 Ok(request)
200}
201
202#[implement(super::Service)]
203fn validate_url(&self, url: &Url) -> Result {
204 if let Some(url_host) = url.host_str()
205 && let Ok(ip) = IPAddress::parse(url_host)
206 {
207 trace!("Checking request URL IP {ip:?}");
208 self.services.resolver.validate_ip(&ip)?;
209 }
210
211 Ok(())
212}
213
214async fn handle_response<T>(
215 actual: &ActualDest,
216 dest: &ServerName,
217 method: &Method,
218 url: &Url,
219 response: Response,
220 limit: usize,
221) -> Result<T::IncomingResponse>
222where
223 T: OutgoingRequest + Send,
224 T::Authentication: FedAuth,
225 T::PathBuilder: FedPath,
226{
227 let response = into_http_response(dest, actual, method, url, response, limit).await?;
228
229 T::IncomingResponse::try_from_http_response(response)
230 .map_err(|e| err!(BadServerResponse("Server returned bad 200 response: {e:?}")))
231}
232
233async fn into_http_response(
234 dest: &ServerName,
235 actual: &ActualDest,
236 method: &Method,
237 url: &Url,
238 mut response: Response,
239 limit: usize,
240) -> Result<http::Response<Bytes>> {
241 let status = response.status();
242 trace!(
243 ?status, ?method,
244 request_url = ?url,
245 response_url = ?response.url(),
246 "Received response from {}",
247 actual.to_string(),
248 );
249
250 let mut http_response_builder = http::Response::builder()
251 .status(status)
252 .version(response.version());
253
254 mem::swap(
255 response.headers_mut(),
256 http_response_builder
257 .headers_mut()
258 .expect("http::response::Builder is usable"),
259 );
260
261 trace!("Waiting for response body...");
263 let body = read_response_capped(response, limit).await?;
264
265 let http_response = http_response_builder
266 .body(body)
267 .expect("reqwest body is valid http body");
268
269 debug!("Got {status:?} for {method} {url}");
270 if !status.is_success() {
271 return Err(Error::Federation(
272 dest.to_owned(),
273 RumaError::from_http_response(http_response),
274 ));
275 }
276
277 Ok(http_response)
278}
279
280#[implement(super::Service)]
281fn handle_error(
282 &self,
283 dest: &ServerName,
284 actual: &ActualDest,
285 method: &Method,
286 url: &Url,
287 mut e: reqwest::Error,
288) -> Result {
289 if e.is_timeout() || e.is_connect() {
290 e = e.without_url();
291 debug_warn!("{e:?}");
292 } else if e.is_redirect() {
293 debug_error!(
294 method = ?method,
295 url = ?url,
296 final_url = ?e.url(),
297 "Redirect loop {}: {}",
298 actual.host,
299 e,
300 );
301 } else {
302 debug_error!("{e:?}");
303 }
304
305 self.services.resolver.cache.del_destination(dest);
306 self.services.resolver.cache.del_override(dest);
307
308 Err(e.into())
309}
310
311#[implement(super::Service)]
312fn to_http_request<T>(
313 &self,
314 actual: &ActualDest,
315 dest: &ServerName,
316 request: T,
317) -> Result<http::Request<Vec<u8>>>
318where
319 T: OutgoingRequest + Send,
320 T::Authentication: FedAuth,
321 T::PathBuilder: FedPath,
322{
323 const VERSIONS: [MatrixVersion; 1] = [MatrixVersion::V1_11];
324 let supported = SupportedVersions {
325 versions: VERSIONS.into(),
326 features: Default::default(),
327 };
328
329 let auth = T::Authentication::input(
330 self.services.server.name.clone(),
331 dest.to_owned(),
332 self.services.server_keys.keypair(),
333 );
334 let path = T::PathBuilder::input(&supported);
335
336 request
337 .try_into_http_request::<Vec<u8>>(actual.to_string().as_str(), auth, path)
338 .map_err(|e| err!(BadServerResponse("Invalid destination: {e:?}")))
339}