1use std::{
2 collections::BTreeMap,
3 ops::ControlFlow,
4 sync::{Arc, RwLock},
5};
6
7use futures::{TryStreamExt, pin_mut};
8use ruma::{
9 CanonicalJsonValue, DeviceId, OwnedDeviceId, OwnedUserId, UserId,
10 api::{
11 client::uiaa::{
12 AuthData, AuthType, EmailIdentity, Password, ThirdpartyIdCredentials, UiaaInfo,
13 UserIdentifier,
14 },
15 error::{ErrorKind, StandardErrorBody},
16 },
17};
18use tuwunel_core::{
19 Err, Result, err, error, extract, implement,
20 utils::{self, BoolExt, hash, string::EMPTY},
21};
22use tuwunel_database::{Deserialized, Json, Map};
23
24use crate::users::PASSWORD_SENTINEL;
25
26pub struct Service {
27 userdevicesessionid_uiaarequest: RwLock<RequestMap>,
28 userdevicesessionid_threepid: RwLock<ThreepidMap>,
29 db: Data,
30 services: Arc<crate::services::OnceServices>,
31}
32
33struct Data {
34 userdevicesessionid_uiaainfo: Arc<Map>,
35}
36
37type RequestMap = BTreeMap<RequestKey, CanonicalJsonValue>;
38type ThreepidMap = BTreeMap<RequestKey, ThirdpartyIdCredentials>;
39type RequestKey = (OwnedUserId, OwnedDeviceId, String);
40
41pub const SESSION_ID_LENGTH: usize = 32;
42
43impl crate::Service for Service {
44 fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
45 Ok(Arc::new(Self {
46 userdevicesessionid_uiaarequest: RwLock::new(RequestMap::new()),
47 userdevicesessionid_threepid: RwLock::new(ThreepidMap::new()),
48 db: Data {
49 userdevicesessionid_uiaainfo: args.db["userdevicesessionid_uiaainfo"].clone(),
50 },
51 services: args.services.clone(),
52 }))
53 }
54
55 fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
56}
57
58#[implement(Service)]
60pub fn create(
61 &self,
62 user_id: &UserId,
63 device_id: &DeviceId,
64 uiaainfo: &UiaaInfo,
65 json_body: &CanonicalJsonValue,
66) {
67 let session = uiaainfo
70 .session
71 .as_ref()
72 .expect("session should be set");
73
74 self.set_uiaa_request(user_id, device_id, session, json_body);
75
76 self.update_uiaa_session(user_id, device_id, session, Some(uiaainfo));
77}
78
79#[implement(Service)]
80pub async fn try_auth(
81 &self,
82 user_id: &UserId,
83 device_id: &DeviceId,
84 auth: &AuthData,
85 uiaainfo: &UiaaInfo,
86) -> Result<(bool, UiaaInfo)> {
87 let mut uiaainfo = if let Some(session) = auth.session() {
88 self.get_uiaa_session(user_id, device_id, session)
89 .await?
90 } else {
91 uiaainfo.clone()
92 };
93
94 if uiaainfo.session.is_none() {
95 uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
96 }
97
98 match auth {
99 | AuthData::Password(password) => {
101 if let ControlFlow::Break(authed) = self
102 .verify_password(user_id, &mut uiaainfo, password)
103 .await?
104 {
105 return Ok((authed, uiaainfo));
106 }
107 },
108 | AuthData::RegistrationToken(t) => {
109 let token = t.token.trim();
110 if self
111 .services
112 .registration_tokens
113 .try_consume(token)
114 .await
115 .is_ok()
116 {
117 uiaainfo
118 .completed
119 .push(AuthType::RegistrationToken);
120 } else {
121 uiaainfo.auth_error = Some(StandardErrorBody {
122 kind: ErrorKind::forbidden(),
123 message: "Invalid registration token.".to_owned(),
124 });
125
126 return Ok((false, uiaainfo));
127 }
128 },
129 | AuthData::FallbackAcknowledgement(_session) => {
130 },
133 | AuthData::OAuth(_) => {
134 if !uiaainfo.completed.contains(&AuthType::OAuth) {
137 if self
138 .services
139 .users
140 .can_replace_cross_signing_keys(user_id)
141 .await
142 {
143 uiaainfo.completed.push(AuthType::OAuth);
144 } else {
145 uiaainfo.auth_error = Some(StandardErrorBody {
146 kind: ErrorKind::forbidden(),
147 message: "OAuth cross-signing reset not approved for this session."
148 .to_owned(),
149 });
150
151 return Ok((false, uiaainfo));
152 }
153 }
154 },
155 | AuthData::Dummy(_) => {
156 uiaainfo.completed.push(AuthType::Dummy);
157 },
158 | AuthData::Terms(_) => {
159 uiaainfo.completed.push(AuthType::Terms);
161 },
162 | AuthData::EmailIdentity(EmailIdentity { thirdparty_id_creds, .. }) => {
163 let validated = self
165 .services
166 .threepid
167 .session_validated(
168 thirdparty_id_creds.sid.as_str(),
169 thirdparty_id_creds.client_secret.as_str(),
170 )
171 .await;
172
173 if !validated {
174 uiaainfo.auth_error = Some(StandardErrorBody {
175 kind: ErrorKind::forbidden(),
176 message: "Email address has not been validated.".to_owned(),
177 });
178
179 return Ok((false, uiaainfo));
180 }
181
182 uiaainfo.completed.push(AuthType::EmailIdentity);
183
184 if let Some(session) = uiaainfo.session.as_deref() {
187 self.set_uiaa_threepid(user_id, device_id, session, thirdparty_id_creds);
188 }
189 },
190 | auth => error!("AuthData type not supported: {auth:?}"),
191 }
192
193 let mut completed = false;
195 'flows: for flow in &mut uiaainfo.flows {
196 for stage in &flow.stages {
197 if !uiaainfo.completed.contains(stage) {
198 continue 'flows;
199 }
200 }
201 completed = true;
203 }
204
205 let session = uiaainfo
206 .session
207 .as_ref()
208 .expect("session is always set");
209
210 if !completed {
211 self.update_uiaa_session(user_id, device_id, session, Some(&uiaainfo));
212
213 return Ok((false, uiaainfo));
214 }
215
216 self.update_uiaa_session(user_id, device_id, session, None);
218
219 Ok((true, uiaainfo))
220}
221
222#[implement(Service)]
223#[allow(clippy::useless_let_if_seq)]
224async fn verify_password(
225 &self,
226 user_id: &UserId,
227 uiaainfo: &mut UiaaInfo,
228 password: &Password,
229) -> Result<ControlFlow<bool>> {
230 let Password { identifier, password, user, .. } = password;
231
232 let username = extract!(identifier, x in Some(UserIdentifier::Matrix(ruma::api::client::uiaa::MatrixUserIdentifier { user: x, .. })))
233 .or_else(|| cfg!(feature = "element_hacks").and(user.as_ref()))
234 .ok_or(err!(Request(Unrecognized("Identifier type not recognized."))))?;
235
236 let user_id_from_username =
237 UserId::parse_with_server_name(username.clone(), self.services.globals.server_name())
238 .map_err(|_| err!(Request(InvalidParam("User ID is invalid."))))?;
239
240 if user_id.localpart() != user_id_from_username.localpart() {
242 return Err!(Request(Forbidden("User ID and access token mismatch.")));
243 }
244
245 let user_id = user_id_from_username;
246 let mut password_verified = false;
247 let mut password_sentinel = false;
248
249 if let Ok(hash) = self.services.users.password_hash(&user_id).await {
251 password_sentinel = hash == PASSWORD_SENTINEL;
252 password_verified = hash::verify_password(password, &hash).is_ok();
253 }
254
255 #[cfg(feature = "ldap")]
258 if !password_verified
259 && self.services.server.config.ldap.enable
260 && self
261 .services
262 .users
263 .origin(&user_id)
264 .await
265 .is_ok_and(|origin| origin == "ldap")
266 && let Ok(dns) = self.services.users.search_ldap(&user_id).await
267 && let Some((user_dn, _is_admin)) = dns.first()
268 {
269 password_verified = self
270 .services
271 .users
272 .auth_ldap(user_dn, password)
273 .await
274 .is_ok();
275 }
276
277 if !password_verified
279 && password_sentinel
280 && self
281 .services
282 .oauth
283 .sessions
284 .exists_for_user(&user_id)
285 .await
286 {
287 return Ok(ControlFlow::Break(true));
288 }
289
290 if !password_verified {
291 uiaainfo.auth_error = Some(StandardErrorBody {
292 kind: ErrorKind::forbidden(),
293 message: "Invalid username or password.".to_owned(),
294 });
295
296 return Ok(ControlFlow::Break(false));
297 }
298
299 uiaainfo.completed.push(AuthType::Password);
300
301 Ok(ControlFlow::Continue(()))
302}
303
304#[implement(Service)]
305fn set_uiaa_request(
306 &self,
307 user_id: &UserId,
308 device_id: &DeviceId,
309 session: &str,
310 request: &CanonicalJsonValue,
311) {
312 let key = (user_id.to_owned(), device_id.to_owned(), session.to_owned());
313
314 self.userdevicesessionid_uiaarequest
315 .write()
316 .expect("locked for writing")
317 .insert(key, request.to_owned());
318}
319
320#[implement(Service)]
321pub fn get_uiaa_request(
322 &self,
323 user_id: &UserId,
324 device_id: Option<&DeviceId>,
325 session: &str,
326) -> Option<CanonicalJsonValue> {
327 let device_id = device_id.unwrap_or_else(|| EMPTY.into());
328 let key = (user_id.to_owned(), device_id.to_owned(), session.to_owned());
329
330 self.userdevicesessionid_uiaarequest
331 .read()
332 .expect("locked for reading")
333 .get(&key)
334 .cloned()
335}
336
337#[implement(Service)]
338fn set_uiaa_threepid(
339 &self,
340 user_id: &UserId,
341 device_id: &DeviceId,
342 session: &str,
343 creds: &ThirdpartyIdCredentials,
344) {
345 let key = (user_id.to_owned(), device_id.to_owned(), session.to_owned());
346
347 self.userdevicesessionid_threepid
348 .write()
349 .expect("locked for writing")
350 .insert(key, creds.to_owned());
351}
352
353#[implement(Service)]
354pub fn take_uiaa_threepid(
355 &self,
356 user_id: &UserId,
357 device_id: &DeviceId,
358 session: &str,
359) -> Option<ThirdpartyIdCredentials> {
360 let key = (user_id.to_owned(), device_id.to_owned(), session.to_owned());
361
362 self.userdevicesessionid_threepid
363 .write()
364 .expect("locked for writing")
365 .remove(&key)
366}
367
368#[implement(Service)]
369pub fn update_uiaa_session(
370 &self,
371 user_id: &UserId,
372 device_id: &DeviceId,
373 session: &str,
374 uiaainfo: Option<&UiaaInfo>,
375) {
376 let key = (user_id, device_id, session);
377
378 if let Some(uiaainfo) = uiaainfo {
379 self.db
380 .userdevicesessionid_uiaainfo
381 .put(key, Json(uiaainfo));
382 } else {
383 self.db.userdevicesessionid_uiaainfo.del(key);
384 }
385}
386
387#[implement(Service)]
388async fn get_uiaa_session(
389 &self,
390 user_id: &UserId,
391 device_id: &DeviceId,
392 session: &str,
393) -> Result<UiaaInfo> {
394 let key = (user_id, device_id, session);
395
396 self.db
397 .userdevicesessionid_uiaainfo
398 .qry(&key)
399 .await
400 .deserialized()
401 .map_err(|_| err!(Request(Forbidden("UIAA session does not exist."))))
402}
403
404#[implement(Service)]
405pub async fn get_uiaa_session_by_session_id(
406 &self,
407 session_id: &str,
408) -> Option<(OwnedUserId, OwnedDeviceId, UiaaInfo)> {
409 let stream = self
411 .db
412 .userdevicesessionid_uiaainfo
413 .keys::<(OwnedUserId, OwnedDeviceId, String)>();
414
415 pin_mut!(stream);
416 while let Ok(Some((user_id, device_id, session))) = stream.try_next().await {
417 if session == session_id {
418 if let Ok(uiaainfo) = self
420 .get_uiaa_session(&user_id, &device_id, session_id)
421 .await
422 {
423 return Some((user_id, device_id, uiaainfo));
424 }
425 }
426 }
427
428 None
429}