tuwunel_api/client/keys/
claim_keys.rs1use std::collections::BTreeMap;
2
3use axum::extract::State;
4use futures::{StreamExt, future::join};
5use ruma::{
6 OneTimeKeyAlgorithm, OwnedDeviceId, OwnedOneTimeKeyId, OwnedUserId, ServerName, UserId,
7 api::{client::keys::claim_keys, federation},
8 encryption::OneTimeKey,
9 serde::Raw,
10};
11use serde_json::json;
12use tuwunel_core::{
13 Result, debug_warn,
14 utils::{
15 IterStream,
16 stream::{BroadbandExt, ReadyExt},
17 },
18};
19use tuwunel_service::Services;
20
21use super::FailureMap;
22use crate::Ruma;
23
24#[derive(Default)]
25struct Claims {
26 one_time_keys: OneTimeKeyMap,
27 failures: FailureMap,
28}
29
30type RequestClaims = BTreeMap<OwnedUserId, Algorithms>;
31type ServerClaims<'a> = BTreeMap<&'a ServerName, RequestClaims>;
32type LocalClaim<'a> = (&'a UserId, &'a Algorithms);
33type Algorithms = BTreeMap<OwnedDeviceId, OneTimeKeyAlgorithm>;
34type OneTimeKeys = BTreeMap<OwnedOneTimeKeyId, Raw<OneTimeKey>>;
35type OneTimeKeyMap = BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, OneTimeKeys>>;
36
37pub(crate) async fn claim_keys_route(
41 State(services): State<crate::State>,
42 body: Ruma<claim_keys::v3::Request>,
43) -> Result<claim_keys::v3::Response> {
44 claim_keys_helper(&services, &body.one_time_keys).await
45}
46
47pub(crate) async fn claim_keys_helper(
48 services: &Services,
49 one_time_keys_input: &RequestClaims,
50) -> Result<claim_keys::v3::Response> {
51 let (local_users, remote_users): (Vec<_>, Vec<_>) = one_time_keys_input
52 .iter()
53 .map(|(uid, map)| (uid.as_ref(), map))
54 .partition(|(user_id, _)| services.globals.user_is_local(user_id));
55
56 let server: ServerClaims<'_> =
57 remote_users
58 .into_iter()
59 .fold(BTreeMap::new(), |mut acc, (user_id, map)| {
60 acc.entry(user_id.server_name())
61 .or_default()
62 .insert(user_id.to_owned(), map.clone());
63 acc
64 });
65
66 let local = collect_local_one_time_keys(services, &local_users);
67 let federation = collect_federation_one_time_keys(services, server);
68
69 let (local, federation) = join(local, federation).await;
70 let merged = local.merge(federation);
71
72 Ok(claim_keys::v3::Response {
73 failures: merged.failures,
74 one_time_keys: merged.one_time_keys,
75 })
76}
77
78async fn collect_local_one_time_keys(services: &Services, users: &[LocalClaim<'_>]) -> Claims {
79 let take_one_time_key = async |(user_id, device_id, algorithm)| {
80 let key = services
81 .users
82 .take_one_time_key(user_id, device_id, algorithm)
83 .await
84 .ok();
85
86 let key = match key {
88 | Some(key) => Some(key),
89 | None => services
90 .users
91 .take_fallback_key(user_id, device_id, algorithm)
92 .await
93 .ok(),
94 };
95
96 key.map(|key| (device_id.to_owned(), [key].into()))
97 };
98
99 let one_time_keys = users
100 .iter()
101 .copied()
102 .stream()
103 .broad_filter_map(async |(user_id, requested)| {
104 let device_keys: BTreeMap<_, _> = requested
105 .iter()
106 .stream()
107 .map(|(device_id, algorithm)| (user_id, device_id.as_ref(), algorithm))
108 .filter_map(take_one_time_key)
109 .collect()
110 .await;
111
112 (!device_keys.is_empty()).then(|| (user_id.to_owned(), device_keys))
114 })
115 .collect()
116 .await;
117
118 Claims { one_time_keys, ..Default::default() }
119}
120
121async fn collect_federation_one_time_keys(
122 services: &Services,
123 server: ServerClaims<'_>,
124) -> Claims {
125 server
126 .into_iter()
127 .stream()
128 .broad_then(async |(server, one_time_keys)| {
129 let failed = || Claims {
130 failures: [(server.to_string(), json!({}))].into(),
131 ..Default::default()
132 };
133
134 let request = federation::keys::claim_keys::v1::Request { one_time_keys };
135
136 match services
137 .federation
138 .execute_keys(server, request)
139 .await
140 .inspect_err(
141 |e| debug_warn!(%server, "claim_keys federation request failed: {e}"),
142 ) {
143 | Err(_e) => failed(),
144 | Ok(keys) => Claims {
145 one_time_keys: keys.one_time_keys,
146 failures: Default::default(),
147 },
148 }
149 })
150 .ready_fold(Claims::default(), Claims::merge)
151 .await
152}
153
154impl Claims {
155 fn merge(mut self, other: Self) -> Self {
156 self.one_time_keys.extend(other.one_time_keys);
157 self.failures.extend(other.failures);
158 self
159 }
160}