tuwunel_api/client/keys/
claim_keys.rs1use std::collections::BTreeMap;
2
3use axum::extract::State;
4use futures::{FutureExt, StreamExt, future::join};
5use ruma::{
6 OneTimeKeyAlgorithm, OwnedDeviceId, OwnedOneTimeKeyId, OwnedUserId, ServerName, UserId,
7 api::{client::keys::claim_keys, federation},
8 encryption::OneTimeKey,
9 serde::Raw,
10};
11use serde_json::json;
12use tuwunel_core::{
13 Result, debug_warn,
14 utils::{
15 IterStream,
16 stream::{BroadbandExt, ReadyExt},
17 },
18};
19use tuwunel_service::Services;
20
21use super::FailureMap;
22use crate::Ruma;
23
24#[derive(Default)]
25struct Claims {
26 one_time_keys: OneTimeKeyMap,
27 failures: FailureMap,
28}
29
30type RequestClaims = BTreeMap<OwnedUserId, Algorithms>;
31type ServerClaims<'a> = BTreeMap<&'a ServerName, RequestClaims>;
32type LocalClaim<'a> = (&'a UserId, &'a Algorithms);
33type Algorithms = BTreeMap<OwnedDeviceId, OneTimeKeyAlgorithm>;
34type OneTimeKeys = BTreeMap<OwnedOneTimeKeyId, Raw<OneTimeKey>>;
35type OneTimeKeyMap = BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, OneTimeKeys>>;
36
37pub(crate) async fn claim_keys_route(
41 State(services): State<crate::State>,
42 body: Ruma<claim_keys::v3::Request>,
43) -> Result<claim_keys::v3::Response> {
44 claim_keys_helper(&services, &body.one_time_keys).await
45}
46
47pub(crate) async fn claim_keys_helper(
48 services: &Services,
49 one_time_keys_input: &RequestClaims,
50) -> Result<claim_keys::v3::Response> {
51 let (local_users, remote_users): (Vec<_>, Vec<_>) = one_time_keys_input
52 .iter()
53 .map(|(uid, map)| (uid.as_ref(), map))
54 .partition(|(user_id, _)| services.globals.user_is_local(user_id));
55
56 let server: ServerClaims<'_> =
57 remote_users
58 .into_iter()
59 .fold(BTreeMap::new(), |mut acc, (user_id, map)| {
60 acc.entry(user_id.server_name())
61 .or_default()
62 .insert(user_id.to_owned(), map.clone());
63 acc
64 });
65
66 let local = collect_local_one_time_keys(services, &local_users);
67 let federation = collect_federation_one_time_keys(services, server);
68
69 let (local, federation) = join(local, federation).await;
70 let merged = local.merge(federation);
71
72 Ok(claim_keys::v3::Response {
73 failures: merged.failures,
74 one_time_keys: merged.one_time_keys,
75 })
76}
77
78async fn collect_local_one_time_keys(services: &Services, users: &[LocalClaim<'_>]) -> Claims {
79 let take_one_time_key = async |(user_id, device_id, algorithm)| {
80 let key = services
81 .users
82 .take_one_time_key(user_id, device_id, algorithm)
83 .await
84 .ok();
85
86 let key = match key {
88 | Some(key) => Some(key),
89 | None => services
90 .users
91 .take_fallback_key(user_id, device_id, algorithm)
92 .await
93 .ok(),
94 };
95
96 key.map(|key| (device_id.to_owned(), [key].into()))
97 };
98
99 let one_time_keys = users
100 .iter()
101 .copied()
102 .stream()
103 .broad_then(async |(user_id, requested)| {
104 requested
105 .iter()
106 .stream()
107 .map(|(device_id, algorithm)| (user_id, device_id.as_ref(), algorithm))
108 .filter_map(take_one_time_key)
109 .collect()
110 .map(|device_keys| (user_id.to_owned(), device_keys))
111 .await
112 })
113 .collect()
114 .await;
115
116 Claims { one_time_keys, ..Default::default() }
117}
118
119async fn collect_federation_one_time_keys(
120 services: &Services,
121 server: ServerClaims<'_>,
122) -> Claims {
123 server
124 .into_iter()
125 .stream()
126 .broad_then(async |(server, one_time_keys)| {
127 let request = federation::keys::claim_keys::v1::Request { one_time_keys };
128
129 match services
130 .federation
131 .execute(server, request)
132 .await
133 .inspect_err(
134 |e| debug_warn!(%server, "claim_keys federation request failed: {e}"),
135 ) {
136 | Ok(keys) => Claims {
137 one_time_keys: keys.one_time_keys,
138 failures: Default::default(),
139 },
140 | Err(_e) => Claims {
141 failures: [(server.to_string(), json!({}))].into(),
142 ..Default::default()
143 },
144 }
145 })
146 .ready_fold(Claims::default(), Claims::merge)
147 .await
148}
149
150impl Claims {
151 fn merge(mut self, other: Self) -> Self {
152 self.one_time_keys.extend(other.one_time_keys);
153 self.failures.extend(other.failures);
154 self
155 }
156}