1use std::{
2 net::IpAddr,
3 sync::Arc,
4 time::{Duration, SystemTime},
5};
6
7use futures::{FutureExt, Stream, StreamExt, future::join};
8use ruma::{
9 DeviceId, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedUserId, UserId,
10 api::client::device::Device, events::AnyToDeviceEvent, serde::Raw,
11};
12use serde_json::json;
13use tuwunel_core::{
14 Err, Result, at, implement,
15 utils::{
16 self, BoolExt, ReadyExt,
17 stream::{IterStream, TryIgnore},
18 string::to_small_string,
19 time::{
20 duration_since_epoch, timepoint_from_epoch, timepoint_from_now, timepoint_has_passed,
21 },
22 },
23};
24use tuwunel_database::{Cbor, Deserialized, Ignore, Interfix, Json, Map};
25
26const DEVICE_ID_LENGTH: usize = 10;
28
29pub const TOKEN_LENGTH: usize = 32;
31
32#[implement(super::Service)]
34#[tracing::instrument(level = "info", skip(self, access_token))]
35pub async fn create_device(
36 &self,
37 user_id: &UserId,
38 device_id: Option<&DeviceId>,
39 (access_token, expires_in): (Option<&str>, Option<Duration>),
40 refresh_token: Option<&str>,
41 initial_device_display_name: Option<&str>,
42 client_ip: Option<IpAddr>,
43) -> Result<OwnedDeviceId> {
44 let device_id = device_id
45 .map(ToOwned::to_owned)
46 .unwrap_or_else(|| OwnedDeviceId::from(utils::random_string(DEVICE_ID_LENGTH)));
47
48 if !self.exists(user_id).await {
49 return Err!(Request(InvalidParam(error!(
50 "Called create_device for non-existent user {user_id}"
51 ))));
52 }
53
54 let notify = true;
55 self.put_device_metadata(user_id, notify, &Device {
56 device_id: device_id.clone(),
57 display_name: initial_device_display_name.map(Into::into),
58 last_seen_ts: Some(MilliSecondsSinceUnixEpoch::now()),
59 last_seen_ip: client_ip.map(to_small_string),
60 });
61
62 if let Some(access_token) = access_token {
63 self.set_access_token(user_id, &device_id, access_token, expires_in, refresh_token)
64 .await?;
65 }
66
67 Ok(device_id)
68}
69
70#[implement(super::Service)]
72#[tracing::instrument(level = "info", skip(self))]
73pub async fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) {
74 self.remove_tokens(user_id, device_id).await;
76
77 let prefix = (user_id, device_id, Interfix);
79 self.db
80 .todeviceid_events
81 .keys_prefix_raw(&prefix)
82 .ignore_err()
83 .ready_for_each(|key| self.db.todeviceid_events.remove(key))
84 .await;
85
86 self.services
88 .pusher
89 .get_device_pushkeys(user_id, device_id)
90 .map(Vec::into_iter)
91 .map(IterStream::stream)
92 .flatten_stream()
93 .for_each(async |pushkey| {
94 self.services
95 .pusher
96 .delete_pusher(user_id, &pushkey)
97 .await;
98 })
99 .await;
100
101 self.remove_dehydrated_device(user_id, Some(device_id))
103 .await
104 .ok();
105
106 let prefix = (user_id, device_id, Interfix);
110 self.db
111 .userdeviceidalgorithm_fallback
112 .keys_prefix_raw(&prefix)
113 .ignore_err()
114 .ready_for_each(|key| self.db.userdeviceidalgorithm_fallback.remove(key))
115 .await;
116
117 let userdeviceid = (user_id, device_id);
118 self.db.userdeviceid_metadata.del(userdeviceid);
119 self.db.oidcdevice_userdeviceid.del(userdeviceid);
120
121 self.mark_device_key_update(user_id).await;
122 increment(&self.db.userid_devicelistversion, user_id.as_bytes());
123}
124
125#[implement(super::Service)]
127pub fn all_device_ids<'a>(
128 &'a self,
129 user_id: &'a UserId,
130) -> impl Stream<Item = &DeviceId> + Send + 'a {
131 let prefix = (user_id, Interfix);
132 self.db
133 .userdeviceid_metadata
134 .keys_prefix(&prefix)
135 .ignore_err()
136 .map(|(_, device_id): (Ignore, &DeviceId)| device_id)
137}
138
139#[implement(super::Service)]
141#[tracing::instrument(level = "trace", skip(self, token))]
142pub async fn find_from_token(
143 &self,
144 token: &str,
145) -> Result<(OwnedUserId, OwnedDeviceId, Option<SystemTime>)> {
146 self.db
147 .token_userdeviceid
148 .get(token)
149 .await
150 .deserialized()
151 .and_then(|(user_id, device_id, expires_at): (_, _, Option<u64>)| {
152 let expires_at = expires_at
153 .map(Duration::from_secs)
154 .map(timepoint_from_epoch)
155 .transpose()?;
156
157 Ok((user_id, device_id, expires_at))
158 })
159}
160
161#[implement(super::Service)]
162#[tracing::instrument(level = "debug", skip(self))]
163pub async fn remove_tokens(&self, user_id: &UserId, device_id: &DeviceId) {
164 let remove_access = self
165 .remove_access_token(user_id, device_id)
166 .map(Result::ok);
167
168 let remove_refresh = self
169 .remove_refresh_token(user_id, device_id)
170 .map(Result::ok);
171
172 join(remove_access, remove_refresh).await;
173}
174
175#[implement(super::Service)]
177#[tracing::instrument(level = "debug", skip(self))]
178pub async fn set_access_token(
179 &self,
180 user_id: &UserId,
181 device_id: &DeviceId,
182 access_token: &str,
183 expires_in: Option<Duration>,
184 refresh_token: Option<&str>,
185) -> Result {
186 assert!(
187 access_token.len() >= TOKEN_LENGTH,
188 "Caller must supply an access_token >= {TOKEN_LENGTH} chars."
189 );
190
191 if let Some(refresh_token) = refresh_token {
192 self.set_refresh_token(user_id, device_id, refresh_token)
193 .await?;
194 }
195
196 self.remove_access_token(user_id, device_id)
198 .await
199 .ok();
200
201 let expires_at = expires_in
202 .map(timepoint_from_now)
203 .transpose()?
204 .map(duration_since_epoch)
205 .as_ref()
206 .map(Duration::as_secs);
207
208 let userdeviceid = (user_id, device_id);
209 let value = (user_id, device_id, expires_at);
210 self.db
211 .token_userdeviceid
212 .raw_put(access_token, value);
213 self.db
214 .userdeviceid_token
215 .put_raw(userdeviceid, access_token);
216
217 Ok(())
218}
219
220#[implement(super::Service)]
223pub async fn remove_access_token(&self, user_id: &UserId, device_id: &DeviceId) -> Result {
224 let userdeviceid = (user_id, device_id);
225 let access_token = self
226 .db
227 .userdeviceid_token
228 .qry(&userdeviceid)
229 .await?;
230
231 self.db.userdeviceid_token.del(userdeviceid);
232 self.db.token_userdeviceid.remove(&access_token);
233
234 Ok(())
235}
236
237#[implement(super::Service)]
238pub async fn get_access_token(&self, user_id: &UserId, device_id: &DeviceId) -> Result<String> {
239 let key = (user_id, device_id);
240 self.db
241 .userdeviceid_token
242 .qry(&key)
243 .await
244 .deserialized()
245}
246
247#[implement(super::Service)]
248pub fn generate_access_token(&self, expires: bool) -> (String, Option<Duration>) {
249 let access_token = utils::random_string(TOKEN_LENGTH);
250 let expires_in = expires
251 .then_some(self.services.server.config.access_token_ttl)
252 .map(Duration::from_secs);
253
254 (access_token, expires_in)
255}
256
257#[implement(super::Service)]
259#[tracing::instrument(level = "debug", skip(self))]
260pub async fn set_refresh_token(
261 &self,
262 user_id: &UserId,
263 device_id: &DeviceId,
264 refresh_token: &str,
265) -> Result {
266 debug_assert!(refresh_token.starts_with("refresh_"), "refresh_token missing prefix");
267
268 let config = &self.services.server.config;
269 let ttl = config.refresh_token_ttl;
270 let idle_only = config.refresh_token_idle_only;
271
272 let prior_expires_at: Option<SystemTime> = (ttl != 0 && !idle_only)
274 .then_async(|| self.find_refresh_token_expires_at(user_id, device_id))
275 .await
276 .flatten();
277
278 self.remove_refresh_token(user_id, device_id)
280 .await
281 .ok();
282
283 let expires_at = match (ttl, prior_expires_at) {
284 | (0, _) => None,
285 | (_, Some(prior)) => Some(prior),
286 | (ttl, None) => Some(timepoint_from_now(Duration::from_secs(ttl))?),
287 };
288
289 let expires_at_secs = expires_at
290 .map(duration_since_epoch)
291 .as_ref()
292 .map(Duration::as_secs);
293
294 let userdeviceid = (user_id, device_id);
295 let value = (user_id, device_id, expires_at_secs);
296 self.db
297 .token_userdeviceid
298 .raw_put(refresh_token, value);
299 self.db
300 .userdeviceid_refresh
301 .put_raw(userdeviceid, refresh_token);
302
303 Ok(())
304}
305
306#[implement(super::Service)]
310async fn find_refresh_token_expires_at(
311 &self,
312 user_id: &UserId,
313 device_id: &DeviceId,
314) -> Option<SystemTime> {
315 let userdeviceid = (user_id, device_id);
316 let old_token: String = self
317 .db
318 .userdeviceid_refresh
319 .qry(&userdeviceid)
320 .await
321 .deserialized()
322 .ok()?;
323
324 let (_, _, expires_at_secs): (Ignore, Ignore, Option<u64>) = self
325 .db
326 .token_userdeviceid
327 .get(&old_token)
328 .await
329 .deserialized()
330 .ok()?;
331
332 expires_at_secs
333 .map(Duration::from_secs)
334 .map(timepoint_from_epoch)?
335 .ok()
336}
337
338#[implement(super::Service)]
341pub async fn remove_refresh_token(&self, user_id: &UserId, device_id: &DeviceId) -> Result {
342 let userdeviceid = (user_id, device_id);
343 let refresh_token = self
344 .db
345 .userdeviceid_refresh
346 .qry(&userdeviceid)
347 .await?;
348
349 self.db.userdeviceid_refresh.del(userdeviceid);
350 self.db.token_userdeviceid.remove(&refresh_token);
351
352 Ok(())
353}
354
355#[implement(super::Service)]
356pub async fn get_refresh_token(&self, user_id: &UserId, device_id: &DeviceId) -> Result<String> {
357 let key = (user_id, device_id);
358 self.db
359 .userdeviceid_refresh
360 .qry(&key)
361 .await
362 .deserialized()
363}
364
365#[must_use]
366pub fn generate_refresh_token() -> String {
367 format!("refresh_{}", utils::random_string(TOKEN_LENGTH))
368}
369
370#[implement(super::Service)]
371pub fn add_to_device_event(
372 &self,
373 sender: &UserId,
374 target_user_id: &UserId,
375 target_device_id: &DeviceId,
376 event_type: &str,
377 content: &serde_json::Value,
378) {
379 let count = self.services.globals.next_count();
380
381 let key = (target_user_id, target_device_id, *count);
382 self.db.todeviceid_events.put(
383 key,
384 Json(json!({
385 "type": event_type,
386 "sender": sender,
387 "content": content,
388 })),
389 );
390}
391
392#[implement(super::Service)]
393pub fn get_to_device_events<'a>(
394 &'a self,
395 user_id: &'a UserId,
396 device_id: &'a DeviceId,
397 since: Option<u64>,
398 to: Option<u64>,
399) -> impl Stream<Item = (u64, Raw<AnyToDeviceEvent>)> + Send + 'a {
400 type Key<'a> = (&'a UserId, &'a DeviceId, u64);
401
402 let from = (user_id, device_id, since.map_or(0, |since| since.saturating_add(1)));
403
404 self.db
405 .todeviceid_events
406 .stream_from(&from)
407 .ignore_err()
408 .ready_take_while(move |((user_id_, device_id_, count), _): &(Key<'_>, _)| {
409 user_id == *user_id_ && device_id == *device_id_ && to.is_none_or(|to| *count <= to)
410 })
411 .map(|((_, _, count), event)| (count, event))
412}
413
414#[implement(super::Service)]
415pub async fn remove_to_device_events<Until>(
416 &self,
417 user_id: &UserId,
418 device_id: &DeviceId,
419 until: Until,
420) where
421 Until: Into<Option<u64>> + Send,
422{
423 type Key<'a> = (&'a UserId, &'a DeviceId, u64);
424
425 let until = until.into().unwrap_or(u64::MAX);
426 let from = (user_id, device_id, until);
427 self.db
428 .todeviceid_events
429 .rev_keys_from(&from)
430 .ignore_err()
431 .ready_take_while(move |(user_id_, device_id_, _): &Key<'_>| {
432 user_id == *user_id_ && device_id == *device_id_
433 })
434 .ready_for_each(|key: Key<'_>| {
435 self.db.todeviceid_events.del(key);
436 })
437 .await;
438}
439
440#[implement(super::Service)]
441pub async fn update_device_last_seen(
442 &self,
443 user_id: &UserId,
444 device_id: &DeviceId,
445 last_seen_ip: Option<IpAddr>,
446 last_seen_ts: Option<MilliSecondsSinceUnixEpoch>,
447) -> Result {
448 let mut device = self
449 .get_device_metadata(user_id, device_id)
450 .await?;
451
452 if let Some(last_seen_ip) = last_seen_ip.map(to_small_string) {
453 device.last_seen_ip.replace(last_seen_ip);
454 }
455
456 device
457 .last_seen_ts
458 .replace(last_seen_ts.unwrap_or_else(MilliSecondsSinceUnixEpoch::now));
459
460 self.put_device_metadata(user_id, false, &device);
461
462 Ok(())
463}
464
465#[implement(super::Service)]
466pub fn put_device_metadata(&self, user_id: &UserId, notify: bool, device: &Device) {
467 let key = (user_id, &device.device_id);
468 self.db
469 .userdeviceid_metadata
470 .put(key, Json(device));
471
472 if notify {
473 increment(&self.db.userid_devicelistversion, user_id.as_bytes());
474 }
475}
476
477#[implement(super::Service)]
479pub async fn get_device_metadata(
480 &self,
481 user_id: &UserId,
482 device_id: &DeviceId,
483) -> Result<Device> {
484 self.db
485 .userdeviceid_metadata
486 .qry(&(user_id, device_id))
487 .await
488 .deserialized()
489 .inspect(|device: &Device| {
490 debug_assert_eq!(&device.device_id, device_id, "device_id mismatch");
491 })
492}
493
494#[implement(super::Service)]
495pub async fn device_exists(&self, user_id: &UserId, device_id: &DeviceId) -> bool {
496 self.db
497 .userdeviceid_metadata
498 .contains(&(user_id, device_id))
499 .await
500}
501
502#[implement(super::Service)]
503pub async fn is_oidc_device(&self, user_id: &UserId, device_id: &DeviceId) -> bool {
504 self.db
505 .oidcdevice_userdeviceid
506 .contains(&(user_id, device_id))
507 .await
508}
509
510#[implement(super::Service)]
513pub async fn get_oidc_device_idp(
514 &self,
515 user_id: &UserId,
516 device_id: &DeviceId,
517) -> Option<String> {
518 self.db
519 .oidcdevice_userdeviceid
520 .qry(&(user_id, device_id))
521 .await
522 .deserialized::<Json<String>>()
523 .ok()
524 .map(|Json(idp)| idp)
525}
526
527#[implement(super::Service)]
528pub fn mark_oidc_device(&self, user_id: &UserId, device_id: &DeviceId, idp_id: &str) {
529 self.db
530 .oidcdevice_userdeviceid
531 .put((user_id, device_id), Json(idp_id));
532}
533
534#[allow(clippy::must_use_candidate)]
537#[implement(super::Service)]
538pub fn allow_cross_signing_replacement(&self, user_id: &UserId) -> SystemTime {
539 let duration = Duration::from_mins(10);
540 let expires = timepoint_from_now(duration).expect("failed to create timepoint from now");
541
542 self.db
543 .oidccskeybypass_userid
544 .raw_put(user_id, Cbor(expires));
545
546 expires
547}
548
549#[implement(super::Service)]
551pub async fn can_replace_cross_signing_keys(&self, user_id: &UserId) -> bool {
552 let Ok(expires): Result<SystemTime, _> = self
553 .db
554 .oidccskeybypass_userid
555 .get(user_id)
556 .await
557 .deserialized::<Cbor<_>>()
558 .map(at!(0))
559 else {
560 return false;
561 };
562
563 if !timepoint_has_passed(expires) {
564 return true;
565 }
566
567 self.db.oidccskeybypass_userid.remove(user_id);
568 false
569}
570
571#[implement(super::Service)]
572pub async fn get_devicelist_version(&self, user_id: &UserId) -> Result<u64> {
573 self.db
574 .userid_devicelistversion
575 .get(user_id)
576 .await
577 .deserialized()
578}
579
580#[implement(super::Service)]
581pub fn all_devices_metadata<'a>(
582 &'a self,
583 user_id: &'a UserId,
584) -> impl Stream<Item = Device> + Send + 'a {
585 let key = (user_id, Interfix);
586 self.db
587 .userdeviceid_metadata
588 .stream_prefix(&key)
589 .ignore_err()
590 .map(|(_, val): (Ignore, Device)| val)
591}
592
593fn increment(db: &Arc<Map>, key: &[u8]) {
595 let old = db.get_blocking(key);
596 let new = utils::increment(old.ok().as_deref());
597 db.insert(key, new);
598}