1use std::collections::{BTreeMap, HashMap};
2
3use axum::extract::State;
4use futures::{
5 FutureExt, StreamExt,
6 future::{
7 Either::{Left, Right},
8 join, join4,
9 },
10};
11use ruma::{
12 CanonicalJsonObject, CanonicalJsonValue, DeviceId, OwnedDeviceId, OwnedUserId, ServerName,
13 UserId,
14 api::{
15 client::{device::Device, keys::get_keys},
16 federation,
17 },
18 encryption::{CrossSigningKey, DeviceKeys},
19 serde::Raw,
20};
21use serde_json::{json, value::to_raw_value};
22use tuwunel_core::{
23 Result, debug_warn, implement,
24 utils::{
25 BoolExt, IterStream,
26 future::TryExtExt,
27 json,
28 stream::{BroadbandExt, ReadyExt},
29 },
30};
31use tuwunel_service::{Services, users::parse_master_key};
32
33use super::FailureMap;
34use crate::Ruma;
35
36#[derive(Default)]
37struct Keys {
38 device_keys: DeviceKeyMap,
39 master_keys: CrossSigningKeys,
40 self_signing_keys: CrossSigningKeys,
41 user_signing_keys: CrossSigningKeys,
42 failures: FailureMap,
43}
44
45type DeviceLists = BTreeMap<OwnedUserId, Vec<OwnedDeviceId>>;
46type DeviceKeyMap = BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, Raw<DeviceKeys>>>;
47type ServerDevices<'a> = HashMap<&'a ServerName, DeviceLists>;
48type LocalDeviceUser<'a> = (&'a UserId, &'a Vec<OwnedDeviceId>);
49type CrossSigningKeys = BTreeMap<OwnedUserId, Raw<CrossSigningKey>>;
50
51pub(crate) async fn get_keys_route(
60 State(services): State<crate::State>,
61 body: Ruma<get_keys::v3::Request>,
62) -> Result<get_keys::v3::Response> {
63 let sender_user = body.sender_user();
64
65 get_keys_helper(
66 &services,
67 Some(sender_user),
68 &body.device_keys,
69 |u| u == sender_user,
70 true, )
72 .await
73}
74
75pub(crate) async fn get_keys_helper<F>(
76 services: &Services,
77 sender_user: Option<&UserId>,
78 device_keys_input: &DeviceLists,
79 allowed_signatures: F,
80 include_display_names: bool,
81) -> Result<get_keys::v3::Response>
82where
83 F: Fn(&UserId) -> bool + Send + Sync,
84{
85 let (local_users, remote_users): (Vec<LocalDeviceUser<'_>>, Vec<_>) = device_keys_input
86 .iter()
87 .map(|(uid, dids)| (uid.as_ref(), dids))
88 .partition(|(user_id, _)| services.globals.user_is_local(user_id));
89
90 let server: ServerDevices<'_> =
91 remote_users
92 .into_iter()
93 .fold(HashMap::new(), |mut acc, (user_id, device_ids)| {
94 acc.entry(user_id.server_name())
95 .or_default()
96 .insert(user_id.to_owned(), device_ids.clone());
97 acc
98 });
99
100 let local = collect_local_keys(
101 services,
102 &local_users,
103 sender_user,
104 &allowed_signatures,
105 include_display_names,
106 );
107
108 let federation = collect_federation_keys(services, server, sender_user, &allowed_signatures);
109
110 let (local, federation) = join(local, federation).await;
111 Ok(local.merge(federation).into_response())
112}
113
114async fn collect_local_keys<F>(
115 services: &Services,
116 users: &[LocalDeviceUser<'_>],
117 sender_user: Option<&UserId>,
118 allowed_signatures: &F,
119 include_display_names: bool,
120) -> Keys
121where
122 F: Fn(&UserId) -> bool + Send + Sync,
123{
124 users
125 .iter()
126 .copied()
127 .stream()
128 .broad_then(async |(user_id, device_ids)| {
129 collect_local_user_keys(
130 services,
131 user_id,
132 device_ids,
133 sender_user,
134 allowed_signatures,
135 include_display_names,
136 )
137 .await
138 })
139 .ready_fold(Keys::default(), Keys::merge)
140 .await
141}
142
143async fn collect_local_user_keys<F>(
144 services: &Services,
145 user_id: &UserId,
146 device_ids: &[OwnedDeviceId],
147 sender_user: Option<&UserId>,
148 allowed_signatures: &F,
149 include_display_names: bool,
150) -> Keys
151where
152 F: Fn(&UserId) -> bool + Send + Sync,
153{
154 let device_keys =
155 collect_local_device_keys(services, user_id, device_ids, include_display_names);
156
157 let master_key = services
158 .users
159 .get_master_key(sender_user, user_id, allowed_signatures)
160 .ok();
161
162 let self_signing_key = services
163 .users
164 .get_self_signing_key(sender_user, user_id, allowed_signatures)
165 .ok();
166
167 let user_signing_key = (Some(user_id) == sender_user)
168 .then_async(|| services.users.get_user_signing_key(user_id).ok())
169 .map(Option::flatten);
170
171 let (device_keys, master_key, self_signing_key, user_signing_key) =
172 join4(device_keys, master_key, self_signing_key, user_signing_key).await;
173
174 let owned = || user_id.to_owned();
175 Keys {
176 device_keys: BTreeMap::from([(owned(), device_keys)]),
177 master_keys: master_key
178 .map(|k| (owned(), k))
179 .into_iter()
180 .collect(),
181
182 self_signing_keys: self_signing_key
183 .map(|k| (owned(), k))
184 .into_iter()
185 .collect(),
186
187 user_signing_keys: user_signing_key
188 .map(|k| (owned(), k))
189 .into_iter()
190 .collect(),
191
192 ..Default::default()
193 }
194}
195
196async fn collect_local_device_keys(
197 services: &Services,
198 user_id: &UserId,
199 device_ids: &[OwnedDeviceId],
200 include_display_names: bool,
201) -> BTreeMap<OwnedDeviceId, Raw<DeviceKeys>> {
202 let stream = if device_ids.is_empty() {
203 Left(
204 services
205 .users
206 .all_device_ids(user_id)
207 .map(ToOwned::to_owned),
208 )
209 } else {
210 Right(device_ids.iter().cloned().stream())
211 };
212
213 stream
214 .broad_filter_map(async |device_id| {
215 get_local_device_keys(services, user_id, &device_id, include_display_names)
216 .await
217 .map(|keys| (device_id, keys))
218 })
219 .collect()
220 .await
221}
222
223async fn get_local_device_keys(
224 services: &Services,
225 user_id: &UserId,
226 device_id: &DeviceId,
227 include_display_names: bool,
228) -> Option<Raw<DeviceKeys>> {
229 let mut keys = services
230 .users
231 .get_device_keys(user_id, device_id)
232 .await
233 .ok()?;
234
235 let metadata = services
236 .users
237 .get_device_metadata(user_id, device_id)
238 .await
239 .inspect_err(|e| debug_warn!(?user_id, ?device_id, "device metadata missing: {e}"))
240 .ok()?;
241
242 add_unsigned_device_display_name(&mut keys, metadata, include_display_names)
243 .inspect_err(|e| debug_warn!(?user_id, ?device_id, "invalid device keys: {e}"))
244 .ok()?;
245
246 Some(keys)
247}
248
249async fn collect_federation_keys<F>(
250 services: &Services,
251 server: ServerDevices<'_>,
252 sender_user: Option<&UserId>,
253 allowed_signatures: &F,
254) -> Keys
255where
256 F: Fn(&UserId) -> bool + Send + Sync,
257{
258 server
259 .into_iter()
260 .stream()
261 .broad_then(async |(server, device_keys)| {
262 let request = federation::keys::get_keys::v1::Request { device_keys };
263
264 (server, services.federation.execute(server, request).await)
265 })
266 .broad_then(async |(server, response)| match response {
267 | Ok(response) =>
268 process_federation_response(services, sender_user, allowed_signatures, response)
269 .await,
270 | Err(e) => {
271 debug_warn!(%server, "key federation request failed: {e}");
272 Keys {
273 failures: BTreeMap::from([(server.to_string(), json!({}))]),
274 ..Default::default()
275 }
276 },
277 })
278 .ready_fold(Keys::default(), Keys::merge)
279 .await
280}
281
282async fn process_federation_response<F>(
283 services: &Services,
284 sender_user: Option<&UserId>,
285 allowed_signatures: &F,
286 response: federation::keys::get_keys::v1::Response,
287) -> Keys
288where
289 F: Fn(&UserId) -> bool + Send + Sync,
290{
291 let federation::keys::get_keys::v1::Response {
292 master_keys,
293 self_signing_keys,
294 device_keys,
295 } = response;
296
297 let master_keys = master_keys
298 .into_iter()
299 .stream()
300 .broad_filter_map(async |(user, master_key)| {
301 merge_remote_master_key(services, sender_user, allowed_signatures, &user, master_key)
302 .await
303 .inspect_err(|e| debug_warn!(?user, "skipping master key from federation: {e}"))
304 .map(|raw| (user, raw))
305 .ok()
306 })
307 .collect()
308 .await;
309
310 Keys {
311 device_keys,
312 master_keys,
313 self_signing_keys,
314 user_signing_keys: BTreeMap::new(),
315 failures: BTreeMap::new(),
316 }
317}
318
319async fn merge_remote_master_key<F>(
323 services: &Services,
324 sender_user: Option<&UserId>,
325 allowed_signatures: &F,
326 user: &UserId,
327 master_key_raw: Raw<CrossSigningKey>,
328) -> Result<Raw<CrossSigningKey>>
329where
330 F: Fn(&UserId) -> bool + Send + Sync,
331{
332 let (master_key_id, mut master_key) = parse_master_key(user, &master_key_raw)?;
333 let our_raw = services
334 .users
335 .get_key(&master_key_id, sender_user, user, allowed_signatures)
336 .await;
337
338 if let Ok(our_raw) = our_raw
339 && let Ok((_, mut ours)) = parse_master_key(user, &our_raw)
340 {
341 master_key.signatures.append(&mut ours.signatures);
342 }
343
344 let raw = json::to_raw(&master_key)?;
345
346 services
349 .users
350 .add_cross_signing_keys(user, &Some(raw.clone()), &None, &None, false)
351 .await?;
352
353 Ok(raw)
354}
355
356fn add_unsigned_device_display_name(
357 keys: &mut Raw<DeviceKeys>,
358 metadata: Device,
359 include_display_names: bool,
360) -> Result {
361 let Some(display_name) = metadata.display_name else {
362 return Ok(());
363 };
364
365 let mut object = keys.deserialize_as_unchecked::<CanonicalJsonObject>()?;
366
367 if let CanonicalJsonValue::Object(unsigned) = object
368 .entry("unsigned".into())
369 .or_insert_with(|| CanonicalJsonObject::default().into())
370 {
371 let display_name = if include_display_names {
372 CanonicalJsonValue::String(display_name.to_string())
373 } else {
374 CanonicalJsonValue::String(metadata.device_id.into())
375 };
376
377 unsigned.insert("device_display_name".into(), display_name);
378 }
379
380 *keys = Raw::from_json(to_raw_value(&object)?);
381
382 Ok(())
383}
384
385#[implement(Keys)]
386fn merge(mut self, other: Self) -> Self {
387 self.failures.extend(other.failures);
388 self.device_keys.extend(other.device_keys);
389 self.master_keys.extend(other.master_keys);
390 self.self_signing_keys
391 .extend(other.self_signing_keys);
392 self.user_signing_keys
393 .extend(other.user_signing_keys);
394 self
395}
396
397#[implement(Keys)]
398fn into_response(self) -> get_keys::v3::Response {
399 get_keys::v3::Response {
400 failures: self.failures,
401 device_keys: self.device_keys,
402 master_keys: self.master_keys,
403 self_signing_keys: self.self_signing_keys,
404 user_signing_keys: self.user_signing_keys,
405 }
406}