Skip to main content

tuwunel_service/server_keys/
request.rs

1use std::{collections::BTreeMap, convert::identity, fmt::Debug};
2
3use futures::{FutureExt, StreamExt, TryFutureExt};
4use ruma::{
5	OwnedServerName, OwnedServerSigningKeyId, ServerName, ServerSigningKeyId,
6	api::federation::discovery::{
7		ServerSigningKeys, get_remote_server_keys,
8		get_remote_server_keys_batch::{self, v2::QueryCriteria},
9		get_server_keys,
10	},
11};
12use tuwunel_core::{
13	Err, Result, error, implement, info, trace,
14	utils::stream::{IterStream, ReadyExt, TryBroadbandExt, TryReadyExt},
15};
16
17#[implement(super::Service)]
18pub(super) async fn batch_notary_request<'a, S, K>(
19	&self,
20	notary: &ServerName,
21	batch: S,
22) -> Result<Vec<ServerSigningKeys>>
23where
24	S: Iterator<Item = (&'a ServerName, K)> + Send,
25	K: Iterator<Item = &'a ServerSigningKeyId> + Send,
26{
27	use get_remote_server_keys_batch::v2::Request;
28	type RumaBatch = BTreeMap<OwnedServerName, BTreeMap<OwnedServerSigningKeyId, QueryCriteria>>;
29
30	let criteria = QueryCriteria {
31		minimum_valid_until_ts: Some(self.minimum_valid_ts()),
32	};
33
34	let mut server_keys = batch.fold(RumaBatch::new(), |mut batch, (server, key_ids)| {
35		batch
36			.entry(server.into())
37			.or_default()
38			.extend(key_ids.map(|key_id| (key_id.into(), criteria.clone())));
39
40		batch
41	});
42
43	let total_keys = server_keys
44		.values()
45		.flat_map(|ids| ids.iter())
46		.count();
47
48	debug_assert!(total_keys > 0, "empty batch request to notary");
49
50	let batch_max = self
51		.services
52		.server
53		.config
54		.trusted_server_batch_size;
55
56	let batch_concurrency = self
57		.services
58		.server
59		.config
60		.trusted_server_batch_concurrency;
61
62	let batches: Vec<_> = server_keys
63		.keys()
64		.rev()
65		.step_by(batch_max.saturating_sub(1))
66		.skip(1)
67		.chain(server_keys.keys().next())
68		.cloned()
69		.collect();
70
71	batches
72		.iter()
73		.stream()
74		.enumerate()
75		.map(|(i, batch)| {
76			let request = Request {
77				server_keys: server_keys.split_off(batch),
78			};
79
80			if request.server_keys.is_empty() {
81				return None;
82			}
83
84			trace!(
85				%i, %notary, ?batch,
86				remaining = ?server_keys,
87				requesting = ?request.server_keys.keys(),
88				"Request to notary server."
89			);
90
91			info!(
92				%notary,
93				remaining = %server_keys.len(),
94				requesting = %request.server_keys.len(),
95				"Sending request to notary server..."
96			);
97
98			Some(Ok(request))
99		})
100		.ready_filter_map(identity)
101		.broadn_and_then(batch_concurrency, |request| {
102			self.services
103				.federation
104				.execute_synapse(notary, request)
105		})
106		.ready_try_fold(Vec::new(), |mut results, response| {
107			let response = response
108				.server_keys
109				.into_iter()
110				.map(|key| key.deserialize())
111				.filter_map(Result::ok);
112
113			trace!(
114				%notary, ?response,
115				"Response from notary server."
116			);
117
118			results.extend(response);
119
120			info!(
121				"Received {0} keys out of {1} from notary server so far...",
122				results.len(),
123				total_keys,
124			);
125
126			Ok(results)
127		})
128		.inspect_err(|e| {
129			error!(
130				?notary, %batch_max, %batch_concurrency, %total_keys,
131				"Requesting keys from notary server failed: {e}",
132			);
133		})
134		.boxed()
135		.await
136}
137
138#[implement(super::Service)]
139pub async fn notary_request(
140	&self,
141	notary: &ServerName,
142	target: &ServerName,
143) -> Result<impl Iterator<Item = ServerSigningKeys> + Clone + Debug + Send + use<>> {
144	use get_remote_server_keys::v2::Request;
145
146	let request = Request {
147		server_name: target.into(),
148		minimum_valid_until_ts: self.minimum_valid_ts(),
149	};
150
151	let response = self
152		.services
153		.federation
154		.execute(notary, request)
155		.await?
156		.server_keys
157		.into_iter()
158		.map(|key| key.deserialize())
159		.filter_map(Result::ok);
160
161	Ok(response)
162}
163
164#[implement(super::Service)]
165pub async fn server_request(&self, target: &ServerName) -> Result<ServerSigningKeys> {
166	use get_server_keys::v2::Request;
167
168	let server_signing_key = self
169		.services
170		.federation
171		.execute(target, Request::new())
172		.await
173		.map(|response| response.server_key)
174		.and_then(|key| key.deserialize().map_err(Into::into))?;
175
176	if server_signing_key.server_name != target {
177		return Err!(BadServerResponse(debug_warn!(
178			requested = ?target,
179			response = ?server_signing_key.server_name,
180			"Server responded with bogus server_name"
181		)));
182	}
183
184	Ok(server_signing_key)
185}