1mod uiaa;
2
3use std::{borrow::Cow, collections::BTreeMap, net::IpAddr, time::Duration};
4
5use axum::extract::State;
6use axum_extra::extract::cookie::{Cookie, SameSite};
7use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as b64};
8use futures::{FutureExt, StreamExt, TryFutureExt, future::try_join};
9use reqwest::header::{CONTENT_TYPE, HeaderValue};
10use ruma::{
11 Mxc, OwnedMxcUri, OwnedRoomId, OwnedUserId, ServerName, UserId,
12 api::client::{
13 session::{sso_callback, sso_login, sso_login_with_provider},
14 uiaa::AuthType,
15 },
16};
17use serde::{Deserialize, Serialize};
18use tuwunel_core::{
19 Err, Result, at,
20 config::IdentityProvider,
21 debug::INFO_SPAN_LEVEL,
22 debug_info, debug_warn, err, info, is_not_equal_to,
23 itertools::Itertools,
24 utils,
25 utils::{
26 OptionExt,
27 content_disposition::make_content_disposition,
28 hash::sha256,
29 result::{FlatOk, LogErr},
30 string::{EMPTY, truncate_deterministic},
31 timepoint_from_now, timepoint_has_passed,
32 },
33 warn,
34};
35use tuwunel_service::{
36 Services,
37 media::MXC_LENGTH,
38 oauth::{
39 CODE_VERIFIER_LENGTH, Provider, SESSION_ID_LENGTH, Session, UserInfo, unique_id_sub,
40 },
41 users::{PASSWORD_SENTINEL, Register, propagation_default},
42};
43use url::Url;
44
45pub(crate) use self::uiaa::sso_fallback_route;
46use super::TOKEN_LENGTH;
47use crate::{ClientIp, Ruma};
48
49#[derive(Debug, Serialize)]
51struct GrantQuery<'a> {
52 client_id: &'a str,
53 state: &'a str,
54 nonce: &'a str,
55 scope: &'a str,
56 response_type: &'a str,
57 access_type: &'a str,
58 code_challenge_method: &'a str,
59 code_challenge: &'a str,
60 redirect_uri: Option<&'a str>,
61}
62
63#[derive(Debug, Deserialize, Serialize)]
64struct GrantCookie<'a> {
65 client_id: Cow<'a, str>,
66 state: Cow<'a, str>,
67 nonce: Cow<'a, str>,
68 redirect_uri: Cow<'a, str>,
69}
70
71static GRANT_SESSION_COOKIE: &str = "tuwunel_grant_session";
72
73#[tracing::instrument(
78 name = "sso_login",
79 level = "debug",
80 skip_all,
81 fields(%client),
82)]
83pub(crate) async fn sso_login_route(
84 State(services): State<crate::State>,
85 ClientIp(client): ClientIp,
86 body: Ruma<sso_login::v3::Request>,
87) -> Result<sso_login::v3::Response> {
88 if services.config.sso_custom_providers_page {
89 return Err!(Request(NotImplemented(
90 "sso_custom_providers_page has been enabled but this URL has not been overridden \
91 with any custom page listing the available providers..."
92 )));
93 }
94
95 let redirect_url = body.body.redirect_url;
96 let default_idp_id = services
97 .oauth
98 .providers
99 .get_default_id()
100 .unwrap_or_default();
101
102 handle_sso_login(&services, &client, default_idp_id, redirect_url, None)
103 .map_ok(|response| sso_login::v3::Response {
104 location: response.location,
105 cookie: response.cookie,
106 })
107 .await
108}
109
110#[tracing::instrument(
116 name = "sso_login_with_provider",
117 level = "info",
118 skip_all,
119 ret(level = "debug")
120 fields(
121 %client,
122 idp_id = body.body.idp_id,
123 ),
124)]
125pub(crate) async fn sso_login_with_provider_route(
126 State(services): State<crate::State>,
127 ClientIp(client): ClientIp,
128 body: Ruma<sso_login_with_provider::v3::Request>,
129) -> Result<sso_login_with_provider::v3::Response> {
130 let idp_id = body.body.idp_id;
131 let redirect_url = body.body.redirect_url;
132 let login_token = body.body.login_token;
133
134 handle_sso_login(&services, &client, idp_id, redirect_url, login_token).await
135}
136
137async fn handle_sso_login(
138 services: &Services,
139 _client: &IpAddr,
140 idp_id: String,
141 redirect_url: String,
142 login_token: Option<String>,
143) -> Result<sso_login_with_provider::v3::Response> {
144 let redirect_url: Url = redirect_url.parse().map_err(|e| {
145 err!(Request(InvalidParam(debug_warn!(
146 ?e,
147 ?redirect_url,
148 "Failed to parse redirect_url.",
149 ))))
150 })?;
151
152 let provider = services.oauth.providers.get(&idp_id).await?;
153 let sess_id = utils::random_string(SESSION_ID_LENGTH);
154 let query_nonce = utils::random_string(CODE_VERIFIER_LENGTH);
155 let cookie_nonce = utils::random_string(CODE_VERIFIER_LENGTH);
156 let code_verifier = utils::random_string(CODE_VERIFIER_LENGTH);
157 let code_challenge = b64.encode(sha256::hash(code_verifier.as_bytes()));
158 let callback_uri = provider.callback_url.as_ref().map(Url::as_str);
159 let scope = provider.scope.iter().join(" ");
160
161 let query = GrantQuery {
162 client_id: &provider.client_id,
163 state: &sess_id,
164 nonce: &query_nonce,
165 access_type: "online",
166 response_type: "code",
167 code_challenge_method: "S256",
168 code_challenge: &code_challenge,
169 redirect_uri: callback_uri,
170 scope: scope
171 .is_empty()
172 .then_some("openid email profile")
173 .unwrap_or(scope.as_str()),
174 };
175
176 let location = provider
177 .authorization_url
178 .clone()
179 .map(|mut location| {
180 let query = serde_html_form::to_string(&query).ok();
181 location.set_query(query.as_deref());
182 if !provider.extra_authorization_parameters.is_empty() {
183 let merged: BTreeMap<String, String> = location
186 .query_pairs()
187 .map(|(k, v)| (k.into_owned(), v.into_owned()))
188 .chain(provider.extra_authorization_parameters.clone())
189 .collect();
190 location.set_query(None);
191 location.query_pairs_mut().extend_pairs(&merged);
192 }
193 location
194 })
195 .ok_or_else(|| {
196 err!(Config("authorization_url", "Missing required IdentityProvider config"))
197 })?;
198
199 let cookie_val = GrantCookie {
200 client_id: query.client_id.into(),
201 state: query.state.into(),
202 nonce: cookie_nonce.as_str().into(),
203 redirect_uri: redirect_url.as_str().into(),
204 };
205
206 let cookie_path = provider
207 .callback_url
208 .as_ref()
209 .map(Url::path)
210 .unwrap_or("/");
211
212 let cookie_max_age = provider
213 .grant_session_duration
214 .map(Duration::from_secs)
215 .expect("Defaulted to Some value during configure_idp()")
216 .try_into()
217 .expect("std::time::Duration to time::Duration conversion failure");
218
219 let cookie = Cookie::build((GRANT_SESSION_COOKIE, serde_html_form::to_string(&cookie_val)?))
220 .path(cookie_path)
221 .max_age(cookie_max_age)
222 .same_site(SameSite::None)
223 .secure(true)
224 .http_only(true)
225 .build()
226 .to_string()
227 .into();
228
229 let session = Session {
230 idp_id: Some(idp_id),
231 sess_id: Some(sess_id.clone()),
232 redirect_url: Some(redirect_url),
233 code_verifier: Some(code_verifier),
234 query_nonce: Some(query_nonce),
235 cookie_nonce: Some(cookie_nonce),
236 authorize_expires_at: provider
237 .grant_session_duration
238 .map(Duration::from_secs)
239 .map(timepoint_from_now)
240 .transpose()?,
241
242 user_id: login_token
243 .as_deref()
244 .map_async(|token| services.users.find_from_login_token(token))
245 .map(FlatOk::flat_ok)
246 .await,
247
248 ..Default::default()
249 };
250
251 services.oauth.sessions.put(&session).await;
252
253 Ok(sso_login_with_provider::v3::Response {
254 location: location.into(),
255 cookie: Some(cookie),
256 })
257}
258
259#[tracing::instrument(
260 name = "sso_callback"
261 level = "debug",
262 skip_all,
263 fields(
264 %client,
265 cookie = ?body.cookie,
266 body = ?body.body,
267 ),
268)]
269pub(crate) async fn sso_callback_route(
270 State(services): State<crate::State>,
271 ClientIp(client): ClientIp,
272 body: Ruma<sso_callback::unstable::Request>,
273) -> Result<sso_callback::unstable::Response> {
274 let sess_id = body
275 .body
276 .state
277 .as_deref()
278 .ok_or_else(|| err!(Request(Forbidden("Missing sess_id in callback."))))?;
279
280 let code = body
281 .body
282 .code
283 .as_deref()
284 .ok_or_else(|| err!(Request(Forbidden("Missing code in callback."))))?;
285
286 let session = services
287 .oauth
288 .sessions
289 .get(sess_id)
290 .map_err(|_| err!(Request(Forbidden("Invalid state in callback"))));
291
292 let provider = services
293 .oauth
294 .providers
295 .get(body.body.idp_id.as_str());
296
297 let (provider, session) = try_join(provider, session).await.log_err()?;
298 let client_id = &provider.client_id;
299 let idp_id = provider.id();
300
301 if session.sess_id.as_deref() != Some(sess_id) {
302 return Err!(Request(Unauthorized("Session ID {sess_id:?} not recognized.")));
303 }
304
305 if session.idp_id.as_deref() != Some(idp_id) {
306 return Err!(Request(Unauthorized(
307 "Identity Provider {idp_id:?} session not recognized."
308 )));
309 }
310
311 if session
312 .authorize_expires_at
313 .map(timepoint_has_passed)
314 .unwrap_or(false)
315 {
316 return Err!(Request(Unauthorized("Authorization grant session has expired.")));
317 }
318
319 if provider.check_cookie {
320 let cookie = body
321 .cookie
322 .get(GRANT_SESSION_COOKIE)
323 .map(Cookie::value)
324 .map(serde_html_form::from_str::<GrantCookie<'_>>)
325 .transpose()?
326 .ok_or_else(|| {
327 err!(Request(Unauthorized("Missing cookie {GRANT_SESSION_COOKIE:?}")))
328 })?;
329
330 if cookie.client_id.as_ref() != client_id.as_str() {
331 return Err!(Request(Unauthorized("Client ID {client_id:?} cookie mismatch.")));
332 }
333
334 if Some(cookie.nonce.as_ref()) != session.cookie_nonce.as_deref() {
335 return Err!(Request(Unauthorized("Cookie nonce does not match session state.")));
336 }
337
338 if cookie.state.as_ref() != sess_id {
339 return Err!(Request(Unauthorized("Session ID {sess_id:?} cookie mismatch.")));
340 }
341 }
342
343 let token_response = services
345 .oauth
346 .request_token((&provider, &session), code)
347 .await?;
348
349 let token_expires_at = token_response
350 .expires_in
351 .map(Duration::from_secs)
352 .map(timepoint_from_now)
353 .transpose()?;
354
355 let refresh_token_expires_at = token_response
356 .refresh_token_expires_in
357 .map(Duration::from_secs)
358 .map(timepoint_from_now)
359 .transpose()?;
360
361 let session = Session {
363 scope: token_response.scope,
364 token_type: token_response.token_type,
365 access_token: token_response.access_token,
366 expires_at: token_expires_at,
367 refresh_token: token_response.refresh_token,
368 refresh_token_expires_at,
369 ..session
370 };
371
372 let userinfo = services
374 .oauth
375 .request_userinfo((&provider, &session))
376 .await?;
377
378 let unique_id = unique_id_sub((&provider, &userinfo.sub))?;
379
380 let (old_user_id, old_sess_id) = match services
384 .oauth
385 .sessions
386 .get_by_unique_id(&unique_id)
387 .await
388 {
389 | Ok(session) => (session.user_id, session.sess_id),
390 | Err(error) if !error.is_not_found() => return Err(error),
391 | Err(_) => (None, None),
392 };
393
394 let session = Session {
396 user_info: Some(userinfo.clone()),
397 ..session
398 };
399
400 let user_id = match (session.user_id, old_user_id) {
402 | (Some(user_id), ..) | (None, Some(user_id)) => user_id,
403 | (None, None) => decide_user_id(&services, &provider, &userinfo, &unique_id).await?,
404 };
405
406 let session = Session {
408 user_id: Some(user_id.clone()),
409 ..session
410 };
411
412 if !services.users.exists(&user_id).await {
414 if !provider.registration {
415 return Err!(Request(Forbidden("Registration from this provider is disabled")));
416 }
417
418 register_user(&services, &provider, &session, &userinfo, &user_id).await?;
419 }
420
421 services.oauth.sessions.put(&session).await;
423
424 if let Some(old_sess_id) = old_sess_id
426 .as_deref()
427 .filter(is_not_equal_to!(&sess_id))
428 {
429 services.oauth.sessions.delete(old_sess_id).await;
430 }
431
432 if !services.users.is_active_local(&user_id).await {
433 return Err!(Request(UserDeactivated("This user has been deactivated.")));
434 }
435
436 let cookie = Cookie::build((GRANT_SESSION_COOKIE, EMPTY))
437 .removal()
438 .build()
439 .to_string()
440 .into();
441
442 if let Some(redirect_url) = session
444 .redirect_url
445 .as_ref()
446 .filter(|url| url.scheme() == "uiaa")
447 {
448 return handle_uiaa(&services, &user_id, cookie, redirect_url).await;
449 }
450
451 let next_idp_url = services
453 .config
454 .identity_provider
455 .values()
456 .filter(|idp| idp.default || services.config.single_sso)
457 .skip_while(|idp| idp.id() != idp_id)
458 .nth(1)
459 .map(IdentityProvider::id)
460 .and_then(|next_idp| {
461 provider.callback_url.clone().map(|mut url| {
462 let path = format!("/_matrix/client/v3/login/sso/redirect/{next_idp}");
463 url.set_path(&path);
464
465 if let Some(redirect_url) = session.redirect_url.as_ref() {
466 url.query_pairs_mut()
467 .append_pair("redirectUrl", redirect_url.as_str());
468 }
469
470 url
471 })
472 });
473
474 let login_token = utils::random_string(TOKEN_LENGTH);
476 let _login_token_expires_in = services
477 .users
478 .create_login_token(&user_id, &login_token);
479
480 let location = next_idp_url
481 .or(session.redirect_url)
482 .as_ref()
483 .ok_or_else(|| err!(Request(InvalidParam("Missing redirect URL in session data"))))?
484 .clone()
485 .query_pairs_mut()
486 .append_pair("loginToken", &login_token)
487 .finish()
488 .to_string();
489
490 Ok(sso_callback::unstable::Response { location, cookie: Some(cookie) })
491}
492
493async fn handle_uiaa(
494 services: &Services,
495 user_id: &UserId,
496 cookie: Cow<'static, str>,
497 redirect_url: &Url,
498) -> Result<sso_callback::unstable::Response> {
499 let uiaa_session_id = redirect_url.path();
500
501 let (user_id, device_id, mut uiaainfo) = services
504 .uiaa
505 .get_uiaa_session_by_session_id(uiaa_session_id)
506 .await
507 .filter(|(db_user_id, ..)| user_id.eq(db_user_id))
508 .ok_or_else(|| err!(Request(Forbidden("UIAA session not found."))))?;
509
510 let has_oauth_flow = uiaainfo
512 .flows
513 .iter()
514 .any(|f| f.stages.contains(&AuthType::OAuth));
515
516 if has_oauth_flow && !uiaainfo.completed.contains(&AuthType::OAuth) {
518 services
520 .users
521 .allow_cross_signing_replacement(&user_id);
522
523 uiaainfo.completed.push(AuthType::OAuth);
524 }
525
526 let has_sso_flow = uiaainfo
528 .flows
529 .iter()
530 .any(|f| f.stages.contains(&AuthType::Sso));
531
532 if has_sso_flow && !uiaainfo.completed.contains(&AuthType::Sso) {
533 uiaainfo.completed.push(AuthType::Sso);
534 }
535
536 services
537 .uiaa
538 .update_uiaa_session(&user_id, &device_id, uiaa_session_id, Some(&uiaainfo));
539
540 let location =
542 format!("/_matrix/client/v3/auth/m.login.sso/fallback/web?session={uiaa_session_id}");
543
544 Ok(sso_callback::unstable::Response { location, cookie: Some(cookie) })
545}
546
547#[tracing::instrument(
548 name = "register",
549 level = INFO_SPAN_LEVEL,
550 skip_all,
551 fields(user_id, userinfo)
552)]
553async fn register_user(
554 services: &Services,
555 provider: &Provider,
556 session: &Session,
557 userinfo: &UserInfo,
558 user_id: &UserId,
559) -> Result {
560 debug_info!(%user_id, "Creating new user account...");
561
562 services
563 .users
564 .full_register(Register {
565 user_id: Some(user_id),
566 password: Some(PASSWORD_SENTINEL),
567 origin: Some("sso"),
568 displayname: userinfo.name.as_deref(),
569 grant_first_user_admin: true,
570 ..Default::default()
571 })
572 .await?;
573
574 if let Some(avatar_url) = userinfo
575 .avatar_url
576 .as_deref()
577 .or(userinfo.picture.as_deref())
578 {
579 set_avatar(services, provider, session, userinfo, user_id, avatar_url)
580 .await
581 .ok();
582 }
583
584 let idp_id = provider.id();
585 let idp_name = provider
586 .name
587 .as_deref()
588 .unwrap_or(provider.brand.as_str());
589
590 let notice =
592 format!("New user \"{user_id}\" registered on this server via {idp_name} ({idp_id})");
593
594 info!("{notice}");
595 if services.server.config.admin_room_notices {
596 services.admin.notice(¬ice).await;
597 }
598
599 Ok(())
600}
601
602#[tracing::instrument(level = "debug", skip_all, fields(user_id, avatar_url))]
603async fn set_avatar(
604 services: &Services,
605 _provider: &Provider,
606 _session: &Session,
607 _userinfo: &UserInfo,
608 user_id: &UserId,
609 avatar_url: &str,
610) -> Result {
611 use reqwest::Response;
612
613 let response = services
614 .client
615 .default
616 .get(avatar_url)
617 .send()
618 .await
619 .and_then(Response::error_for_status)?;
620
621 let content_type = response
622 .headers()
623 .get(CONTENT_TYPE)
624 .map(HeaderValue::to_str)
625 .flat_ok()
626 .map(ToOwned::to_owned);
627
628 let mxc = Mxc {
629 server_name: services.globals.server_name(),
630 media_id: &utils::random_string(MXC_LENGTH),
631 };
632
633 let content_disposition = make_content_disposition(None, content_type.as_deref(), None);
634 let bytes = response.bytes().await?;
635 services
636 .media
637 .create(&mxc, Some(user_id), Some(&content_disposition), content_type.as_deref(), &bytes)
638 .await?;
639
640 let all_joined_rooms: Vec<OwnedRoomId> = services
641 .state_cache
642 .rooms_joined(user_id)
643 .map(ToOwned::to_owned)
644 .collect()
645 .await;
646
647 let mxc_uri: OwnedMxcUri = mxc.to_string().into();
648 services
649 .users
650 .update_avatar_url(
651 user_id,
652 Some(&mxc_uri),
653 None,
654 &all_joined_rooms,
655 propagation_default(
656 services
657 .server
658 .config
659 .preserve_room_profile_overrides,
660 ),
661 )
662 .await;
663
664 Ok(())
665}
666
667#[tracing::instrument(
668 level = "debug",
669 ret(level = "debug")
670 skip_all,
671 fields(user),
672)]
673async fn decide_user_id(
674 services: &Services,
675 provider: &Provider,
676 userinfo: &UserInfo,
677 unique_id: &str,
678) -> Result<OwnedUserId> {
679 if let Some(user_id) = services
680 .oauth
681 .sessions
682 .find_user_association_pending(provider.id(), userinfo)
683 {
684 debug_info!(
685 provider = ?provider.id(),
686 ?user_id,
687 ?userinfo,
688 "Matched pending association"
689 );
690
691 return Ok(user_id);
692 }
693
694 let explicit = |claim: &str| provider.userid_claims.contains(claim);
695
696 let allowed = |claim: &str| provider.userid_claims.is_empty() || explicit(claim);
697
698 let choices = [
699 explicit("sub")
700 .then_some(userinfo.sub.as_str())
701 .map(str::to_lowercase),
702 userinfo
703 .preferred_username
704 .as_deref()
705 .map(str::to_lowercase)
706 .filter(|_| allowed("preferred_username")),
707 userinfo
708 .username
709 .as_deref()
710 .map(str::to_lowercase)
711 .filter(|_| allowed("username")),
712 userinfo
713 .nickname
714 .as_deref()
715 .map(str::to_lowercase)
716 .filter(|_| allowed("nickname")),
717 provider
718 .brand
719 .eq(&"github")
720 .then_some(userinfo.sub.as_str())
721 .map(str::to_lowercase)
722 .filter(|_| allowed("login")),
723 userinfo
724 .email
725 .as_deref()
726 .and_then(|email| email.split_once('@'))
727 .map(at!(0))
728 .map(str::to_lowercase)
729 .filter(|_| allowed("email")),
730 ];
731
732 for choice in choices.into_iter().flatten() {
733 if let Some(user_id) = try_user_id(services, provider, &choice, false).await {
734 return Ok(user_id);
735 }
736 }
737
738 let length = Some(15..23);
739 let unique_id = truncate_deterministic(unique_id, length).to_lowercase();
740 if let Some(user_id) = try_user_id(services, provider, &unique_id, true).await {
741 return Ok(user_id);
742 }
743
744 Err!(Request(UserInUse("User ID is not available.")))
745}
746
747#[tracing::instrument(level = "debug", skip_all, fields(username))]
748async fn try_user_id(
749 services: &Services,
750 provider: &Provider,
751 username: &str,
752 unique_id: bool,
753) -> Option<OwnedUserId> {
754 let server_name = services.globals.server_name();
755 let user_id = parse_user_id(server_name, username)
756 .inspect_err(|e| warn!(?username, "Username invalid: {e}"))
757 .ok()?;
758
759 if services
760 .config
761 .forbidden_usernames
762 .is_match(username)
763 {
764 warn!(?username, "Username forbidden.");
765 return None;
766 }
767
768 if services.users.exists(&user_id).await {
769 if provider.trusted {
770 info!(
771 ?username,
772 provider = ?provider.brand,
773 "Authorizing trusted provider access to existing account."
774 );
775
776 return Some(user_id);
777 }
778
779 if services
780 .users
781 .origin(&user_id)
782 .await
783 .ok()
784 .is_none_or(|origin| origin != "sso")
785 {
786 debug_warn!(?username, "Existing username has non-sso origin.");
787 return None;
788 }
789
790 if !unique_id {
791 debug_warn!(?username, "Username exists.");
792 return None;
793 }
794 } else if unique_id && !provider.unique_id_fallbacks {
795 debug_warn!(
796 ?username,
797 provider = ?provider.brand,
798 "Unique ID fallbacks disabled.",
799 );
800
801 return None;
802 }
803
804 Some(user_id)
805}
806
807fn parse_user_id(server_name: &ServerName, username: &str) -> Result<OwnedUserId> {
808 match UserId::parse_with_server_name(username, server_name) {
809 | Err(e) => {
810 Err!(Request(InvalidUsername(debug_error!("Username {username} is not valid: {e}"))))
811 },
812 | Ok(user_id) => match user_id.validate_strict() {
813 | Ok(()) => Ok(user_id),
814 | Err(e) => Err!(Request(InvalidUsername(debug_error!(
815 "Username {username} contains disallowed characters or spaces: {e}"
816 )))),
817 },
818 }
819}