Skip to main content

tuwunel_service/resolver/
actual.rs

1use std::{
2	fmt::Debug,
3	net::{IpAddr, SocketAddr},
4};
5
6use futures::{FutureExt, TryFutureExt};
7use hickory_resolver::{
8	net::{DnsError, NetError},
9	proto::rr::RData,
10};
11use ipaddress::IPAddress;
12use ruma::ServerName;
13use tuwunel_core::{Err, Result, debug, debug_info, debug_warn, err, error, trace};
14
15use super::{
16	DestString, FedDest,
17	cache::{CachedDest, CachedOverride, MAX_IPS},
18	fed::{PortString, add_port_to_hostname, get_ip_with_port},
19};
20
21#[derive(Clone, Debug)]
22pub(crate) struct ActualDest {
23	pub(crate) dest: FedDest,
24	pub(crate) host: DestString,
25}
26
27impl ActualDest {
28	#[inline]
29	pub(crate) fn to_string(&self) -> DestString { self.dest.https_string() }
30}
31
32impl super::Service {
33	#[tracing::instrument(skip_all, level = "debug", name = "resolve")]
34	pub(crate) async fn get_actual_dest(&self, server_name: &ServerName) -> Result<ActualDest> {
35		let (CachedDest { dest, host, .. }, _cached) =
36			self.lookup_actual_dest(server_name).await?;
37
38		Ok(ActualDest { dest, host })
39	}
40
41	pub(crate) async fn lookup_actual_dest(
42		&self,
43		server_name: &ServerName,
44	) -> Result<(CachedDest, bool)> {
45		if let Ok(result) = self.cache.get_destination(server_name).await {
46			return Ok((result, true));
47		}
48
49		let _dedup = self.resolving.lock(server_name);
50		if let Ok(result) = self.cache.get_destination(server_name).await {
51			return Ok((result, true));
52		}
53
54		self.resolve_actual_dest(server_name, true)
55			.inspect_ok(|result| self.cache.set_destination(server_name, result))
56			.map_ok(|result| (result, false))
57			.boxed()
58			.await
59	}
60
61	/// Returns: `actual_destination`, host header
62	/// Implemented according to the specification at <https://matrix.org/docs/spec/server_server/r0.1.4#resolving-server-names>
63	/// Numbers in comments below refer to bullet points in linked section of
64	/// specification
65	#[tracing::instrument(name = "actual", level = "debug", skip(self, cache))]
66	pub async fn resolve_actual_dest(
67		&self,
68		dest: &ServerName,
69		cache: bool,
70	) -> Result<CachedDest> {
71		self.validate_dest(dest)?;
72		let mut host: DestString = dest.as_str().into();
73		let actual_dest = match get_ip_with_port(dest.as_str()) {
74			| Some(host_port) => Self::actual_dest_1(host_port)?,
75			| None =>
76				if let Some(pos) = dest.as_str().find(':') {
77					self.actual_dest_2(dest, cache, pos).await?
78				} else {
79					self.conditional_query_and_cache(dest.as_str(), 8448, true)
80						.await?;
81					self.services.server.check_running()?;
82					match self.request_well_known(dest.as_str()).await? {
83						| Some(delegated) =>
84							self.actual_dest_3(&mut host, cache, &delegated)
85								.await?,
86						| _ => match self.query_srv_record(dest.as_str()).await? {
87							| Some(overrider) =>
88								self.actual_dest_4(&host, cache, overrider)
89									.await?,
90							| _ => self.actual_dest_5(dest, cache).await?,
91						},
92					}
93				},
94		};
95
96		// Can't use get_ip_with_port here because we don't want to add a port
97		// to an IP address if it wasn't specified
98		let host = if let Ok(addr) = host.parse::<SocketAddr>() {
99			FedDest::Literal(addr)
100		} else if let Ok(addr) = host.parse::<IpAddr>() {
101			FedDest::Named(addr.to_string().into(), FedDest::default_port())
102		} else if let Some(pos) = host.find(':') {
103			let (host, port) = host.split_at(pos);
104			FedDest::Named(
105				host.into(),
106				port.try_into()
107					.unwrap_or_else(|_| FedDest::default_port()),
108			)
109		} else {
110			FedDest::Named(host.as_str().into(), FedDest::default_port())
111		};
112
113		debug!("Actual destination: {actual_dest:?} hostname: {host:?}");
114		Ok(CachedDest {
115			dest: actual_dest,
116			host: host.uri_string(),
117			expire: CachedDest::default_expire(),
118		})
119	}
120
121	fn actual_dest_1(host_port: FedDest) -> Result<FedDest> {
122		debug!("1: IP literal with provided or default port");
123		Ok(host_port)
124	}
125
126	async fn actual_dest_2(&self, dest: &ServerName, cache: bool, pos: usize) -> Result<FedDest> {
127		debug!("2: Hostname with included port");
128		let (host, port) = dest.as_str().split_at(pos);
129		let port_num = port
130			.trim_start_matches(':')
131			.parse::<u16>()
132			.unwrap_or(8448);
133
134		self.conditional_query_and_cache(host, port_num, cache)
135			.await?;
136
137		Ok(FedDest::Named(
138			host.into(),
139			port.try_into()
140				.unwrap_or_else(|_| FedDest::default_port()),
141		))
142	}
143
144	async fn actual_dest_3(
145		&self,
146		host: &mut DestString,
147		cache: bool,
148		delegated: &str,
149	) -> Result<FedDest> {
150		debug!("3: A .well-known file is available");
151		*host = add_port_to_hostname(delegated).uri_string();
152		match get_ip_with_port(delegated) {
153			| Some(host_and_port) => Self::actual_dest_3_1(host_and_port),
154			| None =>
155				if let Some(pos) = delegated.find(':') {
156					self.actual_dest_3_2(cache, delegated, pos).await
157				} else {
158					trace!("Delegated hostname has no port in this branch");
159					match self.query_srv_record(delegated).await? {
160						| Some(overrider) =>
161							self.actual_dest_3_3(cache, delegated, overrider)
162								.await,
163						| _ => self.actual_dest_3_4(cache, delegated).await,
164					}
165				},
166		}
167	}
168
169	fn actual_dest_3_1(host_and_port: FedDest) -> Result<FedDest> {
170		debug!("3.1: IP literal in .well-known file");
171		Ok(host_and_port)
172	}
173
174	async fn actual_dest_3_2(&self, cache: bool, delegated: &str, pos: usize) -> Result<FedDest> {
175		debug!("3.2: Hostname with port in .well-known file");
176		let (host, port) = delegated.split_at(pos);
177		let port_num = port
178			.trim_start_matches(':')
179			.parse::<u16>()
180			.unwrap_or(8448);
181
182		self.conditional_query_and_cache(host, port_num, cache)
183			.await?;
184
185		Ok(FedDest::Named(
186			host.into(),
187			port.try_into()
188				.unwrap_or_else(|_| FedDest::default_port()),
189		))
190	}
191
192	async fn actual_dest_3_3(
193		&self,
194		cache: bool,
195		delegated: &str,
196		overrider: FedDest,
197	) -> Result<FedDest> {
198		debug!("3.3: SRV lookup successful");
199		let force_port = overrider.port();
200		self.conditional_query_and_cache_override(
201			delegated,
202			&overrider.hostname(),
203			force_port.unwrap_or(8448),
204			cache,
205		)
206		.await?;
207
208		if let Some(port) = force_port {
209			return Ok(FedDest::Named(
210				delegated.into(),
211				format!(":{port}")
212					.as_str()
213					.try_into()
214					.unwrap_or_else(|_| FedDest::default_port()),
215			));
216		}
217
218		Ok(add_port_to_hostname(delegated))
219	}
220
221	async fn actual_dest_3_4(&self, cache: bool, delegated: &str) -> Result<FedDest> {
222		debug!("3.4: No SRV records, just use the hostname from .well-known");
223		self.conditional_query_and_cache(delegated, 8448, cache)
224			.await?;
225
226		Ok(add_port_to_hostname(delegated))
227	}
228
229	async fn actual_dest_4(
230		&self,
231		host: &str,
232		cache: bool,
233		overrider: FedDest,
234	) -> Result<FedDest> {
235		debug!("4: No .well-known; SRV record found");
236		let force_port = overrider.port();
237		self.conditional_query_and_cache_override(
238			host,
239			&overrider.hostname(),
240			force_port.unwrap_or(8448),
241			cache,
242		)
243		.await?;
244
245		if let Some(port) = force_port {
246			let port = format!(":{port}");
247			return Ok(FedDest::Named(
248				host.into(),
249				PortString::from(port.as_str()).unwrap_or_else(|_| FedDest::default_port()),
250			));
251		}
252
253		Ok(add_port_to_hostname(host))
254	}
255
256	async fn actual_dest_5(&self, dest: &ServerName, cache: bool) -> Result<FedDest> {
257		debug!("5: No SRV record found");
258		self.conditional_query_and_cache(dest.as_str(), 8448, cache)
259			.await?;
260
261		Ok(add_port_to_hostname(dest.as_str()))
262	}
263
264	#[inline]
265	async fn conditional_query_and_cache(
266		&self,
267		hostname: &str,
268		port: u16,
269		cache: bool,
270	) -> Result {
271		self.conditional_query_and_cache_override(hostname, hostname, port, cache)
272			.await
273	}
274
275	#[inline]
276	async fn conditional_query_and_cache_override(
277		&self,
278		untername: &str,
279		hostname: &str,
280		port: u16,
281		cache: bool,
282	) -> Result {
283		if !cache {
284			return Ok(());
285		}
286
287		if self.cache.has_override(untername).await {
288			return Ok(());
289		}
290
291		self.query_and_cache_override(untername, hostname, port)
292			.await
293	}
294
295	#[tracing::instrument(name = "ip", level = "debug", skip(self))]
296	async fn query_and_cache_override(
297		&self,
298		untername: &'_ str,
299		hostname: &'_ str,
300		port: u16,
301	) -> Result {
302		self.services.server.check_running()?;
303
304		debug!("querying IP for {untername:?} ({hostname:?}:{port})");
305		match self
306			.resolver
307			.resolver
308			.lookup_ip(hostname.to_owned())
309			.await
310		{
311			| Err(e) => Self::handle_resolve_error(&e, hostname),
312			| Ok(override_ip) => {
313				self.cache
314					.set_override(untername, &CachedOverride {
315						ips: override_ip.iter().take(MAX_IPS).collect(),
316						port,
317						expire: CachedOverride::default_expire(),
318						overriding: (hostname != untername)
319							.then_some(hostname.into())
320							.inspect(|_| debug_info!("{untername:?} overridden by {hostname:?}")),
321					});
322
323				Ok(())
324			},
325		}
326	}
327
328	#[tracing::instrument(name = "srv", level = "debug", skip(self))]
329	async fn query_srv_record(&self, hostname: &'_ str) -> Result<Option<FedDest>> {
330		let hostnames =
331			[format!("_matrix-fed._tcp.{hostname}."), format!("_matrix._tcp.{hostname}.")];
332
333		for hostname in hostnames {
334			self.services.server.check_running()?;
335
336			debug!("querying SRV for {hostname:?}");
337			let hostname = hostname.trim_end_matches('.');
338			match self.resolver.resolver.srv_lookup(hostname).await {
339				| Err(e) => Self::handle_resolve_error(&e, hostname)?,
340				| Ok(result) => {
341					let srv = result
342						.answers()
343						.iter()
344						.find_map(|r| match &r.data {
345							| RData::SRV(srv) => Some(srv),
346							| _ => None,
347						});
348
349					return Ok(srv.map(|srv| {
350						FedDest::Named(
351							srv.target
352								.to_string()
353								.trim_end_matches('.')
354								.into(),
355							format!(":{}", srv.port)
356								.as_str()
357								.try_into()
358								.unwrap_or_else(|_| FedDest::default_port()),
359						)
360					}));
361				},
362			}
363		}
364
365		Ok(None)
366	}
367
368	fn handle_resolve_error(e: &NetError, host: &'_ str) -> Result {
369		// `NetError::Dns(_)` covers responses returned by the remote side (NXDOMAIN,
370		// SERVFAIL, REFUSED, ...) only seen with verbose-logging. Local-origin failures
371		// (Timeout, NoConnections, Io, ...) keep their warn/error level so an operator
372		// notices when their own resolver is unhealthy.
373		match e {
374			| NetError::Dns(DnsError::NoRecordsFound(_)) => {
375				// Raise to debug_warn if we can find out the result wasn't from cache
376				debug!(%host, "No DNS records found: {e}");
377				Ok(())
378			},
379			| NetError::Dns(_) => {
380				debug_warn!(%host, "DNS response error: {e}");
381				Ok(())
382			},
383			| NetError::Timeout => Err!(warn!(%host, "DNS {e}")),
384			| NetError::NoConnections => {
385				error!(
386					"Your DNS server is overloaded and has ran out of connections. It is \
387					 strongly recommended you remediate this issue to ensure proper federation \
388					 connectivity."
389				);
390
391				Err!(error!(%host, "DNS error: {e}"))
392			},
393			| _ => Err!(error!(%host, "DNS error: {e}")),
394		}
395	}
396
397	fn validate_dest(&self, dest: &ServerName) -> Result {
398		if dest == self.services.server.name && !self.services.server.config.federation_loopback {
399			return Err!("Won't send federation request to ourselves");
400		}
401
402		if dest.is_ip_literal() || IPAddress::is_valid(dest.host()) {
403			self.validate_dest_ip_literal(dest)?;
404		}
405
406		Ok(())
407	}
408
409	fn validate_dest_ip_literal(&self, dest: &ServerName) -> Result {
410		trace!("Destination is an IP literal, checking against IP range denylist.",);
411		debug_assert!(
412			dest.is_ip_literal() || !IPAddress::is_valid(dest.host()),
413			"Destination is not an IP literal."
414		);
415		let ip = IPAddress::parse(dest.host()).map_err(|e| {
416			err!(BadServerResponse(debug_error!("Failed to parse IP literal from string: {e}")))
417		})?;
418
419		self.validate_ip(&ip)?;
420
421		Ok(())
422	}
423
424	pub(crate) fn validate_ip(&self, ip: &IPAddress) -> Result {
425		if !self.services.client.valid_cidr_range(ip) {
426			return Err!(BadServerResponse("Not allowed to send requests to this IP"));
427		}
428
429		Ok(())
430	}
431}