1use std::{
2 net::IpAddr,
3 sync::Arc,
4 time::{Duration, SystemTime},
5};
6
7use futures::{FutureExt, Stream, StreamExt, future::join};
8use ruma::{
9 DeviceId, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedUserId, UserId,
10 api::client::device::Device, events::AnyToDeviceEvent, serde::Raw,
11};
12use serde_json::json;
13use tuwunel_core::{
14 Err, Result, at, implement, trace,
15 utils::{
16 self, BoolExt, ReadyExt,
17 stream::{IterStream, TryIgnore},
18 string::to_small_string,
19 time::{
20 duration_since_epoch, timepoint_from_epoch, timepoint_from_now, timepoint_has_passed,
21 },
22 },
23};
24use tuwunel_database::{Cbor, Deserialized, Ignore, Interfix, Json, Map};
25
26const DEVICE_ID_LENGTH: usize = 10;
28
29pub const TOKEN_LENGTH: usize = 32;
31
32#[implement(super::Service)]
34#[tracing::instrument(level = "info", skip(self, access_token))]
35pub async fn create_device(
36 &self,
37 user_id: &UserId,
38 device_id: Option<&DeviceId>,
39 (access_token, expires_in): (Option<&str>, Option<Duration>),
40 refresh_token: Option<&str>,
41 initial_device_display_name: Option<&str>,
42 client_ip: Option<IpAddr>,
43) -> Result<OwnedDeviceId> {
44 let device_id = device_id
45 .map(ToOwned::to_owned)
46 .unwrap_or_else(|| OwnedDeviceId::from(utils::random_string(DEVICE_ID_LENGTH)));
47
48 if !self.exists(user_id).await {
49 return Err!(Request(InvalidParam(error!(
50 "Called create_device for non-existent user {user_id}"
51 ))));
52 }
53
54 let notify = true;
55 self.put_device_metadata(user_id, notify, &Device {
56 device_id: device_id.clone(),
57 display_name: initial_device_display_name.map(Into::into),
58 last_seen_ts: Some(MilliSecondsSinceUnixEpoch::now()),
59 last_seen_ip: client_ip.map(to_small_string),
60 });
61
62 if let Some(access_token) = access_token {
63 self.set_access_token(user_id, &device_id, access_token, expires_in, refresh_token)
64 .await?;
65 }
66
67 Ok(device_id)
68}
69
70#[implement(super::Service)]
72#[tracing::instrument(level = "info", skip(self))]
73pub async fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) {
74 self.remove_tokens(user_id, device_id).await;
76
77 let prefix = (user_id, device_id, Interfix);
79 self.db
80 .todeviceid_events
81 .keys_prefix_raw(&prefix)
82 .ignore_err()
83 .ready_for_each(|key| self.db.todeviceid_events.remove(key))
84 .await;
85
86 self.services
88 .pusher
89 .get_device_pushkeys(user_id, device_id)
90 .map(Vec::into_iter)
91 .map(IterStream::stream)
92 .flatten_stream()
93 .for_each(async |pushkey| {
94 self.services
95 .pusher
96 .delete_pusher(user_id, &pushkey)
97 .await;
98 })
99 .await;
100
101 self.remove_dehydrated_device(user_id, Some(device_id))
103 .await
104 .ok();
105
106 let prefix = (user_id, device_id, Interfix);
110 self.db
111 .userdeviceidalgorithm_fallback
112 .keys_prefix_raw(&prefix)
113 .ignore_err()
114 .ready_for_each(|key| self.db.userdeviceidalgorithm_fallback.remove(key))
115 .await;
116
117 let event_type = format!("org.matrix.msc3890.local_notification_settings.{device_id}").into();
119 self.services
120 .account_data
121 .delete(None, user_id, event_type)
122 .await
123 .ok();
124
125 let userdeviceid = (user_id, device_id);
126 self.db.userdeviceid_metadata.del(userdeviceid);
127 self.db.oidcdevice_userdeviceid.del(userdeviceid);
128
129 self.mark_device_key_update(user_id).await;
130 increment(&self.db.userid_devicelistversion, user_id.as_bytes());
131}
132
133#[implement(super::Service)]
135pub fn all_device_ids<'a>(
136 &'a self,
137 user_id: &'a UserId,
138) -> impl Stream<Item = &DeviceId> + Send + 'a {
139 let prefix = (user_id, Interfix);
140 self.db
141 .userdeviceid_metadata
142 .keys_prefix(&prefix)
143 .ignore_err()
144 .map(|(_, device_id): (Ignore, &DeviceId)| device_id)
145}
146
147#[implement(super::Service)]
149#[tracing::instrument(level = "trace", skip(self, token))]
150pub async fn find_from_token(
151 &self,
152 token: &str,
153) -> Result<(OwnedUserId, OwnedDeviceId, Option<SystemTime>)> {
154 self.db
155 .token_userdeviceid
156 .get(token)
157 .await
158 .deserialized()
159 .and_then(|(user_id, device_id, expires_at): (_, _, Option<u64>)| {
160 let expires_at = expires_at
161 .map(Duration::from_secs)
162 .map(timepoint_from_epoch)
163 .transpose()?;
164
165 Ok((user_id, device_id, expires_at))
166 })
167}
168
169#[implement(super::Service)]
170#[tracing::instrument(level = "debug", skip(self))]
171pub async fn remove_tokens(&self, user_id: &UserId, device_id: &DeviceId) {
172 let remove_access = self
173 .remove_access_token(user_id, device_id)
174 .map(Result::ok);
175
176 let remove_refresh = self
177 .remove_refresh_token(user_id, device_id)
178 .map(Result::ok);
179
180 join(remove_access, remove_refresh).await;
181}
182
183#[implement(super::Service)]
185#[tracing::instrument(level = "debug", skip(self))]
186pub async fn set_access_token(
187 &self,
188 user_id: &UserId,
189 device_id: &DeviceId,
190 access_token: &str,
191 expires_in: Option<Duration>,
192 refresh_token: Option<&str>,
193) -> Result {
194 assert!(
195 access_token.len() >= TOKEN_LENGTH,
196 "Caller must supply an access_token >= {TOKEN_LENGTH} chars."
197 );
198
199 if let Some(refresh_token) = refresh_token {
200 self.set_refresh_token(user_id, device_id, refresh_token)
201 .await?;
202 }
203
204 let expires_at = expires_in
205 .map(timepoint_from_now)
206 .transpose()?
207 .map(duration_since_epoch)
208 .as_ref()
209 .map(Duration::as_secs);
210
211 let userdeviceid = (user_id, device_id);
212
213 if let Ok(prev) = self
215 .db
216 .userdeviceid_token
217 .qry(&userdeviceid)
218 .await
219 .deserialized::<String>()
220 {
221 self.db
222 .userdeviceidtoken_index
223 .put_raw((user_id, device_id, prev.as_str()), []);
224 }
225
226 let value = (user_id, device_id, expires_at);
227 self.db
228 .token_userdeviceid
229 .raw_put(access_token, value);
230 self.db
231 .userdeviceidtoken_index
232 .put_raw((user_id, device_id, access_token), []);
233 self.db
234 .userdeviceid_token
235 .put_raw(userdeviceid, access_token);
236
237 Ok(())
238}
239
240#[implement(super::Service)]
243pub async fn remove_access_token(&self, user_id: &UserId, device_id: &DeviceId) -> Result {
244 let prefix = (user_id, device_id, Interfix);
245 self.db
246 .userdeviceidtoken_index
247 .keys_prefix(&prefix)
248 .ignore_err()
249 .ready_for_each(|(_, _, token): (Ignore, Ignore, &str)| {
250 self.db.token_userdeviceid.remove(token);
251 self.db
252 .userdeviceidtoken_index
253 .del((user_id, device_id, token));
254 })
255 .await;
256
257 if let Ok(token) = self
259 .db
260 .userdeviceid_token
261 .qry(&(user_id, device_id))
262 .await
263 .deserialized::<String>()
264 {
265 self.db.token_userdeviceid.remove(token.as_str());
266 }
267
268 self.db
269 .userdeviceid_token
270 .del((user_id, device_id));
271
272 Ok(())
273}
274
275#[implement(super::Service)]
278pub async fn remove_access_token_value(&self, access_token: &str) {
279 if let Ok((user_id, device_id, _)) = self
280 .db
281 .token_userdeviceid
282 .get(access_token)
283 .await
284 .deserialized::<(OwnedUserId, OwnedDeviceId, Option<u64>)>()
285 {
286 self.db
287 .userdeviceidtoken_index
288 .del((&*user_id, &*device_id, access_token));
289 }
290
291 self.db.token_userdeviceid.remove(access_token);
292}
293
294#[implement(super::Service)]
295pub fn generate_access_token(&self, expires: bool) -> (String, Option<Duration>) {
296 let access_token = utils::random_string(TOKEN_LENGTH);
297 let expires_in = expires
298 .then_some(self.services.server.config.access_token_ttl)
299 .map(Duration::from_secs);
300
301 (access_token, expires_in)
302}
303
304#[implement(super::Service)]
306#[tracing::instrument(level = "debug", skip(self))]
307pub async fn set_refresh_token(
308 &self,
309 user_id: &UserId,
310 device_id: &DeviceId,
311 refresh_token: &str,
312) -> Result {
313 debug_assert!(refresh_token.starts_with("refresh_"), "refresh_token missing prefix");
314
315 let config = &self.services.server.config;
316 let ttl = config.refresh_token_ttl;
317 let idle_only = config.refresh_token_idle_only;
318
319 let prior_expires_at: Option<SystemTime> = (ttl != 0 && !idle_only)
321 .then_async(|| self.find_refresh_token_expires_at(user_id, device_id))
322 .await
323 .flatten();
324
325 let spent: Option<String> = self
328 .db
329 .userdeviceid_refresh
330 .qry(&(user_id, device_id))
331 .await
332 .deserialized()
333 .ok();
334
335 self.remove_refresh_token(user_id, device_id)
337 .await
338 .ok();
339
340 let expires_at = match (ttl, prior_expires_at) {
341 | (0, _) => None,
342 | (_, Some(prior)) => Some(prior),
343 | (ttl, None) => Some(timepoint_from_now(Duration::from_secs(ttl))?),
344 };
345
346 let expires_at_secs = expires_at
347 .map(duration_since_epoch)
348 .as_ref()
349 .map(Duration::as_secs);
350
351 let userdeviceid = (user_id, device_id);
352 let value = (user_id, device_id, expires_at_secs);
353 self.db
354 .token_userdeviceid
355 .raw_put(refresh_token, value);
356 self.db
357 .userdeviceid_refresh
358 .put_raw(userdeviceid, refresh_token);
359
360 if let Some(spent) = spent {
363 let spent_at = duration_since_epoch(SystemTime::now()).as_secs();
364 let value = (user_id, device_id, refresh_token, spent_at);
365 self.db
366 .spentrefresh_userdeviceid
367 .raw_put(&*spent, value);
368 self.db
369 .userdeviceid_spentrefresh
370 .put_raw(userdeviceid, &*spent);
371 }
372
373 Ok(())
374}
375
376#[implement(super::Service)]
380async fn find_refresh_token_expires_at(
381 &self,
382 user_id: &UserId,
383 device_id: &DeviceId,
384) -> Option<SystemTime> {
385 let userdeviceid = (user_id, device_id);
386 let old_token: String = self
387 .db
388 .userdeviceid_refresh
389 .qry(&userdeviceid)
390 .await
391 .deserialized()
392 .ok()?;
393
394 let (_, _, expires_at_secs): (Ignore, Ignore, Option<u64>) = self
395 .db
396 .token_userdeviceid
397 .get(&old_token)
398 .await
399 .deserialized()
400 .ok()?;
401
402 expires_at_secs
403 .map(Duration::from_secs)
404 .map(timepoint_from_epoch)?
405 .ok()
406}
407
408#[implement(super::Service)]
411pub async fn remove_refresh_token(&self, user_id: &UserId, device_id: &DeviceId) -> Result {
412 let userdeviceid = (user_id, device_id);
413
414 if let Ok(refresh_token) = self
415 .db
416 .userdeviceid_refresh
417 .qry(&userdeviceid)
418 .await
419 {
420 self.db.token_userdeviceid.remove(&refresh_token);
421 }
422
423 self.db.userdeviceid_refresh.del(userdeviceid);
424
425 self.forget_spent_refresh_token(user_id, device_id)
426 .await;
427
428 Ok(())
429}
430
431#[implement(super::Service)]
434async fn forget_spent_refresh_token(&self, user_id: &UserId, device_id: &DeviceId) {
435 let userdeviceid = (user_id, device_id);
436
437 if let Ok(spent) = self
438 .db
439 .userdeviceid_spentrefresh
440 .qry(&userdeviceid)
441 .await
442 {
443 self.db.spentrefresh_userdeviceid.remove(&spent);
444 }
445
446 self.db
447 .userdeviceid_spentrefresh
448 .del(userdeviceid);
449}
450
451pub enum RefreshToken {
454 Current {
456 user_id: OwnedUserId,
457 device_id: OwnedDeviceId,
458 expires_at: Option<SystemTime>,
459 },
460
461 Replayed {
466 user_id: OwnedUserId,
467 device_id: OwnedDeviceId,
468 current: String,
469 grace: bool,
470 },
471
472 Unknown,
474}
475
476#[implement(super::Service)]
478pub async fn classify_refresh_token(&self, presented: &str) -> RefreshToken {
479 if let Ok((user_id, device_id, expires_at)) = self.find_from_token(presented).await {
482 let current: Option<String> = self
483 .db
484 .userdeviceid_refresh
485 .qry(&(&user_id, &device_id))
486 .await
487 .deserialized()
488 .ok();
489
490 if current.as_deref() == Some(presented) {
491 return RefreshToken::Current { user_id, device_id, expires_at };
492 }
493 }
494
495 let Ok((user_id, device_id, successor, spent_at)) = self
498 .db
499 .spentrefresh_userdeviceid
500 .get(presented)
501 .await
502 .deserialized::<(OwnedUserId, OwnedDeviceId, String, u64)>()
503 else {
504 return RefreshToken::Unknown;
505 };
506
507 let current: Option<String> = self
508 .db
509 .userdeviceid_refresh
510 .qry(&(&user_id, &device_id))
511 .await
512 .deserialized()
513 .ok();
514
515 let grace_window = self
516 .services
517 .server
518 .config
519 .refresh_token_reuse_grace;
520 let elapsed = duration_since_epoch(SystemTime::now())
521 .as_secs()
522 .saturating_sub(spent_at);
523
524 let grace = grace_window != 0
525 && elapsed <= grace_window
526 && current.as_deref() == Some(successor.as_str());
527
528 RefreshToken::Replayed {
529 user_id,
530 device_id,
531 current: successor,
532 grace,
533 }
534}
535
536#[must_use]
537pub fn generate_refresh_token() -> String {
538 format!("refresh_{}", utils::random_string(TOKEN_LENGTH))
539}
540
541#[implement(super::Service)]
542pub fn add_to_device_event(
543 &self,
544 sender: &UserId,
545 target_user_id: &UserId,
546 target_device_id: &DeviceId,
547 event_type: &str,
548 content: &serde_json::Value,
549) {
550 let count = self.services.globals.next_count();
551
552 let key = (target_user_id, target_device_id, *count);
553 self.db.todeviceid_events.put(
554 key,
555 Json(json!({
556 "type": event_type,
557 "sender": sender,
558 "content": content,
559 })),
560 );
561
562 trace!(
563 %target_user_id,
564 %target_device_id,
565 count = *count,
566 %event_type,
567 %sender,
568 "to_device write",
569 );
570}
571
572#[implement(super::Service)]
573pub fn get_to_device_events<'a>(
574 &'a self,
575 user_id: &'a UserId,
576 device_id: &'a DeviceId,
577 since: Option<u64>,
578 to: Option<u64>,
579) -> impl Stream<Item = (u64, Raw<AnyToDeviceEvent>)> + Send + 'a {
580 type Key<'a> = (&'a UserId, &'a DeviceId, u64);
581
582 let from = (user_id, device_id, since.map_or(0, |since| since.saturating_add(1)));
583
584 self.db
585 .todeviceid_events
586 .stream_from(&from)
587 .ignore_err()
588 .ready_take_while(move |((user_id_, device_id_, count), _): &(Key<'_>, _)| {
589 user_id == *user_id_ && device_id == *device_id_ && to.is_none_or(|to| *count <= to)
590 })
591 .map(|((_, _, count), event)| (count, event))
592}
593
594#[implement(super::Service)]
595pub async fn remove_to_device_events<Until>(
596 &self,
597 user_id: &UserId,
598 device_id: &DeviceId,
599 until: Until,
600) where
601 Until: Into<Option<u64>> + Send,
602{
603 type Key<'a> = (&'a UserId, &'a DeviceId, u64);
604
605 let until = until.into().unwrap_or(u64::MAX);
606 let from = (user_id, device_id, until);
607 self.db
608 .todeviceid_events
609 .rev_keys_from(&from)
610 .ignore_err()
611 .ready_take_while(move |(user_id_, device_id_, _): &Key<'_>| {
612 user_id == *user_id_ && device_id == *device_id_
613 })
614 .ready_for_each(|key: Key<'_>| {
615 self.db.todeviceid_events.del(key);
616 })
617 .await;
618}
619
620#[implement(super::Service)]
621pub async fn update_device_last_seen(
622 &self,
623 user_id: &UserId,
624 device_id: &DeviceId,
625 last_seen_ip: Option<IpAddr>,
626 last_seen_ts: Option<MilliSecondsSinceUnixEpoch>,
627) -> Result {
628 let mut device = self
629 .get_device_metadata(user_id, device_id)
630 .await?;
631
632 if let Some(last_seen_ip) = last_seen_ip.map(to_small_string) {
633 device.last_seen_ip.replace(last_seen_ip);
634 }
635
636 device
637 .last_seen_ts
638 .replace(last_seen_ts.unwrap_or_else(MilliSecondsSinceUnixEpoch::now));
639
640 self.put_device_metadata(user_id, false, &device);
641
642 Ok(())
643}
644
645#[implement(super::Service)]
646pub fn put_device_metadata(&self, user_id: &UserId, notify: bool, device: &Device) {
647 let key = (user_id, &device.device_id);
648 self.db
649 .userdeviceid_metadata
650 .put(key, Json(device));
651
652 if notify {
653 increment(&self.db.userid_devicelistversion, user_id.as_bytes());
654 }
655}
656
657#[implement(super::Service)]
659pub async fn get_device_metadata(
660 &self,
661 user_id: &UserId,
662 device_id: &DeviceId,
663) -> Result<Device> {
664 self.db
665 .userdeviceid_metadata
666 .qry(&(user_id, device_id))
667 .await
668 .deserialized()
669 .inspect(|device: &Device| {
670 debug_assert_eq!(&device.device_id, device_id, "device_id mismatch");
671 })
672}
673
674#[implement(super::Service)]
675pub async fn device_exists(&self, user_id: &UserId, device_id: &DeviceId) -> bool {
676 self.db
677 .userdeviceid_metadata
678 .contains(&(user_id, device_id))
679 .await
680}
681
682#[implement(super::Service)]
683pub async fn is_oidc_device(&self, user_id: &UserId, device_id: &DeviceId) -> bool {
684 self.db
685 .oidcdevice_userdeviceid
686 .contains(&(user_id, device_id))
687 .await
688}
689
690#[implement(super::Service)]
693pub async fn get_oidc_device_idp(
694 &self,
695 user_id: &UserId,
696 device_id: &DeviceId,
697) -> Option<String> {
698 self.db
699 .oidcdevice_userdeviceid
700 .qry(&(user_id, device_id))
701 .await
702 .deserialized::<Json<String>>()
703 .ok()
704 .map(|Json(idp)| idp)
705 .filter(|idp| !idp.is_empty())
706}
707
708#[implement(super::Service)]
709pub fn mark_oidc_device(&self, user_id: &UserId, device_id: &DeviceId, idp_id: &str) {
710 self.db
711 .oidcdevice_userdeviceid
712 .put((user_id, device_id), Json(idp_id));
713}
714
715#[allow(clippy::must_use_candidate)]
718#[implement(super::Service)]
719pub fn allow_cross_signing_replacement(&self, user_id: &UserId) -> SystemTime {
720 let duration = Duration::from_mins(10);
721 let expires = timepoint_from_now(duration).expect("failed to create timepoint from now");
722
723 self.db
724 .oidccskeybypass_userid
725 .raw_put(user_id, Cbor(expires));
726
727 expires
728}
729
730#[implement(super::Service)]
732pub async fn can_replace_cross_signing_keys(&self, user_id: &UserId) -> bool {
733 let Ok(expires): Result<SystemTime, _> = self
734 .db
735 .oidccskeybypass_userid
736 .get(user_id)
737 .await
738 .deserialized::<Cbor<_>>()
739 .map(at!(0))
740 else {
741 return false;
742 };
743
744 if !timepoint_has_passed(expires) {
745 return true;
746 }
747
748 self.db.oidccskeybypass_userid.remove(user_id);
749 false
750}
751
752#[implement(super::Service)]
753pub async fn get_devicelist_version(&self, user_id: &UserId) -> Result<u64> {
754 self.db
755 .userid_devicelistversion
756 .get(user_id)
757 .await
758 .deserialized()
759}
760
761#[implement(super::Service)]
762pub fn all_devices_metadata<'a>(
763 &'a self,
764 user_id: &'a UserId,
765) -> impl Stream<Item = Device> + Send + 'a {
766 let key = (user_id, Interfix);
767 self.db
768 .userdeviceid_metadata
769 .stream_prefix(&key)
770 .ignore_err()
771 .map(|(_, val): (Ignore, Device)| val)
772}
773
774fn increment(db: &Arc<Map>, key: &[u8]) {
776 let old = db.get_blocking(key);
777 let new = utils::increment(old.ok().as_deref());
778 db.insert(key, new);
779}