Skip to main content

tuwunel_service/federation/
execute.rs

1use std::{fmt::Debug, mem, time::Duration};
2
3use bytes::Bytes;
4use ipaddress::IPAddress;
5use reqwest::{Client, Method, Request, Response, Url};
6use ruma::{
7	ServerName,
8	api::{
9		EndpointError, IncomingResponse, MatrixVersion, OutgoingRequest, SupportedVersions,
10		error::Error as RumaError,
11	},
12};
13use tokio::time::timeout;
14use tuwunel_core::{
15	Err, Error, Result, debug, debug::INFO_SPAN_LEVEL, debug_error, debug_warn, err, implement,
16	trace,
17};
18
19use super::{
20	ShouldAttempt,
21	peer::classify_error,
22	scheme::{FedAuth, FedPath},
23};
24use crate::{client::read_response_capped, resolver::actual::ActualDest};
25
26/// Sends a request to a federation server
27#[implement(super::Service)]
28#[tracing::instrument(skip_all, name = "request", level = "debug")]
29pub async fn execute<T>(&self, dest: &ServerName, request: T) -> Result<T::IncomingResponse>
30where
31	T: OutgoingRequest + Debug + Send,
32	T::Authentication: FedAuth,
33	T::PathBuilder: FedPath,
34{
35	let client = &self.services.client.federation;
36	self.execute_on(client, dest, request).await
37}
38
39/// Client-initiated key lookup (`/keys/query`, `/keys/claim`) over federation:
40/// skips servers already in backoff and bounds the request by
41/// `federation_keys_timeout` so a waiting client is not held past its own send
42/// deadline. Honors peer-status but does not record into it; a slow key lookup
43/// must not suppress unrelated outbound traffic to the server.
44#[implement(super::Service)]
45#[tracing::instrument(skip_all, name = "keys", level = "debug")]
46pub async fn execute_keys<T>(&self, dest: &ServerName, request: T) -> Result<T::IncomingResponse>
47where
48	T: OutgoingRequest + Debug + Send,
49	T::Authentication: FedAuth,
50	T::PathBuilder: FedPath,
51{
52	if matches!(self.should_attempt(dest).await, ShouldAttempt::No { .. }) {
53		return Err!("{dest} is in federation backoff; skipping key lookup");
54	}
55
56	let timeout_dur = Duration::from_secs(
57		self.services
58			.server
59			.config
60			.federation_keys_timeout,
61	);
62
63	let client = &self.services.client.federation;
64
65	match timeout(timeout_dur, self.execute_uncounted(client, dest, request)).await {
66		| Ok(result) => result,
67		| Err(_elapsed) => Err!("{dest} key lookup exceeded {}s", timeout_dur.as_secs()),
68	}
69}
70
71/// Like execute() but with a very large timeout
72#[implement(super::Service)]
73#[tracing::instrument(skip_all, name = "synapse", level = "debug")]
74pub async fn execute_synapse<T>(
75	&self,
76	dest: &ServerName,
77	request: T,
78) -> Result<T::IncomingResponse>
79where
80	T: OutgoingRequest + Debug + Send,
81	T::Authentication: FedAuth,
82	T::PathBuilder: FedPath,
83{
84	let client = &self.services.client.synapse;
85	self.execute_on(client, dest, request).await
86}
87
88#[implement(super::Service)]
89pub async fn execute_on<T>(
90	&self,
91	client: &Client,
92	dest: &ServerName,
93	request: T,
94) -> Result<T::IncomingResponse>
95where
96	T: OutgoingRequest + Send,
97	T::Authentication: FedAuth,
98	T::PathBuilder: FedPath,
99{
100	let result = self
101		.execute_uncounted(client, dest, request)
102		.await;
103
104	match &result {
105		| Ok(_) => self.record_success(dest),
106		| Err(error) =>
107			if let Some(class) = classify_error(error) {
108				self.record_failure(dest, class);
109			},
110	}
111
112	result
113}
114
115/// Like [`execute_on`] but leaves peer-status untouched, for callers that
116/// must honor backoff without contributing to it.
117#[implement(super::Service)]
118#[tracing::instrument(
119	name = "fed",
120	level = INFO_SPAN_LEVEL,
121	skip(self, client, request),
122)]
123async fn execute_uncounted<T>(
124	&self,
125	client: &Client,
126	dest: &ServerName,
127	request: T,
128) -> Result<T::IncomingResponse>
129where
130	T: OutgoingRequest + Send,
131	T::Authentication: FedAuth,
132	T::PathBuilder: FedPath,
133{
134	if !self.services.server.config.allow_federation {
135		return Err!(Config("allow_federation", "Federation is disabled."));
136	}
137
138	if self
139		.services
140		.server
141		.config
142		.is_forbidden_remote_server_name(dest)
143	{
144		return Err!(Request(Forbidden(debug_warn!("Federation with {dest} is not allowed."))));
145	}
146
147	let actual = self
148		.services
149		.resolver
150		.get_actual_dest(dest)
151		.await?;
152
153	let request = self.prepare(&actual, dest, request)?;
154
155	self.perform::<T>(&actual, dest, request, client)
156		.await
157}
158
159#[implement(super::Service)]
160async fn perform<T>(
161	&self,
162	actual: &ActualDest,
163	dest: &ServerName,
164	request: Request,
165	client: &Client,
166) -> Result<T::IncomingResponse>
167where
168	T: OutgoingRequest + Send,
169	T::Authentication: FedAuth,
170	T::PathBuilder: FedPath,
171{
172	let url = request.url().clone();
173	let method = request.method().clone();
174
175	debug!(?method, ?url, "Sending request");
176	let limit = self.services.server.config.max_response_size;
177
178	match client.execute(request).await {
179		| Ok(response) =>
180			handle_response::<T>(actual, dest, &method, &url, response, limit).await,
181		| Err(error) => Err(self
182			.handle_error(dest, actual, &method, &url, error)
183			.expect_err("always returns error")),
184	}
185}
186
187#[implement(super::Service)]
188fn prepare<T>(&self, actual: &ActualDest, dest: &ServerName, request: T) -> Result<Request>
189where
190	T: OutgoingRequest + Send,
191	T::Authentication: FedAuth,
192	T::PathBuilder: FedPath,
193{
194	let request = self.to_http_request::<T>(actual, dest, request)?;
195	let request = Request::try_from(request)?;
196	self.validate_url(request.url())?;
197	self.services.server.check_running()?;
198
199	Ok(request)
200}
201
202#[implement(super::Service)]
203fn validate_url(&self, url: &Url) -> Result {
204	if let Some(url_host) = url.host_str()
205		&& let Ok(ip) = IPAddress::parse(url_host)
206	{
207		trace!("Checking request URL IP {ip:?}");
208		self.services.resolver.validate_ip(&ip)?;
209	}
210
211	Ok(())
212}
213
214async fn handle_response<T>(
215	actual: &ActualDest,
216	dest: &ServerName,
217	method: &Method,
218	url: &Url,
219	response: Response,
220	limit: usize,
221) -> Result<T::IncomingResponse>
222where
223	T: OutgoingRequest + Send,
224	T::Authentication: FedAuth,
225	T::PathBuilder: FedPath,
226{
227	let response = into_http_response(dest, actual, method, url, response, limit).await?;
228
229	T::IncomingResponse::try_from_http_response(response)
230		.map_err(|e| err!(BadServerResponse("Server returned bad 200 response: {e:?}")))
231}
232
233async fn into_http_response(
234	dest: &ServerName,
235	actual: &ActualDest,
236	method: &Method,
237	url: &Url,
238	mut response: Response,
239	limit: usize,
240) -> Result<http::Response<Bytes>> {
241	let status = response.status();
242	trace!(
243		?status, ?method,
244		request_url = ?url,
245		response_url = ?response.url(),
246		"Received response from {}",
247		actual.to_string(),
248	);
249
250	let mut http_response_builder = http::Response::builder()
251		.status(status)
252		.version(response.version());
253
254	mem::swap(
255		response.headers_mut(),
256		http_response_builder
257			.headers_mut()
258			.expect("http::response::Builder is usable"),
259	);
260
261	// TODO: handle timeout
262	trace!("Waiting for response body...");
263	let body = read_response_capped(response, limit).await?;
264
265	let http_response = http_response_builder
266		.body(body)
267		.expect("reqwest body is valid http body");
268
269	debug!("Got {status:?} for {method} {url}");
270	if !status.is_success() {
271		return Err(Error::Federation(
272			dest.to_owned(),
273			RumaError::from_http_response(http_response),
274		));
275	}
276
277	Ok(http_response)
278}
279
280#[implement(super::Service)]
281fn handle_error(
282	&self,
283	dest: &ServerName,
284	actual: &ActualDest,
285	method: &Method,
286	url: &Url,
287	mut e: reqwest::Error,
288) -> Result {
289	if e.is_timeout() || e.is_connect() {
290		e = e.without_url();
291		debug_warn!("{e:?}");
292	} else if e.is_redirect() {
293		debug_error!(
294			method = ?method,
295			url = ?url,
296			final_url = ?e.url(),
297			"Redirect loop {}: {}",
298			actual.host,
299			e,
300		);
301	} else {
302		debug_error!("{e:?}");
303	}
304
305	self.services.resolver.cache.del_destination(dest);
306	self.services.resolver.cache.del_override(dest);
307
308	Err(e.into())
309}
310
311#[implement(super::Service)]
312fn to_http_request<T>(
313	&self,
314	actual: &ActualDest,
315	dest: &ServerName,
316	request: T,
317) -> Result<http::Request<Vec<u8>>>
318where
319	T: OutgoingRequest + Send,
320	T::Authentication: FedAuth,
321	T::PathBuilder: FedPath,
322{
323	const VERSIONS: [MatrixVersion; 1] = [MatrixVersion::V1_11];
324	let supported = SupportedVersions {
325		versions: VERSIONS.into(),
326		features: Default::default(),
327	};
328
329	let auth = T::Authentication::input(
330		self.services.server.name.clone(),
331		dest.to_owned(),
332		self.services.server_keys.keypair(),
333	);
334	let path = T::PathBuilder::input(&supported);
335
336	request
337		.try_into_http_request::<Vec<u8>>(actual.to_string().as_str(), auth, path)
338		.map_err(|e| err!(BadServerResponse("Invalid destination: {e:?}")))
339}