1use std::{
2 fmt::Debug,
3 net::{IpAddr, SocketAddr},
4};
5
6use futures::{FutureExt, TryFutureExt};
7use hickory_resolver::{
8 net::{DnsError, NetError},
9 proto::rr::RData,
10};
11use ipaddress::IPAddress;
12use ruma::ServerName;
13use tuwunel_core::{Err, Result, debug, debug_info, debug_warn, err, error, trace};
14
15use super::{
16 DestString, FedDest,
17 cache::{CachedDest, CachedOverride, MAX_IPS},
18 fed::{PortString, add_port_to_hostname, get_ip_with_port},
19};
20
21#[derive(Clone, Debug)]
22pub(crate) struct ActualDest {
23 pub(crate) dest: FedDest,
24 pub(crate) host: DestString,
25}
26
27impl ActualDest {
28 #[inline]
29 pub(crate) fn to_string(&self) -> DestString { self.dest.https_string() }
30}
31
32impl super::Service {
33 #[tracing::instrument(skip_all, level = "debug", name = "resolve")]
34 pub(crate) async fn get_actual_dest(&self, server_name: &ServerName) -> Result<ActualDest> {
35 let (CachedDest { dest, host, .. }, _cached) =
36 self.lookup_actual_dest(server_name).await?;
37
38 Ok(ActualDest { dest, host })
39 }
40
41 pub(crate) async fn lookup_actual_dest(
42 &self,
43 server_name: &ServerName,
44 ) -> Result<(CachedDest, bool)> {
45 if let Ok(result) = self.cache.get_destination(server_name).await {
46 return Ok((result, true));
47 }
48
49 let _dedup = self.resolving.lock(server_name);
50 if let Ok(result) = self.cache.get_destination(server_name).await {
51 return Ok((result, true));
52 }
53
54 self.resolve_actual_dest(server_name, true)
55 .inspect_ok(|result| self.cache.set_destination(server_name, result))
56 .map_ok(|result| (result, false))
57 .boxed()
58 .await
59 }
60
61 #[tracing::instrument(name = "actual", level = "debug", skip(self, cache))]
66 pub async fn resolve_actual_dest(
67 &self,
68 dest: &ServerName,
69 cache: bool,
70 ) -> Result<CachedDest> {
71 self.validate_dest(dest)?;
72 let mut host: DestString = dest.as_str().into();
73 let actual_dest = match get_ip_with_port(dest.as_str()) {
74 | Some(host_port) => Self::actual_dest_1(host_port)?,
75 | None =>
76 if let Some(pos) = dest.as_str().find(':') {
77 self.actual_dest_2(dest, cache, pos).await?
78 } else {
79 self.conditional_query_and_cache(dest.as_str(), 8448, true)
80 .await?;
81 self.services.server.check_running()?;
82 match self.request_well_known(dest.as_str()).await? {
83 | Some(delegated) =>
84 self.actual_dest_3(&mut host, cache, &delegated)
85 .await?,
86 | _ => match self.query_srv_record(dest.as_str()).await? {
87 | Some(overrider) =>
88 self.actual_dest_4(&host, cache, overrider)
89 .await?,
90 | _ => self.actual_dest_5(dest, cache).await?,
91 },
92 }
93 },
94 };
95
96 let host = if let Ok(addr) = host.parse::<SocketAddr>() {
99 FedDest::Literal(addr)
100 } else if let Ok(addr) = host.parse::<IpAddr>() {
101 FedDest::Named(addr.to_string().into(), FedDest::default_port())
102 } else if let Some(pos) = host.find(':') {
103 let (host, port) = host.split_at(pos);
104 FedDest::Named(
105 host.into(),
106 port.try_into()
107 .unwrap_or_else(|_| FedDest::default_port()),
108 )
109 } else {
110 FedDest::Named(host.as_str().into(), FedDest::default_port())
111 };
112
113 debug!("Actual destination: {actual_dest:?} hostname: {host:?}");
114 Ok(CachedDest {
115 dest: actual_dest,
116 host: host.uri_string(),
117 expire: CachedDest::default_expire(),
118 })
119 }
120
121 fn actual_dest_1(host_port: FedDest) -> Result<FedDest> {
122 debug!("1: IP literal with provided or default port");
123 Ok(host_port)
124 }
125
126 async fn actual_dest_2(&self, dest: &ServerName, cache: bool, pos: usize) -> Result<FedDest> {
127 debug!("2: Hostname with included port");
128 let (host, port) = dest.as_str().split_at(pos);
129 let port_num = port
130 .trim_start_matches(':')
131 .parse::<u16>()
132 .unwrap_or(8448);
133
134 self.conditional_query_and_cache(host, port_num, cache)
135 .await?;
136
137 Ok(FedDest::Named(
138 host.into(),
139 port.try_into()
140 .unwrap_or_else(|_| FedDest::default_port()),
141 ))
142 }
143
144 async fn actual_dest_3(
145 &self,
146 host: &mut DestString,
147 cache: bool,
148 delegated: &str,
149 ) -> Result<FedDest> {
150 debug!("3: A .well-known file is available");
151 *host = add_port_to_hostname(delegated).uri_string();
152 match get_ip_with_port(delegated) {
153 | Some(host_and_port) => Self::actual_dest_3_1(host_and_port),
154 | None =>
155 if let Some(pos) = delegated.find(':') {
156 self.actual_dest_3_2(cache, delegated, pos).await
157 } else {
158 trace!("Delegated hostname has no port in this branch");
159 match self.query_srv_record(delegated).await? {
160 | Some(overrider) =>
161 self.actual_dest_3_3(cache, delegated, overrider)
162 .await,
163 | _ => self.actual_dest_3_4(cache, delegated).await,
164 }
165 },
166 }
167 }
168
169 fn actual_dest_3_1(host_and_port: FedDest) -> Result<FedDest> {
170 debug!("3.1: IP literal in .well-known file");
171 Ok(host_and_port)
172 }
173
174 async fn actual_dest_3_2(&self, cache: bool, delegated: &str, pos: usize) -> Result<FedDest> {
175 debug!("3.2: Hostname with port in .well-known file");
176 let (host, port) = delegated.split_at(pos);
177 let port_num = port
178 .trim_start_matches(':')
179 .parse::<u16>()
180 .unwrap_or(8448);
181
182 self.conditional_query_and_cache(host, port_num, cache)
183 .await?;
184
185 Ok(FedDest::Named(
186 host.into(),
187 port.try_into()
188 .unwrap_or_else(|_| FedDest::default_port()),
189 ))
190 }
191
192 async fn actual_dest_3_3(
193 &self,
194 cache: bool,
195 delegated: &str,
196 overrider: FedDest,
197 ) -> Result<FedDest> {
198 debug!("3.3: SRV lookup successful");
199 let force_port = overrider.port();
200 self.conditional_query_and_cache_override(
201 delegated,
202 &overrider.hostname(),
203 force_port.unwrap_or(8448),
204 cache,
205 )
206 .await?;
207
208 if let Some(port) = force_port {
209 return Ok(FedDest::Named(
210 delegated.into(),
211 format!(":{port}")
212 .as_str()
213 .try_into()
214 .unwrap_or_else(|_| FedDest::default_port()),
215 ));
216 }
217
218 Ok(add_port_to_hostname(delegated))
219 }
220
221 async fn actual_dest_3_4(&self, cache: bool, delegated: &str) -> Result<FedDest> {
222 debug!("3.4: No SRV records, just use the hostname from .well-known");
223 self.conditional_query_and_cache(delegated, 8448, cache)
224 .await?;
225
226 Ok(add_port_to_hostname(delegated))
227 }
228
229 async fn actual_dest_4(
230 &self,
231 host: &str,
232 cache: bool,
233 overrider: FedDest,
234 ) -> Result<FedDest> {
235 debug!("4: No .well-known; SRV record found");
236 let force_port = overrider.port();
237 self.conditional_query_and_cache_override(
238 host,
239 &overrider.hostname(),
240 force_port.unwrap_or(8448),
241 cache,
242 )
243 .await?;
244
245 if let Some(port) = force_port {
246 let port = format!(":{port}");
247 return Ok(FedDest::Named(
248 host.into(),
249 PortString::from(port.as_str()).unwrap_or_else(|_| FedDest::default_port()),
250 ));
251 }
252
253 Ok(add_port_to_hostname(host))
254 }
255
256 async fn actual_dest_5(&self, dest: &ServerName, cache: bool) -> Result<FedDest> {
257 debug!("5: No SRV record found");
258 self.conditional_query_and_cache(dest.as_str(), 8448, cache)
259 .await?;
260
261 Ok(add_port_to_hostname(dest.as_str()))
262 }
263
264 #[inline]
265 async fn conditional_query_and_cache(
266 &self,
267 hostname: &str,
268 port: u16,
269 cache: bool,
270 ) -> Result {
271 self.conditional_query_and_cache_override(hostname, hostname, port, cache)
272 .await
273 }
274
275 #[inline]
276 async fn conditional_query_and_cache_override(
277 &self,
278 untername: &str,
279 hostname: &str,
280 port: u16,
281 cache: bool,
282 ) -> Result {
283 if !cache {
284 return Ok(());
285 }
286
287 if self.cache.has_override(untername).await {
288 return Ok(());
289 }
290
291 self.query_and_cache_override(untername, hostname, port)
292 .await
293 }
294
295 #[tracing::instrument(name = "ip", level = "debug", skip(self))]
296 async fn query_and_cache_override(
297 &self,
298 untername: &'_ str,
299 hostname: &'_ str,
300 port: u16,
301 ) -> Result {
302 self.services.server.check_running()?;
303
304 debug!("querying IP for {untername:?} ({hostname:?}:{port})");
305 match self
306 .resolver
307 .resolver
308 .lookup_ip(hostname.to_owned())
309 .await
310 {
311 | Err(e) => Self::handle_resolve_error(&e, hostname),
312 | Ok(override_ip) => {
313 self.cache
314 .set_override(untername, &CachedOverride {
315 ips: override_ip.iter().take(MAX_IPS).collect(),
316 port,
317 expire: CachedOverride::default_expire(),
318 overriding: (hostname != untername)
319 .then_some(hostname.into())
320 .inspect(|_| debug_info!("{untername:?} overridden by {hostname:?}")),
321 });
322
323 Ok(())
324 },
325 }
326 }
327
328 #[tracing::instrument(name = "srv", level = "debug", skip(self))]
329 async fn query_srv_record(&self, hostname: &'_ str) -> Result<Option<FedDest>> {
330 let hostnames =
331 [format!("_matrix-fed._tcp.{hostname}."), format!("_matrix._tcp.{hostname}.")];
332
333 for hostname in hostnames {
334 self.services.server.check_running()?;
335
336 debug!("querying SRV for {hostname:?}");
337 let hostname = hostname.trim_end_matches('.');
338 match self.resolver.resolver.srv_lookup(hostname).await {
339 | Err(e) => Self::handle_resolve_error(&e, hostname)?,
340 | Ok(result) => {
341 let srv = result
342 .answers()
343 .iter()
344 .find_map(|r| match &r.data {
345 | RData::SRV(srv) => Some(srv),
346 | _ => None,
347 });
348
349 return Ok(srv.map(|srv| {
350 FedDest::Named(
351 srv.target
352 .to_string()
353 .trim_end_matches('.')
354 .into(),
355 format!(":{}", srv.port)
356 .as_str()
357 .try_into()
358 .unwrap_or_else(|_| FedDest::default_port()),
359 )
360 }));
361 },
362 }
363 }
364
365 Ok(None)
366 }
367
368 fn handle_resolve_error(e: &NetError, host: &'_ str) -> Result {
369 match e {
374 | NetError::Dns(DnsError::NoRecordsFound(_)) => {
375 debug!(%host, "No DNS records found: {e}");
377 Ok(())
378 },
379 | NetError::Dns(_) => {
380 debug_warn!(%host, "DNS response error: {e}");
381 Ok(())
382 },
383 | NetError::Timeout => Err!(warn!(%host, "DNS {e}")),
384 | NetError::NoConnections => {
385 error!(
386 "Your DNS server is overloaded and has ran out of connections. It is \
387 strongly recommended you remediate this issue to ensure proper federation \
388 connectivity."
389 );
390
391 Err!(error!(%host, "DNS error: {e}"))
392 },
393 | _ => Err!(error!(%host, "DNS error: {e}")),
394 }
395 }
396
397 fn validate_dest(&self, dest: &ServerName) -> Result {
398 if dest == self.services.server.name && !self.services.server.config.federation_loopback {
399 return Err!("Won't send federation request to ourselves");
400 }
401
402 if dest.is_ip_literal() || IPAddress::is_valid(dest.host()) {
403 self.validate_dest_ip_literal(dest)?;
404 }
405
406 Ok(())
407 }
408
409 fn validate_dest_ip_literal(&self, dest: &ServerName) -> Result {
410 trace!("Destination is an IP literal, checking against IP range denylist.",);
411 debug_assert!(
412 dest.is_ip_literal() || !IPAddress::is_valid(dest.host()),
413 "Destination is not an IP literal."
414 );
415 let ip = IPAddress::parse(dest.host()).map_err(|e| {
416 err!(BadServerResponse(debug_error!("Failed to parse IP literal from string: {e}")))
417 })?;
418
419 self.validate_ip(&ip)?;
420
421 Ok(())
422 }
423
424 pub(crate) fn validate_ip(&self, ip: &IPAddress) -> Result {
425 if !self.services.client.valid_cidr_range(ip) {
426 return Err!(BadServerResponse("Not allowed to send requests to this IP"));
427 }
428
429 Ok(())
430 }
431}