1use std::{
2 collections::BTreeMap,
3 sync::{Arc, RwLock},
4};
5
6use futures::{TryStreamExt, pin_mut};
7use ruma::{
8 CanonicalJsonValue, DeviceId, OwnedDeviceId, OwnedUserId, UserId,
9 api::{
10 client::uiaa::{AuthData, AuthType, Password, UiaaInfo, UserIdentifier},
11 error::{ErrorKind, StandardErrorBody},
12 },
13};
14use tuwunel_core::{
15 Err, Result, err, error, extract, implement,
16 utils::{self, BoolExt, hash, string::EMPTY},
17};
18use tuwunel_database::{Deserialized, Json, Map};
19
20use crate::users::PASSWORD_SENTINEL;
21
22pub struct Service {
23 userdevicesessionid_uiaarequest: RwLock<RequestMap>,
24 db: Data,
25 services: Arc<crate::services::OnceServices>,
26}
27
28struct Data {
29 userdevicesessionid_uiaainfo: Arc<Map>,
30}
31
32type RequestMap = BTreeMap<RequestKey, CanonicalJsonValue>;
33type RequestKey = (OwnedUserId, OwnedDeviceId, String);
34
35pub const SESSION_ID_LENGTH: usize = 32;
36
37impl crate::Service for Service {
38 fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
39 Ok(Arc::new(Self {
40 userdevicesessionid_uiaarequest: RwLock::new(RequestMap::new()),
41 db: Data {
42 userdevicesessionid_uiaainfo: args.db["userdevicesessionid_uiaainfo"].clone(),
43 },
44 services: args.services.clone(),
45 }))
46 }
47
48 fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
49}
50
51#[implement(Service)]
53pub fn create(
54 &self,
55 user_id: &UserId,
56 device_id: &DeviceId,
57 uiaainfo: &UiaaInfo,
58 json_body: &CanonicalJsonValue,
59) {
60 let session = uiaainfo
63 .session
64 .as_ref()
65 .expect("session should be set");
66
67 self.set_uiaa_request(user_id, device_id, session, json_body);
68
69 self.update_uiaa_session(user_id, device_id, session, Some(uiaainfo));
70}
71
72#[implement(Service)]
73#[allow(clippy::useless_let_if_seq)]
74pub async fn try_auth(
75 &self,
76 user_id: &UserId,
77 device_id: &DeviceId,
78 auth: &AuthData,
79 uiaainfo: &UiaaInfo,
80) -> Result<(bool, UiaaInfo)> {
81 let mut uiaainfo = if let Some(session) = auth.session() {
82 self.get_uiaa_session(user_id, device_id, session)
83 .await?
84 } else {
85 uiaainfo.clone()
86 };
87
88 if uiaainfo.session.is_none() {
89 uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
90 }
91
92 match auth {
93 | AuthData::Password(Password { identifier, password, user, .. }) => {
95 let username = extract!(identifier, x in Some(UserIdentifier::Matrix(ruma::api::client::uiaa::MatrixUserIdentifier { user: x, .. })))
96 .or_else(|| cfg!(feature = "element_hacks").and(user.as_ref()))
97 .ok_or(err!(Request(Unrecognized("Identifier type not recognized."))))?;
98
99 let user_id_from_username = UserId::parse_with_server_name(
100 username.clone(),
101 self.services.globals.server_name(),
102 )
103 .map_err(|_| err!(Request(InvalidParam("User ID is invalid."))))?;
104
105 if user_id.localpart() != user_id_from_username.localpart() {
107 return Err!(Request(Forbidden("User ID and access token mismatch.")));
108 }
109
110 let user_id = user_id_from_username;
112 let mut password_verified = false;
113 let mut password_sentinel = false;
114
115 if let Ok(hash) = self.services.users.password_hash(&user_id).await {
117 password_sentinel = hash == PASSWORD_SENTINEL;
118 password_verified = hash::verify_password(password, &hash).is_ok();
119 }
120
121 #[cfg(feature = "ldap")]
123 if !password_verified && self.services.server.config.ldap.enable {
124 if let Ok(dns) = self.services.users.search_ldap(&user_id).await
126 && let Some((user_dn, _is_admin)) = dns.first()
127 {
128 password_verified = self
130 .services
131 .users
132 .auth_ldap(user_dn, password)
133 .await
134 .is_ok();
135 }
136 }
137
138 if !password_verified
140 && password_sentinel
141 && self
142 .services
143 .oauth
144 .sessions
145 .exists_for_user(&user_id)
146 .await
147 {
148 return Ok((true, uiaainfo));
149 }
150
151 if !password_verified {
152 uiaainfo.auth_error = Some(StandardErrorBody {
153 kind: ErrorKind::forbidden(),
154 message: "Invalid username or password.".to_owned(),
155 });
156
157 return Ok((false, uiaainfo));
158 }
159
160 uiaainfo.completed.push(AuthType::Password);
162 },
163 | AuthData::RegistrationToken(t) => {
164 let token = t.token.trim();
165 if self
166 .services
167 .registration_tokens
168 .try_consume(token)
169 .await
170 .is_ok()
171 {
172 uiaainfo
173 .completed
174 .push(AuthType::RegistrationToken);
175 } else {
176 uiaainfo.auth_error = Some(StandardErrorBody {
177 kind: ErrorKind::forbidden(),
178 message: "Invalid registration token.".to_owned(),
179 });
180
181 return Ok((false, uiaainfo));
182 }
183 },
184 | AuthData::FallbackAcknowledgement(_session) => {
185 },
188 | AuthData::OAuth(_) => {
189 if !uiaainfo.completed.contains(&AuthType::OAuth) {
192 if self
193 .services
194 .users
195 .can_replace_cross_signing_keys(user_id)
196 .await
197 {
198 uiaainfo.completed.push(AuthType::OAuth);
199 } else {
200 uiaainfo.auth_error = Some(StandardErrorBody {
201 kind: ErrorKind::forbidden(),
202 message: "OAuth cross-signing reset not approved for this session."
203 .to_owned(),
204 });
205
206 return Ok((false, uiaainfo));
207 }
208 }
209 },
210 | AuthData::Dummy(_) => {
211 uiaainfo.completed.push(AuthType::Dummy);
212 },
213 | auth => error!("AuthData type not supported: {auth:?}"),
214 }
215
216 let mut completed = false;
218 'flows: for flow in &mut uiaainfo.flows {
219 for stage in &flow.stages {
220 if !uiaainfo.completed.contains(stage) {
221 continue 'flows;
222 }
223 }
224 completed = true;
226 }
227
228 let session = uiaainfo
229 .session
230 .as_ref()
231 .expect("session is always set");
232
233 if !completed {
234 self.update_uiaa_session(user_id, device_id, session, Some(&uiaainfo));
235
236 return Ok((false, uiaainfo));
237 }
238
239 self.update_uiaa_session(user_id, device_id, session, None);
241
242 Ok((true, uiaainfo))
243}
244
245#[implement(Service)]
246fn set_uiaa_request(
247 &self,
248 user_id: &UserId,
249 device_id: &DeviceId,
250 session: &str,
251 request: &CanonicalJsonValue,
252) {
253 let key = (user_id.to_owned(), device_id.to_owned(), session.to_owned());
254
255 self.userdevicesessionid_uiaarequest
256 .write()
257 .expect("locked for writing")
258 .insert(key, request.to_owned());
259}
260
261#[implement(Service)]
262pub fn get_uiaa_request(
263 &self,
264 user_id: &UserId,
265 device_id: Option<&DeviceId>,
266 session: &str,
267) -> Option<CanonicalJsonValue> {
268 let device_id = device_id.unwrap_or_else(|| EMPTY.into());
269 let key = (user_id.to_owned(), device_id.to_owned(), session.to_owned());
270
271 self.userdevicesessionid_uiaarequest
272 .read()
273 .expect("locked for reading")
274 .get(&key)
275 .cloned()
276}
277
278#[implement(Service)]
279pub fn update_uiaa_session(
280 &self,
281 user_id: &UserId,
282 device_id: &DeviceId,
283 session: &str,
284 uiaainfo: Option<&UiaaInfo>,
285) {
286 let key = (user_id, device_id, session);
287
288 if let Some(uiaainfo) = uiaainfo {
289 self.db
290 .userdevicesessionid_uiaainfo
291 .put(key, Json(uiaainfo));
292 } else {
293 self.db.userdevicesessionid_uiaainfo.del(key);
294 }
295}
296
297#[implement(Service)]
298async fn get_uiaa_session(
299 &self,
300 user_id: &UserId,
301 device_id: &DeviceId,
302 session: &str,
303) -> Result<UiaaInfo> {
304 let key = (user_id, device_id, session);
305
306 self.db
307 .userdevicesessionid_uiaainfo
308 .qry(&key)
309 .await
310 .deserialized()
311 .map_err(|_| err!(Request(Forbidden("UIAA session does not exist."))))
312}
313
314#[implement(Service)]
315pub async fn get_uiaa_session_by_session_id(
316 &self,
317 session_id: &str,
318) -> Option<(OwnedUserId, OwnedDeviceId, UiaaInfo)> {
319 let stream = self
321 .db
322 .userdevicesessionid_uiaainfo
323 .keys::<(OwnedUserId, OwnedDeviceId, String)>();
324
325 pin_mut!(stream);
326 while let Ok(Some((user_id, device_id, session))) = stream.try_next().await {
327 if session == session_id {
328 if let Ok(uiaainfo) = self
330 .get_uiaa_session(&user_id, &device_id, session_id)
331 .await
332 {
333 return Some((user_id, device_id, uiaainfo));
334 }
335 }
336 }
337
338 None
339}