Skip to main content

tuwunel_service/users/
device.rs

1use std::{
2	net::IpAddr,
3	sync::Arc,
4	time::{Duration, SystemTime},
5};
6
7use futures::{FutureExt, Stream, StreamExt, future::join};
8use ruma::{
9	DeviceId, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedUserId, UserId,
10	api::client::device::Device, events::AnyToDeviceEvent, serde::Raw,
11};
12use serde_json::json;
13use tuwunel_core::{
14	Err, Result, at, implement,
15	utils::{
16		self, BoolExt, ReadyExt,
17		stream::{IterStream, TryIgnore},
18		string::to_small_string,
19		time::{
20			duration_since_epoch, timepoint_from_epoch, timepoint_from_now, timepoint_has_passed,
21		},
22	},
23};
24use tuwunel_database::{Cbor, Deserialized, Ignore, Interfix, Json, Map};
25
26/// generated device ID length
27const DEVICE_ID_LENGTH: usize = 10;
28
29/// generated user access token length
30pub const TOKEN_LENGTH: usize = 32;
31
32/// Adds a new device to a user.
33#[implement(super::Service)]
34#[tracing::instrument(level = "info", skip(self, access_token))]
35pub async fn create_device(
36	&self,
37	user_id: &UserId,
38	device_id: Option<&DeviceId>,
39	(access_token, expires_in): (Option<&str>, Option<Duration>),
40	refresh_token: Option<&str>,
41	initial_device_display_name: Option<&str>,
42	client_ip: Option<IpAddr>,
43) -> Result<OwnedDeviceId> {
44	let device_id = device_id
45		.map(ToOwned::to_owned)
46		.unwrap_or_else(|| OwnedDeviceId::from(utils::random_string(DEVICE_ID_LENGTH)));
47
48	if !self.exists(user_id).await {
49		return Err!(Request(InvalidParam(error!(
50			"Called create_device for non-existent user {user_id}"
51		))));
52	}
53
54	let notify = true;
55	self.put_device_metadata(user_id, notify, &Device {
56		device_id: device_id.clone(),
57		display_name: initial_device_display_name.map(Into::into),
58		last_seen_ts: Some(MilliSecondsSinceUnixEpoch::now()),
59		last_seen_ip: client_ip.map(to_small_string),
60	});
61
62	if let Some(access_token) = access_token {
63		self.set_access_token(user_id, &device_id, access_token, expires_in, refresh_token)
64			.await?;
65	}
66
67	Ok(device_id)
68}
69
70/// Removes a device from a user.
71#[implement(super::Service)]
72#[tracing::instrument(level = "info", skip(self))]
73pub async fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) {
74	// Remove access tokens
75	self.remove_tokens(user_id, device_id).await;
76
77	// Remove todevice events
78	let prefix = (user_id, device_id, Interfix);
79	self.db
80		.todeviceid_events
81		.keys_prefix_raw(&prefix)
82		.ignore_err()
83		.ready_for_each(|key| self.db.todeviceid_events.remove(key))
84		.await;
85
86	// Remove pushers
87	self.services
88		.pusher
89		.get_device_pushkeys(user_id, device_id)
90		.map(Vec::into_iter)
91		.map(IterStream::stream)
92		.flatten_stream()
93		.for_each(async |pushkey| {
94			self.services
95				.pusher
96				.delete_pusher(user_id, &pushkey)
97				.await;
98		})
99		.await;
100
101	// Removes the dehydrated device if the ID matches, otherwise no-op
102	self.remove_dehydrated_device(user_id, Some(device_id))
103		.await
104		.ok();
105
106	// TODO: Remove onetimekeys
107
108	// MSC2732: drop fallback keys for this device.
109	let prefix = (user_id, device_id, Interfix);
110	self.db
111		.userdeviceidalgorithm_fallback
112		.keys_prefix_raw(&prefix)
113		.ignore_err()
114		.ready_for_each(|key| self.db.userdeviceidalgorithm_fallback.remove(key))
115		.await;
116
117	let userdeviceid = (user_id, device_id);
118	self.db.userdeviceid_metadata.del(userdeviceid);
119	self.db.oidcdevice_userdeviceid.del(userdeviceid);
120
121	self.mark_device_key_update(user_id).await;
122	increment(&self.db.userid_devicelistversion, user_id.as_bytes());
123}
124
125/// Returns an iterator over all device ids of this user.
126#[implement(super::Service)]
127pub fn all_device_ids<'a>(
128	&'a self,
129	user_id: &'a UserId,
130) -> impl Stream<Item = &DeviceId> + Send + 'a {
131	let prefix = (user_id, Interfix);
132	self.db
133		.userdeviceid_metadata
134		.keys_prefix(&prefix)
135		.ignore_err()
136		.map(|(_, device_id): (Ignore, &DeviceId)| device_id)
137}
138
139/// Find out which user an access or refresh token belongs to.
140#[implement(super::Service)]
141#[tracing::instrument(level = "trace", skip(self, token))]
142pub async fn find_from_token(
143	&self,
144	token: &str,
145) -> Result<(OwnedUserId, OwnedDeviceId, Option<SystemTime>)> {
146	self.db
147		.token_userdeviceid
148		.get(token)
149		.await
150		.deserialized()
151		.and_then(|(user_id, device_id, expires_at): (_, _, Option<u64>)| {
152			let expires_at = expires_at
153				.map(Duration::from_secs)
154				.map(timepoint_from_epoch)
155				.transpose()?;
156
157			Ok((user_id, device_id, expires_at))
158		})
159}
160
161#[implement(super::Service)]
162#[tracing::instrument(level = "debug", skip(self))]
163pub async fn remove_tokens(&self, user_id: &UserId, device_id: &DeviceId) {
164	let remove_access = self
165		.remove_access_token(user_id, device_id)
166		.map(Result::ok);
167
168	let remove_refresh = self
169		.remove_refresh_token(user_id, device_id)
170		.map(Result::ok);
171
172	join(remove_access, remove_refresh).await;
173}
174
175/// Replaces the access token of one device.
176#[implement(super::Service)]
177#[tracing::instrument(level = "debug", skip(self))]
178pub async fn set_access_token(
179	&self,
180	user_id: &UserId,
181	device_id: &DeviceId,
182	access_token: &str,
183	expires_in: Option<Duration>,
184	refresh_token: Option<&str>,
185) -> Result {
186	assert!(
187		access_token.len() >= TOKEN_LENGTH,
188		"Caller must supply an access_token >= {TOKEN_LENGTH} chars."
189	);
190
191	if let Some(refresh_token) = refresh_token {
192		self.set_refresh_token(user_id, device_id, refresh_token)
193			.await?;
194	}
195
196	// Remove old token.
197	self.remove_access_token(user_id, device_id)
198		.await
199		.ok();
200
201	let expires_at = expires_in
202		.map(timepoint_from_now)
203		.transpose()?
204		.map(duration_since_epoch)
205		.as_ref()
206		.map(Duration::as_secs);
207
208	let userdeviceid = (user_id, device_id);
209	let value = (user_id, device_id, expires_at);
210	self.db
211		.token_userdeviceid
212		.raw_put(access_token, value);
213	self.db
214		.userdeviceid_token
215		.put_raw(userdeviceid, access_token);
216
217	Ok(())
218}
219
220/// Revoke the access token without deleting the device. Take care to not leave
221/// dangling devices if using this method.
222#[implement(super::Service)]
223pub async fn remove_access_token(&self, user_id: &UserId, device_id: &DeviceId) -> Result {
224	let userdeviceid = (user_id, device_id);
225	let access_token = self
226		.db
227		.userdeviceid_token
228		.qry(&userdeviceid)
229		.await?;
230
231	self.db.userdeviceid_token.del(userdeviceid);
232	self.db.token_userdeviceid.remove(&access_token);
233
234	Ok(())
235}
236
237#[implement(super::Service)]
238pub async fn get_access_token(&self, user_id: &UserId, device_id: &DeviceId) -> Result<String> {
239	let key = (user_id, device_id);
240	self.db
241		.userdeviceid_token
242		.qry(&key)
243		.await
244		.deserialized()
245}
246
247#[implement(super::Service)]
248pub fn generate_access_token(&self, expires: bool) -> (String, Option<Duration>) {
249	let access_token = utils::random_string(TOKEN_LENGTH);
250	let expires_in = expires
251		.then_some(self.services.server.config.access_token_ttl)
252		.map(Duration::from_secs);
253
254	(access_token, expires_in)
255}
256
257/// Replaces the refresh token of one device.
258#[implement(super::Service)]
259#[tracing::instrument(level = "debug", skip(self))]
260pub async fn set_refresh_token(
261	&self,
262	user_id: &UserId,
263	device_id: &DeviceId,
264	refresh_token: &str,
265) -> Result {
266	debug_assert!(refresh_token.starts_with("refresh_"), "refresh_token missing prefix");
267
268	let config = &self.services.server.config;
269	let ttl = config.refresh_token_ttl;
270	let idle_only = config.refresh_token_idle_only;
271
272	// Absolute mode carries the prior deadline forward instead of sliding it.
273	let prior_expires_at: Option<SystemTime> = (ttl != 0 && !idle_only)
274		.then_async(|| self.find_refresh_token_expires_at(user_id, device_id))
275		.await
276		.flatten();
277
278	// Remove old token
279	self.remove_refresh_token(user_id, device_id)
280		.await
281		.ok();
282
283	let expires_at = match (ttl, prior_expires_at) {
284		| (0, _) => None,
285		| (_, Some(prior)) => Some(prior),
286		| (ttl, None) => Some(timepoint_from_now(Duration::from_secs(ttl))?),
287	};
288
289	let expires_at_secs = expires_at
290		.map(duration_since_epoch)
291		.as_ref()
292		.map(Duration::as_secs);
293
294	let userdeviceid = (user_id, device_id);
295	let value = (user_id, device_id, expires_at_secs);
296	self.db
297		.token_userdeviceid
298		.raw_put(refresh_token, value);
299	self.db
300		.userdeviceid_refresh
301		.put_raw(userdeviceid, refresh_token);
302
303	Ok(())
304}
305
306/// Look up the expiry stored alongside the current refresh token for this
307/// device, if one is recorded. Pre-rotation entries carry no expiry and
308/// return `None`.
309#[implement(super::Service)]
310async fn find_refresh_token_expires_at(
311	&self,
312	user_id: &UserId,
313	device_id: &DeviceId,
314) -> Option<SystemTime> {
315	let userdeviceid = (user_id, device_id);
316	let old_token: String = self
317		.db
318		.userdeviceid_refresh
319		.qry(&userdeviceid)
320		.await
321		.deserialized()
322		.ok()?;
323
324	let (_, _, expires_at_secs): (Ignore, Ignore, Option<u64>) = self
325		.db
326		.token_userdeviceid
327		.get(&old_token)
328		.await
329		.deserialized()
330		.ok()?;
331
332	expires_at_secs
333		.map(Duration::from_secs)
334		.map(timepoint_from_epoch)?
335		.ok()
336}
337
338/// Revoke the refresh token without deleting the device. Take care to not leave
339/// dangling devices if using this method.
340#[implement(super::Service)]
341pub async fn remove_refresh_token(&self, user_id: &UserId, device_id: &DeviceId) -> Result {
342	let userdeviceid = (user_id, device_id);
343	let refresh_token = self
344		.db
345		.userdeviceid_refresh
346		.qry(&userdeviceid)
347		.await?;
348
349	self.db.userdeviceid_refresh.del(userdeviceid);
350	self.db.token_userdeviceid.remove(&refresh_token);
351
352	Ok(())
353}
354
355#[implement(super::Service)]
356pub async fn get_refresh_token(&self, user_id: &UserId, device_id: &DeviceId) -> Result<String> {
357	let key = (user_id, device_id);
358	self.db
359		.userdeviceid_refresh
360		.qry(&key)
361		.await
362		.deserialized()
363}
364
365#[must_use]
366pub fn generate_refresh_token() -> String {
367	format!("refresh_{}", utils::random_string(TOKEN_LENGTH))
368}
369
370#[implement(super::Service)]
371pub fn add_to_device_event(
372	&self,
373	sender: &UserId,
374	target_user_id: &UserId,
375	target_device_id: &DeviceId,
376	event_type: &str,
377	content: &serde_json::Value,
378) {
379	let count = self.services.globals.next_count();
380
381	let key = (target_user_id, target_device_id, *count);
382	self.db.todeviceid_events.put(
383		key,
384		Json(json!({
385			"type": event_type,
386			"sender": sender,
387			"content": content,
388		})),
389	);
390}
391
392#[implement(super::Service)]
393pub fn get_to_device_events<'a>(
394	&'a self,
395	user_id: &'a UserId,
396	device_id: &'a DeviceId,
397	since: Option<u64>,
398	to: Option<u64>,
399) -> impl Stream<Item = (u64, Raw<AnyToDeviceEvent>)> + Send + 'a {
400	type Key<'a> = (&'a UserId, &'a DeviceId, u64);
401
402	let from = (user_id, device_id, since.map_or(0, |since| since.saturating_add(1)));
403
404	self.db
405		.todeviceid_events
406		.stream_from(&from)
407		.ignore_err()
408		.ready_take_while(move |((user_id_, device_id_, count), _): &(Key<'_>, _)| {
409			user_id == *user_id_ && device_id == *device_id_ && to.is_none_or(|to| *count <= to)
410		})
411		.map(|((_, _, count), event)| (count, event))
412}
413
414#[implement(super::Service)]
415pub async fn remove_to_device_events<Until>(
416	&self,
417	user_id: &UserId,
418	device_id: &DeviceId,
419	until: Until,
420) where
421	Until: Into<Option<u64>> + Send,
422{
423	type Key<'a> = (&'a UserId, &'a DeviceId, u64);
424
425	let until = until.into().unwrap_or(u64::MAX);
426	let from = (user_id, device_id, until);
427	self.db
428		.todeviceid_events
429		.rev_keys_from(&from)
430		.ignore_err()
431		.ready_take_while(move |(user_id_, device_id_, _): &Key<'_>| {
432			user_id == *user_id_ && device_id == *device_id_
433		})
434		.ready_for_each(|key: Key<'_>| {
435			self.db.todeviceid_events.del(key);
436		})
437		.await;
438}
439
440#[implement(super::Service)]
441pub async fn update_device_last_seen(
442	&self,
443	user_id: &UserId,
444	device_id: &DeviceId,
445	last_seen_ip: Option<IpAddr>,
446	last_seen_ts: Option<MilliSecondsSinceUnixEpoch>,
447) -> Result {
448	let mut device = self
449		.get_device_metadata(user_id, device_id)
450		.await?;
451
452	if let Some(last_seen_ip) = last_seen_ip.map(to_small_string) {
453		device.last_seen_ip.replace(last_seen_ip);
454	}
455
456	device
457		.last_seen_ts
458		.replace(last_seen_ts.unwrap_or_else(MilliSecondsSinceUnixEpoch::now));
459
460	self.put_device_metadata(user_id, false, &device);
461
462	Ok(())
463}
464
465#[implement(super::Service)]
466pub fn put_device_metadata(&self, user_id: &UserId, notify: bool, device: &Device) {
467	let key = (user_id, &device.device_id);
468	self.db
469		.userdeviceid_metadata
470		.put(key, Json(device));
471
472	if notify {
473		increment(&self.db.userid_devicelistversion, user_id.as_bytes());
474	}
475}
476
477/// Get device metadata.
478#[implement(super::Service)]
479pub async fn get_device_metadata(
480	&self,
481	user_id: &UserId,
482	device_id: &DeviceId,
483) -> Result<Device> {
484	self.db
485		.userdeviceid_metadata
486		.qry(&(user_id, device_id))
487		.await
488		.deserialized()
489		.inspect(|device: &Device| {
490			debug_assert_eq!(&device.device_id, device_id, "device_id mismatch");
491		})
492}
493
494#[implement(super::Service)]
495pub async fn device_exists(&self, user_id: &UserId, device_id: &DeviceId) -> bool {
496	self.db
497		.userdeviceid_metadata
498		.contains(&(user_id, device_id))
499		.await
500}
501
502#[implement(super::Service)]
503pub async fn is_oidc_device(&self, user_id: &UserId, device_id: &DeviceId) -> bool {
504	self.db
505		.oidcdevice_userdeviceid
506		.contains(&(user_id, device_id))
507		.await
508}
509
510/// Returns the IdP that originally authenticated this device, if known.
511/// Returns `None` for devices predating the idp_id field or non-OIDC devices.
512#[implement(super::Service)]
513pub async fn get_oidc_device_idp(
514	&self,
515	user_id: &UserId,
516	device_id: &DeviceId,
517) -> Option<String> {
518	self.db
519		.oidcdevice_userdeviceid
520		.qry(&(user_id, device_id))
521		.await
522		.deserialized::<Json<String>>()
523		.ok()
524		.map(|Json(idp)| idp)
525}
526
527#[implement(super::Service)]
528pub fn mark_oidc_device(&self, user_id: &UserId, device_id: &DeviceId, idp_id: &str) {
529	self.db
530		.oidcdevice_userdeviceid
531		.put((user_id, device_id), Json(idp_id));
532}
533
534/// Allow cross-signing key replacement without UIAA for the next 10 minutes.
535/// Returns the expiry timestamp in milliseconds.
536#[allow(clippy::must_use_candidate)]
537#[implement(super::Service)]
538pub fn allow_cross_signing_replacement(&self, user_id: &UserId) -> SystemTime {
539	let duration = Duration::from_mins(10);
540	let expires = timepoint_from_now(duration).expect("failed to create timepoint from now");
541
542	self.db
543		.oidccskeybypass_userid
544		.raw_put(user_id, Cbor(expires));
545
546	expires
547}
548
549/// Check if the user is allowed to replace cross-signing keys without UIAA.
550#[implement(super::Service)]
551pub async fn can_replace_cross_signing_keys(&self, user_id: &UserId) -> bool {
552	let Ok(expires): Result<SystemTime, _> = self
553		.db
554		.oidccskeybypass_userid
555		.get(user_id)
556		.await
557		.deserialized::<Cbor<_>>()
558		.map(at!(0))
559	else {
560		return false;
561	};
562
563	if !timepoint_has_passed(expires) {
564		return true;
565	}
566
567	self.db.oidccskeybypass_userid.remove(user_id);
568	false
569}
570
571#[implement(super::Service)]
572pub async fn get_devicelist_version(&self, user_id: &UserId) -> Result<u64> {
573	self.db
574		.userid_devicelistversion
575		.get(user_id)
576		.await
577		.deserialized()
578}
579
580#[implement(super::Service)]
581pub fn all_devices_metadata<'a>(
582	&'a self,
583	user_id: &'a UserId,
584) -> impl Stream<Item = Device> + Send + 'a {
585	let key = (user_id, Interfix);
586	self.db
587		.userdeviceid_metadata
588		.stream_prefix(&key)
589		.ignore_err()
590		.map(|(_, val): (Ignore, Device)| val)
591}
592
593//TODO: this is an ABA
594fn increment(db: &Arc<Map>, key: &[u8]) {
595	let old = db.get_blocking(key);
596	let new = utils::increment(old.ok().as_deref());
597	db.insert(key, new);
598}