Skip to main content

tuwunel_service/federation/
execute.rs

1use std::{fmt::Debug, mem};
2
3use bytes::Bytes;
4use ipaddress::IPAddress;
5use reqwest::{Client, Method, Request, Response, Url};
6use ruma::{
7	ServerName,
8	api::{
9		EndpointError, IncomingResponse, MatrixVersion, OutgoingRequest, SupportedVersions,
10		error::Error as RumaError,
11	},
12};
13use tuwunel_core::{
14	Err, Error, Result, debug, debug::INFO_SPAN_LEVEL, debug_error, debug_warn, err,
15	error::inspect_debug_log, implement, trace,
16};
17
18use super::scheme::{FedAuth, FedPath};
19use crate::resolver::actual::ActualDest;
20
21/// Sends a request to a federation server
22#[implement(super::Service)]
23#[tracing::instrument(skip_all, name = "request", level = "debug")]
24pub async fn execute<T>(&self, dest: &ServerName, request: T) -> Result<T::IncomingResponse>
25where
26	T: OutgoingRequest + Debug + Send,
27	T::Authentication: FedAuth,
28	T::PathBuilder: FedPath,
29{
30	let client = &self.services.client.federation;
31	self.execute_on(client, dest, request).await
32}
33
34/// Like execute() but with a very large timeout
35#[implement(super::Service)]
36#[tracing::instrument(skip_all, name = "synapse", level = "debug")]
37pub async fn execute_synapse<T>(
38	&self,
39	dest: &ServerName,
40	request: T,
41) -> Result<T::IncomingResponse>
42where
43	T: OutgoingRequest + Debug + Send,
44	T::Authentication: FedAuth,
45	T::PathBuilder: FedPath,
46{
47	let client = &self.services.client.synapse;
48	self.execute_on(client, dest, request).await
49}
50
51#[implement(super::Service)]
52#[tracing::instrument(
53	name = "fed",
54	level = INFO_SPAN_LEVEL,
55	skip(self, client, request),
56)]
57pub async fn execute_on<T>(
58	&self,
59	client: &Client,
60	dest: &ServerName,
61	request: T,
62) -> Result<T::IncomingResponse>
63where
64	T: OutgoingRequest + Send,
65	T::Authentication: FedAuth,
66	T::PathBuilder: FedPath,
67{
68	if !self.services.server.config.allow_federation {
69		return Err!(Config("allow_federation", "Federation is disabled."));
70	}
71
72	if self
73		.services
74		.server
75		.config
76		.is_forbidden_remote_server_name(dest)
77	{
78		return Err!(Request(Forbidden(debug_warn!("Federation with {dest} is not allowed."))));
79	}
80
81	let actual = self
82		.services
83		.resolver
84		.get_actual_dest(dest)
85		.await?;
86
87	let request = self.prepare(&actual, dest, request)?;
88	self.perform::<T>(&actual, dest, request, client)
89		.await
90}
91
92#[implement(super::Service)]
93async fn perform<T>(
94	&self,
95	actual: &ActualDest,
96	dest: &ServerName,
97	request: Request,
98	client: &Client,
99) -> Result<T::IncomingResponse>
100where
101	T: OutgoingRequest + Send,
102	T::Authentication: FedAuth,
103	T::PathBuilder: FedPath,
104{
105	let url = request.url().clone();
106	let method = request.method().clone();
107
108	debug!(?method, ?url, "Sending request");
109	match client.execute(request).await {
110		| Ok(response) => handle_response::<T>(actual, dest, &method, &url, response).await,
111		| Err(error) => Err(self
112			.handle_error(dest, actual, &method, &url, error)
113			.expect_err("always returns error")),
114	}
115}
116
117#[implement(super::Service)]
118fn prepare<T>(&self, actual: &ActualDest, dest: &ServerName, request: T) -> Result<Request>
119where
120	T: OutgoingRequest + Send,
121	T::Authentication: FedAuth,
122	T::PathBuilder: FedPath,
123{
124	let request = self.to_http_request::<T>(actual, dest, request)?;
125	let request = Request::try_from(request)?;
126	self.validate_url(request.url())?;
127	self.services.server.check_running()?;
128
129	Ok(request)
130}
131
132#[implement(super::Service)]
133fn validate_url(&self, url: &Url) -> Result {
134	if let Some(url_host) = url.host_str()
135		&& let Ok(ip) = IPAddress::parse(url_host)
136	{
137		trace!("Checking request URL IP {ip:?}");
138		self.services.resolver.validate_ip(&ip)?;
139	}
140
141	Ok(())
142}
143
144async fn handle_response<T>(
145	actual: &ActualDest,
146	dest: &ServerName,
147	method: &Method,
148	url: &Url,
149	response: Response,
150) -> Result<T::IncomingResponse>
151where
152	T: OutgoingRequest + Send,
153	T::Authentication: FedAuth,
154	T::PathBuilder: FedPath,
155{
156	let response = into_http_response(dest, actual, method, url, response).await?;
157
158	T::IncomingResponse::try_from_http_response(response)
159		.map_err(|e| err!(BadServerResponse("Server returned bad 200 response: {e:?}")))
160}
161
162async fn into_http_response(
163	dest: &ServerName,
164	actual: &ActualDest,
165	method: &Method,
166	url: &Url,
167	mut response: Response,
168) -> Result<http::Response<Bytes>> {
169	let status = response.status();
170	trace!(
171		?status, ?method,
172		request_url = ?url,
173		response_url = ?response.url(),
174		"Received response from {}",
175		actual.to_string(),
176	);
177
178	let mut http_response_builder = http::Response::builder()
179		.status(status)
180		.version(response.version());
181
182	mem::swap(
183		response.headers_mut(),
184		http_response_builder
185			.headers_mut()
186			.expect("http::response::Builder is usable"),
187	);
188
189	// TODO: handle timeout
190	trace!("Waiting for response body...");
191	let body = response
192		.bytes()
193		.await
194		.inspect_err(inspect_debug_log)
195		.unwrap_or_else(|_| Vec::new().into());
196
197	let http_response = http_response_builder
198		.body(body)
199		.expect("reqwest body is valid http body");
200
201	debug!("Got {status:?} for {method} {url}");
202	if !status.is_success() {
203		return Err(Error::Federation(
204			dest.to_owned(),
205			RumaError::from_http_response(http_response),
206		));
207	}
208
209	Ok(http_response)
210}
211
212#[implement(super::Service)]
213fn handle_error(
214	&self,
215	dest: &ServerName,
216	actual: &ActualDest,
217	method: &Method,
218	url: &Url,
219	mut e: reqwest::Error,
220) -> Result {
221	if e.is_timeout() || e.is_connect() {
222		e = e.without_url();
223		debug_warn!("{e:?}");
224	} else if e.is_redirect() {
225		debug_error!(
226			method = ?method,
227			url = ?url,
228			final_url = ?e.url(),
229			"Redirect loop {}: {}",
230			actual.host,
231			e,
232		);
233	} else {
234		debug_error!("{e:?}");
235	}
236
237	self.services.resolver.cache.del_destination(dest);
238	self.services.resolver.cache.del_override(dest);
239
240	Err(e.into())
241}
242
243#[implement(super::Service)]
244fn to_http_request<T>(
245	&self,
246	actual: &ActualDest,
247	dest: &ServerName,
248	request: T,
249) -> Result<http::Request<Vec<u8>>>
250where
251	T: OutgoingRequest + Send,
252	T::Authentication: FedAuth,
253	T::PathBuilder: FedPath,
254{
255	const VERSIONS: [MatrixVersion; 1] = [MatrixVersion::V1_11];
256	let supported = SupportedVersions {
257		versions: VERSIONS.into(),
258		features: Default::default(),
259	};
260
261	let auth = T::Authentication::input(
262		self.services.server.name.clone(),
263		dest.to_owned(),
264		self.services.server_keys.keypair(),
265	);
266	let path = T::PathBuilder::input(&supported);
267
268	request
269		.try_into_http_request::<Vec<u8>>(actual.to_string().as_str(), auth, path)
270		.map_err(|e| err!(BadServerResponse("Invalid destination: {e:?}")))
271}