Skip to main content

tuwunel_api/client/keys/
get_keys.rs

1use std::collections::{BTreeMap, HashMap};
2
3use axum::extract::State;
4use futures::{
5	FutureExt, StreamExt,
6	future::{
7		Either::{Left, Right},
8		join, join4,
9	},
10};
11use ruma::{
12	CanonicalJsonObject, CanonicalJsonValue, DeviceId, OwnedDeviceId, OwnedUserId, ServerName,
13	UserId,
14	api::{
15		client::{device::Device, keys::get_keys},
16		federation,
17	},
18	encryption::{CrossSigningKey, DeviceKeys},
19	serde::Raw,
20};
21use serde_json::{json, value::to_raw_value};
22use tuwunel_core::{
23	Result, debug_warn, implement,
24	utils::{
25		BoolExt, IterStream,
26		future::TryExtExt,
27		json,
28		stream::{BroadbandExt, ReadyExt},
29	},
30};
31use tuwunel_service::{Services, users::parse_master_key};
32
33use super::FailureMap;
34use crate::Ruma;
35
36#[derive(Default)]
37struct Keys {
38	device_keys: DeviceKeyMap,
39	master_keys: CrossSigningKeys,
40	self_signing_keys: CrossSigningKeys,
41	user_signing_keys: CrossSigningKeys,
42	failures: FailureMap,
43}
44
45type DeviceLists = BTreeMap<OwnedUserId, Vec<OwnedDeviceId>>;
46type DeviceKeyMap = BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, Raw<DeviceKeys>>>;
47type ServerDevices<'a> = HashMap<&'a ServerName, DeviceLists>;
48type LocalDeviceUser<'a> = (&'a UserId, &'a Vec<OwnedDeviceId>);
49type CrossSigningKeys = BTreeMap<OwnedUserId, Raw<CrossSigningKey>>;
50
51/// # `POST /_matrix/client/r0/keys/query`
52///
53/// Get end-to-end encryption keys for the given users.
54///
55/// - Always fetches users from other servers over federation
56/// - Gets master keys, self-signing keys, user signing keys and device keys.
57/// - The master and self-signing keys contain signatures that the user is
58///   allowed to see
59pub(crate) async fn get_keys_route(
60	State(services): State<crate::State>,
61	body: Ruma<get_keys::v3::Request>,
62) -> Result<get_keys::v3::Response> {
63	let sender_user = body.sender_user();
64
65	get_keys_helper(
66		&services,
67		Some(sender_user),
68		&body.device_keys,
69		|u| u == sender_user,
70		true, // Always allow local users to see device names of other local users
71	)
72	.await
73}
74
75pub(crate) async fn get_keys_helper<F>(
76	services: &Services,
77	sender_user: Option<&UserId>,
78	device_keys_input: &DeviceLists,
79	allowed_signatures: F,
80	include_display_names: bool,
81) -> Result<get_keys::v3::Response>
82where
83	F: Fn(&UserId) -> bool + Send + Sync,
84{
85	let (local_users, remote_users): (Vec<LocalDeviceUser<'_>>, Vec<_>) = device_keys_input
86		.iter()
87		.map(|(uid, dids)| (uid.as_ref(), dids))
88		.partition(|(user_id, _)| services.globals.user_is_local(user_id));
89
90	let server: ServerDevices<'_> =
91		remote_users
92			.into_iter()
93			.fold(HashMap::new(), |mut acc, (user_id, device_ids)| {
94				acc.entry(user_id.server_name())
95					.or_default()
96					.insert(user_id.to_owned(), device_ids.clone());
97				acc
98			});
99
100	let local = collect_local_keys(
101		services,
102		&local_users,
103		sender_user,
104		&allowed_signatures,
105		include_display_names,
106	);
107
108	let federation = collect_federation_keys(services, server, sender_user, &allowed_signatures);
109
110	let (local, federation) = join(local, federation).await;
111	Ok(local.merge(federation).into_response())
112}
113
114async fn collect_local_keys<F>(
115	services: &Services,
116	users: &[LocalDeviceUser<'_>],
117	sender_user: Option<&UserId>,
118	allowed_signatures: &F,
119	include_display_names: bool,
120) -> Keys
121where
122	F: Fn(&UserId) -> bool + Send + Sync,
123{
124	users
125		.iter()
126		.copied()
127		.stream()
128		.broad_then(async |(user_id, device_ids)| {
129			collect_local_user_keys(
130				services,
131				user_id,
132				device_ids,
133				sender_user,
134				allowed_signatures,
135				include_display_names,
136			)
137			.await
138		})
139		.ready_fold(Keys::default(), Keys::merge)
140		.await
141}
142
143async fn collect_local_user_keys<F>(
144	services: &Services,
145	user_id: &UserId,
146	device_ids: &[OwnedDeviceId],
147	sender_user: Option<&UserId>,
148	allowed_signatures: &F,
149	include_display_names: bool,
150) -> Keys
151where
152	F: Fn(&UserId) -> bool + Send + Sync,
153{
154	let device_keys =
155		collect_local_device_keys(services, user_id, device_ids, include_display_names);
156
157	let master_key = services
158		.users
159		.get_master_key(sender_user, user_id, allowed_signatures)
160		.ok();
161
162	let self_signing_key = services
163		.users
164		.get_self_signing_key(sender_user, user_id, allowed_signatures)
165		.ok();
166
167	let user_signing_key = (Some(user_id) == sender_user)
168		.then_async(|| services.users.get_user_signing_key(user_id).ok())
169		.map(Option::flatten);
170
171	let (device_keys, master_key, self_signing_key, user_signing_key) =
172		join4(device_keys, master_key, self_signing_key, user_signing_key).await;
173
174	let owned = || user_id.to_owned();
175	Keys {
176		device_keys: BTreeMap::from([(owned(), device_keys)]),
177		master_keys: master_key
178			.map(|k| (owned(), k))
179			.into_iter()
180			.collect(),
181
182		self_signing_keys: self_signing_key
183			.map(|k| (owned(), k))
184			.into_iter()
185			.collect(),
186
187		user_signing_keys: user_signing_key
188			.map(|k| (owned(), k))
189			.into_iter()
190			.collect(),
191
192		..Default::default()
193	}
194}
195
196async fn collect_local_device_keys(
197	services: &Services,
198	user_id: &UserId,
199	device_ids: &[OwnedDeviceId],
200	include_display_names: bool,
201) -> BTreeMap<OwnedDeviceId, Raw<DeviceKeys>> {
202	let stream = if device_ids.is_empty() {
203		Left(
204			services
205				.users
206				.all_device_ids(user_id)
207				.map(ToOwned::to_owned),
208		)
209	} else {
210		Right(device_ids.iter().cloned().stream())
211	};
212
213	stream
214		.broad_filter_map(async |device_id| {
215			get_local_device_keys(services, user_id, &device_id, include_display_names)
216				.await
217				.map(|keys| (device_id, keys))
218		})
219		.collect()
220		.await
221}
222
223async fn get_local_device_keys(
224	services: &Services,
225	user_id: &UserId,
226	device_id: &DeviceId,
227	include_display_names: bool,
228) -> Option<Raw<DeviceKeys>> {
229	let mut keys = services
230		.users
231		.get_device_keys(user_id, device_id)
232		.await
233		.ok()?;
234
235	let metadata = services
236		.users
237		.get_device_metadata(user_id, device_id)
238		.await
239		.inspect_err(|e| debug_warn!(?user_id, ?device_id, "device metadata missing: {e}"))
240		.ok()?;
241
242	add_unsigned_device_display_name(&mut keys, metadata, include_display_names)
243		.inspect_err(|e| debug_warn!(?user_id, ?device_id, "invalid device keys: {e}"))
244		.ok()?;
245
246	Some(keys)
247}
248
249async fn collect_federation_keys<F>(
250	services: &Services,
251	server: ServerDevices<'_>,
252	sender_user: Option<&UserId>,
253	allowed_signatures: &F,
254) -> Keys
255where
256	F: Fn(&UserId) -> bool + Send + Sync,
257{
258	server
259		.into_iter()
260		.stream()
261		.broad_then(async |(server, device_keys)| {
262			let failed = || Keys {
263				failures: BTreeMap::from([(server.to_string(), json!({}))]),
264				..Default::default()
265			};
266
267			let request = federation::keys::get_keys::v1::Request { device_keys };
268
269			match services
270				.federation
271				.execute_keys(server, request)
272				.await
273			{
274				| Ok(response) =>
275					process_federation_response(
276						services,
277						sender_user,
278						allowed_signatures,
279						response,
280					)
281					.await,
282				| Err(e) => {
283					debug_warn!(%server, "key federation request failed: {e}");
284					failed()
285				},
286			}
287		})
288		.ready_fold(Keys::default(), Keys::merge)
289		.await
290}
291
292async fn process_federation_response<F>(
293	services: &Services,
294	sender_user: Option<&UserId>,
295	allowed_signatures: &F,
296	response: federation::keys::get_keys::v1::Response,
297) -> Keys
298where
299	F: Fn(&UserId) -> bool + Send + Sync,
300{
301	let federation::keys::get_keys::v1::Response {
302		master_keys,
303		self_signing_keys,
304		device_keys,
305	} = response;
306
307	let master_keys = master_keys
308		.into_iter()
309		.stream()
310		.broad_filter_map(async |(user, master_key)| {
311			merge_remote_master_key(services, sender_user, allowed_signatures, &user, master_key)
312				.await
313				.inspect_err(|e| debug_warn!(?user, "skipping master key from federation: {e}"))
314				.map(|raw| (user, raw))
315				.ok()
316		})
317		.collect()
318		.await;
319
320	Keys {
321		device_keys,
322		master_keys,
323		self_signing_keys,
324		user_signing_keys: BTreeMap::new(),
325		failures: BTreeMap::new(),
326	}
327}
328
329/// Merges signatures from our cached copy of the user's master key (if any)
330/// onto the remote-supplied master key, persists the merged copy to our
331/// database, and returns the merged Raw value for the response.
332async fn merge_remote_master_key<F>(
333	services: &Services,
334	sender_user: Option<&UserId>,
335	allowed_signatures: &F,
336	user: &UserId,
337	master_key_raw: Raw<CrossSigningKey>,
338) -> Result<Raw<CrossSigningKey>>
339where
340	F: Fn(&UserId) -> bool + Send + Sync,
341{
342	let (master_key_id, mut master_key) = parse_master_key(user, &master_key_raw)?;
343	let our_raw = services
344		.users
345		.get_key(&master_key_id, sender_user, user, allowed_signatures)
346		.await;
347
348	if let Ok(our_raw) = our_raw
349		&& let Ok((_, mut ours)) = parse_master_key(user, &our_raw)
350	{
351		master_key.signatures.append(&mut ours.signatures);
352	}
353
354	let raw = json::to_raw(&master_key)?;
355
356	// Don't notify: a notification would trigger another key request resulting
357	// in an endless loop.
358	services
359		.users
360		.add_cross_signing_keys(user, &Some(raw.clone()), &None, &None, false)
361		.await?;
362
363	Ok(raw)
364}
365
366fn add_unsigned_device_display_name(
367	keys: &mut Raw<DeviceKeys>,
368	metadata: Device,
369	include_display_names: bool,
370) -> Result {
371	let Some(display_name) = metadata.display_name else {
372		return Ok(());
373	};
374
375	let mut object = keys.deserialize_as_unchecked::<CanonicalJsonObject>()?;
376
377	if let CanonicalJsonValue::Object(unsigned) = object
378		.entry("unsigned".into())
379		.or_insert_with(|| CanonicalJsonObject::default().into())
380	{
381		let display_name = if include_display_names {
382			CanonicalJsonValue::String(display_name.to_string())
383		} else {
384			CanonicalJsonValue::String(metadata.device_id.into())
385		};
386
387		unsigned.insert("device_display_name".into(), display_name);
388	}
389
390	*keys = Raw::from_json(to_raw_value(&object)?);
391
392	Ok(())
393}
394
395#[implement(Keys)]
396fn merge(mut self, other: Self) -> Self {
397	self.failures.extend(other.failures);
398	self.device_keys.extend(other.device_keys);
399	self.master_keys.extend(other.master_keys);
400	self.self_signing_keys
401		.extend(other.self_signing_keys);
402	self.user_signing_keys
403		.extend(other.user_signing_keys);
404	self
405}
406
407#[implement(Keys)]
408fn into_response(self) -> get_keys::v3::Response {
409	get_keys::v3::Response {
410		failures: self.failures,
411		device_keys: self.device_keys,
412		master_keys: self.master_keys,
413		self_signing_keys: self.self_signing_keys,
414		user_signing_keys: self.user_signing_keys,
415	}
416}