tuwunel_service/resolver/
cache.rs1use std::{net::IpAddr, sync::Arc, time::SystemTime};
2
3use futures::{Stream, StreamExt, future::join};
4use ruma::ServerName;
5use serde::{Deserialize, Serialize};
6use tuwunel_core::{
7 Result,
8 arrayvec::ArrayVec,
9 at, err, implement,
10 utils::{math::Expected, rand, stream::TryIgnore},
11};
12use tuwunel_database::{Cbor, Deserialized, Map};
13
14use super::{DestString, FedDest};
15
16pub struct Cache {
17 destinations: Arc<Map>,
18 overrides: Arc<Map>,
19}
20
21#[derive(Clone, Debug, Deserialize, Serialize)]
22pub struct CachedDest {
23 pub dest: FedDest,
24 pub host: DestString,
25 pub expire: SystemTime,
26}
27
28#[derive(Clone, Debug, Deserialize, Serialize)]
29pub struct CachedOverride {
30 pub ips: IpAddrs,
31 pub port: u16,
32 pub expire: SystemTime,
33 pub overriding: Option<DestString>,
34}
35
36pub type IpAddrs = ArrayVec<IpAddr, MAX_IPS>;
37pub(crate) const MAX_IPS: usize = 3;
38
39impl Cache {
40 pub(super) fn new(args: &crate::Args<'_>) -> Arc<Self> {
41 Arc::new(Self {
42 destinations: args.db["servername_destination"].clone(),
43 overrides: args.db["servername_override"].clone(),
44 })
45 }
46}
47
48#[implement(Cache)]
49pub async fn clear(&self) { join(self.clear_destinations(), self.clear_overrides()).await; }
50
51#[implement(Cache)]
52pub async fn clear_destinations(&self) { self.destinations.clear().await; }
53
54#[implement(Cache)]
55pub async fn clear_overrides(&self) { self.overrides.clear().await; }
56
57#[implement(Cache)]
58pub fn del_destination(&self, name: &ServerName) { self.destinations.remove(name); }
59
60#[implement(Cache)]
61pub fn del_override(&self, name: &ServerName) { self.overrides.remove(name); }
62
63#[implement(Cache)]
64pub fn set_destination(&self, name: &ServerName, dest: &CachedDest) {
65 self.destinations.raw_put(name, Cbor(dest));
66}
67
68#[implement(Cache)]
69pub fn set_override(&self, name: &str, over: &CachedOverride) {
70 self.overrides.raw_put(name, Cbor(over));
71}
72
73#[implement(Cache)]
74#[must_use]
75pub async fn has_destination(&self, destination: &ServerName) -> bool {
76 self.get_destination(destination).await.is_ok()
77}
78
79#[implement(Cache)]
80#[must_use]
81pub async fn has_override(&self, destination: &str) -> bool {
82 self.get_override(destination)
83 .await
84 .iter()
85 .any(CachedOverride::valid)
86}
87
88#[implement(Cache)]
89pub async fn get_destination(&self, name: &ServerName) -> Result<CachedDest> {
90 self.destinations
91 .get(name)
92 .await
93 .deserialized::<Cbor<_>>()
94 .map(at!(0))
95 .into_iter()
96 .find(CachedDest::valid)
97 .ok_or(err!(Request(NotFound("Expired from cache"))))
98}
99
100#[implement(Cache)]
101pub async fn get_override(&self, name: &str) -> Result<CachedOverride> {
102 self.overrides
103 .get(name)
104 .await
105 .deserialized::<Cbor<_>>()
106 .map(at!(0))
107}
108
109#[implement(Cache)]
110pub fn destinations(&self) -> impl Stream<Item = (&ServerName, CachedDest)> + Send + '_ {
111 self.destinations
112 .stream()
113 .ignore_err()
114 .map(|item: (&ServerName, Cbor<_>)| (item.0, item.1.0))
115}
116
117#[implement(Cache)]
118pub fn overrides(&self) -> impl Stream<Item = (&ServerName, CachedOverride)> + Send + '_ {
119 self.overrides
120 .stream()
121 .ignore_err()
122 .map(|item: (&ServerName, Cbor<_>)| (item.0, item.1.0))
123}
124
125impl CachedDest {
126 #[inline]
127 #[must_use]
128 pub fn valid(&self) -> bool { self.expire > SystemTime::now() }
129
130 #[must_use]
131 pub(crate) fn default_expire() -> SystemTime {
132 rand::time_from_now_secs(60 * 60 * 18..60 * 60 * 36)
133 }
134
135 #[inline]
136 #[must_use]
137 pub fn size(&self) -> usize {
138 self.dest
139 .size()
140 .expected_add(self.host.len())
141 .expected_add(size_of_val(&self.expire))
142 }
143}
144
145impl CachedOverride {
146 #[inline]
147 #[must_use]
148 pub fn valid(&self) -> bool { self.expire > SystemTime::now() }
149
150 #[must_use]
151 pub(crate) fn default_expire() -> SystemTime {
152 rand::time_from_now_secs(60 * 60 * 6..60 * 60 * 12)
153 }
154
155 #[inline]
156 #[must_use]
157 pub fn size(&self) -> usize { size_of_val(self) }
158}