Skip to main content

tuwunel_api/client/keys/
get_keys.rs

1use std::collections::{BTreeMap, HashMap};
2
3use axum::extract::State;
4use futures::{
5	FutureExt, StreamExt,
6	future::{
7		Either::{Left, Right},
8		join, join4,
9	},
10};
11use ruma::{
12	CanonicalJsonObject, CanonicalJsonValue, DeviceId, OwnedDeviceId, OwnedUserId, ServerName,
13	UserId,
14	api::{
15		client::{device::Device, keys::get_keys},
16		federation,
17	},
18	encryption::{CrossSigningKey, DeviceKeys},
19	serde::Raw,
20};
21use serde_json::{json, value::to_raw_value};
22use tuwunel_core::{
23	Result, debug_warn, implement,
24	utils::{
25		BoolExt, IterStream,
26		future::TryExtExt,
27		json,
28		stream::{BroadbandExt, ReadyExt},
29	},
30};
31use tuwunel_service::{Services, users::parse_master_key};
32
33use super::FailureMap;
34use crate::Ruma;
35
36#[derive(Default)]
37struct Keys {
38	device_keys: DeviceKeyMap,
39	master_keys: CrossSigningKeys,
40	self_signing_keys: CrossSigningKeys,
41	user_signing_keys: CrossSigningKeys,
42	failures: FailureMap,
43}
44
45type DeviceLists = BTreeMap<OwnedUserId, Vec<OwnedDeviceId>>;
46type DeviceKeyMap = BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, Raw<DeviceKeys>>>;
47type ServerDevices<'a> = HashMap<&'a ServerName, DeviceLists>;
48type LocalDeviceUser<'a> = (&'a UserId, &'a Vec<OwnedDeviceId>);
49type CrossSigningKeys = BTreeMap<OwnedUserId, Raw<CrossSigningKey>>;
50
51/// # `POST /_matrix/client/r0/keys/query`
52///
53/// Get end-to-end encryption keys for the given users.
54///
55/// - Always fetches users from other servers over federation
56/// - Gets master keys, self-signing keys, user signing keys and device keys.
57/// - The master and self-signing keys contain signatures that the user is
58///   allowed to see
59pub(crate) async fn get_keys_route(
60	State(services): State<crate::State>,
61	body: Ruma<get_keys::v3::Request>,
62) -> Result<get_keys::v3::Response> {
63	let sender_user = body.sender_user();
64
65	get_keys_helper(
66		&services,
67		Some(sender_user),
68		&body.device_keys,
69		|u| u == sender_user,
70		true, // Always allow local users to see device names of other local users
71	)
72	.await
73}
74
75pub(crate) async fn get_keys_helper<F>(
76	services: &Services,
77	sender_user: Option<&UserId>,
78	device_keys_input: &DeviceLists,
79	allowed_signatures: F,
80	include_display_names: bool,
81) -> Result<get_keys::v3::Response>
82where
83	F: Fn(&UserId) -> bool + Send + Sync,
84{
85	let (local_users, remote_users): (Vec<LocalDeviceUser<'_>>, Vec<_>) = device_keys_input
86		.iter()
87		.map(|(uid, dids)| (uid.as_ref(), dids))
88		.partition(|(user_id, _)| services.globals.user_is_local(user_id));
89
90	let server: ServerDevices<'_> =
91		remote_users
92			.into_iter()
93			.fold(HashMap::new(), |mut acc, (user_id, device_ids)| {
94				acc.entry(user_id.server_name())
95					.or_default()
96					.insert(user_id.to_owned(), device_ids.clone());
97				acc
98			});
99
100	let local = collect_local_keys(
101		services,
102		&local_users,
103		sender_user,
104		&allowed_signatures,
105		include_display_names,
106	);
107
108	let federation = collect_federation_keys(services, server, sender_user, &allowed_signatures);
109
110	let (local, federation) = join(local, federation).await;
111	Ok(local.merge(federation).into_response())
112}
113
114async fn collect_local_keys<F>(
115	services: &Services,
116	users: &[LocalDeviceUser<'_>],
117	sender_user: Option<&UserId>,
118	allowed_signatures: &F,
119	include_display_names: bool,
120) -> Keys
121where
122	F: Fn(&UserId) -> bool + Send + Sync,
123{
124	users
125		.iter()
126		.copied()
127		.stream()
128		.broad_then(async |(user_id, device_ids)| {
129			collect_local_user_keys(
130				services,
131				user_id,
132				device_ids,
133				sender_user,
134				allowed_signatures,
135				include_display_names,
136			)
137			.await
138		})
139		.ready_fold(Keys::default(), Keys::merge)
140		.await
141}
142
143async fn collect_local_user_keys<F>(
144	services: &Services,
145	user_id: &UserId,
146	device_ids: &[OwnedDeviceId],
147	sender_user: Option<&UserId>,
148	allowed_signatures: &F,
149	include_display_names: bool,
150) -> Keys
151where
152	F: Fn(&UserId) -> bool + Send + Sync,
153{
154	let device_keys =
155		collect_local_device_keys(services, user_id, device_ids, include_display_names);
156
157	let master_key = services
158		.users
159		.get_master_key(sender_user, user_id, allowed_signatures)
160		.ok();
161
162	let self_signing_key = services
163		.users
164		.get_self_signing_key(sender_user, user_id, allowed_signatures)
165		.ok();
166
167	let user_signing_key = (Some(user_id) == sender_user)
168		.then_async(|| services.users.get_user_signing_key(user_id).ok())
169		.map(Option::flatten);
170
171	let (device_keys, master_key, self_signing_key, user_signing_key) =
172		join4(device_keys, master_key, self_signing_key, user_signing_key).await;
173
174	let owned = || user_id.to_owned();
175	Keys {
176		device_keys: BTreeMap::from([(owned(), device_keys)]),
177		master_keys: master_key
178			.map(|k| (owned(), k))
179			.into_iter()
180			.collect(),
181
182		self_signing_keys: self_signing_key
183			.map(|k| (owned(), k))
184			.into_iter()
185			.collect(),
186
187		user_signing_keys: user_signing_key
188			.map(|k| (owned(), k))
189			.into_iter()
190			.collect(),
191
192		..Default::default()
193	}
194}
195
196async fn collect_local_device_keys(
197	services: &Services,
198	user_id: &UserId,
199	device_ids: &[OwnedDeviceId],
200	include_display_names: bool,
201) -> BTreeMap<OwnedDeviceId, Raw<DeviceKeys>> {
202	let stream = if device_ids.is_empty() {
203		Left(
204			services
205				.users
206				.all_device_ids(user_id)
207				.map(ToOwned::to_owned),
208		)
209	} else {
210		Right(device_ids.iter().cloned().stream())
211	};
212
213	stream
214		.broad_filter_map(async |device_id| {
215			get_local_device_keys(services, user_id, &device_id, include_display_names)
216				.await
217				.map(|keys| (device_id, keys))
218		})
219		.collect()
220		.await
221}
222
223async fn get_local_device_keys(
224	services: &Services,
225	user_id: &UserId,
226	device_id: &DeviceId,
227	include_display_names: bool,
228) -> Option<Raw<DeviceKeys>> {
229	let mut keys = services
230		.users
231		.get_device_keys(user_id, device_id)
232		.await
233		.ok()?;
234
235	let metadata = services
236		.users
237		.get_device_metadata(user_id, device_id)
238		.await
239		.inspect_err(|e| debug_warn!(?user_id, ?device_id, "device metadata missing: {e}"))
240		.ok()?;
241
242	add_unsigned_device_display_name(&mut keys, metadata, include_display_names)
243		.inspect_err(|e| debug_warn!(?user_id, ?device_id, "invalid device keys: {e}"))
244		.ok()?;
245
246	Some(keys)
247}
248
249async fn collect_federation_keys<F>(
250	services: &Services,
251	server: ServerDevices<'_>,
252	sender_user: Option<&UserId>,
253	allowed_signatures: &F,
254) -> Keys
255where
256	F: Fn(&UserId) -> bool + Send + Sync,
257{
258	server
259		.into_iter()
260		.stream()
261		.broad_then(async |(server, device_keys)| {
262			let request = federation::keys::get_keys::v1::Request { device_keys };
263
264			(server, services.federation.execute(server, request).await)
265		})
266		.broad_then(async |(server, response)| match response {
267			| Ok(response) =>
268				process_federation_response(services, sender_user, allowed_signatures, response)
269					.await,
270			| Err(e) => {
271				debug_warn!(%server, "key federation request failed: {e}");
272				Keys {
273					failures: BTreeMap::from([(server.to_string(), json!({}))]),
274					..Default::default()
275				}
276			},
277		})
278		.ready_fold(Keys::default(), Keys::merge)
279		.await
280}
281
282async fn process_federation_response<F>(
283	services: &Services,
284	sender_user: Option<&UserId>,
285	allowed_signatures: &F,
286	response: federation::keys::get_keys::v1::Response,
287) -> Keys
288where
289	F: Fn(&UserId) -> bool + Send + Sync,
290{
291	let federation::keys::get_keys::v1::Response {
292		master_keys,
293		self_signing_keys,
294		device_keys,
295	} = response;
296
297	let master_keys = master_keys
298		.into_iter()
299		.stream()
300		.broad_filter_map(async |(user, master_key)| {
301			merge_remote_master_key(services, sender_user, allowed_signatures, &user, master_key)
302				.await
303				.inspect_err(|e| debug_warn!(?user, "skipping master key from federation: {e}"))
304				.map(|raw| (user, raw))
305				.ok()
306		})
307		.collect()
308		.await;
309
310	Keys {
311		device_keys,
312		master_keys,
313		self_signing_keys,
314		user_signing_keys: BTreeMap::new(),
315		failures: BTreeMap::new(),
316	}
317}
318
319/// Merges signatures from our cached copy of the user's master key (if any)
320/// onto the remote-supplied master key, persists the merged copy to our
321/// database, and returns the merged Raw value for the response.
322async fn merge_remote_master_key<F>(
323	services: &Services,
324	sender_user: Option<&UserId>,
325	allowed_signatures: &F,
326	user: &UserId,
327	master_key_raw: Raw<CrossSigningKey>,
328) -> Result<Raw<CrossSigningKey>>
329where
330	F: Fn(&UserId) -> bool + Send + Sync,
331{
332	let (master_key_id, mut master_key) = parse_master_key(user, &master_key_raw)?;
333	let our_raw = services
334		.users
335		.get_key(&master_key_id, sender_user, user, allowed_signatures)
336		.await;
337
338	if let Ok(our_raw) = our_raw
339		&& let Ok((_, mut ours)) = parse_master_key(user, &our_raw)
340	{
341		master_key.signatures.append(&mut ours.signatures);
342	}
343
344	let raw = json::to_raw(&master_key)?;
345
346	// Don't notify: a notification would trigger another key request resulting
347	// in an endless loop.
348	services
349		.users
350		.add_cross_signing_keys(user, &Some(raw.clone()), &None, &None, false)
351		.await?;
352
353	Ok(raw)
354}
355
356fn add_unsigned_device_display_name(
357	keys: &mut Raw<DeviceKeys>,
358	metadata: Device,
359	include_display_names: bool,
360) -> Result {
361	let Some(display_name) = metadata.display_name else {
362		return Ok(());
363	};
364
365	let mut object = keys.deserialize_as_unchecked::<CanonicalJsonObject>()?;
366
367	if let CanonicalJsonValue::Object(unsigned) = object
368		.entry("unsigned".into())
369		.or_insert_with(|| CanonicalJsonObject::default().into())
370	{
371		let display_name = if include_display_names {
372			CanonicalJsonValue::String(display_name.to_string())
373		} else {
374			CanonicalJsonValue::String(metadata.device_id.into())
375		};
376
377		unsigned.insert("device_display_name".into(), display_name);
378	}
379
380	*keys = Raw::from_json(to_raw_value(&object)?);
381
382	Ok(())
383}
384
385#[implement(Keys)]
386fn merge(mut self, other: Self) -> Self {
387	self.failures.extend(other.failures);
388	self.device_keys.extend(other.device_keys);
389	self.master_keys.extend(other.master_keys);
390	self.self_signing_keys
391		.extend(other.self_signing_keys);
392	self.user_signing_keys
393		.extend(other.user_signing_keys);
394	self
395}
396
397#[implement(Keys)]
398fn into_response(self) -> get_keys::v3::Response {
399	get_keys::v3::Response {
400		failures: self.failures,
401		device_keys: self.device_keys,
402		master_keys: self.master_keys,
403		self_signing_keys: self.self_signing_keys,
404		user_signing_keys: self.user_signing_keys,
405	}
406}