Skip to main content

tuwunel_service/users/
device.rs

1use std::{
2	net::IpAddr,
3	sync::Arc,
4	time::{Duration, SystemTime},
5};
6
7use futures::{FutureExt, Stream, StreamExt, future::join};
8use ruma::{
9	DeviceId, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedUserId, UserId,
10	api::client::device::Device, events::AnyToDeviceEvent, serde::Raw,
11};
12use serde_json::json;
13use tuwunel_core::{
14	Err, Result, at, implement, trace,
15	utils::{
16		self, BoolExt, ReadyExt,
17		stream::{IterStream, TryIgnore},
18		string::to_small_string,
19		time::{
20			duration_since_epoch, timepoint_from_epoch, timepoint_from_now, timepoint_has_passed,
21		},
22	},
23};
24use tuwunel_database::{Cbor, Deserialized, Ignore, Interfix, Json, Map};
25
26/// generated device ID length
27const DEVICE_ID_LENGTH: usize = 10;
28
29/// generated user access token length
30pub const TOKEN_LENGTH: usize = 32;
31
32/// Adds a new device to a user.
33#[implement(super::Service)]
34#[tracing::instrument(level = "info", skip(self, access_token))]
35pub async fn create_device(
36	&self,
37	user_id: &UserId,
38	device_id: Option<&DeviceId>,
39	(access_token, expires_in): (Option<&str>, Option<Duration>),
40	refresh_token: Option<&str>,
41	initial_device_display_name: Option<&str>,
42	client_ip: Option<IpAddr>,
43) -> Result<OwnedDeviceId> {
44	let device_id = device_id
45		.map(ToOwned::to_owned)
46		.unwrap_or_else(|| OwnedDeviceId::from(utils::random_string(DEVICE_ID_LENGTH)));
47
48	if !self.exists(user_id).await {
49		return Err!(Request(InvalidParam(error!(
50			"Called create_device for non-existent user {user_id}"
51		))));
52	}
53
54	let notify = true;
55	self.put_device_metadata(user_id, notify, &Device {
56		device_id: device_id.clone(),
57		display_name: initial_device_display_name.map(Into::into),
58		last_seen_ts: Some(MilliSecondsSinceUnixEpoch::now()),
59		last_seen_ip: client_ip.map(to_small_string),
60	});
61
62	if let Some(access_token) = access_token {
63		self.set_access_token(user_id, &device_id, access_token, expires_in, refresh_token)
64			.await?;
65	}
66
67	Ok(device_id)
68}
69
70/// Removes a device from a user.
71#[implement(super::Service)]
72#[tracing::instrument(level = "info", skip(self))]
73pub async fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) {
74	// Remove access tokens
75	self.remove_tokens(user_id, device_id).await;
76
77	// Remove todevice events
78	let prefix = (user_id, device_id, Interfix);
79	self.db
80		.todeviceid_events
81		.keys_prefix_raw(&prefix)
82		.ignore_err()
83		.ready_for_each(|key| self.db.todeviceid_events.remove(key))
84		.await;
85
86	// Remove pushers
87	self.services
88		.pusher
89		.get_device_pushkeys(user_id, device_id)
90		.map(Vec::into_iter)
91		.map(IterStream::stream)
92		.flatten_stream()
93		.for_each(async |pushkey| {
94			self.services
95				.pusher
96				.delete_pusher(user_id, &pushkey)
97				.await;
98		})
99		.await;
100
101	// Removes the dehydrated device if the ID matches, otherwise no-op
102	self.remove_dehydrated_device(user_id, Some(device_id))
103		.await
104		.ok();
105
106	// TODO: Remove onetimekeys
107
108	// MSC2732: drop fallback keys for this device.
109	let prefix = (user_id, device_id, Interfix);
110	self.db
111		.userdeviceidalgorithm_fallback
112		.keys_prefix_raw(&prefix)
113		.ignore_err()
114		.ready_for_each(|key| self.db.userdeviceidalgorithm_fallback.remove(key))
115		.await;
116
117	// MSC3890: drop this device's local notification settings.
118	let event_type = format!("org.matrix.msc3890.local_notification_settings.{device_id}").into();
119	self.services
120		.account_data
121		.delete(None, user_id, event_type)
122		.await
123		.ok();
124
125	let userdeviceid = (user_id, device_id);
126	self.db.userdeviceid_metadata.del(userdeviceid);
127	self.db.oidcdevice_userdeviceid.del(userdeviceid);
128
129	self.mark_device_key_update(user_id).await;
130	increment(&self.db.userid_devicelistversion, user_id.as_bytes());
131}
132
133/// Returns an iterator over all device ids of this user.
134#[implement(super::Service)]
135pub fn all_device_ids<'a>(
136	&'a self,
137	user_id: &'a UserId,
138) -> impl Stream<Item = &DeviceId> + Send + 'a {
139	let prefix = (user_id, Interfix);
140	self.db
141		.userdeviceid_metadata
142		.keys_prefix(&prefix)
143		.ignore_err()
144		.map(|(_, device_id): (Ignore, &DeviceId)| device_id)
145}
146
147/// Find out which user an access or refresh token belongs to.
148#[implement(super::Service)]
149#[tracing::instrument(level = "trace", skip(self, token))]
150pub async fn find_from_token(
151	&self,
152	token: &str,
153) -> Result<(OwnedUserId, OwnedDeviceId, Option<SystemTime>)> {
154	self.db
155		.token_userdeviceid
156		.get(token)
157		.await
158		.deserialized()
159		.and_then(|(user_id, device_id, expires_at): (_, _, Option<u64>)| {
160			let expires_at = expires_at
161				.map(Duration::from_secs)
162				.map(timepoint_from_epoch)
163				.transpose()?;
164
165			Ok((user_id, device_id, expires_at))
166		})
167}
168
169#[implement(super::Service)]
170#[tracing::instrument(level = "debug", skip(self))]
171pub async fn remove_tokens(&self, user_id: &UserId, device_id: &DeviceId) {
172	let remove_access = self
173		.remove_access_token(user_id, device_id)
174		.map(Result::ok);
175
176	let remove_refresh = self
177		.remove_refresh_token(user_id, device_id)
178		.map(Result::ok);
179
180	join(remove_access, remove_refresh).await;
181}
182
183/// Replaces the access token of one device.
184#[implement(super::Service)]
185#[tracing::instrument(level = "debug", skip(self))]
186pub async fn set_access_token(
187	&self,
188	user_id: &UserId,
189	device_id: &DeviceId,
190	access_token: &str,
191	expires_in: Option<Duration>,
192	refresh_token: Option<&str>,
193) -> Result {
194	assert!(
195		access_token.len() >= TOKEN_LENGTH,
196		"Caller must supply an access_token >= {TOKEN_LENGTH} chars."
197	);
198
199	if let Some(refresh_token) = refresh_token {
200		self.set_refresh_token(user_id, device_id, refresh_token)
201			.await?;
202	}
203
204	let expires_at = expires_in
205		.map(timepoint_from_now)
206		.transpose()?
207		.map(duration_since_epoch)
208		.as_ref()
209		.map(Duration::as_secs);
210
211	let userdeviceid = (user_id, device_id);
212
213	// Fold the prior pointer token into the index for pre-index upgrades.
214	if let Ok(prev) = self
215		.db
216		.userdeviceid_token
217		.qry(&userdeviceid)
218		.await
219		.deserialized::<String>()
220	{
221		self.db
222			.userdeviceidtoken_index
223			.put_raw((user_id, device_id, prev.as_str()), []);
224	}
225
226	let value = (user_id, device_id, expires_at);
227	self.db
228		.token_userdeviceid
229		.raw_put(access_token, value);
230	self.db
231		.userdeviceidtoken_index
232		.put_raw((user_id, device_id, access_token), []);
233	self.db
234		.userdeviceid_token
235		.put_raw(userdeviceid, access_token);
236
237	Ok(())
238}
239
240/// Revoke every access token of one device, without deleting the device. Take
241/// care to not leave dangling devices if using this method.
242#[implement(super::Service)]
243pub async fn remove_access_token(&self, user_id: &UserId, device_id: &DeviceId) -> Result {
244	let prefix = (user_id, device_id, Interfix);
245	self.db
246		.userdeviceidtoken_index
247		.keys_prefix(&prefix)
248		.ignore_err()
249		.ready_for_each(|(_, _, token): (Ignore, Ignore, &str)| {
250			self.db.token_userdeviceid.remove(token);
251			self.db
252				.userdeviceidtoken_index
253				.del((user_id, device_id, token));
254		})
255		.await;
256
257	// Cover any pre-index token still recorded only in the legacy pointer.
258	if let Ok(token) = self
259		.db
260		.userdeviceid_token
261		.qry(&(user_id, device_id))
262		.await
263		.deserialized::<String>()
264	{
265		self.db.token_userdeviceid.remove(token.as_str());
266	}
267
268	self.db
269		.userdeviceid_token
270		.del((user_id, device_id));
271
272	Ok(())
273}
274
275/// Revoke a single access token by value, leaving the device and any other
276/// tokens it holds intact.
277#[implement(super::Service)]
278pub async fn remove_access_token_value(&self, access_token: &str) {
279	if let Ok((user_id, device_id, _)) = self
280		.db
281		.token_userdeviceid
282		.get(access_token)
283		.await
284		.deserialized::<(OwnedUserId, OwnedDeviceId, Option<u64>)>()
285	{
286		self.db
287			.userdeviceidtoken_index
288			.del((&*user_id, &*device_id, access_token));
289	}
290
291	self.db.token_userdeviceid.remove(access_token);
292}
293
294#[implement(super::Service)]
295pub fn generate_access_token(&self, expires: bool) -> (String, Option<Duration>) {
296	let access_token = utils::random_string(TOKEN_LENGTH);
297	let expires_in = expires
298		.then_some(self.services.server.config.access_token_ttl)
299		.map(Duration::from_secs);
300
301	(access_token, expires_in)
302}
303
304/// Replaces the refresh token of one device.
305#[implement(super::Service)]
306#[tracing::instrument(level = "debug", skip(self))]
307pub async fn set_refresh_token(
308	&self,
309	user_id: &UserId,
310	device_id: &DeviceId,
311	refresh_token: &str,
312) -> Result {
313	debug_assert!(refresh_token.starts_with("refresh_"), "refresh_token missing prefix");
314
315	let config = &self.services.server.config;
316	let ttl = config.refresh_token_ttl;
317	let idle_only = config.refresh_token_idle_only;
318
319	// Absolute mode carries the prior deadline forward instead of sliding it.
320	let prior_expires_at: Option<SystemTime> = (ttl != 0 && !idle_only)
321		.then_async(|| self.find_refresh_token_expires_at(user_id, device_id))
322		.await
323		.flatten();
324
325	// Capture the outgoing token before removal so it can be retained for one
326	// generation, making a later replay detectable.
327	let spent: Option<String> = self
328		.db
329		.userdeviceid_refresh
330		.qry(&(user_id, device_id))
331		.await
332		.deserialized()
333		.ok();
334
335	// Also drops the prior spent entry.
336	self.remove_refresh_token(user_id, device_id)
337		.await
338		.ok();
339
340	let expires_at = match (ttl, prior_expires_at) {
341		| (0, _) => None,
342		| (_, Some(prior)) => Some(prior),
343		| (ttl, None) => Some(timepoint_from_now(Duration::from_secs(ttl))?),
344	};
345
346	let expires_at_secs = expires_at
347		.map(duration_since_epoch)
348		.as_ref()
349		.map(Duration::as_secs);
350
351	let userdeviceid = (user_id, device_id);
352	let value = (user_id, device_id, expires_at_secs);
353	self.db
354		.token_userdeviceid
355		.raw_put(refresh_token, value);
356	self.db
357		.userdeviceid_refresh
358		.put_raw(userdeviceid, refresh_token);
359
360	// Retain the outgoing token as the device's spent token, pointing at its
361	// successor so a double-submit can be distinguished from a replay.
362	if let Some(spent) = spent {
363		let spent_at = duration_since_epoch(SystemTime::now()).as_secs();
364		let value = (user_id, device_id, refresh_token, spent_at);
365		self.db
366			.spentrefresh_userdeviceid
367			.raw_put(&*spent, value);
368		self.db
369			.userdeviceid_spentrefresh
370			.put_raw(userdeviceid, &*spent);
371	}
372
373	Ok(())
374}
375
376/// Look up the expiry stored alongside the current refresh token for this
377/// device, if one is recorded. Pre-rotation entries carry no expiry and
378/// return `None`.
379#[implement(super::Service)]
380async fn find_refresh_token_expires_at(
381	&self,
382	user_id: &UserId,
383	device_id: &DeviceId,
384) -> Option<SystemTime> {
385	let userdeviceid = (user_id, device_id);
386	let old_token: String = self
387		.db
388		.userdeviceid_refresh
389		.qry(&userdeviceid)
390		.await
391		.deserialized()
392		.ok()?;
393
394	let (_, _, expires_at_secs): (Ignore, Ignore, Option<u64>) = self
395		.db
396		.token_userdeviceid
397		.get(&old_token)
398		.await
399		.deserialized()
400		.ok()?;
401
402	expires_at_secs
403		.map(Duration::from_secs)
404		.map(timepoint_from_epoch)?
405		.ok()
406}
407
408/// Revoke the refresh token without deleting the device. Take care to not leave
409/// dangling devices if using this method.
410#[implement(super::Service)]
411pub async fn remove_refresh_token(&self, user_id: &UserId, device_id: &DeviceId) -> Result {
412	let userdeviceid = (user_id, device_id);
413
414	if let Ok(refresh_token) = self
415		.db
416		.userdeviceid_refresh
417		.qry(&userdeviceid)
418		.await
419	{
420		self.db.token_userdeviceid.remove(&refresh_token);
421	}
422
423	self.db.userdeviceid_refresh.del(userdeviceid);
424
425	self.forget_spent_refresh_token(user_id, device_id)
426		.await;
427
428	Ok(())
429}
430
431/// Drop the spent (previous-generation) refresh token retained for reuse
432/// detection, if any.
433#[implement(super::Service)]
434async fn forget_spent_refresh_token(&self, user_id: &UserId, device_id: &DeviceId) {
435	let userdeviceid = (user_id, device_id);
436
437	if let Ok(spent) = self
438		.db
439		.userdeviceid_spentrefresh
440		.qry(&userdeviceid)
441		.await
442	{
443		self.db.spentrefresh_userdeviceid.remove(&spent);
444	}
445
446	self.db
447		.userdeviceid_spentrefresh
448		.del(userdeviceid);
449}
450
451/// Classification of a refresh token presented for rotation at a token
452/// endpoint.
453pub enum RefreshToken {
454	/// The device's current refresh token; rotate it.
455	Current {
456		user_id: OwnedUserId,
457		device_id: OwnedDeviceId,
458		expires_at: Option<SystemTime>,
459	},
460
461	/// A spent (already-rotated) token retained for one generation. `grace` is
462	/// set when its successor is still current and it was spent within the
463	/// configured window, marking a benign double-submit rather than a replay;
464	/// `current` is the successor for which to re-issue an access token.
465	Replayed {
466		user_id: OwnedUserId,
467		device_id: OwnedDeviceId,
468		current: String,
469		grace: bool,
470	},
471
472	/// Not a recognised refresh token.
473	Unknown,
474}
475
476/// Classify a presented refresh token for the token-endpoint rotation path.
477#[implement(super::Service)]
478pub async fn classify_refresh_token(&self, presented: &str) -> RefreshToken {
479	// The current refresh token resolves and matches the device's active
480	// pointer (an access token resolves but will not match).
481	if let Ok((user_id, device_id, expires_at)) = self.find_from_token(presented).await {
482		let current: Option<String> = self
483			.db
484			.userdeviceid_refresh
485			.qry(&(&user_id, &device_id))
486			.await
487			.deserialized()
488			.ok();
489
490		if current.as_deref() == Some(presented) {
491			return RefreshToken::Current { user_id, device_id, expires_at };
492		}
493	}
494
495	// Otherwise it may be the one retained spent token: a benign double-submit
496	// inside the grace window, or a replay to be treated as a compromise.
497	let Ok((user_id, device_id, successor, spent_at)) = self
498		.db
499		.spentrefresh_userdeviceid
500		.get(presented)
501		.await
502		.deserialized::<(OwnedUserId, OwnedDeviceId, String, u64)>()
503	else {
504		return RefreshToken::Unknown;
505	};
506
507	let current: Option<String> = self
508		.db
509		.userdeviceid_refresh
510		.qry(&(&user_id, &device_id))
511		.await
512		.deserialized()
513		.ok();
514
515	let grace_window = self
516		.services
517		.server
518		.config
519		.refresh_token_reuse_grace;
520	let elapsed = duration_since_epoch(SystemTime::now())
521		.as_secs()
522		.saturating_sub(spent_at);
523
524	let grace = grace_window != 0
525		&& elapsed <= grace_window
526		&& current.as_deref() == Some(successor.as_str());
527
528	RefreshToken::Replayed {
529		user_id,
530		device_id,
531		current: successor,
532		grace,
533	}
534}
535
536#[must_use]
537pub fn generate_refresh_token() -> String {
538	format!("refresh_{}", utils::random_string(TOKEN_LENGTH))
539}
540
541#[implement(super::Service)]
542pub fn add_to_device_event(
543	&self,
544	sender: &UserId,
545	target_user_id: &UserId,
546	target_device_id: &DeviceId,
547	event_type: &str,
548	content: &serde_json::Value,
549) {
550	let count = self.services.globals.next_count();
551
552	let key = (target_user_id, target_device_id, *count);
553	self.db.todeviceid_events.put(
554		key,
555		Json(json!({
556			"type": event_type,
557			"sender": sender,
558			"content": content,
559		})),
560	);
561
562	trace!(
563		%target_user_id,
564		%target_device_id,
565		count = *count,
566		%event_type,
567		%sender,
568		"to_device write",
569	);
570}
571
572#[implement(super::Service)]
573pub fn get_to_device_events<'a>(
574	&'a self,
575	user_id: &'a UserId,
576	device_id: &'a DeviceId,
577	since: Option<u64>,
578	to: Option<u64>,
579) -> impl Stream<Item = (u64, Raw<AnyToDeviceEvent>)> + Send + 'a {
580	type Key<'a> = (&'a UserId, &'a DeviceId, u64);
581
582	let from = (user_id, device_id, since.map_or(0, |since| since.saturating_add(1)));
583
584	self.db
585		.todeviceid_events
586		.stream_from(&from)
587		.ignore_err()
588		.ready_take_while(move |((user_id_, device_id_, count), _): &(Key<'_>, _)| {
589			user_id == *user_id_ && device_id == *device_id_ && to.is_none_or(|to| *count <= to)
590		})
591		.map(|((_, _, count), event)| (count, event))
592}
593
594#[implement(super::Service)]
595pub async fn remove_to_device_events<Until>(
596	&self,
597	user_id: &UserId,
598	device_id: &DeviceId,
599	until: Until,
600) where
601	Until: Into<Option<u64>> + Send,
602{
603	type Key<'a> = (&'a UserId, &'a DeviceId, u64);
604
605	let until = until.into().unwrap_or(u64::MAX);
606	let from = (user_id, device_id, until);
607	self.db
608		.todeviceid_events
609		.rev_keys_from(&from)
610		.ignore_err()
611		.ready_take_while(move |(user_id_, device_id_, _): &Key<'_>| {
612			user_id == *user_id_ && device_id == *device_id_
613		})
614		.ready_for_each(|key: Key<'_>| {
615			self.db.todeviceid_events.del(key);
616		})
617		.await;
618}
619
620#[implement(super::Service)]
621pub async fn update_device_last_seen(
622	&self,
623	user_id: &UserId,
624	device_id: &DeviceId,
625	last_seen_ip: Option<IpAddr>,
626	last_seen_ts: Option<MilliSecondsSinceUnixEpoch>,
627) -> Result {
628	let mut device = self
629		.get_device_metadata(user_id, device_id)
630		.await?;
631
632	if let Some(last_seen_ip) = last_seen_ip.map(to_small_string) {
633		device.last_seen_ip.replace(last_seen_ip);
634	}
635
636	device
637		.last_seen_ts
638		.replace(last_seen_ts.unwrap_or_else(MilliSecondsSinceUnixEpoch::now));
639
640	self.put_device_metadata(user_id, false, &device);
641
642	Ok(())
643}
644
645#[implement(super::Service)]
646pub fn put_device_metadata(&self, user_id: &UserId, notify: bool, device: &Device) {
647	let key = (user_id, &device.device_id);
648	self.db
649		.userdeviceid_metadata
650		.put(key, Json(device));
651
652	if notify {
653		increment(&self.db.userid_devicelistversion, user_id.as_bytes());
654	}
655}
656
657/// Get device metadata.
658#[implement(super::Service)]
659pub async fn get_device_metadata(
660	&self,
661	user_id: &UserId,
662	device_id: &DeviceId,
663) -> Result<Device> {
664	self.db
665		.userdeviceid_metadata
666		.qry(&(user_id, device_id))
667		.await
668		.deserialized()
669		.inspect(|device: &Device| {
670			debug_assert_eq!(&device.device_id, device_id, "device_id mismatch");
671		})
672}
673
674#[implement(super::Service)]
675pub async fn device_exists(&self, user_id: &UserId, device_id: &DeviceId) -> bool {
676	self.db
677		.userdeviceid_metadata
678		.contains(&(user_id, device_id))
679		.await
680}
681
682#[implement(super::Service)]
683pub async fn is_oidc_device(&self, user_id: &UserId, device_id: &DeviceId) -> bool {
684	self.db
685		.oidcdevice_userdeviceid
686		.contains(&(user_id, device_id))
687		.await
688}
689
690/// Returns the IdP that originally authenticated this device, if known.
691/// Returns `None` for devices predating the idp_id field or non-OIDC devices.
692#[implement(super::Service)]
693pub async fn get_oidc_device_idp(
694	&self,
695	user_id: &UserId,
696	device_id: &DeviceId,
697) -> Option<String> {
698	self.db
699		.oidcdevice_userdeviceid
700		.qry(&(user_id, device_id))
701		.await
702		.deserialized::<Json<String>>()
703		.ok()
704		.map(|Json(idp)| idp)
705		.filter(|idp| !idp.is_empty())
706}
707
708#[implement(super::Service)]
709pub fn mark_oidc_device(&self, user_id: &UserId, device_id: &DeviceId, idp_id: &str) {
710	self.db
711		.oidcdevice_userdeviceid
712		.put((user_id, device_id), Json(idp_id));
713}
714
715/// Allow cross-signing key replacement without UIAA for the next 10 minutes.
716/// Returns the expiry timestamp in milliseconds.
717#[allow(clippy::must_use_candidate)]
718#[implement(super::Service)]
719pub fn allow_cross_signing_replacement(&self, user_id: &UserId) -> SystemTime {
720	let duration = Duration::from_mins(10);
721	let expires = timepoint_from_now(duration).expect("failed to create timepoint from now");
722
723	self.db
724		.oidccskeybypass_userid
725		.raw_put(user_id, Cbor(expires));
726
727	expires
728}
729
730/// Check if the user is allowed to replace cross-signing keys without UIAA.
731#[implement(super::Service)]
732pub async fn can_replace_cross_signing_keys(&self, user_id: &UserId) -> bool {
733	let Ok(expires): Result<SystemTime, _> = self
734		.db
735		.oidccskeybypass_userid
736		.get(user_id)
737		.await
738		.deserialized::<Cbor<_>>()
739		.map(at!(0))
740	else {
741		return false;
742	};
743
744	if !timepoint_has_passed(expires) {
745		return true;
746	}
747
748	self.db.oidccskeybypass_userid.remove(user_id);
749	false
750}
751
752#[implement(super::Service)]
753pub async fn get_devicelist_version(&self, user_id: &UserId) -> Result<u64> {
754	self.db
755		.userid_devicelistversion
756		.get(user_id)
757		.await
758		.deserialized()
759}
760
761#[implement(super::Service)]
762pub fn all_devices_metadata<'a>(
763	&'a self,
764	user_id: &'a UserId,
765) -> impl Stream<Item = Device> + Send + 'a {
766	let key = (user_id, Interfix);
767	self.db
768		.userdeviceid_metadata
769		.stream_prefix(&key)
770		.ignore_err()
771		.map(|(_, val): (Ignore, Device)| val)
772}
773
774//TODO: this is an ABA
775fn increment(db: &Arc<Map>, key: &[u8]) {
776	let old = db.get_blocking(key);
777	let new = utils::increment(old.ok().as_deref());
778	db.insert(key, new);
779}