Skip to main content

tuwunel_service/users/
keys.rs

1use std::{collections::BTreeMap, mem};
2
3use futures::{Stream, StreamExt, TryFutureExt, pin_mut};
4use ruma::{
5	DeviceId, KeyId, OneTimeKeyAlgorithm, OneTimeKeyId, OneTimeKeyName, OwnedKeyId,
6	OwnedOneTimeKeyId, RoomId, UInt, UserId,
7	encryption::{CrossSigningKey, DeviceKeys, OneTimeKey},
8	serde::Raw,
9};
10use serde::{Deserialize, Serialize};
11use tuwunel_core::{
12	Err, Result, debug_error, err, implement,
13	utils::{BoolExt, ReadyExt, stream::TryIgnore},
14};
15use tuwunel_database::{Deserialized, Ignore, Interfix, Json};
16
17/// MSC2732: row stored under `(user, device, algorithm)` in
18/// `userdeviceidalgorithm_fallback`. Fallback keys are not deleted on
19/// claim; the row is rewritten with `used = true`.
20#[derive(Debug, Deserialize, Serialize)]
21struct FallbackEntry {
22	key_id: OwnedOneTimeKeyId,
23	key: Raw<OneTimeKey>,
24	used: bool,
25}
26
27/// Row-key shape of `onetimekeyid4225_otk`: per-device pool keyed by
28/// upload-order count for MSC4225 ordering.
29type OtkRowKey<'a> = (&'a UserId, &'a DeviceId, u64, &'a OneTimeKeyId);
30
31#[implement(super::Service)]
32pub async fn add_one_time_keys<'a, Keys>(
33	&self,
34	user_id: &UserId,
35	device_id: &DeviceId,
36	keys: Keys,
37) -> Result
38where
39	Keys: Iterator<Item = (&'a OneTimeKeyId, &'a Raw<OneTimeKey>)> + Send + 'a,
40{
41	for (id, key) in keys {
42		self.add_one_time_key(user_id, device_id, id, key)
43			.await
44			.ok();
45	}
46
47	Ok(())
48}
49
50#[implement(super::Service)]
51pub async fn add_one_time_key(
52	&self,
53	user_id: &UserId,
54	device_id: &DeviceId,
55	one_time_key_key: &KeyId<OneTimeKeyAlgorithm, OneTimeKeyName>,
56	one_time_key_value: &Raw<OneTimeKey>,
57) -> Result {
58	if !self.device_exists(user_id, device_id).await {
59		return Err!(Database(error!(
60			?user_id,
61			?device_id,
62			"User does not exist or device has no metadata."
63		)));
64	}
65
66	if let Err(e) = one_time_key_value
67		.deserialize()
68		.map_err(Into::into)
69	{
70		debug_error!(
71			?one_time_key_key,
72			?one_time_key_value,
73			"Invalid one time key JSON submitted by client, skipping: {e}"
74		);
75
76		return Err(e);
77	}
78
79	// Racy dedup: two concurrent uploads of the same id can both pass this
80	// check and produce duplicate rows that persist until aged out by prune.
81	let prefix = (user_id, device_id, Interfix);
82	let already_present = self
83		.db
84		.onetimekeyid4225_otk
85		.keys_prefix(&prefix)
86		.ignore_err()
87		.ready_any(|(.., id): OtkRowKey<'_>| id == one_time_key_key)
88		.await;
89
90	if already_present {
91		return Ok(());
92	}
93
94	let count = self.services.globals.next_count();
95
96	// MSC4225: RocksDB iterates the (user, device) prefix in count_be ascending
97	// order, so /keys/claim issues one-time keys in the order they were uploaded.
98	self.db.onetimekeyid4225_otk.put(
99		(user_id, device_id, *count, one_time_key_key.as_str()),
100		Json(one_time_key_value),
101	);
102
103	self.db
104		.userid_lastonetimekeyupdate
105		.raw_put(user_id, *count);
106
107	Ok(())
108}
109
110#[implement(super::Service)]
111pub async fn add_fallback_keys<'a, Keys>(
112	&self,
113	user_id: &UserId,
114	device_id: &DeviceId,
115	keys: Keys,
116) -> Result
117where
118	Keys: Iterator<Item = (&'a OneTimeKeyId, &'a Raw<OneTimeKey>)> + Send + 'a,
119{
120	for (id, key) in keys {
121		self.add_fallback_key(user_id, device_id, id, key)
122			.await
123			.ok();
124	}
125
126	Ok(())
127}
128
129#[implement(super::Service)]
130pub async fn add_fallback_key(
131	&self,
132	user_id: &UserId,
133	device_id: &DeviceId,
134	one_time_key_key: &KeyId<OneTimeKeyAlgorithm, OneTimeKeyName>,
135	one_time_key_value: &Raw<OneTimeKey>,
136) -> Result {
137	if !self.device_exists(user_id, device_id).await {
138		return Err!(Database(error!(
139			?user_id,
140			?device_id,
141			"User does not exist or device has no metadata."
142		)));
143	}
144
145	if let Err(e) = one_time_key_value
146		.deserialize()
147		.map_err(Into::into)
148	{
149		debug_error!(
150			?one_time_key_key,
151			?one_time_key_value,
152			"Invalid fallback key JSON submitted by client, skipping: {e}"
153		);
154
155		return Err(e);
156	}
157
158	let entry = FallbackEntry {
159		key_id: one_time_key_key.to_owned(),
160		key: one_time_key_value.clone(),
161		used: false,
162	};
163
164	let key = (user_id, device_id, one_time_key_key.algorithm());
165	self.db
166		.userdeviceidalgorithm_fallback
167		.put(key, Json(&entry));
168
169	let count = self.services.globals.next_count();
170	self.db
171		.userid_lastonetimekeyupdate
172		.raw_put(user_id, *count);
173
174	Ok(())
175}
176
177#[implement(super::Service)]
178pub async fn take_fallback_key(
179	&self,
180	user_id: &UserId,
181	device_id: &DeviceId,
182	algorithm: &OneTimeKeyAlgorithm,
183) -> Result<(OwnedKeyId<OneTimeKeyAlgorithm, OneTimeKeyName>, Raw<OneTimeKey>)> {
184	let key = (user_id, device_id, algorithm);
185	let entry: FallbackEntry = self
186		.db
187		.userdeviceidalgorithm_fallback
188		.qry(&key)
189		.await
190		.deserialized::<Json<_>>()
191		.map(|Json(entry)| entry)
192		.map_err(|_| err!(Request(NotFound("No fallback key found"))))?;
193
194	let updated = FallbackEntry { used: true, ..entry };
195	self.db
196		.userdeviceidalgorithm_fallback
197		.put(key, Json(&updated));
198
199	Ok((updated.key_id, updated.key))
200}
201
202#[implement(super::Service)]
203pub fn unused_fallback_key_algorithms<'a>(
204	&'a self,
205	user_id: &'a UserId,
206	device_id: &'a DeviceId,
207) -> impl Stream<Item = OneTimeKeyAlgorithm> + Send + 'a {
208	type KeyVal = ((Ignore, Ignore, OneTimeKeyAlgorithm), Json<FallbackEntry>);
209
210	let prefix = (user_id, device_id);
211	self.db
212		.userdeviceidalgorithm_fallback
213		.stream_prefix(&prefix)
214		.ignore_err()
215		.ready_filter_map(|((_, _, algorithm), Json(entry)): KeyVal| {
216			entry.used.is_false().then_some(algorithm)
217		})
218}
219
220#[implement(super::Service)]
221pub async fn last_one_time_keys_update(&self, user_id: &UserId) -> u64 {
222	self.db
223		.userid_lastonetimekeyupdate
224		.get(user_id)
225		.await
226		.deserialized()
227		.unwrap_or(0)
228}
229
230#[implement(super::Service)]
231pub async fn take_one_time_key(
232	&self,
233	user_id: &UserId,
234	device_id: &DeviceId,
235	key_algorithm: &OneTimeKeyAlgorithm,
236) -> Result<(OwnedKeyId<OneTimeKeyAlgorithm, OneTimeKeyName>, Raw<OneTimeKey>)> {
237	let update_count = self.services.globals.next_count();
238	self.db
239		.userid_lastonetimekeyupdate
240		.insert(user_id, update_count.to_be_bytes());
241
242	let prefix = (user_id, device_id, Interfix);
243	let one_time_keys = self
244		.db
245		.onetimekeyid4225_otk
246		.stream_prefix(&prefix)
247		.ignore_err()
248		.ready_filter(|(row, _): &(OtkRowKey<'_>, &[u8])| row.3.algorithm() == *key_algorithm);
249
250	pin_mut!(one_time_keys);
251	let ((user_id, device_id, count, id), val) = one_time_keys
252		.next()
253		.await
254		.ok_or_else(|| err!(Request(NotFound("No one-time-key found"))))?;
255
256	self.db
257		.onetimekeyid4225_otk
258		.del((user_id, device_id, count, id));
259
260	Ok((id.into(), serde_json::from_slice(val)?))
261}
262
263#[implement(super::Service)]
264pub async fn count_one_time_keys(
265	&self,
266	user_id: &UserId,
267	device_id: &DeviceId,
268) -> BTreeMap<OneTimeKeyAlgorithm, UInt> {
269	let prefix = (user_id, device_id, Interfix);
270	let algorithm_counts: BTreeMap<OneTimeKeyAlgorithm, UInt> = self
271		.db
272		.onetimekeyid4225_otk
273		.keys_prefix(&prefix)
274		.ignore_err()
275		.ready_fold(BTreeMap::new(), |mut acc, (.., id): OtkRowKey<'_>| {
276			let count: &mut UInt = acc.entry(id.algorithm()).or_default();
277			*count = count.saturating_add(1_u32.into());
278			acc
279		})
280		.await;
281
282	let total = algorithm_counts
283		.values()
284		.copied()
285		.map(TryInto::try_into)
286		.filter_map(Result::ok)
287		.fold(0_usize, usize::saturating_add);
288
289	let limit = self.services.config.one_time_key_limit;
290	if let Some(excess) = total.checked_sub(limit).filter(|&n| n > 0) {
291		self.prune_one_time_keys(user_id, device_id, excess)
292			.await;
293	}
294
295	algorithm_counts
296}
297
298/// MSC4225: drop the `excess` oldest rows for this `(user, device)`. Forward
299/// iteration over the prefix runs in count_be ascending order, so
300/// `take(excess)` yields the earliest-uploaded rows.
301#[implement(super::Service)]
302pub async fn prune_one_time_keys(&self, user_id: &UserId, device_id: &DeviceId, excess: usize) {
303	let prefix = (user_id, device_id, Interfix);
304	self.db
305		.onetimekeyid4225_otk
306		.keys_prefix(&prefix)
307		.ignore_err()
308		.take(excess)
309		.ready_for_each(|row: OtkRowKey<'_>| {
310			self.db.onetimekeyid4225_otk.del(row);
311		})
312		.await;
313}
314
315#[implement(super::Service)]
316pub async fn add_device_keys(
317	&self,
318	user_id: &UserId,
319	device_id: &DeviceId,
320	device_keys: &Raw<DeviceKeys>,
321) {
322	let key = (user_id, device_id);
323
324	self.db.keyid_key.put(key, Json(device_keys));
325	self.mark_device_key_update(user_id).await;
326}
327
328#[implement(super::Service)]
329pub async fn add_cross_signing_keys(
330	&self,
331	user_id: &UserId,
332	master_key: &Option<Raw<CrossSigningKey>>,
333	self_signing_key: &Option<Raw<CrossSigningKey>>,
334	user_signing_key: &Option<Raw<CrossSigningKey>>,
335	notify: bool,
336) -> Result {
337	// TODO: Check signatures
338	let mut prefix = user_id.as_bytes().to_vec();
339	prefix.push(0xFF);
340
341	if let Some(master_key) = master_key {
342		let (master_key_key, _) = parse_master_key(user_id, master_key)?;
343
344		self.db
345			.keyid_key
346			.insert(&master_key_key, master_key.json().get().as_bytes());
347
348		self.db
349			.userid_masterkeyid
350			.insert(user_id.as_bytes(), &master_key_key);
351	}
352
353	// Self-signing key
354	if let Some(self_signing_key) = self_signing_key {
355		let mut self_signing_key_ids = self_signing_key
356			.deserialize()
357			.map_err(|e| err!(Request(InvalidParam("Invalid self signing key: {e:?}"))))?
358			.keys
359			.into_values();
360
361		let self_signing_key_id = self_signing_key_ids
362			.next()
363			.ok_or_else(|| err!(Request(InvalidParam("Self signing key contained no key."))))?;
364
365		if self_signing_key_ids.next().is_some() {
366			return Err!(Request(InvalidParam("Self signing key contained more than one key.")));
367		}
368
369		let mut self_signing_key_key = prefix.clone();
370		self_signing_key_key.extend_from_slice(self_signing_key_id.as_bytes());
371
372		self.db
373			.keyid_key
374			.insert(&self_signing_key_key, self_signing_key.json().get().as_bytes());
375
376		self.db
377			.userid_selfsigningkeyid
378			.insert(user_id.as_bytes(), &self_signing_key_key);
379	}
380
381	// User-signing key
382	if let Some(user_signing_key) = user_signing_key {
383		let user_signing_key_id = parse_user_signing_key(user_signing_key)?;
384
385		let user_signing_key_key = (user_id, &user_signing_key_id);
386		self.db
387			.keyid_key
388			.put_raw(user_signing_key_key, user_signing_key.json().get().as_bytes());
389
390		self.db
391			.userid_usersigningkeyid
392			.raw_put(user_id, user_signing_key_key);
393	}
394
395	if notify {
396		self.mark_device_key_update(user_id).await;
397	}
398
399	Ok(())
400}
401
402#[implement(super::Service)]
403pub async fn sign_key(
404	&self,
405	target_id: &UserId,
406	key_id: &str,
407	signature: (String, String),
408	sender_id: &UserId,
409) -> Result {
410	let key = (target_id, key_id);
411
412	let mut cross_signing_key: serde_json::Value = self
413		.db
414		.keyid_key
415		.qry(&key)
416		.await
417		.map_err(|_| err!(Request(InvalidParam("Tried to sign nonexistent key"))))?
418		.deserialized()
419		.map_err(|e| err!(Database(debug_warn!("key in keyid_key is invalid: {e:?}"))))?;
420
421	let signatures = cross_signing_key
422		.get_mut("signatures")
423		.ok_or_else(|| err!(Database(debug_warn!("key in keyid_key has no signatures field"))))?
424		.as_object_mut()
425		.ok_or_else(|| {
426			err!(Database(debug_warn!("key in keyid_key has invalid signatures field.")))
427		})?
428		.entry(sender_id.to_string())
429		.or_insert_with(|| serde_json::Map::new().into());
430
431	signatures
432		.as_object_mut()
433		.ok_or_else(|| {
434			err!(Database(debug_warn!("signatures in keyid_key for a user is invalid.")))
435		})?
436		.insert(signature.0, signature.1.into());
437
438	let key = (target_id, key_id);
439	self.db
440		.keyid_key
441		.put(key, Json(cross_signing_key));
442
443	self.mark_device_key_update(target_id).await;
444
445	Ok(())
446}
447
448#[implement(super::Service)]
449#[inline]
450pub fn keys_changed<'a>(
451	&'a self,
452	user_id: &'a UserId,
453	from: u64,
454	to: Option<u64>,
455) -> impl Stream<Item = &UserId> + Send + 'a {
456	self.keys_changed_user_or_room(user_id.as_str(), from, to)
457		.map(|(user_id, ..)| user_id)
458}
459
460#[implement(super::Service)]
461#[inline]
462pub fn room_keys_changed<'a>(
463	&'a self,
464	room_id: &'a RoomId,
465	from: u64,
466	to: Option<u64>,
467) -> impl Stream<Item = (&UserId, u64)> + Send + 'a {
468	self.keys_changed_user_or_room(room_id.as_str(), from, to)
469}
470
471#[implement(super::Service)]
472fn keys_changed_user_or_room<'a>(
473	&'a self,
474	user_or_room_id: &'a str,
475	from: u64,
476	to: Option<u64>,
477) -> impl Stream<Item = (&UserId, u64)> + Send + 'a {
478	type KeyVal<'a> = ((&'a str, u64), &'a UserId);
479
480	let to = to.unwrap_or(u64::MAX);
481	let start = (user_or_room_id, from.saturating_add(1));
482	self.db
483		.keychangeid_userid
484		.stream_from(&start)
485		.ignore_err()
486		.ready_take_while(move |((prefix, count), _): &KeyVal<'_>| {
487			*prefix == user_or_room_id && *count <= to
488		})
489		.map(|((_, count), user_id): KeyVal<'_>| (user_id, count))
490}
491
492#[implement(super::Service)]
493pub async fn mark_device_key_update(&self, user_id: &UserId) {
494	let update_all_rooms = !self
495		.services
496		.config
497		.device_key_update_encrypted_rooms_only;
498
499	let all_or_is_encrypted = async |room_id: &RoomId| {
500		update_all_rooms
501			|| self
502				.services
503				.state_accessor
504				.is_encrypted_room(room_id)
505				.await
506	};
507
508	let count = self.services.globals.next_count();
509	let user_key = (user_id, *count);
510
511	self.db
512		.keychangeid_userid
513		.put_raw(user_key, user_id);
514	self.services
515		.state_cache
516		.rooms_joined(user_id)
517		.filter(|room_id| all_or_is_encrypted(*room_id))
518		.ready_for_each(|room_id| {
519			let room_key = (room_id, *count);
520			self.db
521				.keychangeid_userid
522				.put_raw(room_key, user_id);
523		})
524		.await;
525}
526
527#[implement(super::Service)]
528pub async fn get_device_keys<'a>(
529	&'a self,
530	user_id: &'a UserId,
531	device_id: &DeviceId,
532) -> Result<Raw<DeviceKeys>> {
533	let key_id = (user_id, device_id);
534	self.db
535		.keyid_key
536		.qry(&key_id)
537		.await
538		.deserialized()
539}
540
541#[implement(super::Service)]
542pub async fn get_key<F>(
543	&self,
544	key_id: &[u8],
545	sender_user: Option<&UserId>,
546	user_id: &UserId,
547	allowed_signatures: &F,
548) -> Result<Raw<CrossSigningKey>>
549where
550	F: Fn(&UserId) -> bool + Send + Sync,
551{
552	let key: serde_json::Value = self
553		.db
554		.keyid_key
555		.get(key_id)
556		.await
557		.deserialized()?;
558
559	let cleaned = clean_signatures(key, sender_user, user_id, allowed_signatures)?;
560	let raw_value = serde_json::value::to_raw_value(&cleaned)?;
561
562	Ok(Raw::from_json(raw_value))
563}
564
565#[implement(super::Service)]
566pub async fn get_master_key<F>(
567	&self,
568	sender_user: Option<&UserId>,
569	user_id: &UserId,
570	allowed_signatures: &F,
571) -> Result<Raw<CrossSigningKey>>
572where
573	F: Fn(&UserId) -> bool + Send + Sync,
574{
575	let key_id = self.db.userid_masterkeyid.get(user_id).await?;
576
577	self.get_key(&key_id, sender_user, user_id, allowed_signatures)
578		.await
579}
580
581#[implement(super::Service)]
582pub async fn get_self_signing_key<F>(
583	&self,
584	sender_user: Option<&UserId>,
585	user_id: &UserId,
586	allowed_signatures: &F,
587) -> Result<Raw<CrossSigningKey>>
588where
589	F: Fn(&UserId) -> bool + Send + Sync,
590{
591	let key_id = self
592		.db
593		.userid_selfsigningkeyid
594		.get(user_id)
595		.await?;
596
597	self.get_key(&key_id, sender_user, user_id, allowed_signatures)
598		.await
599}
600
601#[implement(super::Service)]
602pub async fn get_user_signing_key(&self, user_id: &UserId) -> Result<Raw<CrossSigningKey>> {
603	self.db
604		.userid_usersigningkeyid
605		.get(user_id)
606		.and_then(|key_id| self.db.keyid_key.get(&*key_id))
607		.await
608		.deserialized()
609}
610
611pub fn parse_master_key(
612	user_id: &UserId,
613	master_key: &Raw<CrossSigningKey>,
614) -> Result<(Vec<u8>, CrossSigningKey)> {
615	let mut prefix = user_id.as_bytes().to_vec();
616	prefix.push(0xFF);
617
618	let master_key = master_key
619		.deserialize()
620		.map_err(|_| err!(Request(InvalidParam("Invalid master key"))))?;
621
622	let mut master_key_ids = master_key.keys.values();
623	let master_key_id = master_key_ids
624		.next()
625		.ok_or(err!(Request(InvalidParam("Master key contained no key."))))?;
626
627	if master_key_ids.next().is_some() {
628		return Err!(Request(InvalidParam("Master key contained more than one key.")));
629	}
630
631	let mut master_key_key = prefix.clone();
632	master_key_key.extend_from_slice(master_key_id.as_bytes());
633
634	Ok((master_key_key, master_key))
635}
636
637pub(super) fn parse_user_signing_key(user_signing_key: &Raw<CrossSigningKey>) -> Result<String> {
638	let mut user_signing_key_ids = user_signing_key
639		.deserialize()
640		.map_err(|_| err!(Request(InvalidParam("Invalid user signing key"))))?
641		.keys
642		.into_values();
643
644	let user_signing_key_id = user_signing_key_ids
645		.next()
646		.ok_or(err!(Request(InvalidParam("User signing key contained no key."))))?;
647
648	if user_signing_key_ids.next().is_some() {
649		return Err!(Request(InvalidParam("User signing key contained more than one key.")));
650	}
651
652	Ok(user_signing_key_id)
653}
654
655/// Ensure that a user only sees signatures from themselves and the target user
656fn clean_signatures<F>(
657	mut cross_signing_key: serde_json::Value,
658	sender_user: Option<&UserId>,
659	user_id: &UserId,
660	allowed_signatures: &F,
661) -> Result<serde_json::Value>
662where
663	F: Fn(&UserId) -> bool + Send + Sync,
664{
665	if let Some(signatures) = cross_signing_key
666		.get_mut("signatures")
667		.and_then(|v| v.as_object_mut())
668	{
669		// Don't allocate for the full size of the current signatures, but require
670		// at most one resize if nothing is dropped
671		let new_capacity = signatures.len() / 2;
672		for (user, signature) in
673			mem::replace(signatures, serde_json::Map::with_capacity(new_capacity))
674		{
675			let sid = <&UserId>::try_from(user.as_str())
676				.map_err(|e| err!(Database("Invalid user ID in database: {e}")))?;
677
678			if sender_user == Some(user_id) || sid == user_id || allowed_signatures(sid) {
679				signatures.insert(user, signature);
680			}
681		}
682	}
683
684	Ok(cross_signing_key)
685}