tuwunel_service/server_keys/
request.rs1use std::{collections::BTreeMap, convert::identity, fmt::Debug};
2
3use futures::{FutureExt, StreamExt, TryFutureExt};
4use ruma::{
5 OwnedServerName, OwnedServerSigningKeyId, ServerName, ServerSigningKeyId,
6 api::federation::discovery::{
7 ServerSigningKeys, get_remote_server_keys,
8 get_remote_server_keys_batch::{self, v2::QueryCriteria},
9 get_server_keys,
10 },
11};
12use tuwunel_core::{
13 Err, Result, error, implement, info, trace,
14 utils::stream::{IterStream, ReadyExt, TryBroadbandExt, TryReadyExt},
15};
16
17#[implement(super::Service)]
18pub(super) async fn batch_notary_request<'a, S, K>(
19 &self,
20 notary: &ServerName,
21 batch: S,
22) -> Result<Vec<ServerSigningKeys>>
23where
24 S: Iterator<Item = (&'a ServerName, K)> + Send,
25 K: Iterator<Item = &'a ServerSigningKeyId> + Send,
26{
27 use get_remote_server_keys_batch::v2::Request;
28 type RumaBatch = BTreeMap<OwnedServerName, BTreeMap<OwnedServerSigningKeyId, QueryCriteria>>;
29
30 let criteria = QueryCriteria {
31 minimum_valid_until_ts: Some(self.minimum_valid_ts()),
32 };
33
34 let mut server_keys = batch.fold(RumaBatch::new(), |mut batch, (server, key_ids)| {
35 batch
36 .entry(server.into())
37 .or_default()
38 .extend(key_ids.map(|key_id| (key_id.into(), criteria.clone())));
39
40 batch
41 });
42
43 let total_keys = server_keys
44 .values()
45 .flat_map(|ids| ids.iter())
46 .count();
47
48 debug_assert!(total_keys > 0, "empty batch request to notary");
49
50 let batch_max = self
51 .services
52 .server
53 .config
54 .trusted_server_batch_size;
55
56 let batch_concurrency = self
57 .services
58 .server
59 .config
60 .trusted_server_batch_concurrency;
61
62 let batches: Vec<_> = server_keys
63 .keys()
64 .rev()
65 .step_by(batch_max.saturating_sub(1))
66 .skip(1)
67 .chain(server_keys.keys().next())
68 .cloned()
69 .collect();
70
71 batches
72 .iter()
73 .stream()
74 .enumerate()
75 .map(|(i, batch)| {
76 let request = Request {
77 server_keys: server_keys.split_off(batch),
78 };
79
80 if request.server_keys.is_empty() {
81 return None;
82 }
83
84 trace!(
85 %i, %notary, ?batch,
86 remaining = ?server_keys,
87 requesting = ?request.server_keys.keys(),
88 "Request to notary server."
89 );
90
91 info!(
92 %notary,
93 remaining = %server_keys.len(),
94 requesting = %request.server_keys.len(),
95 "Sending request to notary server..."
96 );
97
98 Some(Ok(request))
99 })
100 .ready_filter_map(identity)
101 .broadn_and_then(batch_concurrency, |request| {
102 self.services
103 .federation
104 .execute_synapse(notary, request)
105 })
106 .ready_try_fold(Vec::new(), |mut results, response| {
107 let response = response
108 .server_keys
109 .into_iter()
110 .map(|key| key.deserialize())
111 .filter_map(Result::ok);
112
113 trace!(
114 %notary, ?response,
115 "Response from notary server."
116 );
117
118 results.extend(response);
119
120 info!(
121 "Received {0} keys out of {1} from notary server so far...",
122 results.len(),
123 total_keys,
124 );
125
126 Ok(results)
127 })
128 .inspect_err(|e| {
129 error!(
130 ?notary, %batch_max, %batch_concurrency, %total_keys,
131 "Requesting keys from notary server failed: {e}",
132 );
133 })
134 .boxed()
135 .await
136}
137
138#[implement(super::Service)]
139pub async fn notary_request(
140 &self,
141 notary: &ServerName,
142 target: &ServerName,
143) -> Result<impl Iterator<Item = ServerSigningKeys> + Clone + Debug + Send + use<>> {
144 use get_remote_server_keys::v2::Request;
145
146 let request = Request {
147 server_name: target.into(),
148 minimum_valid_until_ts: self.minimum_valid_ts(),
149 };
150
151 let response = self
152 .services
153 .federation
154 .execute(notary, request)
155 .await?
156 .server_keys
157 .into_iter()
158 .map(|key| key.deserialize())
159 .filter_map(Result::ok);
160
161 Ok(response)
162}
163
164#[implement(super::Service)]
165pub async fn server_request(&self, target: &ServerName) -> Result<ServerSigningKeys> {
166 use get_server_keys::v2::Request;
167
168 let server_signing_key = self
169 .services
170 .federation
171 .execute(target, Request::new())
172 .await
173 .map(|response| response.server_key)
174 .and_then(|key| key.deserialize().map_err(Into::into))?;
175
176 if server_signing_key.server_name != target {
177 return Err!(BadServerResponse(debug_warn!(
178 requested = ?target,
179 response = ?server_signing_key.server_name,
180 "Server responded with bogus server_name"
181 )));
182 }
183
184 Ok(server_signing_key)
185}