1use std::{collections::BTreeMap, mem};
2
3use futures::{Stream, StreamExt, TryFutureExt, pin_mut};
4use ruma::{
5 DeviceId, KeyId, OneTimeKeyAlgorithm, OneTimeKeyId, OneTimeKeyName, OwnedKeyId,
6 OwnedOneTimeKeyId, RoomId, UInt, UserId,
7 encryption::{CrossSigningKey, DeviceKeys, OneTimeKey},
8 serde::Raw,
9};
10use serde::{Deserialize, Serialize};
11use tuwunel_core::{
12 Err, Result, debug_error, err, implement,
13 utils::{BoolExt, ReadyExt, stream::TryIgnore},
14};
15use tuwunel_database::{Deserialized, Ignore, Interfix, Json};
16
17#[derive(Debug, Deserialize, Serialize)]
21struct FallbackEntry {
22 key_id: OwnedOneTimeKeyId,
23 key: Raw<OneTimeKey>,
24 used: bool,
25}
26
27type OtkRowKey<'a> = (&'a UserId, &'a DeviceId, u64, &'a OneTimeKeyId);
30
31#[implement(super::Service)]
32pub async fn add_one_time_keys<'a, Keys>(
33 &self,
34 user_id: &UserId,
35 device_id: &DeviceId,
36 keys: Keys,
37) -> Result
38where
39 Keys: Iterator<Item = (&'a OneTimeKeyId, &'a Raw<OneTimeKey>)> + Send + 'a,
40{
41 for (id, key) in keys {
42 self.add_one_time_key(user_id, device_id, id, key)
43 .await
44 .ok();
45 }
46
47 Ok(())
48}
49
50#[implement(super::Service)]
51pub async fn add_one_time_key(
52 &self,
53 user_id: &UserId,
54 device_id: &DeviceId,
55 one_time_key_key: &KeyId<OneTimeKeyAlgorithm, OneTimeKeyName>,
56 one_time_key_value: &Raw<OneTimeKey>,
57) -> Result {
58 if !self.device_exists(user_id, device_id).await {
59 return Err!(Database(error!(
60 ?user_id,
61 ?device_id,
62 "User does not exist or device has no metadata."
63 )));
64 }
65
66 if let Err(e) = one_time_key_value
67 .deserialize()
68 .map_err(Into::into)
69 {
70 debug_error!(
71 ?one_time_key_key,
72 ?one_time_key_value,
73 "Invalid one time key JSON submitted by client, skipping: {e}"
74 );
75
76 return Err(e);
77 }
78
79 let prefix = (user_id, device_id, Interfix);
82 let already_present = self
83 .db
84 .onetimekeyid4225_otk
85 .keys_prefix(&prefix)
86 .ignore_err()
87 .ready_any(|(.., id): OtkRowKey<'_>| id == one_time_key_key)
88 .await;
89
90 if already_present {
91 return Ok(());
92 }
93
94 let count = self.services.globals.next_count();
95
96 self.db.onetimekeyid4225_otk.put(
99 (user_id, device_id, *count, one_time_key_key.as_str()),
100 Json(one_time_key_value),
101 );
102
103 self.db
104 .userid_lastonetimekeyupdate
105 .raw_put(user_id, *count);
106
107 Ok(())
108}
109
110#[implement(super::Service)]
111pub async fn add_fallback_keys<'a, Keys>(
112 &self,
113 user_id: &UserId,
114 device_id: &DeviceId,
115 keys: Keys,
116) -> Result
117where
118 Keys: Iterator<Item = (&'a OneTimeKeyId, &'a Raw<OneTimeKey>)> + Send + 'a,
119{
120 for (id, key) in keys {
121 self.add_fallback_key(user_id, device_id, id, key)
122 .await
123 .ok();
124 }
125
126 Ok(())
127}
128
129#[implement(super::Service)]
130pub async fn add_fallback_key(
131 &self,
132 user_id: &UserId,
133 device_id: &DeviceId,
134 one_time_key_key: &KeyId<OneTimeKeyAlgorithm, OneTimeKeyName>,
135 one_time_key_value: &Raw<OneTimeKey>,
136) -> Result {
137 if !self.device_exists(user_id, device_id).await {
138 return Err!(Database(error!(
139 ?user_id,
140 ?device_id,
141 "User does not exist or device has no metadata."
142 )));
143 }
144
145 if let Err(e) = one_time_key_value
146 .deserialize()
147 .map_err(Into::into)
148 {
149 debug_error!(
150 ?one_time_key_key,
151 ?one_time_key_value,
152 "Invalid fallback key JSON submitted by client, skipping: {e}"
153 );
154
155 return Err(e);
156 }
157
158 let entry = FallbackEntry {
159 key_id: one_time_key_key.to_owned(),
160 key: one_time_key_value.clone(),
161 used: false,
162 };
163
164 let key = (user_id, device_id, one_time_key_key.algorithm());
165 self.db
166 .userdeviceidalgorithm_fallback
167 .put(key, Json(&entry));
168
169 let count = self.services.globals.next_count();
170 self.db
171 .userid_lastonetimekeyupdate
172 .raw_put(user_id, *count);
173
174 Ok(())
175}
176
177#[implement(super::Service)]
178pub async fn take_fallback_key(
179 &self,
180 user_id: &UserId,
181 device_id: &DeviceId,
182 algorithm: &OneTimeKeyAlgorithm,
183) -> Result<(OwnedKeyId<OneTimeKeyAlgorithm, OneTimeKeyName>, Raw<OneTimeKey>)> {
184 let key = (user_id, device_id, algorithm);
185 let entry: FallbackEntry = self
186 .db
187 .userdeviceidalgorithm_fallback
188 .qry(&key)
189 .await
190 .deserialized::<Json<_>>()
191 .map(|Json(entry)| entry)
192 .map_err(|_| err!(Request(NotFound("No fallback key found"))))?;
193
194 let updated = FallbackEntry { used: true, ..entry };
195 self.db
196 .userdeviceidalgorithm_fallback
197 .put(key, Json(&updated));
198
199 Ok((updated.key_id, updated.key))
200}
201
202#[implement(super::Service)]
203pub fn unused_fallback_key_algorithms<'a>(
204 &'a self,
205 user_id: &'a UserId,
206 device_id: &'a DeviceId,
207) -> impl Stream<Item = OneTimeKeyAlgorithm> + Send + 'a {
208 type KeyVal = ((Ignore, Ignore, OneTimeKeyAlgorithm), Json<FallbackEntry>);
209
210 let prefix = (user_id, device_id);
211 self.db
212 .userdeviceidalgorithm_fallback
213 .stream_prefix(&prefix)
214 .ignore_err()
215 .ready_filter_map(|((_, _, algorithm), Json(entry)): KeyVal| {
216 entry.used.is_false().then_some(algorithm)
217 })
218}
219
220#[implement(super::Service)]
221pub async fn last_one_time_keys_update(&self, user_id: &UserId) -> u64 {
222 self.db
223 .userid_lastonetimekeyupdate
224 .get(user_id)
225 .await
226 .deserialized()
227 .unwrap_or(0)
228}
229
230#[implement(super::Service)]
231pub async fn take_one_time_key(
232 &self,
233 user_id: &UserId,
234 device_id: &DeviceId,
235 key_algorithm: &OneTimeKeyAlgorithm,
236) -> Result<(OwnedKeyId<OneTimeKeyAlgorithm, OneTimeKeyName>, Raw<OneTimeKey>)> {
237 let update_count = self.services.globals.next_count();
238 self.db
239 .userid_lastonetimekeyupdate
240 .insert(user_id, update_count.to_be_bytes());
241
242 let prefix = (user_id, device_id, Interfix);
243 let one_time_keys = self
244 .db
245 .onetimekeyid4225_otk
246 .stream_prefix(&prefix)
247 .ignore_err()
248 .ready_filter(|(row, _): &(OtkRowKey<'_>, &[u8])| row.3.algorithm() == *key_algorithm);
249
250 pin_mut!(one_time_keys);
251 let ((user_id, device_id, count, id), val) = one_time_keys
252 .next()
253 .await
254 .ok_or_else(|| err!(Request(NotFound("No one-time-key found"))))?;
255
256 self.db
257 .onetimekeyid4225_otk
258 .del((user_id, device_id, count, id));
259
260 Ok((id.into(), serde_json::from_slice(val)?))
261}
262
263#[implement(super::Service)]
264pub async fn count_one_time_keys(
265 &self,
266 user_id: &UserId,
267 device_id: &DeviceId,
268) -> BTreeMap<OneTimeKeyAlgorithm, UInt> {
269 let prefix = (user_id, device_id, Interfix);
270 let algorithm_counts: BTreeMap<OneTimeKeyAlgorithm, UInt> = self
271 .db
272 .onetimekeyid4225_otk
273 .keys_prefix(&prefix)
274 .ignore_err()
275 .ready_fold(BTreeMap::new(), |mut acc, (.., id): OtkRowKey<'_>| {
276 let count: &mut UInt = acc.entry(id.algorithm()).or_default();
277 *count = count.saturating_add(1_u32.into());
278 acc
279 })
280 .await;
281
282 let total = algorithm_counts
283 .values()
284 .copied()
285 .map(TryInto::try_into)
286 .filter_map(Result::ok)
287 .fold(0_usize, usize::saturating_add);
288
289 let limit = self.services.config.one_time_key_limit;
290 if let Some(excess) = total.checked_sub(limit).filter(|&n| n > 0) {
291 self.prune_one_time_keys(user_id, device_id, excess)
292 .await;
293 }
294
295 algorithm_counts
296}
297
298#[implement(super::Service)]
302pub async fn prune_one_time_keys(&self, user_id: &UserId, device_id: &DeviceId, excess: usize) {
303 let prefix = (user_id, device_id, Interfix);
304 self.db
305 .onetimekeyid4225_otk
306 .keys_prefix(&prefix)
307 .ignore_err()
308 .take(excess)
309 .ready_for_each(|row: OtkRowKey<'_>| {
310 self.db.onetimekeyid4225_otk.del(row);
311 })
312 .await;
313}
314
315#[implement(super::Service)]
316pub async fn add_device_keys(
317 &self,
318 user_id: &UserId,
319 device_id: &DeviceId,
320 device_keys: &Raw<DeviceKeys>,
321) {
322 let key = (user_id, device_id);
323
324 self.db.keyid_key.put(key, Json(device_keys));
325 self.mark_device_key_update(user_id).await;
326}
327
328#[implement(super::Service)]
329pub async fn add_cross_signing_keys(
330 &self,
331 user_id: &UserId,
332 master_key: &Option<Raw<CrossSigningKey>>,
333 self_signing_key: &Option<Raw<CrossSigningKey>>,
334 user_signing_key: &Option<Raw<CrossSigningKey>>,
335 notify: bool,
336) -> Result {
337 let mut prefix = user_id.as_bytes().to_vec();
339 prefix.push(0xFF);
340
341 if let Some(master_key) = master_key {
342 let (master_key_key, _) = parse_master_key(user_id, master_key)?;
343
344 self.db
345 .keyid_key
346 .insert(&master_key_key, master_key.json().get().as_bytes());
347
348 self.db
349 .userid_masterkeyid
350 .insert(user_id.as_bytes(), &master_key_key);
351 }
352
353 if let Some(self_signing_key) = self_signing_key {
355 let mut self_signing_key_ids = self_signing_key
356 .deserialize()
357 .map_err(|e| err!(Request(InvalidParam("Invalid self signing key: {e:?}"))))?
358 .keys
359 .into_values();
360
361 let self_signing_key_id = self_signing_key_ids
362 .next()
363 .ok_or_else(|| err!(Request(InvalidParam("Self signing key contained no key."))))?;
364
365 if self_signing_key_ids.next().is_some() {
366 return Err!(Request(InvalidParam("Self signing key contained more than one key.")));
367 }
368
369 let mut self_signing_key_key = prefix.clone();
370 self_signing_key_key.extend_from_slice(self_signing_key_id.as_bytes());
371
372 self.db
373 .keyid_key
374 .insert(&self_signing_key_key, self_signing_key.json().get().as_bytes());
375
376 self.db
377 .userid_selfsigningkeyid
378 .insert(user_id.as_bytes(), &self_signing_key_key);
379 }
380
381 if let Some(user_signing_key) = user_signing_key {
383 let user_signing_key_id = parse_user_signing_key(user_signing_key)?;
384
385 let user_signing_key_key = (user_id, &user_signing_key_id);
386 self.db
387 .keyid_key
388 .put_raw(user_signing_key_key, user_signing_key.json().get().as_bytes());
389
390 self.db
391 .userid_usersigningkeyid
392 .raw_put(user_id, user_signing_key_key);
393 }
394
395 if notify {
396 self.mark_device_key_update(user_id).await;
397 }
398
399 Ok(())
400}
401
402#[implement(super::Service)]
403pub async fn sign_key(
404 &self,
405 target_id: &UserId,
406 key_id: &str,
407 signature: (String, String),
408 sender_id: &UserId,
409) -> Result {
410 let key = (target_id, key_id);
411
412 let mut cross_signing_key: serde_json::Value = self
413 .db
414 .keyid_key
415 .qry(&key)
416 .await
417 .map_err(|_| err!(Request(InvalidParam("Tried to sign nonexistent key"))))?
418 .deserialized()
419 .map_err(|e| err!(Database(debug_warn!("key in keyid_key is invalid: {e:?}"))))?;
420
421 let signatures = cross_signing_key
422 .get_mut("signatures")
423 .ok_or_else(|| err!(Database(debug_warn!("key in keyid_key has no signatures field"))))?
424 .as_object_mut()
425 .ok_or_else(|| {
426 err!(Database(debug_warn!("key in keyid_key has invalid signatures field.")))
427 })?
428 .entry(sender_id.to_string())
429 .or_insert_with(|| serde_json::Map::new().into());
430
431 signatures
432 .as_object_mut()
433 .ok_or_else(|| {
434 err!(Database(debug_warn!("signatures in keyid_key for a user is invalid.")))
435 })?
436 .insert(signature.0, signature.1.into());
437
438 let key = (target_id, key_id);
439 self.db
440 .keyid_key
441 .put(key, Json(cross_signing_key));
442
443 self.mark_device_key_update(target_id).await;
444
445 Ok(())
446}
447
448#[implement(super::Service)]
449#[inline]
450pub fn keys_changed<'a>(
451 &'a self,
452 user_id: &'a UserId,
453 from: u64,
454 to: Option<u64>,
455) -> impl Stream<Item = &UserId> + Send + 'a {
456 self.keys_changed_user_or_room(user_id.as_str(), from, to)
457 .map(|(user_id, ..)| user_id)
458}
459
460#[implement(super::Service)]
461#[inline]
462pub fn room_keys_changed<'a>(
463 &'a self,
464 room_id: &'a RoomId,
465 from: u64,
466 to: Option<u64>,
467) -> impl Stream<Item = (&UserId, u64)> + Send + 'a {
468 self.keys_changed_user_or_room(room_id.as_str(), from, to)
469}
470
471#[implement(super::Service)]
472fn keys_changed_user_or_room<'a>(
473 &'a self,
474 user_or_room_id: &'a str,
475 from: u64,
476 to: Option<u64>,
477) -> impl Stream<Item = (&UserId, u64)> + Send + 'a {
478 type KeyVal<'a> = ((&'a str, u64), &'a UserId);
479
480 let to = to.unwrap_or(u64::MAX);
481 let start = (user_or_room_id, from.saturating_add(1));
482 self.db
483 .keychangeid_userid
484 .stream_from(&start)
485 .ignore_err()
486 .ready_take_while(move |((prefix, count), _): &KeyVal<'_>| {
487 *prefix == user_or_room_id && *count <= to
488 })
489 .map(|((_, count), user_id): KeyVal<'_>| (user_id, count))
490}
491
492#[implement(super::Service)]
493pub async fn mark_device_key_update(&self, user_id: &UserId) {
494 let update_all_rooms = !self
495 .services
496 .config
497 .device_key_update_encrypted_rooms_only;
498
499 let all_or_is_encrypted = async |room_id: &RoomId| {
500 update_all_rooms
501 || self
502 .services
503 .state_accessor
504 .is_encrypted_room(room_id)
505 .await
506 };
507
508 let count = self.services.globals.next_count();
509 let user_key = (user_id, *count);
510
511 self.db
512 .keychangeid_userid
513 .put_raw(user_key, user_id);
514 self.services
515 .state_cache
516 .rooms_joined(user_id)
517 .filter(|room_id| all_or_is_encrypted(*room_id))
518 .ready_for_each(|room_id| {
519 let room_key = (room_id, *count);
520 self.db
521 .keychangeid_userid
522 .put_raw(room_key, user_id);
523 })
524 .await;
525}
526
527#[implement(super::Service)]
528pub async fn get_device_keys<'a>(
529 &'a self,
530 user_id: &'a UserId,
531 device_id: &DeviceId,
532) -> Result<Raw<DeviceKeys>> {
533 let key_id = (user_id, device_id);
534 self.db
535 .keyid_key
536 .qry(&key_id)
537 .await
538 .deserialized()
539}
540
541#[implement(super::Service)]
542pub async fn get_key<F>(
543 &self,
544 key_id: &[u8],
545 sender_user: Option<&UserId>,
546 user_id: &UserId,
547 allowed_signatures: &F,
548) -> Result<Raw<CrossSigningKey>>
549where
550 F: Fn(&UserId) -> bool + Send + Sync,
551{
552 let key: serde_json::Value = self
553 .db
554 .keyid_key
555 .get(key_id)
556 .await
557 .deserialized()?;
558
559 let cleaned = clean_signatures(key, sender_user, user_id, allowed_signatures)?;
560 let raw_value = serde_json::value::to_raw_value(&cleaned)?;
561
562 Ok(Raw::from_json(raw_value))
563}
564
565#[implement(super::Service)]
566pub async fn get_master_key<F>(
567 &self,
568 sender_user: Option<&UserId>,
569 user_id: &UserId,
570 allowed_signatures: &F,
571) -> Result<Raw<CrossSigningKey>>
572where
573 F: Fn(&UserId) -> bool + Send + Sync,
574{
575 let key_id = self.db.userid_masterkeyid.get(user_id).await?;
576
577 self.get_key(&key_id, sender_user, user_id, allowed_signatures)
578 .await
579}
580
581#[implement(super::Service)]
582pub async fn get_self_signing_key<F>(
583 &self,
584 sender_user: Option<&UserId>,
585 user_id: &UserId,
586 allowed_signatures: &F,
587) -> Result<Raw<CrossSigningKey>>
588where
589 F: Fn(&UserId) -> bool + Send + Sync,
590{
591 let key_id = self
592 .db
593 .userid_selfsigningkeyid
594 .get(user_id)
595 .await?;
596
597 self.get_key(&key_id, sender_user, user_id, allowed_signatures)
598 .await
599}
600
601#[implement(super::Service)]
602pub async fn get_user_signing_key(&self, user_id: &UserId) -> Result<Raw<CrossSigningKey>> {
603 self.db
604 .userid_usersigningkeyid
605 .get(user_id)
606 .and_then(|key_id| self.db.keyid_key.get(&*key_id))
607 .await
608 .deserialized()
609}
610
611pub fn parse_master_key(
612 user_id: &UserId,
613 master_key: &Raw<CrossSigningKey>,
614) -> Result<(Vec<u8>, CrossSigningKey)> {
615 let mut prefix = user_id.as_bytes().to_vec();
616 prefix.push(0xFF);
617
618 let master_key = master_key
619 .deserialize()
620 .map_err(|_| err!(Request(InvalidParam("Invalid master key"))))?;
621
622 let mut master_key_ids = master_key.keys.values();
623 let master_key_id = master_key_ids
624 .next()
625 .ok_or(err!(Request(InvalidParam("Master key contained no key."))))?;
626
627 if master_key_ids.next().is_some() {
628 return Err!(Request(InvalidParam("Master key contained more than one key.")));
629 }
630
631 let mut master_key_key = prefix.clone();
632 master_key_key.extend_from_slice(master_key_id.as_bytes());
633
634 Ok((master_key_key, master_key))
635}
636
637pub(super) fn parse_user_signing_key(user_signing_key: &Raw<CrossSigningKey>) -> Result<String> {
638 let mut user_signing_key_ids = user_signing_key
639 .deserialize()
640 .map_err(|_| err!(Request(InvalidParam("Invalid user signing key"))))?
641 .keys
642 .into_values();
643
644 let user_signing_key_id = user_signing_key_ids
645 .next()
646 .ok_or(err!(Request(InvalidParam("User signing key contained no key."))))?;
647
648 if user_signing_key_ids.next().is_some() {
649 return Err!(Request(InvalidParam("User signing key contained more than one key.")));
650 }
651
652 Ok(user_signing_key_id)
653}
654
655fn clean_signatures<F>(
657 mut cross_signing_key: serde_json::Value,
658 sender_user: Option<&UserId>,
659 user_id: &UserId,
660 allowed_signatures: &F,
661) -> Result<serde_json::Value>
662where
663 F: Fn(&UserId) -> bool + Send + Sync,
664{
665 if let Some(signatures) = cross_signing_key
666 .get_mut("signatures")
667 .and_then(|v| v.as_object_mut())
668 {
669 let new_capacity = signatures.len() / 2;
672 for (user, signature) in
673 mem::replace(signatures, serde_json::Map::with_capacity(new_capacity))
674 {
675 let sid = <&UserId>::try_from(user.as_str())
676 .map_err(|e| err!(Database("Invalid user ID in database: {e}")))?;
677
678 if sender_user == Some(user_id) || sid == user_id || allowed_signatures(sid) {
679 signatures.insert(user, signature);
680 }
681 }
682 }
683
684 Ok(cross_signing_key)
685}