Skip to main content

tuwunel_api/client/keys/
claim_keys.rs

1use std::collections::BTreeMap;
2
3use axum::extract::State;
4use futures::{FutureExt, StreamExt, future::join};
5use ruma::{
6	OneTimeKeyAlgorithm, OwnedDeviceId, OwnedOneTimeKeyId, OwnedUserId, ServerName, UserId,
7	api::{client::keys::claim_keys, federation},
8	encryption::OneTimeKey,
9	serde::Raw,
10};
11use serde_json::json;
12use tuwunel_core::{
13	Result, debug_warn,
14	utils::{
15		IterStream,
16		stream::{BroadbandExt, ReadyExt},
17	},
18};
19use tuwunel_service::Services;
20
21use super::FailureMap;
22use crate::Ruma;
23
24#[derive(Default)]
25struct Claims {
26	one_time_keys: OneTimeKeyMap,
27	failures: FailureMap,
28}
29
30type RequestClaims = BTreeMap<OwnedUserId, Algorithms>;
31type ServerClaims<'a> = BTreeMap<&'a ServerName, RequestClaims>;
32type LocalClaim<'a> = (&'a UserId, &'a Algorithms);
33type Algorithms = BTreeMap<OwnedDeviceId, OneTimeKeyAlgorithm>;
34type OneTimeKeys = BTreeMap<OwnedOneTimeKeyId, Raw<OneTimeKey>>;
35type OneTimeKeyMap = BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, OneTimeKeys>>;
36
37/// # `POST /_matrix/client/r0/keys/claim`
38///
39/// Claims one-time keys
40pub(crate) async fn claim_keys_route(
41	State(services): State<crate::State>,
42	body: Ruma<claim_keys::v3::Request>,
43) -> Result<claim_keys::v3::Response> {
44	claim_keys_helper(&services, &body.one_time_keys).await
45}
46
47pub(crate) async fn claim_keys_helper(
48	services: &Services,
49	one_time_keys_input: &RequestClaims,
50) -> Result<claim_keys::v3::Response> {
51	let (local_users, remote_users): (Vec<_>, Vec<_>) = one_time_keys_input
52		.iter()
53		.map(|(uid, map)| (uid.as_ref(), map))
54		.partition(|(user_id, _)| services.globals.user_is_local(user_id));
55
56	let server: ServerClaims<'_> =
57		remote_users
58			.into_iter()
59			.fold(BTreeMap::new(), |mut acc, (user_id, map)| {
60				acc.entry(user_id.server_name())
61					.or_default()
62					.insert(user_id.to_owned(), map.clone());
63				acc
64			});
65
66	let local = collect_local_one_time_keys(services, &local_users);
67	let federation = collect_federation_one_time_keys(services, server);
68
69	let (local, federation) = join(local, federation).await;
70	let merged = local.merge(federation);
71
72	Ok(claim_keys::v3::Response {
73		failures: merged.failures,
74		one_time_keys: merged.one_time_keys,
75	})
76}
77
78async fn collect_local_one_time_keys(services: &Services, users: &[LocalClaim<'_>]) -> Claims {
79	let take_one_time_key = async |(user_id, device_id, algorithm)| {
80		let key = services
81			.users
82			.take_one_time_key(user_id, device_id, algorithm)
83			.await
84			.ok();
85
86		// MSC2732: serve the fallback key when the one-time pool is empty.
87		let key = match key {
88			| Some(key) => Some(key),
89			| None => services
90				.users
91				.take_fallback_key(user_id, device_id, algorithm)
92				.await
93				.ok(),
94		};
95
96		key.map(|key| (device_id.to_owned(), [key].into()))
97	};
98
99	let one_time_keys = users
100		.iter()
101		.copied()
102		.stream()
103		.broad_then(async |(user_id, requested)| {
104			requested
105				.iter()
106				.stream()
107				.map(|(device_id, algorithm)| (user_id, device_id.as_ref(), algorithm))
108				.filter_map(take_one_time_key)
109				.collect()
110				.map(|device_keys| (user_id.to_owned(), device_keys))
111				.await
112		})
113		.collect()
114		.await;
115
116	Claims { one_time_keys, ..Default::default() }
117}
118
119async fn collect_federation_one_time_keys(
120	services: &Services,
121	server: ServerClaims<'_>,
122) -> Claims {
123	server
124		.into_iter()
125		.stream()
126		.broad_then(async |(server, one_time_keys)| {
127			let request = federation::keys::claim_keys::v1::Request { one_time_keys };
128
129			match services
130				.federation
131				.execute(server, request)
132				.await
133				.inspect_err(
134					|e| debug_warn!(%server, "claim_keys federation request failed: {e}"),
135				) {
136				| Ok(keys) => Claims {
137					one_time_keys: keys.one_time_keys,
138					failures: Default::default(),
139				},
140				| Err(_e) => Claims {
141					failures: [(server.to_string(), json!({}))].into(),
142					..Default::default()
143				},
144			}
145		})
146		.ready_fold(Claims::default(), Claims::merge)
147		.await
148}
149
150impl Claims {
151	fn merge(mut self, other: Self) -> Self {
152		self.one_time_keys.extend(other.one_time_keys);
153		self.failures.extend(other.failures);
154		self
155	}
156}