Skip to main content

tuwunel_service/uiaa/
mod.rs

1use std::{
2	collections::BTreeMap,
3	ops::ControlFlow,
4	sync::{Arc, RwLock},
5};
6
7use futures::{TryStreamExt, pin_mut};
8use ruma::{
9	CanonicalJsonValue, DeviceId, OwnedDeviceId, OwnedUserId, UserId,
10	api::{
11		client::uiaa::{
12			AuthData, AuthType, EmailIdentity, Password, ThirdpartyIdCredentials, UiaaInfo,
13			UserIdentifier,
14		},
15		error::{ErrorKind, StandardErrorBody},
16	},
17};
18use tuwunel_core::{
19	Err, Result, err, error, extract, implement,
20	utils::{self, BoolExt, hash, string::EMPTY},
21};
22use tuwunel_database::{Deserialized, Json, Map};
23
24use crate::users::PASSWORD_SENTINEL;
25
26pub struct Service {
27	userdevicesessionid_uiaarequest: RwLock<RequestMap>,
28	userdevicesessionid_threepid: RwLock<ThreepidMap>,
29	db: Data,
30	services: Arc<crate::services::OnceServices>,
31}
32
33struct Data {
34	userdevicesessionid_uiaainfo: Arc<Map>,
35}
36
37type RequestMap = BTreeMap<RequestKey, CanonicalJsonValue>;
38type ThreepidMap = BTreeMap<RequestKey, ThirdpartyIdCredentials>;
39type RequestKey = (OwnedUserId, OwnedDeviceId, String);
40
41pub const SESSION_ID_LENGTH: usize = 32;
42
43impl crate::Service for Service {
44	fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
45		Ok(Arc::new(Self {
46			userdevicesessionid_uiaarequest: RwLock::new(RequestMap::new()),
47			userdevicesessionid_threepid: RwLock::new(ThreepidMap::new()),
48			db: Data {
49				userdevicesessionid_uiaainfo: args.db["userdevicesessionid_uiaainfo"].clone(),
50			},
51			services: args.services.clone(),
52		}))
53	}
54
55	fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
56}
57
58/// Creates a new Uiaa session. Make sure the session token is unique.
59#[implement(Service)]
60pub fn create(
61	&self,
62	user_id: &UserId,
63	device_id: &DeviceId,
64	uiaainfo: &UiaaInfo,
65	json_body: &CanonicalJsonValue,
66) {
67	// TODO: better session error handling (why is uiaainfo.session optional in
68	// ruma?)
69	let session = uiaainfo
70		.session
71		.as_ref()
72		.expect("session should be set");
73
74	self.set_uiaa_request(user_id, device_id, session, json_body);
75
76	self.update_uiaa_session(user_id, device_id, session, Some(uiaainfo));
77}
78
79#[implement(Service)]
80pub async fn try_auth(
81	&self,
82	user_id: &UserId,
83	device_id: &DeviceId,
84	auth: &AuthData,
85	uiaainfo: &UiaaInfo,
86) -> Result<(bool, UiaaInfo)> {
87	let mut uiaainfo = if let Some(session) = auth.session() {
88		self.get_uiaa_session(user_id, device_id, session)
89			.await?
90	} else {
91		uiaainfo.clone()
92	};
93
94	if uiaainfo.session.is_none() {
95		uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
96	}
97
98	match auth {
99		// Find out what the user completed
100		| AuthData::Password(password) => {
101			if let ControlFlow::Break(authed) = self
102				.verify_password(user_id, &mut uiaainfo, password)
103				.await?
104			{
105				return Ok((authed, uiaainfo));
106			}
107		},
108		| AuthData::RegistrationToken(t) => {
109			let token = t.token.trim();
110			if self
111				.services
112				.registration_tokens
113				.try_consume(token)
114				.await
115				.is_ok()
116			{
117				uiaainfo
118					.completed
119					.push(AuthType::RegistrationToken);
120			} else {
121				uiaainfo.auth_error = Some(StandardErrorBody {
122					kind: ErrorKind::forbidden(),
123					message: "Invalid registration token.".to_owned(),
124				});
125
126				return Ok((false, uiaainfo));
127			}
128		},
129		| AuthData::FallbackAcknowledgement(_session) => {
130			// A fallback acknowledgement is a session re-poll. The fallback
131			// web handler (e.g. the SSO callback) is what records completion.
132		},
133		| AuthData::OAuth(_) => {
134			// MSC4312: OAuth cross-signing reset uses SSO re-authentication.
135			// If a bypass was granted via SSO re-auth, mark OAuth as completed.
136			if !uiaainfo.completed.contains(&AuthType::OAuth) {
137				if self
138					.services
139					.users
140					.can_replace_cross_signing_keys(user_id)
141					.await
142				{
143					uiaainfo.completed.push(AuthType::OAuth);
144				} else {
145					uiaainfo.auth_error = Some(StandardErrorBody {
146						kind: ErrorKind::forbidden(),
147						message: "OAuth cross-signing reset not approved for this session."
148							.to_owned(),
149					});
150
151					return Ok((false, uiaainfo));
152				}
153			}
154		},
155		| AuthData::Dummy(_) => {
156			uiaainfo.completed.push(AuthType::Dummy);
157		},
158		| AuthData::Terms(_) => {
159			// MSC1692: an empty auth dict accepts every presented policy.
160			uiaainfo.completed.push(AuthType::Terms);
161		},
162		| AuthData::EmailIdentity(EmailIdentity { thirdparty_id_creds, .. }) => {
163			// A stray id_server is tolerated and id_access_token is never required.
164			let validated = self
165				.services
166				.threepid
167				.session_validated(
168					thirdparty_id_creds.sid.as_str(),
169					thirdparty_id_creds.client_secret.as_str(),
170				)
171				.await;
172
173			if !validated {
174				uiaainfo.auth_error = Some(StandardErrorBody {
175					kind: ErrorKind::forbidden(),
176					message: "Email address has not been validated.".to_owned(),
177				});
178
179				return Ok((false, uiaainfo));
180			}
181
182			uiaainfo.completed.push(AuthType::EmailIdentity);
183
184			// Retain the validated credentials so a later stage can still bind the email
185			// (MSC2263).
186			if let Some(session) = uiaainfo.session.as_deref() {
187				self.set_uiaa_threepid(user_id, device_id, session, thirdparty_id_creds);
188			}
189		},
190		| auth => error!("AuthData type not supported: {auth:?}"),
191	}
192
193	// Check if a flow now succeeds
194	let mut completed = false;
195	'flows: for flow in &mut uiaainfo.flows {
196		for stage in &flow.stages {
197			if !uiaainfo.completed.contains(stage) {
198				continue 'flows;
199			}
200		}
201		// We didn't break, so this flow succeeded!
202		completed = true;
203	}
204
205	let session = uiaainfo
206		.session
207		.as_ref()
208		.expect("session is always set");
209
210	if !completed {
211		self.update_uiaa_session(user_id, device_id, session, Some(&uiaainfo));
212
213		return Ok((false, uiaainfo));
214	}
215
216	// UIAA was successful! Remove this session and return true
217	self.update_uiaa_session(user_id, device_id, session, None);
218
219	Ok((true, uiaainfo))
220}
221
222#[implement(Service)]
223#[allow(clippy::useless_let_if_seq)]
224async fn verify_password(
225	&self,
226	user_id: &UserId,
227	uiaainfo: &mut UiaaInfo,
228	password: &Password,
229) -> Result<ControlFlow<bool>> {
230	let Password { identifier, password, user, .. } = password;
231
232	let username = extract!(identifier, x in Some(UserIdentifier::Matrix(ruma::api::client::uiaa::MatrixUserIdentifier { user: x, .. })))
233		.or_else(|| cfg!(feature = "element_hacks").and(user.as_ref()))
234		.ok_or(err!(Request(Unrecognized("Identifier type not recognized."))))?;
235
236	let user_id_from_username =
237		UserId::parse_with_server_name(username.clone(), self.services.globals.server_name())
238			.map_err(|_| err!(Request(InvalidParam("User ID is invalid."))))?;
239
240	// Check if the access token being used matches the credentials used for UIAA
241	if user_id.localpart() != user_id_from_username.localpart() {
242		return Err!(Request(Forbidden("User ID and access token mismatch.")));
243	}
244
245	let user_id = user_id_from_username;
246	let mut password_verified = false;
247	let mut password_sentinel = false;
248
249	// First try local password hash verification
250	if let Ok(hash) = self.services.users.password_hash(&user_id).await {
251		password_sentinel = hash == PASSWORD_SENTINEL;
252		password_verified = hash::verify_password(password, &hash).is_ok();
253	}
254
255	// Only LDAP-origin accounts fall back to LDAP; others would trigger a
256	// directory-wide search.
257	#[cfg(feature = "ldap")]
258	if !password_verified
259		&& self.services.server.config.ldap.enable
260		&& self
261			.services
262			.users
263			.origin(&user_id)
264			.await
265			.is_ok_and(|origin| origin == "ldap")
266		&& let Ok(dns) = self.services.users.search_ldap(&user_id).await
267		&& let Some((user_dn, _is_admin)) = dns.first()
268	{
269		password_verified = self
270			.services
271			.users
272			.auth_ldap(user_dn, password)
273			.await
274			.is_ok();
275	}
276
277	// For SSO users that have never set a password, allow.
278	if !password_verified
279		&& password_sentinel
280		&& self
281			.services
282			.oauth
283			.sessions
284			.exists_for_user(&user_id)
285			.await
286	{
287		return Ok(ControlFlow::Break(true));
288	}
289
290	if !password_verified {
291		uiaainfo.auth_error = Some(StandardErrorBody {
292			kind: ErrorKind::forbidden(),
293			message: "Invalid username or password.".to_owned(),
294		});
295
296		return Ok(ControlFlow::Break(false));
297	}
298
299	uiaainfo.completed.push(AuthType::Password);
300
301	Ok(ControlFlow::Continue(()))
302}
303
304#[implement(Service)]
305fn set_uiaa_request(
306	&self,
307	user_id: &UserId,
308	device_id: &DeviceId,
309	session: &str,
310	request: &CanonicalJsonValue,
311) {
312	let key = (user_id.to_owned(), device_id.to_owned(), session.to_owned());
313
314	self.userdevicesessionid_uiaarequest
315		.write()
316		.expect("locked for writing")
317		.insert(key, request.to_owned());
318}
319
320#[implement(Service)]
321pub fn get_uiaa_request(
322	&self,
323	user_id: &UserId,
324	device_id: Option<&DeviceId>,
325	session: &str,
326) -> Option<CanonicalJsonValue> {
327	let device_id = device_id.unwrap_or_else(|| EMPTY.into());
328	let key = (user_id.to_owned(), device_id.to_owned(), session.to_owned());
329
330	self.userdevicesessionid_uiaarequest
331		.read()
332		.expect("locked for reading")
333		.get(&key)
334		.cloned()
335}
336
337#[implement(Service)]
338fn set_uiaa_threepid(
339	&self,
340	user_id: &UserId,
341	device_id: &DeviceId,
342	session: &str,
343	creds: &ThirdpartyIdCredentials,
344) {
345	let key = (user_id.to_owned(), device_id.to_owned(), session.to_owned());
346
347	self.userdevicesessionid_threepid
348		.write()
349		.expect("locked for writing")
350		.insert(key, creds.to_owned());
351}
352
353#[implement(Service)]
354pub fn take_uiaa_threepid(
355	&self,
356	user_id: &UserId,
357	device_id: &DeviceId,
358	session: &str,
359) -> Option<ThirdpartyIdCredentials> {
360	let key = (user_id.to_owned(), device_id.to_owned(), session.to_owned());
361
362	self.userdevicesessionid_threepid
363		.write()
364		.expect("locked for writing")
365		.remove(&key)
366}
367
368#[implement(Service)]
369pub fn update_uiaa_session(
370	&self,
371	user_id: &UserId,
372	device_id: &DeviceId,
373	session: &str,
374	uiaainfo: Option<&UiaaInfo>,
375) {
376	let key = (user_id, device_id, session);
377
378	if let Some(uiaainfo) = uiaainfo {
379		self.db
380			.userdevicesessionid_uiaainfo
381			.put(key, Json(uiaainfo));
382	} else {
383		self.db.userdevicesessionid_uiaainfo.del(key);
384	}
385}
386
387#[implement(Service)]
388async fn get_uiaa_session(
389	&self,
390	user_id: &UserId,
391	device_id: &DeviceId,
392	session: &str,
393) -> Result<UiaaInfo> {
394	let key = (user_id, device_id, session);
395
396	self.db
397		.userdevicesessionid_uiaainfo
398		.qry(&key)
399		.await
400		.deserialized()
401		.map_err(|_| err!(Request(Forbidden("UIAA session does not exist."))))
402}
403
404#[implement(Service)]
405pub async fn get_uiaa_session_by_session_id(
406	&self,
407	session_id: &str,
408) -> Option<(OwnedUserId, OwnedDeviceId, UiaaInfo)> {
409	// Iterate over keys only (fastest way without a secondary index)
410	let stream = self
411		.db
412		.userdevicesessionid_uiaainfo
413		.keys::<(OwnedUserId, OwnedDeviceId, String)>();
414
415	pin_mut!(stream);
416	while let Ok(Some((user_id, device_id, session))) = stream.try_next().await {
417		if session == session_id {
418			// Found the key, now fetch the actual UiaaInfo
419			if let Ok(uiaainfo) = self
420				.get_uiaa_session(&user_id, &device_id, session_id)
421				.await
422			{
423				return Some((user_id, device_id, uiaainfo));
424			}
425		}
426	}
427
428	None
429}