tuwunel_service/server_keys/
acquire.rs1use std::{
2 borrow::Borrow,
3 collections::{BTreeMap, BTreeSet},
4 time::Duration,
5};
6
7use futures::{StreamExt, stream::FuturesUnordered};
8use ruma::{
9 CanonicalJsonObject, OwnedServerName, OwnedServerSigningKeyId, ServerName,
10 ServerSigningKeyId, api::federation::discovery::ServerSigningKeys, serde::Raw,
11};
12use serde_json::value::RawValue as RawJsonValue;
13use tokio::time::{Instant, timeout_at};
14use tuwunel_core::{
15 debug, debug_error, debug_warn, error, implement, info, result::FlatOk, trace, warn,
16};
17
18use super::key_exists;
19
20type Batch = BTreeMap<OwnedServerName, Vec<OwnedServerSigningKeyId>>;
21
22#[implement(super::Service)]
23pub async fn acquire_events_pubkeys<'a, I>(&self, events: I)
24where
25 I: Iterator<Item = &'a Box<RawJsonValue>> + Send,
26{
27 type Batch = BTreeMap<OwnedServerName, BTreeSet<OwnedServerSigningKeyId>>;
28 type Signatures = BTreeMap<OwnedServerName, BTreeMap<OwnedServerSigningKeyId, String>>;
29
30 let mut batch = Batch::new();
31 events
32 .cloned()
33 .map(Raw::<CanonicalJsonObject>::from_json)
34 .map(|event| event.get_field::<Signatures>("signatures"))
35 .filter_map(FlatOk::flat_ok)
36 .flat_map(IntoIterator::into_iter)
37 .for_each(|(server, sigs)| {
38 batch
39 .entry(server)
40 .or_default()
41 .extend(sigs.into_keys());
42 });
43
44 let batch = batch
45 .iter()
46 .map(|(server, keys)| (server.borrow(), keys.iter().map(Borrow::borrow)));
47
48 self.acquire_pubkeys(batch).await;
49}
50
51#[implement(super::Service)]
52pub async fn acquire_pubkeys<'a, S, K>(&self, batch: S)
53where
54 S: Iterator<Item = (&'a ServerName, K)> + Send + Clone,
55 K: Iterator<Item = &'a ServerSigningKeyId> + Send + Clone,
56{
57 let notary_only = self
58 .services
59 .config
60 .only_query_trusted_key_servers;
61
62 let notary_first_always = self
63 .services
64 .config
65 .query_trusted_key_servers_first;
66
67 let notary_first_on_join = self
68 .services
69 .config
70 .query_trusted_key_servers_first_on_join;
71
72 let requested_servers = batch.clone().count();
73 let requested_keys = batch
74 .clone()
75 .flat_map(|(_, key_ids)| key_ids)
76 .count();
77
78 debug!("acquire {requested_keys} keys from {requested_servers}");
79
80 let mut missing = self.acquire_locals(batch).await;
81 let mut missing_keys = keys_count(&missing);
82 let mut missing_servers = missing.len();
83 if missing_servers == 0 {
84 return;
85 }
86
87 info!("{missing_keys} keys for {missing_servers} servers will be acquired");
88
89 if notary_first_always || notary_first_on_join {
90 missing = self.acquire_notary(missing.into_iter()).await;
91 missing_keys = keys_count(&missing);
92 missing_servers = missing.len();
93 if missing_keys == 0 {
94 return;
95 }
96
97 warn!(
98 "missing {missing_keys} keys for {missing_servers} servers from all notaries first"
99 );
100 }
101
102 if !notary_only {
103 missing = self.acquire_origins(missing.into_iter()).await;
104 missing_keys = keys_count(&missing);
105 missing_servers = missing.len();
106 if missing_keys == 0 {
107 return;
108 }
109
110 debug_warn!("missing {missing_keys} keys for {missing_servers} servers unreachable");
111 }
112
113 if !notary_first_always && !notary_first_on_join {
114 missing = self.acquire_notary(missing.into_iter()).await;
115 missing_keys = keys_count(&missing);
116 missing_servers = missing.len();
117 if missing_keys == 0 {
118 return;
119 }
120
121 debug_warn!(
122 "still missing {missing_keys} keys for {missing_servers} servers from all notaries."
123 );
124 }
125
126 if missing_keys > 0 {
127 warn!(
128 "did not obtain {missing_keys} keys for {missing_servers} servers out of \
129 {requested_keys} total keys for {requested_servers} total servers."
130 );
131 }
132
133 for (server, key_ids) in missing {
134 debug_warn!(?server, ?key_ids, "missing");
135 }
136}
137
138#[implement(super::Service)]
139async fn acquire_locals<'a, S, K>(&self, batch: S) -> Batch
140where
141 S: Iterator<Item = (&'a ServerName, K)> + Send,
142 K: Iterator<Item = &'a ServerSigningKeyId> + Send,
143{
144 let mut missing = Batch::new();
145 for (server, key_ids) in batch {
146 for key_id in key_ids {
147 if !self.verify_key_exists(server, key_id).await {
148 missing
149 .entry(server.into())
150 .or_default()
151 .push(key_id.into());
152 }
153 }
154 }
155
156 missing
157}
158
159#[implement(super::Service)]
160async fn acquire_origins<I>(&self, batch: I) -> Batch
161where
162 I: Iterator<Item = (OwnedServerName, Vec<OwnedServerSigningKeyId>)> + Send,
163{
164 let timeout = Instant::now()
165 .checked_add(Duration::from_secs(45))
166 .expect("timeout overflows");
167
168 let mut requests: FuturesUnordered<_> = batch
169 .map(|(origin, key_ids)| self.acquire_origin(origin, key_ids, timeout))
170 .collect();
171
172 let mut missing = Batch::new();
173 while let Some((origin, key_ids)) = requests.next().await {
174 if !key_ids.is_empty() {
175 missing.insert(origin, key_ids);
176 }
177 }
178
179 missing
180}
181
182#[implement(super::Service)]
183async fn acquire_origin(
184 &self,
185 origin: OwnedServerName,
186 mut key_ids: Vec<OwnedServerSigningKeyId>,
187 timeout: Instant,
188) -> (OwnedServerName, Vec<OwnedServerSigningKeyId>) {
189 match timeout_at(timeout, self.server_request(&origin)).await {
190 | Err(e) => debug_warn!(?origin, "timed out: {e}"),
191 | Ok(Err(e)) => debug_error!(?origin, "{e}"),
192 | Ok(Ok(server_keys)) => {
193 trace!(
194 %origin,
195 ?key_ids,
196 ?server_keys,
197 "received server_keys"
198 );
199
200 self.add_signing_keys(server_keys.clone()).await;
201 key_ids.retain(|key_id| !key_exists(&server_keys, key_id));
202 },
203 }
204
205 (origin, key_ids)
206}
207
208#[implement(super::Service)]
209async fn acquire_notary<I>(&self, batch: I) -> Batch
210where
211 I: Iterator<Item = (OwnedServerName, Vec<OwnedServerSigningKeyId>)> + Send,
212{
213 let mut missing: Batch = batch.collect();
214 for notary in &self.services.config.trusted_servers {
215 let missing_keys = keys_count(&missing);
216 let missing_servers = missing.len();
217 debug!(
218 "Asking notary {notary} for {missing_keys} missing keys from {missing_servers} \
219 servers"
220 );
221
222 let batch = missing
223 .iter()
224 .map(|(server, keys)| (server.borrow(), keys.iter().map(Borrow::borrow)));
225
226 match self.batch_notary_request(notary, batch).await {
227 | Err(e) => error!("Failed to contact notary {notary:?}: {e}"),
228 | Ok(results) =>
229 for server_keys in results {
230 self.acquire_notary_result(&mut missing, server_keys)
231 .await;
232 },
233 }
234 }
235
236 missing
237}
238
239#[implement(super::Service)]
240async fn acquire_notary_result(&self, missing: &mut Batch, server_keys: ServerSigningKeys) {
241 let server = &server_keys.server_name;
242 self.add_signing_keys(server_keys.clone()).await;
243
244 if let Some(key_ids) = missing.get_mut(server) {
245 key_ids.retain(|key_id| !key_exists(&server_keys, key_id));
246 if key_ids.is_empty() {
247 missing.remove(server);
248 }
249 }
250}
251
252fn keys_count(batch: &Batch) -> usize {
253 batch
254 .values()
255 .flat_map(|key_ids| key_ids.iter())
256 .count()
257}