1mod uiaa;
2
3use std::{borrow::Cow, collections::BTreeMap, net::IpAddr, time::Duration};
4
5use axum::extract::State;
6use axum_extra::extract::cookie::{Cookie, CookieJar, SameSite};
7use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as b64};
8use futures::{FutureExt, TryFutureExt, future::try_join};
9use reqwest::header::{CONTENT_TYPE, HeaderValue};
10use ruma::{
11 Mxc, OwnedMxcUri, OwnedUserId, ServerName, UserId,
12 api::client::{
13 session::{SsoRedirectAction, sso_callback, sso_login, sso_login_with_provider},
14 uiaa::AuthType,
15 },
16};
17use serde::{Deserialize, Serialize};
18use serde_json::Value as JsonValue;
19use tuwunel_core::{
20 Err, Result, at,
21 config::IdentityProvider,
22 debug::INFO_SPAN_LEVEL,
23 debug_info, debug_warn, err, info, is_not_equal_to,
24 itertools::Itertools,
25 utils,
26 utils::{
27 OptionExt,
28 content_disposition::make_content_disposition,
29 hash::sha256,
30 result::{FlatOk, LogErr},
31 string::{EMPTY, truncate_deterministic},
32 timepoint_from_now, timepoint_has_passed,
33 },
34 warn,
35};
36use tuwunel_service::{
37 Services,
38 client::read_response_capped,
39 media::MXC_LENGTH,
40 oauth::{
41 CODE_VERIFIER_LENGTH, Provider, SESSION_ID_LENGTH, Session, TokenResponse, UserInfo,
42 unique_id_sub,
43 },
44 users::{PASSWORD_SENTINEL, Register},
45};
46use url::Url;
47
48pub(crate) use self::uiaa::sso_fallback_route;
49use super::TOKEN_LENGTH;
50use crate::{ClientIp, Ruma};
51
52#[derive(Debug, Serialize)]
54struct GrantQuery<'a> {
55 client_id: &'a str,
56 state: &'a str,
57 nonce: &'a str,
58 scope: &'a str,
59 response_type: &'a str,
60 access_type: &'a str,
61 code_challenge_method: &'a str,
62 code_challenge: &'a str,
63 redirect_uri: Option<&'a str>,
64 #[serde(skip_serializing_if = "Option::is_none")]
65 prompt: Option<&'a str>,
66}
67
68#[derive(Debug, Deserialize, Serialize)]
69struct GrantCookie<'a> {
70 client_id: Cow<'a, str>,
71 state: Cow<'a, str>,
72 nonce: Cow<'a, str>,
73 redirect_uri: Cow<'a, str>,
74}
75
76static GRANT_SESSION_COOKIE: &str = "tuwunel_grant_session";
77
78fn decode_apple_userinfo_from_id_token(session: &Session) -> Result<UserInfo> {
79 let id_token = session.id_token.as_deref().ok_or_else(|| {
80 err!(Request(Unauthorized("Missing Apple id_token in token response.")))
81 })?;
82
83 let payload_b64 = id_token
84 .split('.')
85 .nth(1)
86 .ok_or_else(|| err!(Request(Unauthorized("Apple id_token is malformed."))))?;
87
88 let payload = b64
89 .decode(payload_b64)
90 .map_err(|_| err!(Request(Unauthorized("Apple id_token payload is invalid base64."))))?;
91
92 let payload: JsonValue = serde_json::from_slice(&payload)
93 .map_err(|_| err!(Request(Unauthorized("Apple id_token payload is not valid JSON."))))?;
94
95 let sub = payload
96 .get("sub")
97 .and_then(JsonValue::as_str)
98 .ok_or_else(|| {
99 err!(Request(Unauthorized("Apple id_token missing required sub claim.")))
100 })?;
101
102 let email = payload
103 .get("email")
104 .and_then(JsonValue::as_str)
105 .map(ToOwned::to_owned);
106
107 let preferred_username = email
108 .as_deref()
109 .and_then(|value| value.split_once('@'))
110 .map(at!(0))
111 .map(ToOwned::to_owned);
112
113 Ok(UserInfo {
114 sub: sub.to_owned(),
115 preferred_username: preferred_username.clone(),
116 username: preferred_username,
117 nickname: None,
118 name: payload
119 .get("name")
120 .and_then(JsonValue::as_str)
121 .map(ToOwned::to_owned),
122 given_name: payload
123 .get("given_name")
124 .and_then(JsonValue::as_str)
125 .map(ToOwned::to_owned),
126 family_name: payload
127 .get("family_name")
128 .and_then(JsonValue::as_str)
129 .map(ToOwned::to_owned),
130 email,
131 avatar_url: None,
132 picture: None,
133 })
134}
135
136#[tracing::instrument(
141 name = "sso_login",
142 level = "debug",
143 skip_all,
144 fields(%client),
145)]
146pub(crate) async fn sso_login_route(
147 State(services): State<crate::State>,
148 ClientIp(client): ClientIp,
149 body: Ruma<sso_login::v3::Request>,
150) -> Result<sso_login::v3::Response> {
151 if services.config.sso_custom_providers_page {
152 return Err!(Request(NotImplemented(
153 "sso_custom_providers_page has been enabled but this URL has not been overridden \
154 with any custom page listing the available providers..."
155 )));
156 }
157
158 let redirect_url = body.body.redirect_url;
159 let action = body.body.action;
160 let default_idp_id = services
161 .oauth
162 .providers
163 .get_default_id()
164 .unwrap_or_default();
165
166 handle_sso_login(&services, &client, default_idp_id, redirect_url, None, action)
167 .map_ok(|response| sso_login::v3::Response {
168 location: response.location,
169 cookie: response.cookie,
170 })
171 .await
172}
173
174#[tracing::instrument(
180 name = "sso_login_with_provider",
181 level = "info",
182 skip_all,
183 ret(level = "debug")
184 fields(
185 %client,
186 idp_id = body.body.idp_id,
187 ),
188)]
189pub(crate) async fn sso_login_with_provider_route(
190 State(services): State<crate::State>,
191 ClientIp(client): ClientIp,
192 body: Ruma<sso_login_with_provider::v3::Request>,
193) -> Result<sso_login_with_provider::v3::Response> {
194 let idp_id = body.body.idp_id;
195 let redirect_url = body.body.redirect_url;
196 let login_token = body.body.login_token;
197 let action = body.body.action;
198
199 handle_sso_login(&services, &client, idp_id, redirect_url, login_token, action).await
200}
201
202async fn handle_sso_login(
203 services: &Services,
204 _client: &IpAddr,
205 idp_id: String,
206 redirect_url: String,
207 login_token: Option<String>,
208 action: Option<SsoRedirectAction>,
209) -> Result<sso_login_with_provider::v3::Response> {
210 let redirect_url: Url = redirect_url.parse().map_err(|e| {
211 err!(Request(InvalidParam(debug_warn!(
212 ?e,
213 ?redirect_url,
214 "Failed to parse redirect_url.",
215 ))))
216 })?;
217
218 let provider = services.oauth.providers.get(&idp_id).await?;
219 let sess_id = utils::random_string(SESSION_ID_LENGTH);
220 let query_nonce = utils::random_string(CODE_VERIFIER_LENGTH);
221 let cookie_nonce = utils::random_string(CODE_VERIFIER_LENGTH);
222 let code_verifier = utils::random_string(CODE_VERIFIER_LENGTH);
223 let code_challenge = b64.encode(sha256::hash(code_verifier.as_bytes()));
224 let callback_uri = provider.callback_url.as_ref().map(Url::as_str);
225 let scope = provider.scope.iter().join(" ");
226 let prompt = action
227 .filter(|_| provider.forward_action_prompt)
228 .and_then(|action| matches!(action, SsoRedirectAction::Register).then_some("create"));
229
230 let query = GrantQuery {
231 client_id: &provider.client_id,
232 state: &sess_id,
233 nonce: &query_nonce,
234 access_type: "online",
235 response_type: "code",
236 code_challenge_method: "S256",
237 code_challenge: &code_challenge,
238 redirect_uri: callback_uri,
239 prompt,
240 scope: scope
241 .is_empty()
242 .then_some("openid email profile")
243 .unwrap_or(scope.as_str()),
244 };
245
246 let location = provider
247 .authorization_url
248 .clone()
249 .map(|mut location| {
250 let query = serde_html_form::to_string(&query).ok();
251 location.set_query(query.as_deref());
252 if !provider.extra_authorization_parameters.is_empty() {
253 let merged: BTreeMap<String, String> = provider
255 .extra_authorization_parameters
256 .clone()
257 .into_iter()
258 .chain(
259 location
260 .query_pairs()
261 .map(|(k, v)| (k.into_owned(), v.into_owned())),
262 )
263 .collect();
264
265 location.set_query(None);
266 location.query_pairs_mut().extend_pairs(&merged);
267 }
268 location
269 })
270 .ok_or_else(|| {
271 err!(Config("authorization_url", "Missing required IdentityProvider config"))
272 })?;
273
274 let cookie_val = GrantCookie {
275 client_id: query.client_id.into(),
276 state: query.state.into(),
277 nonce: cookie_nonce.as_str().into(),
278 redirect_uri: redirect_url.as_str().into(),
279 };
280
281 let cookie_path = provider
282 .callback_url
283 .as_ref()
284 .map(Url::path)
285 .unwrap_or("/");
286
287 let cookie_max_age = provider
288 .grant_session_duration
289 .map(Duration::from_secs)
290 .expect("Defaulted to Some value during configure_idp()")
291 .try_into()
292 .expect("std::time::Duration to time::Duration conversion failure");
293
294 let cookie = Cookie::build((GRANT_SESSION_COOKIE, serde_html_form::to_string(&cookie_val)?))
295 .path(cookie_path)
296 .max_age(cookie_max_age)
297 .same_site(SameSite::None)
298 .secure(true)
299 .http_only(true)
300 .build()
301 .to_string()
302 .into();
303
304 let session = Session {
305 idp_id: Some(idp_id),
306 sess_id: Some(sess_id.clone()),
307 redirect_url: Some(redirect_url),
308 code_verifier: Some(code_verifier),
309 query_nonce: Some(query_nonce),
310 cookie_nonce: Some(cookie_nonce),
311 authorize_expires_at: provider
312 .grant_session_duration
313 .map(Duration::from_secs)
314 .map(timepoint_from_now)
315 .transpose()?,
316
317 user_id: login_token
318 .as_deref()
319 .map_async(|token| services.users.find_from_login_token(token))
320 .map(FlatOk::flat_ok)
321 .await,
322
323 ..Default::default()
324 };
325
326 services.oauth.sessions.put(&session).await;
327
328 Ok(sso_login_with_provider::v3::Response {
329 location: location.into(),
330 cookie: Some(cookie),
331 })
332}
333
334#[tracing::instrument(
335 name = "sso_callback"
336 level = "debug",
337 skip_all,
338 fields(
339 %client,
340 cookie = ?body.cookie,
341 body = ?body.body,
342 ),
343)]
344pub(crate) async fn sso_callback_route(
345 State(services): State<crate::State>,
346 ClientIp(client): ClientIp,
347 body: Ruma<sso_callback::unstable::Request>,
348) -> Result<sso_callback::unstable::Response> {
349 let sess_id = body
350 .body
351 .state
352 .as_deref()
353 .ok_or_else(|| err!(Request(Forbidden("Missing sess_id in callback."))))?;
354
355 let code = body
356 .body
357 .code
358 .as_deref()
359 .ok_or_else(|| err!(Request(Forbidden("Missing code in callback."))))?;
360
361 let session = services
362 .oauth
363 .sessions
364 .get(sess_id)
365 .map_err(|_| err!(Request(Forbidden("Invalid state in callback"))));
366
367 let provider = services
368 .oauth
369 .providers
370 .get(body.body.idp_id.as_str());
371
372 let (provider, session) = try_join(provider, session).await.log_err()?;
373 let idp_id = provider.id();
374
375 if session.sess_id.as_deref() != Some(sess_id) {
376 return Err!(Request(Unauthorized("Session ID {sess_id:?} not recognized.")));
377 }
378
379 if session.idp_id.as_deref() != Some(idp_id) {
380 return Err!(Request(Unauthorized(
381 "Identity Provider {idp_id:?} session not recognized."
382 )));
383 }
384
385 if session
386 .authorize_expires_at
387 .is_some_and(timepoint_has_passed)
388 {
389 return Err!(Request(Unauthorized("Authorization grant session has expired.")));
390 }
391
392 if provider.check_cookie {
393 validate_session_cookie(&body.cookie, &provider, &session, sess_id)?;
394 }
395
396 let token_response = services
397 .oauth
398 .request_token((&provider, &session), code)
399 .await?;
400
401 let session = apply_token_response(session, token_response)?;
402
403 let userinfo = services
404 .oauth
405 .request_userinfo((&provider, &session))
406 .await
407 .or_else(|error| {
408 if provider.brand != "appleoidc" {
409 return Err(error);
410 }
411
412 debug_warn!(
413 ?error,
414 idp_id = provider.id(),
415 "Failed to fetch Apple userinfo endpoint; falling back to id_token claims.",
416 );
417
418 decode_apple_userinfo_from_id_token(&session).map_err(|decode_error| {
419 debug_warn!(
420 ?decode_error,
421 idp_id = provider.id(),
422 "Failed to decode Apple id_token fallback.",
423 );
424 error
425 })
426 })?;
427
428 let unique_id = unique_id_sub((&provider, &userinfo.sub))?;
429
430 let (old_user_id, old_sess_id) = existing_identity_session(&services, &unique_id).await?;
431
432 let session = Session {
433 user_info: Some(userinfo.clone()),
434 ..session
435 };
436
437 let user_id = match (session.user_id, old_user_id) {
438 | (Some(user_id), ..) | (None, Some(user_id)) => user_id,
439 | (None, None) => decide_user_id(&services, &provider, &userinfo, &unique_id).await?,
440 };
441
442 let session = Session {
443 user_id: Some(user_id.clone()),
444 ..session
445 };
446
447 if !services.users.exists(&user_id).await {
448 if !provider.registration {
449 return Err!(Request(Forbidden("Registration from this provider is disabled")));
450 }
451
452 register_user(&services, &provider, &session, &userinfo, &user_id).await?;
453 }
454
455 services.oauth.sessions.put(&session).await;
456
457 if let Some(old_sess_id) = old_sess_id
458 .as_deref()
459 .filter(is_not_equal_to!(&sess_id))
460 {
461 services.oauth.sessions.delete(old_sess_id).await;
462 }
463
464 if !services.users.is_active_local(&user_id).await {
465 return Err!(Request(UserDeactivated("This user has been deactivated.")));
466 }
467
468 let cookie = Cookie::build((GRANT_SESSION_COOKIE, EMPTY))
469 .removal()
470 .build()
471 .to_string()
472 .into();
473
474 if let Some(redirect_url) = session
475 .redirect_url
476 .as_ref()
477 .filter(|url| url.scheme() == "uiaa")
478 {
479 return handle_uiaa(&services, &user_id, cookie, redirect_url).await;
480 }
481
482 let next_idp_url = chain_next_idp_url(&services, &provider, &session, idp_id);
483
484 let location = finalize_login_redirect(&services, &session, next_idp_url, &user_id)?;
485
486 Ok(sso_callback::unstable::Response { location, cookie: Some(cookie) })
487}
488
489fn validate_session_cookie(
490 cookies: &CookieJar,
491 provider: &Provider,
492 session: &Session,
493 sess_id: &str,
494) -> Result {
495 let client_id = &provider.client_id;
496 let cookie = cookies
497 .get(GRANT_SESSION_COOKIE)
498 .map(Cookie::value)
499 .map(serde_html_form::from_str::<GrantCookie<'_>>)
500 .transpose()?
501 .ok_or_else(|| err!(Request(Unauthorized("Missing cookie {GRANT_SESSION_COOKIE:?}"))))?;
502
503 if cookie.client_id.as_ref() != client_id.as_str() {
504 return Err!(Request(Unauthorized("Client ID {client_id:?} cookie mismatch.")));
505 }
506
507 if Some(cookie.nonce.as_ref()) != session.cookie_nonce.as_deref() {
508 return Err!(Request(Unauthorized("Cookie nonce does not match session state.")));
509 }
510
511 if cookie.state.as_ref() != sess_id {
512 return Err!(Request(Unauthorized("Session ID {sess_id:?} cookie mismatch.")));
513 }
514
515 Ok(())
516}
517
518fn apply_token_response(session: Session, token: TokenResponse) -> Result<Session> {
519 let expires_at = token
520 .expires_in
521 .map(Duration::from_secs)
522 .map(timepoint_from_now)
523 .transpose()?;
524
525 let refresh_token_expires_at = token
526 .refresh_token_expires_in
527 .map(Duration::from_secs)
528 .map(timepoint_from_now)
529 .transpose()?;
530
531 Ok(Session {
532 scope: token.scope,
533 token_type: token.token_type,
534 access_token: token.access_token,
535 id_token: token.id_token,
536 expires_at,
537 refresh_token: token.refresh_token,
538 refresh_token_expires_at,
539 ..session
540 })
541}
542
543async fn existing_identity_session(
546 services: &Services,
547 unique_id: &str,
548) -> Result<(Option<OwnedUserId>, Option<String>)> {
549 match services
550 .oauth
551 .sessions
552 .get_by_unique_id(unique_id)
553 .await
554 {
555 | Ok(session) => Ok((session.user_id, session.sess_id)),
556 | Err(error) if !error.is_not_found() => Err(error),
557 | Err(_) => Ok((None, None)),
558 }
559}
560
561fn chain_next_idp_url(
562 services: &Services,
563 provider: &Provider,
564 session: &Session,
565 idp_id: &str,
566) -> Option<Url> {
567 services
568 .config
569 .identity_provider
570 .values()
571 .filter(|idp| idp.default || services.config.single_sso)
572 .skip_while(|idp| idp.id() != idp_id)
573 .nth(1)
574 .map(IdentityProvider::id)
575 .and_then(|next_idp| {
576 provider.callback_url.clone().map(|mut url| {
577 let path = format!("/_matrix/client/v3/login/sso/redirect/{next_idp}");
578 url.set_path(&path);
579
580 if let Some(redirect_url) = session.redirect_url.as_ref() {
581 url.query_pairs_mut()
582 .append_pair("redirectUrl", redirect_url.as_str());
583 }
584
585 url
586 })
587 })
588}
589
590fn finalize_login_redirect(
591 services: &Services,
592 session: &Session,
593 next_idp_url: Option<Url>,
594 user_id: &UserId,
595) -> Result<String> {
596 let login_token = utils::random_string(TOKEN_LENGTH);
597 let _login_token_expires_in = services
598 .users
599 .create_login_token(user_id, &login_token);
600
601 let location = next_idp_url
602 .or_else(|| session.redirect_url.clone())
603 .ok_or_else(|| err!(Request(InvalidParam("Missing redirect URL in session data"))))?
604 .query_pairs_mut()
605 .append_pair("loginToken", &login_token)
606 .finish()
607 .to_string();
608
609 Ok(location)
610}
611
612async fn handle_uiaa(
613 services: &Services,
614 user_id: &UserId,
615 cookie: Cow<'static, str>,
616 redirect_url: &Url,
617) -> Result<sso_callback::unstable::Response> {
618 let uiaa_session_id = redirect_url.path();
619
620 let (user_id, device_id, mut uiaainfo) = services
623 .uiaa
624 .get_uiaa_session_by_session_id(uiaa_session_id)
625 .await
626 .filter(|(db_user_id, ..)| user_id.eq(db_user_id))
627 .ok_or_else(|| err!(Request(Forbidden("UIAA session not found."))))?;
628
629 let has_oauth_flow = uiaainfo
631 .flows
632 .iter()
633 .any(|f| f.stages.contains(&AuthType::OAuth));
634
635 if has_oauth_flow && !uiaainfo.completed.contains(&AuthType::OAuth) {
637 services
639 .users
640 .allow_cross_signing_replacement(&user_id);
641
642 uiaainfo.completed.push(AuthType::OAuth);
643 }
644
645 let has_sso_flow = uiaainfo
647 .flows
648 .iter()
649 .any(|f| f.stages.contains(&AuthType::Sso));
650
651 if has_sso_flow && !uiaainfo.completed.contains(&AuthType::Sso) {
652 uiaainfo.completed.push(AuthType::Sso);
653 }
654
655 services
656 .uiaa
657 .update_uiaa_session(&user_id, &device_id, uiaa_session_id, Some(&uiaainfo));
658
659 let location =
661 format!("/_matrix/client/v3/auth/m.login.sso/fallback/web?session={uiaa_session_id}");
662
663 Ok(sso_callback::unstable::Response { location, cookie: Some(cookie) })
664}
665
666#[tracing::instrument(
667 name = "register",
668 level = INFO_SPAN_LEVEL,
669 skip_all,
670 fields(user_id, userinfo)
671)]
672async fn register_user(
673 services: &Services,
674 provider: &Provider,
675 session: &Session,
676 userinfo: &UserInfo,
677 user_id: &UserId,
678) -> Result {
679 debug_info!(%user_id, "Creating new user account...");
680
681 services
682 .users
683 .full_register(Register {
684 user_id: Some(user_id),
685 password: Some(PASSWORD_SENTINEL),
686 origin: Some("sso"),
687 displayname: userinfo.name.as_deref(),
688 grant_first_user_admin: true,
689 ..Default::default()
690 })
691 .await?;
692
693 if let Some(avatar_url) = userinfo
694 .avatar_url
695 .as_deref()
696 .or(userinfo.picture.as_deref())
697 {
698 set_avatar(services, provider, session, userinfo, user_id, avatar_url)
699 .await
700 .ok();
701 }
702
703 let idp_id = provider.id();
704 let idp_name = provider
705 .name
706 .as_deref()
707 .unwrap_or(provider.brand.as_str());
708
709 let notice =
711 format!("New user \"{user_id}\" registered on this server via {idp_name} ({idp_id})");
712
713 info!("{notice}");
714 if services.server.config.admin_room_notices {
715 services.admin.notice(¬ice).await;
716 }
717
718 Ok(())
719}
720
721#[tracing::instrument(level = "debug", skip_all, fields(user_id, avatar_url))]
722async fn set_avatar(
723 services: &Services,
724 _provider: &Provider,
725 _session: &Session,
726 _userinfo: &UserInfo,
727 user_id: &UserId,
728 avatar_url: &str,
729) -> Result {
730 use reqwest::Response;
731
732 let response = services
733 .client
734 .default
735 .get(avatar_url)
736 .send()
737 .await
738 .and_then(Response::error_for_status)?;
739
740 let content_type = response
741 .headers()
742 .get(CONTENT_TYPE)
743 .map(HeaderValue::to_str)
744 .flat_ok()
745 .map(ToOwned::to_owned);
746
747 let mxc = Mxc {
748 server_name: services.globals.server_name(),
749 media_id: &utils::random_string(MXC_LENGTH),
750 };
751
752 let content_disposition = make_content_disposition(None, content_type.as_deref(), None);
753 let limit = services.server.config.max_response_size;
754 let bytes = read_response_capped(response, limit).await?;
755 services
756 .media
757 .create(&mxc, Some(user_id), Some(&content_disposition), content_type.as_deref(), &bytes)
758 .await?;
759
760 let mxc_uri: OwnedMxcUri = mxc.to_string().into();
761 services
762 .profile
763 .set_avatar_url(user_id, Some(&mxc_uri), None)
764 .await?;
765
766 Ok(())
767}
768
769#[tracing::instrument(
770 level = "debug",
771 ret(level = "debug")
772 skip_all,
773 fields(user),
774)]
775async fn decide_user_id(
776 services: &Services,
777 provider: &Provider,
778 userinfo: &UserInfo,
779 unique_id: &str,
780) -> Result<OwnedUserId> {
781 if let Some(user_id) = services
782 .oauth
783 .sessions
784 .find_user_association_pending(provider.id(), userinfo)
785 {
786 debug_info!(
787 provider = ?provider.id(),
788 ?user_id,
789 ?userinfo,
790 "Matched pending association"
791 );
792
793 return Ok(user_id);
794 }
795
796 let explicit = |claim: &str| provider.userid_claims.contains(claim);
797
798 let allowed = |claim: &str| provider.userid_claims.is_empty() || explicit(claim);
799
800 let choices = [
801 explicit("sub")
802 .then_some(userinfo.sub.as_str())
803 .map(str::to_lowercase),
804 userinfo
805 .preferred_username
806 .as_deref()
807 .map(str::to_lowercase)
808 .filter(|_| allowed("preferred_username")),
809 userinfo
810 .username
811 .as_deref()
812 .map(str::to_lowercase)
813 .filter(|_| allowed("username")),
814 userinfo
815 .nickname
816 .as_deref()
817 .map(str::to_lowercase)
818 .filter(|_| allowed("nickname")),
819 provider
820 .brand
821 .eq(&"github")
822 .then_some(userinfo.sub.as_str())
823 .map(str::to_lowercase)
824 .filter(|_| allowed("login")),
825 userinfo
826 .email
827 .as_deref()
828 .and_then(|email| email.split_once('@'))
829 .map(at!(0))
830 .map(str::to_lowercase)
831 .filter(|_| allowed("email")),
832 ];
833
834 for choice in choices.into_iter().flatten() {
835 if let Some(user_id) = try_user_id(services, provider, &choice, false).await {
836 return Ok(user_id);
837 }
838 }
839
840 let length = Some(15..23);
841 let unique_id = truncate_deterministic(unique_id, length).to_lowercase();
842 if let Some(user_id) = try_user_id(services, provider, &unique_id, true).await {
843 return Ok(user_id);
844 }
845
846 Err!(Request(UserInUse("User ID is not available.")))
847}
848
849#[tracing::instrument(level = "debug", skip_all, fields(username))]
850async fn try_user_id(
851 services: &Services,
852 provider: &Provider,
853 username: &str,
854 unique_id: bool,
855) -> Option<OwnedUserId> {
856 let server_name = services.globals.server_name();
857 let user_id = parse_user_id(server_name, username)
858 .inspect_err(|e| warn!(?username, "Username invalid: {e}"))
859 .ok()?;
860
861 if services
862 .config
863 .forbidden_usernames
864 .is_match(username)
865 {
866 warn!(?username, "Username forbidden.");
867 return None;
868 }
869
870 if services.users.exists(&user_id).await {
871 if provider.trusted {
872 info!(
873 ?username,
874 provider = ?provider.brand,
875 "Authorizing trusted provider access to existing account."
876 );
877
878 return Some(user_id);
879 }
880
881 if services
882 .users
883 .origin(&user_id)
884 .await
885 .ok()
886 .is_none_or(|origin| origin != "sso")
887 {
888 debug_warn!(?username, "Existing username has non-sso origin.");
889 return None;
890 }
891
892 if !unique_id {
893 debug_warn!(?username, "Username exists.");
894 return None;
895 }
896 } else if unique_id && !provider.unique_id_fallbacks {
897 debug_warn!(
898 ?username,
899 provider = ?provider.brand,
900 "Unique ID fallbacks disabled.",
901 );
902
903 return None;
904 }
905
906 Some(user_id)
907}
908
909fn parse_user_id(server_name: &ServerName, username: &str) -> Result<OwnedUserId> {
910 match UserId::parse_with_server_name(username, server_name) {
911 | Err(e) => {
912 Err!(Request(InvalidUsername(debug_error!("Username {username} is not valid: {e}"))))
913 },
914 | Ok(user_id) => match user_id.validate_strict() {
915 | Ok(()) => Ok(user_id),
916 | Err(e) => Err!(Request(InvalidUsername(debug_error!(
917 "Username {username} contains disallowed characters or spaces: {e}"
918 )))),
919 },
920 }
921}
922
923#[cfg(test)]
924mod tests {
925 use serde_json::json;
926
927 use super::*;
928
929 fn apple_session_with_claims(claims: &serde_json::Value) -> Session {
930 let payload = b64.encode(serde_json::to_vec(claims).expect("serialize claims"));
931
932 Session {
933 id_token: Some(format!("header.{payload}.signature")),
934 ..Default::default()
935 }
936 }
937
938 #[test]
939 fn decode_apple_userinfo_from_id_token_extracts_expected_claims() {
940 let session = apple_session_with_claims(&json!({
941 "sub": "apple-user-123",
942 "email": "alice@example.com",
943 "name": "Alice Example",
944 "given_name": "Alice",
945 "family_name": "Example"
946 }));
947
948 let userinfo =
949 decode_apple_userinfo_from_id_token(&session).expect("decode Apple id_token claims");
950
951 assert_eq!(userinfo.sub, "apple-user-123");
952 assert_eq!(userinfo.email.as_deref(), Some("alice@example.com"));
953 assert_eq!(userinfo.preferred_username.as_deref(), Some("alice"));
954 assert_eq!(userinfo.username.as_deref(), Some("alice"));
955 assert_eq!(userinfo.name.as_deref(), Some("Alice Example"));
956 assert_eq!(userinfo.given_name.as_deref(), Some("Alice"));
957 assert_eq!(userinfo.family_name.as_deref(), Some("Example"));
958 }
959
960 #[test]
961 fn decode_apple_userinfo_from_id_token_requires_sub_claim() {
962 let session = apple_session_with_claims(&json!({
963 "email": "alice@example.com"
964 }));
965
966 let error = decode_apple_userinfo_from_id_token(&session)
967 .expect_err("missing sub claim should fail");
968
969 let message = format!("{error}");
970 assert!(message.contains("sub claim"), "unexpected error: {message}");
971 }
972
973 #[test]
974 fn decode_apple_userinfo_from_id_token_requires_id_token() {
975 let session = Session::default();
976
977 let error = decode_apple_userinfo_from_id_token(&session)
978 .expect_err("missing id_token should fail");
979
980 let message = format!("{error}");
981 assert!(message.contains("Missing Apple id_token"), "unexpected error: {message}");
982 }
983
984 #[test]
985 fn decode_apple_userinfo_from_id_token_rejects_invalid_payload() {
986 let session = Session {
987 id_token: Some("header.!.signature".to_owned()),
988 ..Default::default()
989 };
990
991 let error = decode_apple_userinfo_from_id_token(&session)
992 .expect_err("invalid id_token payload should fail");
993
994 let message = format!("{error}");
995 assert!(message.contains("invalid base64"), "unexpected error: {message}");
996 }
997}