Skip to main content

tuwunel_service/users/
keys.rs

1use std::{collections::BTreeMap, mem};
2
3use futures::{Stream, StreamExt, TryFutureExt, pin_mut};
4use ruma::{
5	DeviceId, KeyId, OneTimeKeyAlgorithm, OneTimeKeyId, OneTimeKeyName, OwnedKeyId,
6	OwnedOneTimeKeyId, OwnedRoomId, OwnedServerName, RoomId, UInt, UserId,
7	encryption::{CrossSigningKey, DeviceKeys, OneTimeKey},
8	serde::Raw,
9};
10use serde::{Deserialize, Serialize};
11use tuwunel_core::{
12	Err, Result, debug_error, err, implement,
13	smallvec::SmallVec,
14	utils::{
15		BoolExt, IterStream, ReadyExt,
16		stream::{BroadbandExt, TryIgnore},
17	},
18};
19use tuwunel_database::{Deserialized, Ignore, Interfix, Json};
20
21type Servers = SmallVec<[OwnedServerName; 1]>;
22
23/// MSC2732: row stored under `(user, device, algorithm)` in
24/// `userdeviceidalgorithm_fallback`. Fallback keys are not deleted on
25/// claim; the row is rewritten with `used = true`.
26#[derive(Debug, Deserialize, Serialize)]
27struct FallbackEntry {
28	key_id: OwnedOneTimeKeyId,
29	key: Raw<OneTimeKey>,
30	used: bool,
31}
32
33/// Row-key shape of `onetimekeyid4225_otk`: per-device pool keyed by
34/// upload-order count for MSC4225 ordering.
35type OtkRowKey<'a> = (&'a UserId, &'a DeviceId, u64, &'a OneTimeKeyId);
36
37#[implement(super::Service)]
38pub async fn add_one_time_keys<'a, Keys>(
39	&self,
40	user_id: &UserId,
41	device_id: &DeviceId,
42	keys: Keys,
43) -> Result
44where
45	Keys: Iterator<Item = (&'a OneTimeKeyId, &'a Raw<OneTimeKey>)> + Send + 'a,
46{
47	for (id, key) in keys {
48		self.add_one_time_key(user_id, device_id, id, key)
49			.await
50			.ok();
51	}
52
53	Ok(())
54}
55
56#[implement(super::Service)]
57pub async fn add_one_time_key(
58	&self,
59	user_id: &UserId,
60	device_id: &DeviceId,
61	one_time_key_key: &KeyId<OneTimeKeyAlgorithm, OneTimeKeyName>,
62	one_time_key_value: &Raw<OneTimeKey>,
63) -> Result {
64	let Some(otk) = self.db.onetimekeyid4225_otk.as_ref() else {
65		return Err!(Database("one-time-key column unavailable"));
66	};
67
68	if !self.device_exists(user_id, device_id).await {
69		return Err!(Database(error!(
70			?user_id,
71			?device_id,
72			"User does not exist or device has no metadata."
73		)));
74	}
75
76	if let Err(e) = one_time_key_value
77		.deserialize()
78		.map_err(Into::into)
79	{
80		debug_error!(
81			?one_time_key_key,
82			?one_time_key_value,
83			"Invalid one time key JSON submitted by client, skipping: {e}"
84		);
85
86		return Err(e);
87	}
88
89	// Racy dedup: two concurrent uploads of the same id can both pass this
90	// check and produce duplicate rows that persist until aged out by prune.
91	let prefix = (user_id, device_id, Interfix);
92	let already_present = otk
93		.keys_prefix(&prefix)
94		.ignore_err()
95		.ready_any(|(.., id): OtkRowKey<'_>| id == one_time_key_key)
96		.await;
97
98	if already_present {
99		return Ok(());
100	}
101
102	let count = self.services.globals.next_count();
103
104	// MSC4225: RocksDB iterates the (user, device) prefix in count_be ascending
105	// order, so /keys/claim issues one-time keys in the order they were uploaded.
106	otk.put(
107		(user_id, device_id, *count, one_time_key_key.as_str()),
108		Json(one_time_key_value),
109	);
110
111	self.db
112		.userid_lastonetimekeyupdate
113		.raw_put(user_id, *count);
114
115	Ok(())
116}
117
118#[implement(super::Service)]
119pub async fn add_fallback_keys<'a, Keys>(
120	&self,
121	user_id: &UserId,
122	device_id: &DeviceId,
123	keys: Keys,
124) -> Result
125where
126	Keys: Iterator<Item = (&'a OneTimeKeyId, &'a Raw<OneTimeKey>)> + Send + 'a,
127{
128	for (id, key) in keys {
129		self.add_fallback_key(user_id, device_id, id, key)
130			.await
131			.ok();
132	}
133
134	Ok(())
135}
136
137#[implement(super::Service)]
138pub async fn add_fallback_key(
139	&self,
140	user_id: &UserId,
141	device_id: &DeviceId,
142	one_time_key_key: &KeyId<OneTimeKeyAlgorithm, OneTimeKeyName>,
143	one_time_key_value: &Raw<OneTimeKey>,
144) -> Result {
145	if !self.device_exists(user_id, device_id).await {
146		return Err!(Database(error!(
147			?user_id,
148			?device_id,
149			"User does not exist or device has no metadata."
150		)));
151	}
152
153	if let Err(e) = one_time_key_value
154		.deserialize()
155		.map_err(Into::into)
156	{
157		debug_error!(
158			?one_time_key_key,
159			?one_time_key_value,
160			"Invalid fallback key JSON submitted by client, skipping: {e}"
161		);
162
163		return Err(e);
164	}
165
166	let entry = FallbackEntry {
167		key_id: one_time_key_key.to_owned(),
168		key: one_time_key_value.clone(),
169		used: false,
170	};
171
172	let key = (user_id, device_id, one_time_key_key.algorithm());
173	self.db
174		.userdeviceidalgorithm_fallback
175		.put(key, Json(&entry));
176
177	let count = self.services.globals.next_count();
178	self.db
179		.userid_lastonetimekeyupdate
180		.raw_put(user_id, *count);
181
182	Ok(())
183}
184
185#[implement(super::Service)]
186pub async fn take_fallback_key(
187	&self,
188	user_id: &UserId,
189	device_id: &DeviceId,
190	algorithm: &OneTimeKeyAlgorithm,
191) -> Result<(OwnedKeyId<OneTimeKeyAlgorithm, OneTimeKeyName>, Raw<OneTimeKey>)> {
192	let key = (user_id, device_id, algorithm);
193	let entry: FallbackEntry = self
194		.db
195		.userdeviceidalgorithm_fallback
196		.qry(&key)
197		.await
198		.deserialized::<Json<_>>()
199		.map(|Json(entry)| entry)
200		.map_err(|_| err!(Request(NotFound("No fallback key found"))))?;
201
202	let updated = FallbackEntry { used: true, ..entry };
203	self.db
204		.userdeviceidalgorithm_fallback
205		.put(key, Json(&updated));
206
207	Ok((updated.key_id, updated.key))
208}
209
210#[implement(super::Service)]
211pub fn unused_fallback_key_algorithms<'a>(
212	&'a self,
213	user_id: &'a UserId,
214	device_id: &'a DeviceId,
215) -> impl Stream<Item = OneTimeKeyAlgorithm> + Send + 'a {
216	type KeyVal = ((Ignore, Ignore, OneTimeKeyAlgorithm), Json<FallbackEntry>);
217
218	let prefix = (user_id, device_id);
219	self.db
220		.userdeviceidalgorithm_fallback
221		.stream_prefix(&prefix)
222		.ignore_err()
223		.ready_filter_map(|((_, _, algorithm), Json(entry)): KeyVal| {
224			entry.used.is_false().then_some(algorithm)
225		})
226}
227
228#[implement(super::Service)]
229pub async fn last_one_time_keys_update(&self, user_id: &UserId) -> u64 {
230	self.db
231		.userid_lastonetimekeyupdate
232		.get(user_id)
233		.await
234		.deserialized()
235		.unwrap_or(0)
236}
237
238#[implement(super::Service)]
239pub async fn take_one_time_key(
240	&self,
241	user_id: &UserId,
242	device_id: &DeviceId,
243	key_algorithm: &OneTimeKeyAlgorithm,
244) -> Result<(OwnedKeyId<OneTimeKeyAlgorithm, OneTimeKeyName>, Raw<OneTimeKey>)> {
245	let Some(otk) = self.db.onetimekeyid4225_otk.as_ref() else {
246		return Err!(Request(NotFound("No one-time-key found")));
247	};
248
249	let update_count = self.services.globals.next_count();
250	self.db
251		.userid_lastonetimekeyupdate
252		.insert(user_id, update_count.to_be_bytes());
253
254	let prefix = (user_id, device_id, Interfix);
255	let one_time_keys = otk
256		.stream_prefix(&prefix)
257		.ignore_err()
258		.ready_filter(|(row, _): &(OtkRowKey<'_>, &[u8])| row.3.algorithm() == *key_algorithm);
259
260	pin_mut!(one_time_keys);
261	let ((user_id, device_id, count, id), val) = one_time_keys
262		.next()
263		.await
264		.ok_or_else(|| err!(Request(NotFound("No one-time-key found"))))?;
265
266	otk.del((user_id, device_id, count, id));
267
268	Ok((id.into(), serde_json::from_slice(val)?))
269}
270
271#[implement(super::Service)]
272pub async fn count_one_time_keys(
273	&self,
274	user_id: &UserId,
275	device_id: &DeviceId,
276) -> BTreeMap<OneTimeKeyAlgorithm, UInt> {
277	let Some(otk) = self.db.onetimekeyid4225_otk.as_ref() else {
278		return BTreeMap::new();
279	};
280
281	let prefix = (user_id, device_id, Interfix);
282	let algorithm_counts: BTreeMap<OneTimeKeyAlgorithm, UInt> = otk
283		.keys_prefix(&prefix)
284		.ignore_err()
285		.ready_fold(BTreeMap::new(), |mut acc, (.., id): OtkRowKey<'_>| {
286			let count: &mut UInt = acc.entry(id.algorithm()).or_default();
287			*count = count.saturating_add(1_u32.into());
288			acc
289		})
290		.await;
291
292	let total = algorithm_counts
293		.values()
294		.copied()
295		.map(TryInto::try_into)
296		.filter_map(Result::ok)
297		.fold(0_usize, usize::saturating_add);
298
299	let limit = self.services.config.one_time_key_limit;
300	if let Some(excess) = total.checked_sub(limit).filter(|&n| n > 0) {
301		self.prune_one_time_keys(user_id, device_id, excess)
302			.await;
303	}
304
305	algorithm_counts
306}
307
308/// MSC4225: drop the `excess` oldest rows for this `(user, device)`. Forward
309/// iteration over the prefix runs in count_be ascending order, so
310/// `take(excess)` yields the earliest-uploaded rows.
311#[implement(super::Service)]
312pub async fn prune_one_time_keys(&self, user_id: &UserId, device_id: &DeviceId, excess: usize) {
313	let Some(otk) = self.db.onetimekeyid4225_otk.as_ref() else {
314		return;
315	};
316
317	let prefix = (user_id, device_id, Interfix);
318	otk.keys_prefix(&prefix)
319		.ignore_err()
320		.take(excess)
321		.ready_for_each(|row: OtkRowKey<'_>| {
322			otk.del(row);
323		})
324		.await;
325}
326
327#[implement(super::Service)]
328pub async fn add_device_keys(
329	&self,
330	user_id: &UserId,
331	device_id: &DeviceId,
332	device_keys: &Raw<DeviceKeys>,
333) {
334	let key = (user_id, device_id);
335
336	self.db.keyid_key.put(key, Json(device_keys));
337	self.mark_device_key_update(user_id).await;
338}
339
340#[implement(super::Service)]
341pub async fn add_cross_signing_keys(
342	&self,
343	user_id: &UserId,
344	master_key: &Option<Raw<CrossSigningKey>>,
345	self_signing_key: &Option<Raw<CrossSigningKey>>,
346	user_signing_key: &Option<Raw<CrossSigningKey>>,
347	notify: bool,
348) -> Result {
349	// TODO: Check signatures
350	let mut prefix = user_id.as_bytes().to_vec();
351	prefix.push(0xFF);
352
353	if let Some(master_key) = master_key {
354		let (master_key_key, _) = parse_master_key(user_id, master_key)?;
355
356		self.db
357			.keyid_key
358			.insert(&master_key_key, master_key.json().get().as_bytes());
359
360		self.db
361			.userid_masterkeyid
362			.insert(user_id.as_bytes(), &master_key_key);
363	}
364
365	// Self-signing key
366	if let Some(self_signing_key) = self_signing_key {
367		let mut self_signing_key_ids = self_signing_key
368			.deserialize()
369			.map_err(|e| err!(Request(InvalidParam("Invalid self signing key: {e:?}"))))?
370			.keys
371			.into_values();
372
373		let self_signing_key_id = self_signing_key_ids
374			.next()
375			.ok_or_else(|| err!(Request(InvalidParam("Self signing key contained no key."))))?;
376
377		if self_signing_key_ids.next().is_some() {
378			return Err!(Request(InvalidParam("Self signing key contained more than one key.")));
379		}
380
381		let mut self_signing_key_key = prefix.clone();
382		self_signing_key_key.extend_from_slice(self_signing_key_id.as_bytes());
383
384		self.db
385			.keyid_key
386			.insert(&self_signing_key_key, self_signing_key.json().get().as_bytes());
387
388		self.db
389			.userid_selfsigningkeyid
390			.insert(user_id.as_bytes(), &self_signing_key_key);
391	}
392
393	// User-signing key
394	if let Some(user_signing_key) = user_signing_key {
395		let user_signing_key_id = parse_user_signing_key(user_signing_key)?;
396
397		let user_signing_key_key = (user_id, &user_signing_key_id);
398		self.db
399			.keyid_key
400			.put_raw(user_signing_key_key, user_signing_key.json().get().as_bytes());
401
402		self.db
403			.userid_usersigningkeyid
404			.raw_put(user_id, user_signing_key_key);
405	}
406
407	if notify {
408		self.mark_device_key_update(user_id).await;
409	}
410
411	Ok(())
412}
413
414#[implement(super::Service)]
415pub async fn sign_key(
416	&self,
417	target_id: &UserId,
418	key_id: &str,
419	signature: (String, String),
420	sender_id: &UserId,
421) -> Result {
422	let key = (target_id, key_id);
423
424	let mut cross_signing_key: serde_json::Value = self
425		.db
426		.keyid_key
427		.qry(&key)
428		.await
429		.map_err(|_| err!(Request(InvalidParam("Tried to sign nonexistent key"))))?
430		.deserialized()
431		.map_err(|e| err!(Database(debug_warn!("key in keyid_key is invalid: {e:?}"))))?;
432
433	let signatures = cross_signing_key
434		.get_mut("signatures")
435		.ok_or_else(|| err!(Database(debug_warn!("key in keyid_key has no signatures field"))))?
436		.as_object_mut()
437		.ok_or_else(|| {
438			err!(Database(debug_warn!("key in keyid_key has invalid signatures field.")))
439		})?
440		.entry(sender_id.to_string())
441		.or_insert_with(|| serde_json::Map::new().into());
442
443	signatures
444		.as_object_mut()
445		.ok_or_else(|| {
446			err!(Database(debug_warn!("signatures in keyid_key for a user is invalid.")))
447		})?
448		.insert(signature.0, signature.1.into());
449
450	let key = (target_id, key_id);
451	self.db
452		.keyid_key
453		.put(key, Json(cross_signing_key));
454
455	self.mark_device_key_update(target_id).await;
456
457	Ok(())
458}
459
460#[implement(super::Service)]
461#[inline]
462pub fn keys_changed<'a>(
463	&'a self,
464	user_id: &'a UserId,
465	from: u64,
466	to: Option<u64>,
467) -> impl Stream<Item = &UserId> + Send + 'a {
468	self.keys_changed_user_or_room(user_id.as_str(), from, to)
469		.map(|(user_id, ..)| user_id)
470}
471
472#[implement(super::Service)]
473#[inline]
474pub fn room_keys_changed<'a>(
475	&'a self,
476	room_id: &'a RoomId,
477	from: u64,
478	to: Option<u64>,
479) -> impl Stream<Item = (&UserId, u64)> + Send + 'a {
480	self.keys_changed_user_or_room(room_id.as_str(), from, to)
481}
482
483#[implement(super::Service)]
484fn keys_changed_user_or_room<'a>(
485	&'a self,
486	user_or_room_id: &'a str,
487	from: u64,
488	to: Option<u64>,
489) -> impl Stream<Item = (&UserId, u64)> + Send + 'a {
490	type KeyVal<'a> = ((&'a str, u64), &'a UserId);
491
492	let to = to.unwrap_or(u64::MAX);
493	let start = (user_or_room_id, from.saturating_add(1));
494	self.db
495		.keychangeid_userid
496		.stream_from(&start)
497		.ignore_err()
498		.ready_take_while(move |((prefix, count), _): &KeyVal<'_>| {
499			*prefix == user_or_room_id && *count <= to
500		})
501		.map(|((_, count), user_id): KeyVal<'_>| (user_id, count))
502}
503
504#[implement(super::Service)]
505pub async fn mark_device_key_update(&self, user_id: &UserId) {
506	let update_all_rooms = !self
507		.services
508		.config
509		.device_key_update_encrypted_rooms_only;
510
511	let all_or_is_encrypted = async |room_id: &RoomId| {
512		update_all_rooms
513			|| self
514				.services
515				.state_accessor
516				.is_encrypted_room(room_id)
517				.await
518	};
519
520	let count = self.services.globals.next_count();
521	let user_key = (user_id, *count);
522
523	self.db
524		.keychangeid_userid
525		.put_raw(user_key, user_id);
526	self.services
527		.state_cache
528		.rooms_joined(user_id)
529		.filter(|room_id| all_or_is_encrypted(*room_id))
530		.ready_for_each(|room_id| {
531			let room_key = (room_id, *count);
532			self.db
533				.keychangeid_userid
534				.put_raw(room_key, user_id);
535		})
536		.await;
537
538	if !self.services.globals.user_is_local(user_id) {
539		return;
540	}
541
542	// device_list_update EDUs reach remote servers only on a sender flush.
543	let mut servers: Servers = self
544		.services
545		.state_cache
546		.rooms_joined(user_id)
547		.filter(|room_id| all_or_is_encrypted(*room_id))
548		.map(ToOwned::to_owned)
549		.broad_then(async |room_id: OwnedRoomId| {
550			self.services
551				.state_cache
552				.room_servers(&room_id)
553				.ready_filter(|server| !self.services.globals.server_is_ours(server))
554				.map(ToOwned::to_owned)
555				.collect()
556				.await
557		})
558		.flat_map(|servers: Vec<OwnedServerName>| servers.into_iter().stream())
559		.collect()
560		.await;
561
562	servers.sort_unstable();
563	servers.dedup();
564
565	self.services
566		.sending
567		.flush_servers(servers.iter().map(|server| &**server).stream())
568		.await
569		.expect("device key update flush failed");
570}
571
572#[implement(super::Service)]
573pub async fn get_device_keys<'a>(
574	&'a self,
575	user_id: &'a UserId,
576	device_id: &DeviceId,
577) -> Result<Raw<DeviceKeys>> {
578	let key_id = (user_id, device_id);
579	self.db
580		.keyid_key
581		.qry(&key_id)
582		.await
583		.deserialized()
584}
585
586#[implement(super::Service)]
587pub async fn get_key<F>(
588	&self,
589	key_id: &[u8],
590	sender_user: Option<&UserId>,
591	user_id: &UserId,
592	allowed_signatures: &F,
593) -> Result<Raw<CrossSigningKey>>
594where
595	F: Fn(&UserId) -> bool + Send + Sync,
596{
597	let key: serde_json::Value = self
598		.db
599		.keyid_key
600		.get(key_id)
601		.await
602		.deserialized()?;
603
604	let cleaned = clean_signatures(key, sender_user, user_id, allowed_signatures)?;
605	let raw_value = serde_json::value::to_raw_value(&cleaned)?;
606
607	Ok(Raw::from_json(raw_value))
608}
609
610#[implement(super::Service)]
611pub async fn get_master_key<F>(
612	&self,
613	sender_user: Option<&UserId>,
614	user_id: &UserId,
615	allowed_signatures: &F,
616) -> Result<Raw<CrossSigningKey>>
617where
618	F: Fn(&UserId) -> bool + Send + Sync,
619{
620	let key_id = self.db.userid_masterkeyid.get(user_id).await?;
621
622	self.get_key(&key_id, sender_user, user_id, allowed_signatures)
623		.await
624}
625
626#[implement(super::Service)]
627pub async fn get_self_signing_key<F>(
628	&self,
629	sender_user: Option<&UserId>,
630	user_id: &UserId,
631	allowed_signatures: &F,
632) -> Result<Raw<CrossSigningKey>>
633where
634	F: Fn(&UserId) -> bool + Send + Sync,
635{
636	let key_id = self
637		.db
638		.userid_selfsigningkeyid
639		.get(user_id)
640		.await?;
641
642	self.get_key(&key_id, sender_user, user_id, allowed_signatures)
643		.await
644}
645
646#[implement(super::Service)]
647pub async fn get_user_signing_key(&self, user_id: &UserId) -> Result<Raw<CrossSigningKey>> {
648	self.db
649		.userid_usersigningkeyid
650		.get(user_id)
651		.and_then(|key_id| self.db.keyid_key.get(&*key_id))
652		.await
653		.deserialized()
654}
655
656pub fn parse_master_key(
657	user_id: &UserId,
658	master_key: &Raw<CrossSigningKey>,
659) -> Result<(Vec<u8>, CrossSigningKey)> {
660	let mut prefix = user_id.as_bytes().to_vec();
661	prefix.push(0xFF);
662
663	let master_key = master_key
664		.deserialize()
665		.map_err(|_| err!(Request(InvalidParam("Invalid master key"))))?;
666
667	let mut master_key_ids = master_key.keys.values();
668	let master_key_id = master_key_ids
669		.next()
670		.ok_or(err!(Request(InvalidParam("Master key contained no key."))))?;
671
672	if master_key_ids.next().is_some() {
673		return Err!(Request(InvalidParam("Master key contained more than one key.")));
674	}
675
676	let mut master_key_key = prefix.clone();
677	master_key_key.extend_from_slice(master_key_id.as_bytes());
678
679	Ok((master_key_key, master_key))
680}
681
682pub(super) fn parse_user_signing_key(user_signing_key: &Raw<CrossSigningKey>) -> Result<String> {
683	let mut user_signing_key_ids = user_signing_key
684		.deserialize()
685		.map_err(|_| err!(Request(InvalidParam("Invalid user signing key"))))?
686		.keys
687		.into_values();
688
689	let user_signing_key_id = user_signing_key_ids
690		.next()
691		.ok_or(err!(Request(InvalidParam("User signing key contained no key."))))?;
692
693	if user_signing_key_ids.next().is_some() {
694		return Err!(Request(InvalidParam("User signing key contained more than one key.")));
695	}
696
697	Ok(user_signing_key_id)
698}
699
700/// Ensure that a user only sees signatures from themselves and the target user
701fn clean_signatures<F>(
702	mut cross_signing_key: serde_json::Value,
703	sender_user: Option<&UserId>,
704	user_id: &UserId,
705	allowed_signatures: &F,
706) -> Result<serde_json::Value>
707where
708	F: Fn(&UserId) -> bool + Send + Sync,
709{
710	if let Some(signatures) = cross_signing_key
711		.get_mut("signatures")
712		.and_then(|v| v.as_object_mut())
713	{
714		// Don't allocate for the full size of the current signatures, but require
715		// at most one resize if nothing is dropped
716		let new_capacity = signatures.len() / 2;
717		for (user, signature) in
718			mem::replace(signatures, serde_json::Map::with_capacity(new_capacity))
719		{
720			let sid = <&UserId>::try_from(user.as_str())
721				.map_err(|e| err!(Database("Invalid user ID in database: {e}")))?;
722
723			if sender_user == Some(user_id) || sid == user_id || allowed_signatures(sid) {
724				signatures.insert(user, signature);
725			}
726		}
727	}
728
729	Ok(cross_signing_key)
730}