Skip to main content

tuwunel_api/client/keys/
claim_keys.rs

1use std::collections::BTreeMap;
2
3use axum::extract::State;
4use futures::{StreamExt, future::join};
5use ruma::{
6	OneTimeKeyAlgorithm, OwnedDeviceId, OwnedOneTimeKeyId, OwnedUserId, ServerName, UserId,
7	api::{client::keys::claim_keys, federation},
8	encryption::OneTimeKey,
9	serde::Raw,
10};
11use serde_json::json;
12use tuwunel_core::{
13	Result, debug_warn,
14	utils::{
15		IterStream,
16		stream::{BroadbandExt, ReadyExt},
17	},
18};
19use tuwunel_service::Services;
20
21use super::FailureMap;
22use crate::Ruma;
23
24#[derive(Default)]
25struct Claims {
26	one_time_keys: OneTimeKeyMap,
27	failures: FailureMap,
28}
29
30type RequestClaims = BTreeMap<OwnedUserId, Algorithms>;
31type ServerClaims<'a> = BTreeMap<&'a ServerName, RequestClaims>;
32type LocalClaim<'a> = (&'a UserId, &'a Algorithms);
33type Algorithms = BTreeMap<OwnedDeviceId, OneTimeKeyAlgorithm>;
34type OneTimeKeys = BTreeMap<OwnedOneTimeKeyId, Raw<OneTimeKey>>;
35type OneTimeKeyMap = BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, OneTimeKeys>>;
36
37/// # `POST /_matrix/client/r0/keys/claim`
38///
39/// Claims one-time keys
40pub(crate) async fn claim_keys_route(
41	State(services): State<crate::State>,
42	body: Ruma<claim_keys::v3::Request>,
43) -> Result<claim_keys::v3::Response> {
44	claim_keys_helper(&services, &body.one_time_keys).await
45}
46
47pub(crate) async fn claim_keys_helper(
48	services: &Services,
49	one_time_keys_input: &RequestClaims,
50) -> Result<claim_keys::v3::Response> {
51	let (local_users, remote_users): (Vec<_>, Vec<_>) = one_time_keys_input
52		.iter()
53		.map(|(uid, map)| (uid.as_ref(), map))
54		.partition(|(user_id, _)| services.globals.user_is_local(user_id));
55
56	let server: ServerClaims<'_> =
57		remote_users
58			.into_iter()
59			.fold(BTreeMap::new(), |mut acc, (user_id, map)| {
60				acc.entry(user_id.server_name())
61					.or_default()
62					.insert(user_id.to_owned(), map.clone());
63				acc
64			});
65
66	let local = collect_local_one_time_keys(services, &local_users);
67	let federation = collect_federation_one_time_keys(services, server);
68
69	let (local, federation) = join(local, federation).await;
70	let merged = local.merge(federation);
71
72	Ok(claim_keys::v3::Response {
73		failures: merged.failures,
74		one_time_keys: merged.one_time_keys,
75	})
76}
77
78async fn collect_local_one_time_keys(services: &Services, users: &[LocalClaim<'_>]) -> Claims {
79	let take_one_time_key = async |(user_id, device_id, algorithm)| {
80		let key = services
81			.users
82			.take_one_time_key(user_id, device_id, algorithm)
83			.await
84			.ok();
85
86		// MSC2732: serve the fallback key when the one-time pool is empty.
87		let key = match key {
88			| Some(key) => Some(key),
89			| None => services
90				.users
91				.take_fallback_key(user_id, device_id, algorithm)
92				.await
93				.ok(),
94		};
95
96		key.map(|key| (device_id.to_owned(), [key].into()))
97	};
98
99	let one_time_keys = users
100		.iter()
101		.copied()
102		.stream()
103		.broad_filter_map(async |(user_id, requested)| {
104			let device_keys: BTreeMap<_, _> = requested
105				.iter()
106				.stream()
107				.map(|(device_id, algorithm)| (user_id, device_id.as_ref(), algorithm))
108				.filter_map(take_one_time_key)
109				.collect()
110				.await;
111
112			// Omit a depleted user entirely; Synapse returns no entry, not an empty map.
113			(!device_keys.is_empty()).then(|| (user_id.to_owned(), device_keys))
114		})
115		.collect()
116		.await;
117
118	Claims { one_time_keys, ..Default::default() }
119}
120
121async fn collect_federation_one_time_keys(
122	services: &Services,
123	server: ServerClaims<'_>,
124) -> Claims {
125	server
126		.into_iter()
127		.stream()
128		.broad_then(async |(server, one_time_keys)| {
129			let failed = || Claims {
130				failures: [(server.to_string(), json!({}))].into(),
131				..Default::default()
132			};
133
134			let request = federation::keys::claim_keys::v1::Request { one_time_keys };
135
136			match services
137				.federation
138				.execute_keys(server, request)
139				.await
140				.inspect_err(
141					|e| debug_warn!(%server, "claim_keys federation request failed: {e}"),
142				) {
143				| Err(_e) => failed(),
144				| Ok(keys) => Claims {
145					one_time_keys: keys.one_time_keys,
146					failures: Default::default(),
147				},
148			}
149		})
150		.ready_fold(Claims::default(), Claims::merge)
151		.await
152}
153
154impl Claims {
155	fn merge(mut self, other: Self) -> Self {
156		self.one_time_keys.extend(other.one_time_keys);
157		self.failures.extend(other.failures);
158		self
159	}
160}