Skip to main content

tuwunel_service/server_keys/
acquire.rs

1use std::{
2	borrow::Borrow,
3	collections::{BTreeMap, BTreeSet},
4	time::Duration,
5};
6
7use futures::{StreamExt, stream::FuturesUnordered};
8use ruma::{
9	CanonicalJsonObject, OwnedServerName, OwnedServerSigningKeyId, ServerName,
10	ServerSigningKeyId, api::federation::discovery::ServerSigningKeys, serde::Raw,
11};
12use serde_json::value::RawValue as RawJsonValue;
13use tokio::time::{Instant, timeout_at};
14use tuwunel_core::{
15	debug, debug_error, debug_warn, error, implement, info, result::FlatOk, trace, warn,
16};
17
18use super::key_exists;
19
20type Batch = BTreeMap<OwnedServerName, Vec<OwnedServerSigningKeyId>>;
21
22#[implement(super::Service)]
23pub async fn acquire_events_pubkeys<'a, I>(&self, events: I)
24where
25	I: Iterator<Item = &'a Box<RawJsonValue>> + Send,
26{
27	type Batch = BTreeMap<OwnedServerName, BTreeSet<OwnedServerSigningKeyId>>;
28	type Signatures = BTreeMap<OwnedServerName, BTreeMap<OwnedServerSigningKeyId, String>>;
29
30	let mut batch = Batch::new();
31	events
32		.cloned()
33		.map(Raw::<CanonicalJsonObject>::from_json)
34		.map(|event| event.get_field::<Signatures>("signatures"))
35		.filter_map(FlatOk::flat_ok)
36		.flat_map(IntoIterator::into_iter)
37		.for_each(|(server, sigs)| {
38			batch
39				.entry(server)
40				.or_default()
41				.extend(sigs.into_keys());
42		});
43
44	let batch = batch
45		.iter()
46		.map(|(server, keys)| (server.borrow(), keys.iter().map(Borrow::borrow)));
47
48	self.acquire_pubkeys(batch).await;
49}
50
51#[implement(super::Service)]
52pub async fn acquire_pubkeys<'a, S, K>(&self, batch: S)
53where
54	S: Iterator<Item = (&'a ServerName, K)> + Send + Clone,
55	K: Iterator<Item = &'a ServerSigningKeyId> + Send + Clone,
56{
57	let notary_only = self
58		.services
59		.config
60		.only_query_trusted_key_servers;
61
62	let notary_first_always = self
63		.services
64		.config
65		.query_trusted_key_servers_first;
66
67	let notary_first_on_join = self
68		.services
69		.config
70		.query_trusted_key_servers_first_on_join;
71
72	let requested_servers = batch.clone().count();
73	let requested_keys = batch
74		.clone()
75		.flat_map(|(_, key_ids)| key_ids)
76		.count();
77
78	debug!("acquire {requested_keys} keys from {requested_servers}");
79
80	let mut missing = self.acquire_locals(batch).await;
81	let mut missing_keys = keys_count(&missing);
82	let mut missing_servers = missing.len();
83	if missing_servers == 0 {
84		return;
85	}
86
87	info!("{missing_keys} keys for {missing_servers} servers will be acquired");
88
89	if notary_first_always || notary_first_on_join {
90		missing = self.acquire_notary(missing.into_iter()).await;
91		missing_keys = keys_count(&missing);
92		missing_servers = missing.len();
93		if missing_keys == 0 {
94			return;
95		}
96
97		warn!(
98			"missing {missing_keys} keys for {missing_servers} servers from all notaries first"
99		);
100	}
101
102	if !notary_only {
103		missing = self.acquire_origins(missing.into_iter()).await;
104		missing_keys = keys_count(&missing);
105		missing_servers = missing.len();
106		if missing_keys == 0 {
107			return;
108		}
109
110		debug_warn!("missing {missing_keys} keys for {missing_servers} servers unreachable");
111	}
112
113	if !notary_first_always && !notary_first_on_join {
114		missing = self.acquire_notary(missing.into_iter()).await;
115		missing_keys = keys_count(&missing);
116		missing_servers = missing.len();
117		if missing_keys == 0 {
118			return;
119		}
120
121		debug_warn!(
122			"still missing {missing_keys} keys for {missing_servers} servers from all notaries."
123		);
124	}
125
126	if missing_keys > 0 {
127		warn!(
128			"did not obtain {missing_keys} keys for {missing_servers} servers out of \
129			 {requested_keys} total keys for {requested_servers} total servers."
130		);
131	}
132
133	for (server, key_ids) in missing {
134		debug_warn!(?server, ?key_ids, "missing");
135	}
136}
137
138#[implement(super::Service)]
139async fn acquire_locals<'a, S, K>(&self, batch: S) -> Batch
140where
141	S: Iterator<Item = (&'a ServerName, K)> + Send,
142	K: Iterator<Item = &'a ServerSigningKeyId> + Send,
143{
144	let mut missing = Batch::new();
145	for (server, key_ids) in batch {
146		for key_id in key_ids {
147			if !self.verify_key_exists(server, key_id).await {
148				missing
149					.entry(server.into())
150					.or_default()
151					.push(key_id.into());
152			}
153		}
154	}
155
156	missing
157}
158
159#[implement(super::Service)]
160async fn acquire_origins<I>(&self, batch: I) -> Batch
161where
162	I: Iterator<Item = (OwnedServerName, Vec<OwnedServerSigningKeyId>)> + Send,
163{
164	let timeout = Instant::now()
165		.checked_add(Duration::from_secs(45))
166		.expect("timeout overflows");
167
168	let mut requests: FuturesUnordered<_> = batch
169		.map(|(origin, key_ids)| self.acquire_origin(origin, key_ids, timeout))
170		.collect();
171
172	let mut missing = Batch::new();
173	while let Some((origin, key_ids)) = requests.next().await {
174		if !key_ids.is_empty() {
175			missing.insert(origin, key_ids);
176		}
177	}
178
179	missing
180}
181
182#[implement(super::Service)]
183async fn acquire_origin(
184	&self,
185	origin: OwnedServerName,
186	mut key_ids: Vec<OwnedServerSigningKeyId>,
187	timeout: Instant,
188) -> (OwnedServerName, Vec<OwnedServerSigningKeyId>) {
189	match timeout_at(timeout, self.server_request(&origin)).await {
190		| Err(e) => debug_warn!(?origin, "timed out: {e}"),
191		| Ok(Err(e)) => debug_error!(?origin, "{e}"),
192		| Ok(Ok(server_keys)) => {
193			trace!(
194				%origin,
195				?key_ids,
196				?server_keys,
197				"received server_keys"
198			);
199
200			self.add_signing_keys(server_keys.clone()).await;
201			key_ids.retain(|key_id| !key_exists(&server_keys, key_id));
202		},
203	}
204
205	(origin, key_ids)
206}
207
208#[implement(super::Service)]
209async fn acquire_notary<I>(&self, batch: I) -> Batch
210where
211	I: Iterator<Item = (OwnedServerName, Vec<OwnedServerSigningKeyId>)> + Send,
212{
213	let mut missing: Batch = batch.collect();
214	for notary in &self.services.config.trusted_servers {
215		let missing_keys = keys_count(&missing);
216		let missing_servers = missing.len();
217		debug!(
218			"Asking notary {notary} for {missing_keys} missing keys from {missing_servers} \
219			 servers"
220		);
221
222		let batch = missing
223			.iter()
224			.map(|(server, keys)| (server.borrow(), keys.iter().map(Borrow::borrow)));
225
226		match self.batch_notary_request(notary, batch).await {
227			| Err(e) => error!("Failed to contact notary {notary:?}: {e}"),
228			| Ok(results) =>
229				for server_keys in results {
230					self.acquire_notary_result(&mut missing, server_keys)
231						.await;
232				},
233		}
234	}
235
236	missing
237}
238
239#[implement(super::Service)]
240async fn acquire_notary_result(&self, missing: &mut Batch, server_keys: ServerSigningKeys) {
241	let server = &server_keys.server_name;
242	self.add_signing_keys(server_keys.clone()).await;
243
244	if let Some(key_ids) = missing.get_mut(server) {
245		key_ids.retain(|key_id| !key_exists(&server_keys, key_id));
246		if key_ids.is_empty() {
247			missing.remove(server);
248		}
249	}
250}
251
252fn keys_count(batch: &Batch) -> usize {
253	batch
254		.values()
255		.flat_map(|key_ids| key_ids.iter())
256		.count()
257}