Skip to main content

tuwunel_service/users/
mod.rs

1mod dehydrated_device;
2pub mod device;
3mod keys;
4mod ldap;
5mod register;
6
7use std::sync::Arc;
8
9use futures::{Stream, StreamExt, TryFutureExt};
10use ruma::{
11	MilliSecondsSinceUnixEpoch, OwnedUserId, UserId,
12	api::client::filter::FilterDefinition,
13	events::{
14		GlobalAccountDataEventType,
15		ignored_user_list::IgnoredUserListEvent,
16		invite_permission_config::{InvitePermissionAction, InvitePermissionConfigEvent},
17	},
18};
19use serde::{Deserialize, Serialize};
20use tuwunel_core::{
21	Err, Result, debug_warn, err, is_equal_to, trace,
22	utils::{self, ReadyExt, stream::TryIgnore},
23};
24use tuwunel_database::{Deserialized, Json, Map};
25
26pub use self::{keys::parse_master_key, register::Register};
27
28pub const PASSWORD_SENTINEL: &str = "*";
29pub const PASSWORD_DISABLED: &str = "";
30
31/// Forensic record for a moderation action (MSC3823 suspend, MSC3939 lock).
32/// Presence of the row is the load-bearing fact; this body is written but
33/// never read on the hot path.
34#[derive(Clone, Debug, Serialize, Deserialize)]
35pub struct Moderation {
36	pub when: MilliSecondsSinceUnixEpoch,
37	pub by: OwnedUserId,
38}
39
40pub struct Service {
41	services: Arc<crate::services::OnceServices>,
42	db: Data,
43}
44
45struct Data {
46	keychangeid_userid: Arc<Map>,
47	keyid_key: Arc<Map>,
48	onetimekeyid4225_otk: Option<Arc<Map>>,
49	openidtoken_expiresatuserid: Arc<Map>,
50	logintoken_expiresatuserid: Arc<Map>,
51	todeviceid_events: Arc<Map>,
52	spentrefresh_userdeviceid: Arc<Map>,
53	token_userdeviceid: Arc<Map>,
54	userdeviceid_metadata: Arc<Map>,
55	userdeviceid_token: Arc<Map>,
56	userdeviceidtoken_index: Arc<Map>,
57	userdeviceid_refresh: Arc<Map>,
58	userdeviceid_spentrefresh: Arc<Map>,
59	userdeviceidalgorithm_fallback: Arc<Map>,
60	oidcdevice_userdeviceid: Arc<Map>,
61	oidccskeybypass_userid: Arc<Map>,
62	userfilterid_filter: Arc<Map>,
63	userid_dehydrateddevice: Arc<Map>,
64	userid_devicelistversion: Arc<Map>,
65	userid_lastonetimekeyupdate: Arc<Map>,
66	userid_locked: Arc<Map>,
67	userid_masterkeyid: Arc<Map>,
68	userid_password: Arc<Map>,
69	userid_origin: Arc<Map>,
70	userid_selfsigningkeyid: Arc<Map>,
71	userid_suspended: Arc<Map>,
72	userid_usersigningkeyid: Arc<Map>,
73}
74
75impl crate::Service for Service {
76	fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
77		Ok(Arc::new(Self {
78			services: args.services.clone(),
79			db: Data {
80				keychangeid_userid: args.db["keychangeid_userid"].clone(),
81				keyid_key: args.db["keyid_key"].clone(),
82				onetimekeyid4225_otk: args.db.get("onetimekeyid4225_otk").ok().cloned(),
83				openidtoken_expiresatuserid: args.db["openidtoken_expiresatuserid"].clone(),
84				logintoken_expiresatuserid: args.db["logintoken_expiresatuserid"].clone(),
85				oidcdevice_userdeviceid: args.db["oidcdevice_userdeviceid"].clone(),
86				oidccskeybypass_userid: args.db["oidccskeybypass_userid"].clone(),
87				todeviceid_events: args.db["todeviceid_events"].clone(),
88				spentrefresh_userdeviceid: args.db["spentrefresh_userdeviceid"].clone(),
89				token_userdeviceid: args.db["token_userdeviceid"].clone(),
90				userdeviceid_metadata: args.db["userdeviceid_metadata"].clone(),
91				userdeviceid_token: args.db["userdeviceid_token"].clone(),
92				userdeviceidtoken_index: args.db["userdeviceidtoken_index"].clone(),
93				userdeviceid_refresh: args.db["userdeviceid_refresh"].clone(),
94				userdeviceid_spentrefresh: args.db["userdeviceid_spentrefresh"].clone(),
95				userdeviceidalgorithm_fallback: args.db["userdeviceidalgorithm_fallback"].clone(),
96				userfilterid_filter: args.db["userfilterid_filter"].clone(),
97				userid_dehydrateddevice: args.db["userid_dehydrateddevice"].clone(),
98				userid_devicelistversion: args.db["userid_devicelistversion"].clone(),
99				userid_lastonetimekeyupdate: args.db["userid_lastonetimekeyupdate"].clone(),
100				userid_locked: args.db["userid_locked"].clone(),
101				userid_masterkeyid: args.db["userid_masterkeyid"].clone(),
102				userid_password: args.db["userid_password"].clone(),
103				userid_origin: args.db["userid_origin"].clone(),
104				userid_selfsigningkeyid: args.db["userid_selfsigningkeyid"].clone(),
105				userid_suspended: args.db["userid_suspended"].clone(),
106				userid_usersigningkeyid: args.db["userid_usersigningkeyid"].clone(),
107			},
108		}))
109	}
110
111	fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
112}
113
114impl Service {
115	/// Returns true/false based on whether the recipient/receiving user has
116	/// blocked the sender
117	pub async fn user_is_ignored(&self, sender_user: &UserId, recipient_user: &UserId) -> bool {
118		self.services
119			.account_data
120			.get_global(recipient_user, GlobalAccountDataEventType::IgnoredUserList)
121			.await
122			.is_ok_and(|ignored: IgnoredUserListEvent| {
123				ignored
124					.content
125					.ignored_users
126					.keys()
127					.any(|blocked_user| blocked_user == sender_user)
128			})
129	}
130
131	/// MSC4380: `m.invite_permission_config.default_action == "block"`.
132	pub async fn invites_blocked(&self, user_id: &UserId) -> bool {
133		self.services
134			.account_data
135			.get_global(user_id, GlobalAccountDataEventType::InvitePermissionConfig)
136			.await
137			.is_ok_and(|event: InvitePermissionConfigEvent| {
138				matches!(event.content.default_action, Some(InvitePermissionAction::Block))
139			})
140	}
141
142	/// Create a new user account on this homeserver.
143	///
144	/// User origin is by default "password" (meaning that it will login using
145	/// its user_id/password). Users with other origins (currently only "ldap"
146	/// is available) have special login processes.
147	#[inline]
148	pub async fn create(
149		&self,
150		user_id: &UserId,
151		password: Option<&str>,
152		origin: Option<&str>,
153	) -> Result {
154		let origin = origin.unwrap_or("password");
155		self.db.userid_origin.insert(user_id, origin);
156		self.set_password(user_id, password).await
157	}
158
159	/// Deactivate account
160	pub async fn deactivate_account(&self, user_id: &UserId) -> Result {
161		// Revoke any SSO authorizations
162		self.services
163			.oauth
164			.revoke_user_tokens(user_id)
165			.await;
166
167		// Remove all associated devices
168		self.all_device_ids(user_id)
169			.for_each(|device_id| self.remove_device(user_id, device_id))
170			.await;
171
172		// Set the password to "" to indicate a deactivated account. Hashes will never
173		// result in an empty string, so the user will not be able to log in again.
174		// Systems like changing the password without logging in should check if the
175		// account is deactivated.
176		self.set_password(user_id, None).await?;
177
178		// TODO: Unhook 3PID
179		Ok(())
180	}
181
182	/// Check if a user has an account on this homeserver.
183	#[inline]
184	pub async fn exists(&self, user_id: &UserId) -> bool {
185		self.db.userid_password.get(user_id).await.is_ok()
186	}
187
188	/// Check if account is deactivated
189	pub async fn is_deactivated(&self, user_id: &UserId) -> Result<bool> {
190		self.db
191			.userid_password
192			.get(user_id)
193			.map_ok(|val| val.is_empty())
194			.map_err(|_| err!(Request(NotFound("User does not exist."))))
195			.await
196	}
197
198	/// Check if account is active, infallible
199	pub async fn is_active(&self, user_id: &UserId) -> bool {
200		!self.is_deactivated(user_id).await.unwrap_or(true)
201	}
202
203	/// Check if account is active, infallible
204	pub async fn is_active_local(&self, user_id: &UserId) -> bool {
205		self.services.globals.user_is_local(user_id) && self.is_active(user_id).await
206	}
207
208	/// MSC3823: account is suspended (read-mostly mode, sessions retained).
209	pub async fn is_suspended(&self, user_id: &UserId) -> bool {
210		self.db
211			.userid_suspended
212			.get(user_id)
213			.await
214			.is_ok()
215	}
216
217	/// MSC3939: account is locked (401 + soft_logout, sessions retained).
218	pub async fn is_locked(&self, user_id: &UserId) -> bool {
219		self.db.userid_locked.get(user_id).await.is_ok()
220	}
221
222	/// MSC3823: forensic record for the active suspension, if any.
223	pub async fn get_suspension(&self, user_id: &UserId) -> Option<Moderation> {
224		self.db
225			.userid_suspended
226			.get(user_id)
227			.await
228			.deserialized::<Json<_>>()
229			.map(|Json(m)| m)
230			.ok()
231	}
232
233	/// MSC3939: forensic record for the active lock, if any.
234	pub async fn get_lock(&self, user_id: &UserId) -> Option<Moderation> {
235		self.db
236			.userid_locked
237			.get(user_id)
238			.await
239			.deserialized::<Json<_>>()
240			.map(|Json(m)| m)
241			.ok()
242	}
243
244	pub fn set_suspended(&self, user_id: &UserId, by: &UserId) {
245		let entry = Moderation {
246			when: MilliSecondsSinceUnixEpoch::now(),
247			by: by.to_owned(),
248		};
249
250		self.db
251			.userid_suspended
252			.raw_put(user_id, Json(entry));
253	}
254
255	pub fn clear_suspended(&self, user_id: &UserId) { self.db.userid_suspended.remove(user_id); }
256
257	pub fn set_locked(&self, user_id: &UserId, by: &UserId) {
258		let entry = Moderation {
259			when: MilliSecondsSinceUnixEpoch::now(),
260			by: by.to_owned(),
261		};
262
263		self.db
264			.userid_locked
265			.raw_put(user_id, Json(entry));
266	}
267
268	pub fn clear_locked(&self, user_id: &UserId) { self.db.userid_locked.remove(user_id); }
269
270	/// Returns the number of users registered on this server.
271	#[inline]
272	pub async fn count(&self) -> usize { self.db.userid_password.count().await }
273
274	/// Returns an iterator over all users on this homeserver.
275	pub fn stream(&self) -> impl Stream<Item = &UserId> + Send {
276		self.db.userid_password.keys().ignore_err()
277	}
278
279	/// Returns a list of local users as list of usernames.
280	///
281	/// A user account is considered `local` if the length of it's password is
282	/// greater then zero.
283	pub fn list_local_users(&self) -> impl Stream<Item = &UserId> + Send + '_ {
284		self.db
285			.userid_password
286			.stream()
287			.ignore_err()
288			.ready_filter_map(|(u, p): (&UserId, &[u8])| (!p.is_empty()).then_some(u))
289	}
290
291	/// Returns the origin of the user (password/LDAP/...).
292	pub async fn origin(&self, user_id: &UserId) -> Result<String> {
293		self.db
294			.userid_origin
295			.get(user_id)
296			.await
297			.deserialized()
298	}
299
300	/// Returns whether the user has a password. Disabled accounts and
301	/// registrations setting a sentinel password will return false here.
302	pub async fn has_password(&self, user_id: &UserId) -> Result<bool> {
303		self.password_hash(user_id)
304			.map_ok(|value| value != PASSWORD_DISABLED && value != PASSWORD_SENTINEL)
305			.await
306	}
307
308	/// Returns the password hash for the given user.
309	pub async fn password_hash(&self, user_id: &UserId) -> Result<String> {
310		self.db
311			.userid_password
312			.get(user_id)
313			.await
314			.deserialized()
315	}
316
317	/// Hash and set the user's password to the Argon2 hash
318	pub async fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result {
319		// Cannot change the password of a LDAP user. There are two special cases :
320		// - a `None` password can be used to deactivate a LDAP user
321		// - a "*" password is used as the default password of an active LDAP user
322		//
323		// The above now applies to all non-password origin users by default unless an
324		// exception is made for that origin in the condition below. Note that users
325		// with no origin are also password-origin users.
326		let allowed_origins = ["password", "sso"];
327		if password.is_some() && password != Some(PASSWORD_SENTINEL) {
328			let origin = self.origin(user_id).await;
329			let origin = origin.as_deref().unwrap_or("password");
330
331			if !allowed_origins.iter().any(is_equal_to!(&origin)) {
332				return Err!(Request(InvalidParam(
333					"Cannot change password of an {origin:?} user."
334				)));
335			}
336		}
337
338		match password.map(utils::hash::password) {
339			| None => {
340				self.db
341					.userid_password
342					.insert(user_id, PASSWORD_DISABLED);
343			},
344			| Some(Ok(_)) if password == Some(PASSWORD_SENTINEL) => {
345				self.db
346					.userid_password
347					.insert(user_id, PASSWORD_SENTINEL);
348			},
349			| Some(Ok(hash)) => {
350				self.db.userid_password.insert(user_id, hash);
351				self.db.userid_origin.insert(user_id, "password");
352			},
353			| Some(Err(e)) => {
354				return Err!(Request(InvalidParam(
355					"Password does not meet the requirements: {e}"
356				)));
357			},
358		}
359
360		Ok(())
361	}
362
363	/// Creates a new sync filter. Returns the filter id.
364	#[must_use]
365	pub fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> String {
366		let filter_id = utils::random_string(4);
367
368		let key = (user_id, &filter_id);
369		self.db.userfilterid_filter.put(key, Json(filter));
370
371		filter_id
372	}
373
374	pub async fn get_filter(
375		&self,
376		user_id: &UserId,
377		filter_id: &str,
378	) -> Result<FilterDefinition> {
379		let key = (user_id, filter_id);
380		self.db
381			.userfilterid_filter
382			.qry(&key)
383			.await
384			.deserialized()
385	}
386
387	/// Creates an OpenID token, which can be used to prove that a user has
388	/// access to an account (primarily for integrations)
389	pub fn create_openid_token(&self, user_id: &UserId, token: &str) -> Result<u64> {
390		use std::num::Saturating as Sat;
391
392		let expires_in = self.services.server.config.openid_token_ttl;
393		let expires_at = Sat(utils::millis_since_unix_epoch()) + Sat(expires_in) * Sat(1000);
394
395		let mut value = expires_at.0.to_be_bytes().to_vec();
396		value.extend_from_slice(user_id.as_bytes());
397
398		self.db
399			.openidtoken_expiresatuserid
400			.insert(token.as_bytes(), value.as_slice());
401
402		Ok(expires_in)
403	}
404
405	/// Find out which user an OpenID access token belongs to.
406	pub async fn find_from_openid_token(&self, token: &str) -> Result<OwnedUserId> {
407		let Ok(value) = self
408			.db
409			.openidtoken_expiresatuserid
410			.get(token)
411			.await
412		else {
413			return Err!(Request(Unauthorized("OpenID token is unrecognised")));
414		};
415
416		let (expires_at_bytes, user_bytes) = value.split_at(0_u64.to_be_bytes().len());
417		let expires_at =
418			u64::from_be_bytes(expires_at_bytes.try_into().map_err(|e| {
419				err!(Database("expires_at in openid_userid is invalid u64. {e}"))
420			})?);
421
422		if expires_at < utils::millis_since_unix_epoch() {
423			debug_warn!("OpenID token is expired, removing");
424			self.db
425				.openidtoken_expiresatuserid
426				.remove(token.as_bytes());
427
428			return Err!(Request(Unauthorized("OpenID token is expired")));
429		}
430
431		let user_string = utils::string_from_bytes(user_bytes)
432			.map_err(|e| err!(Database("User ID in openid_userid is invalid unicode. {e}")))?;
433
434		OwnedUserId::try_from(user_string)
435			.map_err(|e| err!(Database("User ID in openid_userid is invalid. {e}")))
436	}
437
438	/// Creates a short-lived login token, which can be used to log in using the
439	/// `m.login.token` mechanism.
440	#[must_use]
441	pub fn create_login_token(&self, user_id: &UserId, token: &str) -> u64 {
442		use std::num::Saturating as Sat;
443
444		let expires_in = self.services.server.config.login_token_ttl;
445		let expires_at = Sat(utils::millis_since_unix_epoch()) + Sat(expires_in);
446
447		let value = (expires_at.0, user_id);
448		self.db
449			.logintoken_expiresatuserid
450			.raw_put(token, value);
451
452		expires_in
453	}
454
455	/// Verify a login token is valid and return its owner without consuming it.
456	/// Unlike `find_from_login_token`, the token remains in the database
457	/// after this call and can still be consumed later.
458	pub async fn peek_login_token(&self, token: &str) -> Result<OwnedUserId> {
459		let Ok(value) = self
460			.db
461			.logintoken_expiresatuserid
462			.get(token)
463			.await
464		else {
465			return Err!(Request(Forbidden("Login token is unrecognised")));
466		};
467		let (expires_at, user_id): (u64, OwnedUserId) = value.deserialized()?;
468
469		if expires_at < utils::millis_since_unix_epoch() {
470			trace!(?user_id, ?token, "Removing expired login token");
471			self.db.logintoken_expiresatuserid.remove(token);
472			return Err!(Request(Forbidden("Login token is expired")));
473		}
474
475		Ok(user_id)
476	}
477
478	/// Find out which user a login token belongs to.
479	/// Removes the token to prevent double-use attacks.
480	pub async fn find_from_login_token(&self, token: &str) -> Result<OwnedUserId> {
481		let Ok(value) = self
482			.db
483			.logintoken_expiresatuserid
484			.get(token)
485			.await
486		else {
487			return Err!(Request(Forbidden("Login token is unrecognised")));
488		};
489		let (expires_at, user_id): (u64, OwnedUserId) = value.deserialized()?;
490
491		if expires_at < utils::millis_since_unix_epoch() {
492			trace!(?user_id, ?token, "Removing expired login token");
493
494			self.db.logintoken_expiresatuserid.remove(token);
495
496			return Err!(Request(Forbidden("Login token is expired")));
497		}
498
499		self.db.logintoken_expiresatuserid.remove(token);
500
501		Ok(user_id)
502	}
503
504	#[cfg(not(feature = "ldap"))]
505	#[expect(clippy::unused_async)]
506	pub async fn search_ldap(&self, _user_id: &UserId) -> Result<Vec<(String, bool)>> {
507		Err!(FeatureDisabled("ldap"))
508	}
509
510	#[cfg(not(feature = "ldap"))]
511	#[expect(clippy::unused_async)]
512	pub async fn auth_ldap(&self, _user_dn: &str, _password: &str) -> Result {
513		Err!(FeatureDisabled("ldap"))
514	}
515}