1use std::{collections::BTreeMap, mem};
2
3use futures::{Stream, StreamExt, TryFutureExt, pin_mut};
4use ruma::{
5 DeviceId, KeyId, OneTimeKeyAlgorithm, OneTimeKeyId, OneTimeKeyName, OwnedKeyId,
6 OwnedOneTimeKeyId, OwnedRoomId, OwnedServerName, RoomId, UInt, UserId,
7 encryption::{CrossSigningKey, DeviceKeys, OneTimeKey},
8 serde::Raw,
9};
10use serde::{Deserialize, Serialize};
11use tuwunel_core::{
12 Err, Result, debug_error, err, implement,
13 smallvec::SmallVec,
14 utils::{
15 BoolExt, IterStream, ReadyExt,
16 stream::{BroadbandExt, TryIgnore},
17 },
18};
19use tuwunel_database::{Deserialized, Ignore, Interfix, Json};
20
21type Servers = SmallVec<[OwnedServerName; 1]>;
22
23#[derive(Debug, Deserialize, Serialize)]
27struct FallbackEntry {
28 key_id: OwnedOneTimeKeyId,
29 key: Raw<OneTimeKey>,
30 used: bool,
31}
32
33type OtkRowKey<'a> = (&'a UserId, &'a DeviceId, u64, &'a OneTimeKeyId);
36
37#[implement(super::Service)]
38pub async fn add_one_time_keys<'a, Keys>(
39 &self,
40 user_id: &UserId,
41 device_id: &DeviceId,
42 keys: Keys,
43) -> Result
44where
45 Keys: Iterator<Item = (&'a OneTimeKeyId, &'a Raw<OneTimeKey>)> + Send + 'a,
46{
47 for (id, key) in keys {
48 self.add_one_time_key(user_id, device_id, id, key)
49 .await
50 .ok();
51 }
52
53 Ok(())
54}
55
56#[implement(super::Service)]
57pub async fn add_one_time_key(
58 &self,
59 user_id: &UserId,
60 device_id: &DeviceId,
61 one_time_key_key: &KeyId<OneTimeKeyAlgorithm, OneTimeKeyName>,
62 one_time_key_value: &Raw<OneTimeKey>,
63) -> Result {
64 let Some(otk) = self.db.onetimekeyid4225_otk.as_ref() else {
65 return Err!(Database("one-time-key column unavailable"));
66 };
67
68 if !self.device_exists(user_id, device_id).await {
69 return Err!(Database(error!(
70 ?user_id,
71 ?device_id,
72 "User does not exist or device has no metadata."
73 )));
74 }
75
76 if let Err(e) = one_time_key_value
77 .deserialize()
78 .map_err(Into::into)
79 {
80 debug_error!(
81 ?one_time_key_key,
82 ?one_time_key_value,
83 "Invalid one time key JSON submitted by client, skipping: {e}"
84 );
85
86 return Err(e);
87 }
88
89 let prefix = (user_id, device_id, Interfix);
92 let already_present = otk
93 .keys_prefix(&prefix)
94 .ignore_err()
95 .ready_any(|(.., id): OtkRowKey<'_>| id == one_time_key_key)
96 .await;
97
98 if already_present {
99 return Ok(());
100 }
101
102 let count = self.services.globals.next_count();
103
104 otk.put(
107 (user_id, device_id, *count, one_time_key_key.as_str()),
108 Json(one_time_key_value),
109 );
110
111 self.db
112 .userid_lastonetimekeyupdate
113 .raw_put(user_id, *count);
114
115 Ok(())
116}
117
118#[implement(super::Service)]
119pub async fn add_fallback_keys<'a, Keys>(
120 &self,
121 user_id: &UserId,
122 device_id: &DeviceId,
123 keys: Keys,
124) -> Result
125where
126 Keys: Iterator<Item = (&'a OneTimeKeyId, &'a Raw<OneTimeKey>)> + Send + 'a,
127{
128 for (id, key) in keys {
129 self.add_fallback_key(user_id, device_id, id, key)
130 .await
131 .ok();
132 }
133
134 Ok(())
135}
136
137#[implement(super::Service)]
138pub async fn add_fallback_key(
139 &self,
140 user_id: &UserId,
141 device_id: &DeviceId,
142 one_time_key_key: &KeyId<OneTimeKeyAlgorithm, OneTimeKeyName>,
143 one_time_key_value: &Raw<OneTimeKey>,
144) -> Result {
145 if !self.device_exists(user_id, device_id).await {
146 return Err!(Database(error!(
147 ?user_id,
148 ?device_id,
149 "User does not exist or device has no metadata."
150 )));
151 }
152
153 if let Err(e) = one_time_key_value
154 .deserialize()
155 .map_err(Into::into)
156 {
157 debug_error!(
158 ?one_time_key_key,
159 ?one_time_key_value,
160 "Invalid fallback key JSON submitted by client, skipping: {e}"
161 );
162
163 return Err(e);
164 }
165
166 let entry = FallbackEntry {
167 key_id: one_time_key_key.to_owned(),
168 key: one_time_key_value.clone(),
169 used: false,
170 };
171
172 let key = (user_id, device_id, one_time_key_key.algorithm());
173 self.db
174 .userdeviceidalgorithm_fallback
175 .put(key, Json(&entry));
176
177 let count = self.services.globals.next_count();
178 self.db
179 .userid_lastonetimekeyupdate
180 .raw_put(user_id, *count);
181
182 Ok(())
183}
184
185#[implement(super::Service)]
186pub async fn take_fallback_key(
187 &self,
188 user_id: &UserId,
189 device_id: &DeviceId,
190 algorithm: &OneTimeKeyAlgorithm,
191) -> Result<(OwnedKeyId<OneTimeKeyAlgorithm, OneTimeKeyName>, Raw<OneTimeKey>)> {
192 let key = (user_id, device_id, algorithm);
193 let entry: FallbackEntry = self
194 .db
195 .userdeviceidalgorithm_fallback
196 .qry(&key)
197 .await
198 .deserialized::<Json<_>>()
199 .map(|Json(entry)| entry)
200 .map_err(|_| err!(Request(NotFound("No fallback key found"))))?;
201
202 let updated = FallbackEntry { used: true, ..entry };
203 self.db
204 .userdeviceidalgorithm_fallback
205 .put(key, Json(&updated));
206
207 Ok((updated.key_id, updated.key))
208}
209
210#[implement(super::Service)]
211pub fn unused_fallback_key_algorithms<'a>(
212 &'a self,
213 user_id: &'a UserId,
214 device_id: &'a DeviceId,
215) -> impl Stream<Item = OneTimeKeyAlgorithm> + Send + 'a {
216 type KeyVal = ((Ignore, Ignore, OneTimeKeyAlgorithm), Json<FallbackEntry>);
217
218 let prefix = (user_id, device_id);
219 self.db
220 .userdeviceidalgorithm_fallback
221 .stream_prefix(&prefix)
222 .ignore_err()
223 .ready_filter_map(|((_, _, algorithm), Json(entry)): KeyVal| {
224 entry.used.is_false().then_some(algorithm)
225 })
226}
227
228#[implement(super::Service)]
229pub async fn last_one_time_keys_update(&self, user_id: &UserId) -> u64 {
230 self.db
231 .userid_lastonetimekeyupdate
232 .get(user_id)
233 .await
234 .deserialized()
235 .unwrap_or(0)
236}
237
238#[implement(super::Service)]
239pub async fn take_one_time_key(
240 &self,
241 user_id: &UserId,
242 device_id: &DeviceId,
243 key_algorithm: &OneTimeKeyAlgorithm,
244) -> Result<(OwnedKeyId<OneTimeKeyAlgorithm, OneTimeKeyName>, Raw<OneTimeKey>)> {
245 let Some(otk) = self.db.onetimekeyid4225_otk.as_ref() else {
246 return Err!(Request(NotFound("No one-time-key found")));
247 };
248
249 let update_count = self.services.globals.next_count();
250 self.db
251 .userid_lastonetimekeyupdate
252 .insert(user_id, update_count.to_be_bytes());
253
254 let prefix = (user_id, device_id, Interfix);
255 let one_time_keys = otk
256 .stream_prefix(&prefix)
257 .ignore_err()
258 .ready_filter(|(row, _): &(OtkRowKey<'_>, &[u8])| row.3.algorithm() == *key_algorithm);
259
260 pin_mut!(one_time_keys);
261 let ((user_id, device_id, count, id), val) = one_time_keys
262 .next()
263 .await
264 .ok_or_else(|| err!(Request(NotFound("No one-time-key found"))))?;
265
266 otk.del((user_id, device_id, count, id));
267
268 Ok((id.into(), serde_json::from_slice(val)?))
269}
270
271#[implement(super::Service)]
272pub async fn count_one_time_keys(
273 &self,
274 user_id: &UserId,
275 device_id: &DeviceId,
276) -> BTreeMap<OneTimeKeyAlgorithm, UInt> {
277 let Some(otk) = self.db.onetimekeyid4225_otk.as_ref() else {
278 return BTreeMap::new();
279 };
280
281 let prefix = (user_id, device_id, Interfix);
282 let algorithm_counts: BTreeMap<OneTimeKeyAlgorithm, UInt> = otk
283 .keys_prefix(&prefix)
284 .ignore_err()
285 .ready_fold(BTreeMap::new(), |mut acc, (.., id): OtkRowKey<'_>| {
286 let count: &mut UInt = acc.entry(id.algorithm()).or_default();
287 *count = count.saturating_add(1_u32.into());
288 acc
289 })
290 .await;
291
292 let total = algorithm_counts
293 .values()
294 .copied()
295 .map(TryInto::try_into)
296 .filter_map(Result::ok)
297 .fold(0_usize, usize::saturating_add);
298
299 let limit = self.services.config.one_time_key_limit;
300 if let Some(excess) = total.checked_sub(limit).filter(|&n| n > 0) {
301 self.prune_one_time_keys(user_id, device_id, excess)
302 .await;
303 }
304
305 algorithm_counts
306}
307
308#[implement(super::Service)]
312pub async fn prune_one_time_keys(&self, user_id: &UserId, device_id: &DeviceId, excess: usize) {
313 let Some(otk) = self.db.onetimekeyid4225_otk.as_ref() else {
314 return;
315 };
316
317 let prefix = (user_id, device_id, Interfix);
318 otk.keys_prefix(&prefix)
319 .ignore_err()
320 .take(excess)
321 .ready_for_each(|row: OtkRowKey<'_>| {
322 otk.del(row);
323 })
324 .await;
325}
326
327#[implement(super::Service)]
328pub async fn add_device_keys(
329 &self,
330 user_id: &UserId,
331 device_id: &DeviceId,
332 device_keys: &Raw<DeviceKeys>,
333) {
334 let key = (user_id, device_id);
335
336 self.db.keyid_key.put(key, Json(device_keys));
337 self.mark_device_key_update(user_id).await;
338}
339
340#[implement(super::Service)]
341pub async fn add_cross_signing_keys(
342 &self,
343 user_id: &UserId,
344 master_key: &Option<Raw<CrossSigningKey>>,
345 self_signing_key: &Option<Raw<CrossSigningKey>>,
346 user_signing_key: &Option<Raw<CrossSigningKey>>,
347 notify: bool,
348) -> Result {
349 let mut prefix = user_id.as_bytes().to_vec();
351 prefix.push(0xFF);
352
353 if let Some(master_key) = master_key {
354 let (master_key_key, _) = parse_master_key(user_id, master_key)?;
355
356 self.db
357 .keyid_key
358 .insert(&master_key_key, master_key.json().get().as_bytes());
359
360 self.db
361 .userid_masterkeyid
362 .insert(user_id.as_bytes(), &master_key_key);
363 }
364
365 if let Some(self_signing_key) = self_signing_key {
367 let mut self_signing_key_ids = self_signing_key
368 .deserialize()
369 .map_err(|e| err!(Request(InvalidParam("Invalid self signing key: {e:?}"))))?
370 .keys
371 .into_values();
372
373 let self_signing_key_id = self_signing_key_ids
374 .next()
375 .ok_or_else(|| err!(Request(InvalidParam("Self signing key contained no key."))))?;
376
377 if self_signing_key_ids.next().is_some() {
378 return Err!(Request(InvalidParam("Self signing key contained more than one key.")));
379 }
380
381 let mut self_signing_key_key = prefix.clone();
382 self_signing_key_key.extend_from_slice(self_signing_key_id.as_bytes());
383
384 self.db
385 .keyid_key
386 .insert(&self_signing_key_key, self_signing_key.json().get().as_bytes());
387
388 self.db
389 .userid_selfsigningkeyid
390 .insert(user_id.as_bytes(), &self_signing_key_key);
391 }
392
393 if let Some(user_signing_key) = user_signing_key {
395 let user_signing_key_id = parse_user_signing_key(user_signing_key)?;
396
397 let user_signing_key_key = (user_id, &user_signing_key_id);
398 self.db
399 .keyid_key
400 .put_raw(user_signing_key_key, user_signing_key.json().get().as_bytes());
401
402 self.db
403 .userid_usersigningkeyid
404 .raw_put(user_id, user_signing_key_key);
405 }
406
407 if notify {
408 self.mark_device_key_update(user_id).await;
409 }
410
411 Ok(())
412}
413
414#[implement(super::Service)]
415pub async fn sign_key(
416 &self,
417 target_id: &UserId,
418 key_id: &str,
419 signature: (String, String),
420 sender_id: &UserId,
421) -> Result {
422 let key = (target_id, key_id);
423
424 let mut cross_signing_key: serde_json::Value = self
425 .db
426 .keyid_key
427 .qry(&key)
428 .await
429 .map_err(|_| err!(Request(InvalidParam("Tried to sign nonexistent key"))))?
430 .deserialized()
431 .map_err(|e| err!(Database(debug_warn!("key in keyid_key is invalid: {e:?}"))))?;
432
433 let signatures = cross_signing_key
434 .get_mut("signatures")
435 .ok_or_else(|| err!(Database(debug_warn!("key in keyid_key has no signatures field"))))?
436 .as_object_mut()
437 .ok_or_else(|| {
438 err!(Database(debug_warn!("key in keyid_key has invalid signatures field.")))
439 })?
440 .entry(sender_id.to_string())
441 .or_insert_with(|| serde_json::Map::new().into());
442
443 signatures
444 .as_object_mut()
445 .ok_or_else(|| {
446 err!(Database(debug_warn!("signatures in keyid_key for a user is invalid.")))
447 })?
448 .insert(signature.0, signature.1.into());
449
450 let key = (target_id, key_id);
451 self.db
452 .keyid_key
453 .put(key, Json(cross_signing_key));
454
455 self.mark_device_key_update(target_id).await;
456
457 Ok(())
458}
459
460#[implement(super::Service)]
461#[inline]
462pub fn keys_changed<'a>(
463 &'a self,
464 user_id: &'a UserId,
465 from: u64,
466 to: Option<u64>,
467) -> impl Stream<Item = &UserId> + Send + 'a {
468 self.keys_changed_user_or_room(user_id.as_str(), from, to)
469 .map(|(user_id, ..)| user_id)
470}
471
472#[implement(super::Service)]
473#[inline]
474pub fn room_keys_changed<'a>(
475 &'a self,
476 room_id: &'a RoomId,
477 from: u64,
478 to: Option<u64>,
479) -> impl Stream<Item = (&UserId, u64)> + Send + 'a {
480 self.keys_changed_user_or_room(room_id.as_str(), from, to)
481}
482
483#[implement(super::Service)]
484fn keys_changed_user_or_room<'a>(
485 &'a self,
486 user_or_room_id: &'a str,
487 from: u64,
488 to: Option<u64>,
489) -> impl Stream<Item = (&UserId, u64)> + Send + 'a {
490 type KeyVal<'a> = ((&'a str, u64), &'a UserId);
491
492 let to = to.unwrap_or(u64::MAX);
493 let start = (user_or_room_id, from.saturating_add(1));
494 self.db
495 .keychangeid_userid
496 .stream_from(&start)
497 .ignore_err()
498 .ready_take_while(move |((prefix, count), _): &KeyVal<'_>| {
499 *prefix == user_or_room_id && *count <= to
500 })
501 .map(|((_, count), user_id): KeyVal<'_>| (user_id, count))
502}
503
504#[implement(super::Service)]
505pub async fn mark_device_key_update(&self, user_id: &UserId) {
506 let update_all_rooms = !self
507 .services
508 .config
509 .device_key_update_encrypted_rooms_only;
510
511 let all_or_is_encrypted = async |room_id: &RoomId| {
512 update_all_rooms
513 || self
514 .services
515 .state_accessor
516 .is_encrypted_room(room_id)
517 .await
518 };
519
520 let count = self.services.globals.next_count();
521 let user_key = (user_id, *count);
522
523 self.db
524 .keychangeid_userid
525 .put_raw(user_key, user_id);
526 self.services
527 .state_cache
528 .rooms_joined(user_id)
529 .filter(|room_id| all_or_is_encrypted(*room_id))
530 .ready_for_each(|room_id| {
531 let room_key = (room_id, *count);
532 self.db
533 .keychangeid_userid
534 .put_raw(room_key, user_id);
535 })
536 .await;
537
538 if !self.services.globals.user_is_local(user_id) {
539 return;
540 }
541
542 let mut servers: Servers = self
544 .services
545 .state_cache
546 .rooms_joined(user_id)
547 .filter(|room_id| all_or_is_encrypted(*room_id))
548 .map(ToOwned::to_owned)
549 .broad_then(async |room_id: OwnedRoomId| {
550 self.services
551 .state_cache
552 .room_servers(&room_id)
553 .ready_filter(|server| !self.services.globals.server_is_ours(server))
554 .map(ToOwned::to_owned)
555 .collect()
556 .await
557 })
558 .flat_map(|servers: Vec<OwnedServerName>| servers.into_iter().stream())
559 .collect()
560 .await;
561
562 servers.sort_unstable();
563 servers.dedup();
564
565 self.services
566 .sending
567 .flush_servers(servers.iter().map(|server| &**server).stream())
568 .await
569 .expect("device key update flush failed");
570}
571
572#[implement(super::Service)]
573pub async fn get_device_keys<'a>(
574 &'a self,
575 user_id: &'a UserId,
576 device_id: &DeviceId,
577) -> Result<Raw<DeviceKeys>> {
578 let key_id = (user_id, device_id);
579 self.db
580 .keyid_key
581 .qry(&key_id)
582 .await
583 .deserialized()
584}
585
586#[implement(super::Service)]
587pub async fn get_key<F>(
588 &self,
589 key_id: &[u8],
590 sender_user: Option<&UserId>,
591 user_id: &UserId,
592 allowed_signatures: &F,
593) -> Result<Raw<CrossSigningKey>>
594where
595 F: Fn(&UserId) -> bool + Send + Sync,
596{
597 let key: serde_json::Value = self
598 .db
599 .keyid_key
600 .get(key_id)
601 .await
602 .deserialized()?;
603
604 let cleaned = clean_signatures(key, sender_user, user_id, allowed_signatures)?;
605 let raw_value = serde_json::value::to_raw_value(&cleaned)?;
606
607 Ok(Raw::from_json(raw_value))
608}
609
610#[implement(super::Service)]
611pub async fn get_master_key<F>(
612 &self,
613 sender_user: Option<&UserId>,
614 user_id: &UserId,
615 allowed_signatures: &F,
616) -> Result<Raw<CrossSigningKey>>
617where
618 F: Fn(&UserId) -> bool + Send + Sync,
619{
620 let key_id = self.db.userid_masterkeyid.get(user_id).await?;
621
622 self.get_key(&key_id, sender_user, user_id, allowed_signatures)
623 .await
624}
625
626#[implement(super::Service)]
627pub async fn get_self_signing_key<F>(
628 &self,
629 sender_user: Option<&UserId>,
630 user_id: &UserId,
631 allowed_signatures: &F,
632) -> Result<Raw<CrossSigningKey>>
633where
634 F: Fn(&UserId) -> bool + Send + Sync,
635{
636 let key_id = self
637 .db
638 .userid_selfsigningkeyid
639 .get(user_id)
640 .await?;
641
642 self.get_key(&key_id, sender_user, user_id, allowed_signatures)
643 .await
644}
645
646#[implement(super::Service)]
647pub async fn get_user_signing_key(&self, user_id: &UserId) -> Result<Raw<CrossSigningKey>> {
648 self.db
649 .userid_usersigningkeyid
650 .get(user_id)
651 .and_then(|key_id| self.db.keyid_key.get(&*key_id))
652 .await
653 .deserialized()
654}
655
656pub fn parse_master_key(
657 user_id: &UserId,
658 master_key: &Raw<CrossSigningKey>,
659) -> Result<(Vec<u8>, CrossSigningKey)> {
660 let mut prefix = user_id.as_bytes().to_vec();
661 prefix.push(0xFF);
662
663 let master_key = master_key
664 .deserialize()
665 .map_err(|_| err!(Request(InvalidParam("Invalid master key"))))?;
666
667 let mut master_key_ids = master_key.keys.values();
668 let master_key_id = master_key_ids
669 .next()
670 .ok_or(err!(Request(InvalidParam("Master key contained no key."))))?;
671
672 if master_key_ids.next().is_some() {
673 return Err!(Request(InvalidParam("Master key contained more than one key.")));
674 }
675
676 let mut master_key_key = prefix.clone();
677 master_key_key.extend_from_slice(master_key_id.as_bytes());
678
679 Ok((master_key_key, master_key))
680}
681
682pub(super) fn parse_user_signing_key(user_signing_key: &Raw<CrossSigningKey>) -> Result<String> {
683 let mut user_signing_key_ids = user_signing_key
684 .deserialize()
685 .map_err(|_| err!(Request(InvalidParam("Invalid user signing key"))))?
686 .keys
687 .into_values();
688
689 let user_signing_key_id = user_signing_key_ids
690 .next()
691 .ok_or(err!(Request(InvalidParam("User signing key contained no key."))))?;
692
693 if user_signing_key_ids.next().is_some() {
694 return Err!(Request(InvalidParam("User signing key contained more than one key.")));
695 }
696
697 Ok(user_signing_key_id)
698}
699
700fn clean_signatures<F>(
702 mut cross_signing_key: serde_json::Value,
703 sender_user: Option<&UserId>,
704 user_id: &UserId,
705 allowed_signatures: &F,
706) -> Result<serde_json::Value>
707where
708 F: Fn(&UserId) -> bool + Send + Sync,
709{
710 if let Some(signatures) = cross_signing_key
711 .get_mut("signatures")
712 .and_then(|v| v.as_object_mut())
713 {
714 let new_capacity = signatures.len() / 2;
717 for (user, signature) in
718 mem::replace(signatures, serde_json::Map::with_capacity(new_capacity))
719 {
720 let sid = <&UserId>::try_from(user.as_str())
721 .map_err(|e| err!(Database("Invalid user ID in database: {e}")))?;
722
723 if sender_user == Some(user_id) || sid == user_id || allowed_signatures(sid) {
724 signatures.insert(user, signature);
725 }
726 }
727 }
728
729 Ok(cross_signing_key)
730}