Skip to main content

tuwunel_service/uiaa/
mod.rs

1use std::{
2	collections::BTreeMap,
3	sync::{Arc, RwLock},
4};
5
6use futures::{TryStreamExt, pin_mut};
7use ruma::{
8	CanonicalJsonValue, DeviceId, OwnedDeviceId, OwnedUserId, UserId,
9	api::{
10		client::uiaa::{AuthData, AuthType, Password, UiaaInfo, UserIdentifier},
11		error::{ErrorKind, StandardErrorBody},
12	},
13};
14use tuwunel_core::{
15	Err, Result, err, error, extract, implement,
16	utils::{self, BoolExt, hash, string::EMPTY},
17};
18use tuwunel_database::{Deserialized, Json, Map};
19
20use crate::users::PASSWORD_SENTINEL;
21
22pub struct Service {
23	userdevicesessionid_uiaarequest: RwLock<RequestMap>,
24	db: Data,
25	services: Arc<crate::services::OnceServices>,
26}
27
28struct Data {
29	userdevicesessionid_uiaainfo: Arc<Map>,
30}
31
32type RequestMap = BTreeMap<RequestKey, CanonicalJsonValue>;
33type RequestKey = (OwnedUserId, OwnedDeviceId, String);
34
35pub const SESSION_ID_LENGTH: usize = 32;
36
37impl crate::Service for Service {
38	fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
39		Ok(Arc::new(Self {
40			userdevicesessionid_uiaarequest: RwLock::new(RequestMap::new()),
41			db: Data {
42				userdevicesessionid_uiaainfo: args.db["userdevicesessionid_uiaainfo"].clone(),
43			},
44			services: args.services.clone(),
45		}))
46	}
47
48	fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
49}
50
51/// Creates a new Uiaa session. Make sure the session token is unique.
52#[implement(Service)]
53pub fn create(
54	&self,
55	user_id: &UserId,
56	device_id: &DeviceId,
57	uiaainfo: &UiaaInfo,
58	json_body: &CanonicalJsonValue,
59) {
60	// TODO: better session error handling (why is uiaainfo.session optional in
61	// ruma?)
62	let session = uiaainfo
63		.session
64		.as_ref()
65		.expect("session should be set");
66
67	self.set_uiaa_request(user_id, device_id, session, json_body);
68
69	self.update_uiaa_session(user_id, device_id, session, Some(uiaainfo));
70}
71
72#[implement(Service)]
73#[allow(clippy::useless_let_if_seq)]
74pub async fn try_auth(
75	&self,
76	user_id: &UserId,
77	device_id: &DeviceId,
78	auth: &AuthData,
79	uiaainfo: &UiaaInfo,
80) -> Result<(bool, UiaaInfo)> {
81	let mut uiaainfo = if let Some(session) = auth.session() {
82		self.get_uiaa_session(user_id, device_id, session)
83			.await?
84	} else {
85		uiaainfo.clone()
86	};
87
88	if uiaainfo.session.is_none() {
89		uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
90	}
91
92	match auth {
93		// Find out what the user completed
94		| AuthData::Password(Password { identifier, password, user, .. }) => {
95			let username = extract!(identifier, x in Some(UserIdentifier::Matrix(ruma::api::client::uiaa::MatrixUserIdentifier { user: x, .. })))
96				.or_else(|| cfg!(feature = "element_hacks").and(user.as_ref()))
97				.ok_or(err!(Request(Unrecognized("Identifier type not recognized."))))?;
98
99			let user_id_from_username = UserId::parse_with_server_name(
100				username.clone(),
101				self.services.globals.server_name(),
102			)
103			.map_err(|_| err!(Request(InvalidParam("User ID is invalid."))))?;
104
105			// Check if the access token being used matches the credentials used for UIAA
106			if user_id.localpart() != user_id_from_username.localpart() {
107				return Err!(Request(Forbidden("User ID and access token mismatch.")));
108			}
109
110			// Check if password is correct
111			let user_id = user_id_from_username;
112			let mut password_verified = false;
113			let mut password_sentinel = false;
114
115			// First try local password hash verification
116			if let Ok(hash) = self.services.users.password_hash(&user_id).await {
117				password_sentinel = hash == PASSWORD_SENTINEL;
118				password_verified = hash::verify_password(password, &hash).is_ok();
119			}
120
121			// If local password verification failed, try LDAP authentication
122			#[cfg(feature = "ldap")]
123			if !password_verified && self.services.server.config.ldap.enable {
124				// Search for user in LDAP to get their DN
125				if let Ok(dns) = self.services.users.search_ldap(&user_id).await
126					&& let Some((user_dn, _is_admin)) = dns.first()
127				{
128					// Try to authenticate with LDAP
129					password_verified = self
130						.services
131						.users
132						.auth_ldap(user_dn, password)
133						.await
134						.is_ok();
135				}
136			}
137
138			// For SSO users that have never set a password, allow.
139			if !password_verified
140				&& password_sentinel
141				&& self
142					.services
143					.oauth
144					.sessions
145					.exists_for_user(&user_id)
146					.await
147			{
148				return Ok((true, uiaainfo));
149			}
150
151			if !password_verified {
152				uiaainfo.auth_error = Some(StandardErrorBody {
153					kind: ErrorKind::forbidden(),
154					message: "Invalid username or password.".to_owned(),
155				});
156
157				return Ok((false, uiaainfo));
158			}
159
160			// Password was correct! Let's add it to `completed`
161			uiaainfo.completed.push(AuthType::Password);
162		},
163		| AuthData::RegistrationToken(t) => {
164			let token = t.token.trim();
165			if self
166				.services
167				.registration_tokens
168				.try_consume(token)
169				.await
170				.is_ok()
171			{
172				uiaainfo
173					.completed
174					.push(AuthType::RegistrationToken);
175			} else {
176				uiaainfo.auth_error = Some(StandardErrorBody {
177					kind: ErrorKind::forbidden(),
178					message: "Invalid registration token.".to_owned(),
179				});
180
181				return Ok((false, uiaainfo));
182			}
183		},
184		| AuthData::FallbackAcknowledgement(_session) => {
185			// A fallback acknowledgement is a session re-poll. The fallback
186			// web handler (e.g. the SSO callback) is what records completion.
187		},
188		| AuthData::OAuth(_) => {
189			// MSC4312: OAuth cross-signing reset uses SSO re-authentication.
190			// If a bypass was granted via SSO re-auth, mark OAuth as completed.
191			if !uiaainfo.completed.contains(&AuthType::OAuth) {
192				if self
193					.services
194					.users
195					.can_replace_cross_signing_keys(user_id)
196					.await
197				{
198					uiaainfo.completed.push(AuthType::OAuth);
199				} else {
200					uiaainfo.auth_error = Some(StandardErrorBody {
201						kind: ErrorKind::forbidden(),
202						message: "OAuth cross-signing reset not approved for this session."
203							.to_owned(),
204					});
205
206					return Ok((false, uiaainfo));
207				}
208			}
209		},
210		| AuthData::Dummy(_) => {
211			uiaainfo.completed.push(AuthType::Dummy);
212		},
213		| auth => error!("AuthData type not supported: {auth:?}"),
214	}
215
216	// Check if a flow now succeeds
217	let mut completed = false;
218	'flows: for flow in &mut uiaainfo.flows {
219		for stage in &flow.stages {
220			if !uiaainfo.completed.contains(stage) {
221				continue 'flows;
222			}
223		}
224		// We didn't break, so this flow succeeded!
225		completed = true;
226	}
227
228	let session = uiaainfo
229		.session
230		.as_ref()
231		.expect("session is always set");
232
233	if !completed {
234		self.update_uiaa_session(user_id, device_id, session, Some(&uiaainfo));
235
236		return Ok((false, uiaainfo));
237	}
238
239	// UIAA was successful! Remove this session and return true
240	self.update_uiaa_session(user_id, device_id, session, None);
241
242	Ok((true, uiaainfo))
243}
244
245#[implement(Service)]
246fn set_uiaa_request(
247	&self,
248	user_id: &UserId,
249	device_id: &DeviceId,
250	session: &str,
251	request: &CanonicalJsonValue,
252) {
253	let key = (user_id.to_owned(), device_id.to_owned(), session.to_owned());
254
255	self.userdevicesessionid_uiaarequest
256		.write()
257		.expect("locked for writing")
258		.insert(key, request.to_owned());
259}
260
261#[implement(Service)]
262pub fn get_uiaa_request(
263	&self,
264	user_id: &UserId,
265	device_id: Option<&DeviceId>,
266	session: &str,
267) -> Option<CanonicalJsonValue> {
268	let device_id = device_id.unwrap_or_else(|| EMPTY.into());
269	let key = (user_id.to_owned(), device_id.to_owned(), session.to_owned());
270
271	self.userdevicesessionid_uiaarequest
272		.read()
273		.expect("locked for reading")
274		.get(&key)
275		.cloned()
276}
277
278#[implement(Service)]
279pub fn update_uiaa_session(
280	&self,
281	user_id: &UserId,
282	device_id: &DeviceId,
283	session: &str,
284	uiaainfo: Option<&UiaaInfo>,
285) {
286	let key = (user_id, device_id, session);
287
288	if let Some(uiaainfo) = uiaainfo {
289		self.db
290			.userdevicesessionid_uiaainfo
291			.put(key, Json(uiaainfo));
292	} else {
293		self.db.userdevicesessionid_uiaainfo.del(key);
294	}
295}
296
297#[implement(Service)]
298async fn get_uiaa_session(
299	&self,
300	user_id: &UserId,
301	device_id: &DeviceId,
302	session: &str,
303) -> Result<UiaaInfo> {
304	let key = (user_id, device_id, session);
305
306	self.db
307		.userdevicesessionid_uiaainfo
308		.qry(&key)
309		.await
310		.deserialized()
311		.map_err(|_| err!(Request(Forbidden("UIAA session does not exist."))))
312}
313
314#[implement(Service)]
315pub async fn get_uiaa_session_by_session_id(
316	&self,
317	session_id: &str,
318) -> Option<(OwnedUserId, OwnedDeviceId, UiaaInfo)> {
319	// Iterate over keys only (fastest way without a secondary index)
320	let stream = self
321		.db
322		.userdevicesessionid_uiaainfo
323		.keys::<(OwnedUserId, OwnedDeviceId, String)>();
324
325	pin_mut!(stream);
326	while let Ok(Some((user_id, device_id, session))) = stream.try_next().await {
327		if session == session_id {
328			// Found the key, now fetch the actual UiaaInfo
329			if let Ok(uiaainfo) = self
330				.get_uiaa_session(&user_id, &device_id, session_id)
331				.await
332			{
333				return Some((user_id, device_id, uiaainfo));
334			}
335		}
336	}
337
338	None
339}