Skip to main content

tuwunel_service/resolver/
cache.rs

1use std::{net::IpAddr, sync::Arc, time::SystemTime};
2
3use futures::{Stream, StreamExt, future::join};
4use ruma::ServerName;
5use serde::{Deserialize, Serialize};
6use tuwunel_core::{
7	Result,
8	arrayvec::ArrayVec,
9	at, err, implement,
10	utils::{math::Expected, rand, stream::TryIgnore},
11};
12use tuwunel_database::{Cbor, Deserialized, Map};
13
14use super::{DestString, FedDest};
15
16pub struct Cache {
17	destinations: Arc<Map>,
18	overrides: Arc<Map>,
19}
20
21#[derive(Clone, Debug, Deserialize, Serialize)]
22pub struct CachedDest {
23	pub dest: FedDest,
24	pub host: DestString,
25	pub expire: SystemTime,
26}
27
28#[derive(Clone, Debug, Deserialize, Serialize)]
29pub struct CachedOverride {
30	pub ips: IpAddrs,
31	pub port: u16,
32	pub expire: SystemTime,
33	pub overriding: Option<DestString>,
34}
35
36pub type IpAddrs = ArrayVec<IpAddr, MAX_IPS>;
37pub(crate) const MAX_IPS: usize = 3;
38
39impl Cache {
40	pub(super) fn new(args: &crate::Args<'_>) -> Arc<Self> {
41		Arc::new(Self {
42			destinations: args.db["servername_destination"].clone(),
43			overrides: args.db["servername_override"].clone(),
44		})
45	}
46}
47
48#[implement(Cache)]
49pub async fn clear(&self) { join(self.clear_destinations(), self.clear_overrides()).await; }
50
51#[implement(Cache)]
52pub async fn clear_destinations(&self) { self.destinations.clear().await; }
53
54#[implement(Cache)]
55pub async fn clear_overrides(&self) { self.overrides.clear().await; }
56
57#[implement(Cache)]
58pub fn del_destination(&self, name: &ServerName) { self.destinations.remove(name); }
59
60#[implement(Cache)]
61pub fn del_override(&self, name: &ServerName) { self.overrides.remove(name); }
62
63#[implement(Cache)]
64pub fn set_destination(&self, name: &ServerName, dest: &CachedDest) {
65	self.destinations.raw_put(name, Cbor(dest));
66}
67
68#[implement(Cache)]
69pub fn set_override(&self, name: &str, over: &CachedOverride) {
70	self.overrides.raw_put(name, Cbor(over));
71}
72
73#[implement(Cache)]
74#[must_use]
75pub async fn has_destination(&self, destination: &ServerName) -> bool {
76	self.get_destination(destination).await.is_ok()
77}
78
79#[implement(Cache)]
80#[must_use]
81pub async fn has_override(&self, destination: &str) -> bool {
82	self.get_override(destination)
83		.await
84		.iter()
85		.any(CachedOverride::valid)
86}
87
88#[implement(Cache)]
89pub async fn get_destination(&self, name: &ServerName) -> Result<CachedDest> {
90	self.destinations
91		.get(name)
92		.await
93		.deserialized::<Cbor<_>>()
94		.map(at!(0))
95		.into_iter()
96		.find(CachedDest::valid)
97		.ok_or(err!(Request(NotFound("Expired from cache"))))
98}
99
100#[implement(Cache)]
101pub async fn get_override(&self, name: &str) -> Result<CachedOverride> {
102	self.overrides
103		.get(name)
104		.await
105		.deserialized::<Cbor<_>>()
106		.map(at!(0))
107}
108
109#[implement(Cache)]
110pub fn destinations(&self) -> impl Stream<Item = (&ServerName, CachedDest)> + Send + '_ {
111	self.destinations
112		.stream()
113		.ignore_err()
114		.map(|item: (&ServerName, Cbor<_>)| (item.0, item.1.0))
115}
116
117#[implement(Cache)]
118pub fn overrides(&self) -> impl Stream<Item = (&ServerName, CachedOverride)> + Send + '_ {
119	self.overrides
120		.stream()
121		.ignore_err()
122		.map(|item: (&ServerName, Cbor<_>)| (item.0, item.1.0))
123}
124
125impl CachedDest {
126	#[inline]
127	#[must_use]
128	pub fn valid(&self) -> bool { self.expire > SystemTime::now() }
129
130	#[must_use]
131	pub(crate) fn default_expire() -> SystemTime {
132		rand::time_from_now_secs(60 * 60 * 18..60 * 60 * 36)
133	}
134
135	#[inline]
136	#[must_use]
137	pub fn size(&self) -> usize {
138		self.dest
139			.size()
140			.expected_add(self.host.len())
141			.expected_add(size_of_val(&self.expire))
142	}
143}
144
145impl CachedOverride {
146	#[inline]
147	#[must_use]
148	pub fn valid(&self) -> bool { self.expire > SystemTime::now() }
149
150	#[must_use]
151	pub(crate) fn default_expire() -> SystemTime {
152		rand::time_from_now_secs(60 * 60 * 6..60 * 60 * 12)
153	}
154
155	#[inline]
156	#[must_use]
157	pub fn size(&self) -> usize { size_of_val(self) }
158}