Skip to main content

tuwunel_api/client/session/
sso.rs

1mod uiaa;
2
3use std::{borrow::Cow, collections::BTreeMap, net::IpAddr, time::Duration};
4
5use axum::extract::State;
6use axum_extra::extract::cookie::{Cookie, CookieJar, SameSite};
7use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as b64};
8use futures::{FutureExt, TryFutureExt, future::try_join};
9use reqwest::header::{CONTENT_TYPE, HeaderValue};
10use ruma::{
11	Mxc, OwnedMxcUri, OwnedUserId, ServerName, UserId,
12	api::client::{
13		session::{SsoRedirectAction, sso_callback, sso_login, sso_login_with_provider},
14		uiaa::AuthType,
15	},
16};
17use serde::{Deserialize, Serialize};
18use serde_json::Value as JsonValue;
19use tuwunel_core::{
20	Err, Result, at,
21	config::IdentityProvider,
22	debug::INFO_SPAN_LEVEL,
23	debug_info, debug_warn, err, info, is_not_equal_to,
24	itertools::Itertools,
25	utils,
26	utils::{
27		OptionExt,
28		content_disposition::make_content_disposition,
29		hash::sha256,
30		result::{FlatOk, LogErr},
31		string::{EMPTY, truncate_deterministic},
32		timepoint_from_now, timepoint_has_passed,
33	},
34	warn,
35};
36use tuwunel_service::{
37	Services,
38	client::read_response_capped,
39	media::MXC_LENGTH,
40	oauth::{
41		CODE_VERIFIER_LENGTH, Provider, SESSION_ID_LENGTH, Session, TokenResponse, UserInfo,
42		unique_id_sub,
43	},
44	users::{PASSWORD_SENTINEL, Register},
45};
46use url::Url;
47
48pub(crate) use self::uiaa::sso_fallback_route;
49use super::TOKEN_LENGTH;
50use crate::{ClientIp, Ruma};
51
52/// Grant phase query string.
53#[derive(Debug, Serialize)]
54struct GrantQuery<'a> {
55	client_id: &'a str,
56	state: &'a str,
57	nonce: &'a str,
58	scope: &'a str,
59	response_type: &'a str,
60	access_type: &'a str,
61	code_challenge_method: &'a str,
62	code_challenge: &'a str,
63	redirect_uri: Option<&'a str>,
64	#[serde(skip_serializing_if = "Option::is_none")]
65	prompt: Option<&'a str>,
66}
67
68#[derive(Debug, Deserialize, Serialize)]
69struct GrantCookie<'a> {
70	client_id: Cow<'a, str>,
71	state: Cow<'a, str>,
72	nonce: Cow<'a, str>,
73	redirect_uri: Cow<'a, str>,
74}
75
76static GRANT_SESSION_COOKIE: &str = "tuwunel_grant_session";
77
78fn decode_apple_userinfo_from_id_token(session: &Session) -> Result<UserInfo> {
79	let id_token = session.id_token.as_deref().ok_or_else(|| {
80		err!(Request(Unauthorized("Missing Apple id_token in token response.")))
81	})?;
82
83	let payload_b64 = id_token
84		.split('.')
85		.nth(1)
86		.ok_or_else(|| err!(Request(Unauthorized("Apple id_token is malformed."))))?;
87
88	let payload = b64
89		.decode(payload_b64)
90		.map_err(|_| err!(Request(Unauthorized("Apple id_token payload is invalid base64."))))?;
91
92	let payload: JsonValue = serde_json::from_slice(&payload)
93		.map_err(|_| err!(Request(Unauthorized("Apple id_token payload is not valid JSON."))))?;
94
95	let sub = payload
96		.get("sub")
97		.and_then(JsonValue::as_str)
98		.ok_or_else(|| {
99			err!(Request(Unauthorized("Apple id_token missing required sub claim.")))
100		})?;
101
102	let email = payload
103		.get("email")
104		.and_then(JsonValue::as_str)
105		.map(ToOwned::to_owned);
106
107	let preferred_username = email
108		.as_deref()
109		.and_then(|value| value.split_once('@'))
110		.map(at!(0))
111		.map(ToOwned::to_owned);
112
113	Ok(UserInfo {
114		sub: sub.to_owned(),
115		preferred_username: preferred_username.clone(),
116		username: preferred_username,
117		nickname: None,
118		name: payload
119			.get("name")
120			.and_then(JsonValue::as_str)
121			.map(ToOwned::to_owned),
122		given_name: payload
123			.get("given_name")
124			.and_then(JsonValue::as_str)
125			.map(ToOwned::to_owned),
126		family_name: payload
127			.get("family_name")
128			.and_then(JsonValue::as_str)
129			.map(ToOwned::to_owned),
130		email,
131		avatar_url: None,
132		picture: None,
133	})
134}
135
136/// # `GET /_matrix/client/v3/login/sso/redirect`
137///
138/// A web-based Matrix client should instruct the user’s browser to navigate to
139/// this endpoint in order to log in via SSO.
140#[tracing::instrument(
141	name = "sso_login",
142	level = "debug",
143	skip_all,
144	fields(%client),
145)]
146pub(crate) async fn sso_login_route(
147	State(services): State<crate::State>,
148	ClientIp(client): ClientIp,
149	body: Ruma<sso_login::v3::Request>,
150) -> Result<sso_login::v3::Response> {
151	if services.config.sso_custom_providers_page {
152		return Err!(Request(NotImplemented(
153			"sso_custom_providers_page has been enabled but this URL has not been overridden \
154			 with any custom page listing the available providers..."
155		)));
156	}
157
158	let redirect_url = body.body.redirect_url;
159	let action = body.body.action;
160	let default_idp_id = services
161		.oauth
162		.providers
163		.get_default_id()
164		.unwrap_or_default();
165
166	handle_sso_login(&services, &client, default_idp_id, redirect_url, None, action)
167		.map_ok(|response| sso_login::v3::Response {
168			location: response.location,
169			cookie: response.cookie,
170		})
171		.await
172}
173
174/// # `GET /_matrix/client/v3/login/sso/redirect/{idpId}`
175///
176/// This endpoint is the same as /login/sso/redirect, though with an IdP ID from
177/// the original identity_providers array to inform the server of which IdP the
178/// client/user would like to continue with.
179#[tracing::instrument(
180	name = "sso_login_with_provider",
181	level = "info",
182	skip_all,
183	ret(level = "debug")
184	fields(
185		%client,
186		idp_id = body.body.idp_id,
187	),
188)]
189pub(crate) async fn sso_login_with_provider_route(
190	State(services): State<crate::State>,
191	ClientIp(client): ClientIp,
192	body: Ruma<sso_login_with_provider::v3::Request>,
193) -> Result<sso_login_with_provider::v3::Response> {
194	let idp_id = body.body.idp_id;
195	let redirect_url = body.body.redirect_url;
196	let login_token = body.body.login_token;
197	let action = body.body.action;
198
199	handle_sso_login(&services, &client, idp_id, redirect_url, login_token, action).await
200}
201
202async fn handle_sso_login(
203	services: &Services,
204	_client: &IpAddr,
205	idp_id: String,
206	redirect_url: String,
207	login_token: Option<String>,
208	action: Option<SsoRedirectAction>,
209) -> Result<sso_login_with_provider::v3::Response> {
210	let redirect_url: Url = redirect_url.parse().map_err(|e| {
211		err!(Request(InvalidParam(debug_warn!(
212			?e,
213			?redirect_url,
214			"Failed to parse redirect_url.",
215		))))
216	})?;
217
218	let provider = services.oauth.providers.get(&idp_id).await?;
219	let sess_id = utils::random_string(SESSION_ID_LENGTH);
220	let query_nonce = utils::random_string(CODE_VERIFIER_LENGTH);
221	let cookie_nonce = utils::random_string(CODE_VERIFIER_LENGTH);
222	let code_verifier = utils::random_string(CODE_VERIFIER_LENGTH);
223	let code_challenge = b64.encode(sha256::hash(code_verifier.as_bytes()));
224	let callback_uri = provider.callback_url.as_ref().map(Url::as_str);
225	let scope = provider.scope.iter().join(" ");
226	let prompt = action
227		.filter(|_| provider.forward_action_prompt)
228		.and_then(|action| matches!(action, SsoRedirectAction::Register).then_some("create"));
229
230	let query = GrantQuery {
231		client_id: &provider.client_id,
232		state: &sess_id,
233		nonce: &query_nonce,
234		access_type: "online",
235		response_type: "code",
236		code_challenge_method: "S256",
237		code_challenge: &code_challenge,
238		redirect_uri: callback_uri,
239		prompt,
240		scope: scope
241			.is_empty()
242			.then_some("openid email profile")
243			.unwrap_or(scope.as_str()),
244	};
245
246	let location = provider
247		.authorization_url
248		.clone()
249		.map(|mut location| {
250			let query = serde_html_form::to_string(&query).ok();
251			location.set_query(query.as_deref());
252			if !provider.extra_authorization_parameters.is_empty() {
253				// Base wins on key collision so extras cannot disable CSRF/PKCE.
254				let merged: BTreeMap<String, String> = provider
255					.extra_authorization_parameters
256					.clone()
257					.into_iter()
258					.chain(
259						location
260							.query_pairs()
261							.map(|(k, v)| (k.into_owned(), v.into_owned())),
262					)
263					.collect();
264
265				location.set_query(None);
266				location.query_pairs_mut().extend_pairs(&merged);
267			}
268			location
269		})
270		.ok_or_else(|| {
271			err!(Config("authorization_url", "Missing required IdentityProvider config"))
272		})?;
273
274	let cookie_val = GrantCookie {
275		client_id: query.client_id.into(),
276		state: query.state.into(),
277		nonce: cookie_nonce.as_str().into(),
278		redirect_uri: redirect_url.as_str().into(),
279	};
280
281	let cookie_path = provider
282		.callback_url
283		.as_ref()
284		.map(Url::path)
285		.unwrap_or("/");
286
287	let cookie_max_age = provider
288		.grant_session_duration
289		.map(Duration::from_secs)
290		.expect("Defaulted to Some value during configure_idp()")
291		.try_into()
292		.expect("std::time::Duration to time::Duration conversion failure");
293
294	let cookie = Cookie::build((GRANT_SESSION_COOKIE, serde_html_form::to_string(&cookie_val)?))
295		.path(cookie_path)
296		.max_age(cookie_max_age)
297		.same_site(SameSite::None)
298		.secure(true)
299		.http_only(true)
300		.build()
301		.to_string()
302		.into();
303
304	let session = Session {
305		idp_id: Some(idp_id),
306		sess_id: Some(sess_id.clone()),
307		redirect_url: Some(redirect_url),
308		code_verifier: Some(code_verifier),
309		query_nonce: Some(query_nonce),
310		cookie_nonce: Some(cookie_nonce),
311		authorize_expires_at: provider
312			.grant_session_duration
313			.map(Duration::from_secs)
314			.map(timepoint_from_now)
315			.transpose()?,
316
317		user_id: login_token
318			.as_deref()
319			.map_async(|token| services.users.find_from_login_token(token))
320			.map(FlatOk::flat_ok)
321			.await,
322
323		..Default::default()
324	};
325
326	services.oauth.sessions.put(&session).await;
327
328	Ok(sso_login_with_provider::v3::Response {
329		location: location.into(),
330		cookie: Some(cookie),
331	})
332}
333
334#[tracing::instrument(
335	name = "sso_callback"
336	level = "debug",
337	skip_all,
338	fields(
339		%client,
340		cookie = ?body.cookie,
341		body = ?body.body,
342	),
343)]
344pub(crate) async fn sso_callback_route(
345	State(services): State<crate::State>,
346	ClientIp(client): ClientIp,
347	body: Ruma<sso_callback::unstable::Request>,
348) -> Result<sso_callback::unstable::Response> {
349	let sess_id = body
350		.body
351		.state
352		.as_deref()
353		.ok_or_else(|| err!(Request(Forbidden("Missing sess_id in callback."))))?;
354
355	let code = body
356		.body
357		.code
358		.as_deref()
359		.ok_or_else(|| err!(Request(Forbidden("Missing code in callback."))))?;
360
361	let session = services
362		.oauth
363		.sessions
364		.get(sess_id)
365		.map_err(|_| err!(Request(Forbidden("Invalid state in callback"))));
366
367	let provider = services
368		.oauth
369		.providers
370		.get(body.body.idp_id.as_str());
371
372	let (provider, session) = try_join(provider, session).await.log_err()?;
373	let idp_id = provider.id();
374
375	if session.sess_id.as_deref() != Some(sess_id) {
376		return Err!(Request(Unauthorized("Session ID {sess_id:?} not recognized.")));
377	}
378
379	if session.idp_id.as_deref() != Some(idp_id) {
380		return Err!(Request(Unauthorized(
381			"Identity Provider {idp_id:?} session not recognized."
382		)));
383	}
384
385	if session
386		.authorize_expires_at
387		.is_some_and(timepoint_has_passed)
388	{
389		return Err!(Request(Unauthorized("Authorization grant session has expired.")));
390	}
391
392	if provider.check_cookie {
393		validate_session_cookie(&body.cookie, &provider, &session, sess_id)?;
394	}
395
396	let token_response = services
397		.oauth
398		.request_token((&provider, &session), code)
399		.await?;
400
401	let session = apply_token_response(session, token_response)?;
402
403	let userinfo = services
404		.oauth
405		.request_userinfo((&provider, &session))
406		.await
407		.or_else(|error| {
408			if provider.brand != "appleoidc" {
409				return Err(error);
410			}
411
412			debug_warn!(
413				?error,
414				idp_id = provider.id(),
415				"Failed to fetch Apple userinfo endpoint; falling back to id_token claims.",
416			);
417
418			decode_apple_userinfo_from_id_token(&session).map_err(|decode_error| {
419				debug_warn!(
420					?decode_error,
421					idp_id = provider.id(),
422					"Failed to decode Apple id_token fallback.",
423				);
424				error
425			})
426		})?;
427
428	let unique_id = unique_id_sub((&provider, &userinfo.sub))?;
429
430	let (old_user_id, old_sess_id) = existing_identity_session(&services, &unique_id).await?;
431
432	let session = Session {
433		user_info: Some(userinfo.clone()),
434		..session
435	};
436
437	let user_id = match (session.user_id, old_user_id) {
438		| (Some(user_id), ..) | (None, Some(user_id)) => user_id,
439		| (None, None) => decide_user_id(&services, &provider, &userinfo, &unique_id).await?,
440	};
441
442	let session = Session {
443		user_id: Some(user_id.clone()),
444		..session
445	};
446
447	if !services.users.exists(&user_id).await {
448		if !provider.registration {
449			return Err!(Request(Forbidden("Registration from this provider is disabled")));
450		}
451
452		register_user(&services, &provider, &session, &userinfo, &user_id).await?;
453	}
454
455	services.oauth.sessions.put(&session).await;
456
457	if let Some(old_sess_id) = old_sess_id
458		.as_deref()
459		.filter(is_not_equal_to!(&sess_id))
460	{
461		services.oauth.sessions.delete(old_sess_id).await;
462	}
463
464	if !services.users.is_active_local(&user_id).await {
465		return Err!(Request(UserDeactivated("This user has been deactivated.")));
466	}
467
468	let cookie = Cookie::build((GRANT_SESSION_COOKIE, EMPTY))
469		.removal()
470		.build()
471		.to_string()
472		.into();
473
474	if let Some(redirect_url) = session
475		.redirect_url
476		.as_ref()
477		.filter(|url| url.scheme() == "uiaa")
478	{
479		return handle_uiaa(&services, &user_id, cookie, redirect_url).await;
480	}
481
482	let next_idp_url = chain_next_idp_url(&services, &provider, &session, idp_id);
483
484	let location = finalize_login_redirect(&services, &session, next_idp_url, &user_id)?;
485
486	Ok(sso_callback::unstable::Response { location, cookie: Some(cookie) })
487}
488
489fn validate_session_cookie(
490	cookies: &CookieJar,
491	provider: &Provider,
492	session: &Session,
493	sess_id: &str,
494) -> Result {
495	let client_id = &provider.client_id;
496	let cookie = cookies
497		.get(GRANT_SESSION_COOKIE)
498		.map(Cookie::value)
499		.map(serde_html_form::from_str::<GrantCookie<'_>>)
500		.transpose()?
501		.ok_or_else(|| err!(Request(Unauthorized("Missing cookie {GRANT_SESSION_COOKIE:?}"))))?;
502
503	if cookie.client_id.as_ref() != client_id.as_str() {
504		return Err!(Request(Unauthorized("Client ID {client_id:?} cookie mismatch.")));
505	}
506
507	if Some(cookie.nonce.as_ref()) != session.cookie_nonce.as_deref() {
508		return Err!(Request(Unauthorized("Cookie nonce does not match session state.")));
509	}
510
511	if cookie.state.as_ref() != sess_id {
512		return Err!(Request(Unauthorized("Session ID {sess_id:?} cookie mismatch.")));
513	}
514
515	Ok(())
516}
517
518fn apply_token_response(session: Session, token: TokenResponse) -> Result<Session> {
519	let expires_at = token
520		.expires_in
521		.map(Duration::from_secs)
522		.map(timepoint_from_now)
523		.transpose()?;
524
525	let refresh_token_expires_at = token
526		.refresh_token_expires_in
527		.map(Duration::from_secs)
528		.map(timepoint_from_now)
529		.transpose()?;
530
531	Ok(Session {
532		scope: token.scope,
533		token_type: token.token_type,
534		access_token: token.access_token,
535		id_token: token.id_token,
536		expires_at,
537		refresh_token: token.refresh_token,
538		refresh_token_expires_at,
539		..session
540	})
541}
542
543/// Locate any prior session bound to the same upstream identity, to preserve
544/// one session and its `user_id` association per identity.
545async fn existing_identity_session(
546	services: &Services,
547	unique_id: &str,
548) -> Result<(Option<OwnedUserId>, Option<String>)> {
549	match services
550		.oauth
551		.sessions
552		.get_by_unique_id(unique_id)
553		.await
554	{
555		| Ok(session) => Ok((session.user_id, session.sess_id)),
556		| Err(error) if !error.is_not_found() => Err(error),
557		| Err(_) => Ok((None, None)),
558	}
559}
560
561fn chain_next_idp_url(
562	services: &Services,
563	provider: &Provider,
564	session: &Session,
565	idp_id: &str,
566) -> Option<Url> {
567	services
568		.config
569		.identity_provider
570		.values()
571		.filter(|idp| idp.default || services.config.single_sso)
572		.skip_while(|idp| idp.id() != idp_id)
573		.nth(1)
574		.map(IdentityProvider::id)
575		.and_then(|next_idp| {
576			provider.callback_url.clone().map(|mut url| {
577				let path = format!("/_matrix/client/v3/login/sso/redirect/{next_idp}");
578				url.set_path(&path);
579
580				if let Some(redirect_url) = session.redirect_url.as_ref() {
581					url.query_pairs_mut()
582						.append_pair("redirectUrl", redirect_url.as_str());
583				}
584
585				url
586			})
587		})
588}
589
590fn finalize_login_redirect(
591	services: &Services,
592	session: &Session,
593	next_idp_url: Option<Url>,
594	user_id: &UserId,
595) -> Result<String> {
596	let login_token = utils::random_string(TOKEN_LENGTH);
597	let _login_token_expires_in = services
598		.users
599		.create_login_token(user_id, &login_token);
600
601	let location = next_idp_url
602		.or_else(|| session.redirect_url.clone())
603		.ok_or_else(|| err!(Request(InvalidParam("Missing redirect URL in session data"))))?
604		.query_pairs_mut()
605		.append_pair("loginToken", &login_token)
606		.finish()
607		.to_string();
608
609	Ok(location)
610}
611
612async fn handle_uiaa(
613	services: &Services,
614	user_id: &UserId,
615	cookie: Cow<'static, str>,
616	redirect_url: &Url,
617) -> Result<sso_callback::unstable::Response> {
618	let uiaa_session_id = redirect_url.path();
619
620	// Find the UIAA session by its ID. SECURITY: Ensure the user authenticating via
621	// SSO is the owner of the UIAA session
622	let (user_id, device_id, mut uiaainfo) = services
623		.uiaa
624		.get_uiaa_session_by_session_id(uiaa_session_id)
625		.await
626		.filter(|(db_user_id, ..)| user_id.eq(db_user_id))
627		.ok_or_else(|| err!(Request(Forbidden("UIAA session not found."))))?;
628
629	// MSC4312 m.oauth flow → mark OAuth.
630	let has_oauth_flow = uiaainfo
631		.flows
632		.iter()
633		.any(|f| f.stages.contains(&AuthType::OAuth));
634
635	// Mark the completed step based on the UIAA session's flow.
636	if has_oauth_flow && !uiaainfo.completed.contains(&AuthType::OAuth) {
637		// Grant 10-minute bypass for cross-signing key replacement (like Synapse).
638		services
639			.users
640			.allow_cross_signing_replacement(&user_id);
641
642		uiaainfo.completed.push(AuthType::OAuth);
643	}
644
645	// Legacy m.login.sso flow → mark Sso.
646	let has_sso_flow = uiaainfo
647		.flows
648		.iter()
649		.any(|f| f.stages.contains(&AuthType::Sso));
650
651	if has_sso_flow && !uiaainfo.completed.contains(&AuthType::Sso) {
652		uiaainfo.completed.push(AuthType::Sso);
653	}
654
655	services
656		.uiaa
657		.update_uiaa_session(&user_id, &device_id, uiaa_session_id, Some(&uiaainfo));
658
659	// Redirect back to the fallback page to render the success HTML
660	let location =
661		format!("/_matrix/client/v3/auth/m.login.sso/fallback/web?session={uiaa_session_id}");
662
663	Ok(sso_callback::unstable::Response { location, cookie: Some(cookie) })
664}
665
666#[tracing::instrument(
667	name = "register",
668	level = INFO_SPAN_LEVEL,
669	skip_all,
670	fields(user_id, userinfo)
671)]
672async fn register_user(
673	services: &Services,
674	provider: &Provider,
675	session: &Session,
676	userinfo: &UserInfo,
677	user_id: &UserId,
678) -> Result {
679	debug_info!(%user_id, "Creating new user account...");
680
681	services
682		.users
683		.full_register(Register {
684			user_id: Some(user_id),
685			password: Some(PASSWORD_SENTINEL),
686			origin: Some("sso"),
687			displayname: userinfo.name.as_deref(),
688			grant_first_user_admin: true,
689			..Default::default()
690		})
691		.await?;
692
693	if let Some(avatar_url) = userinfo
694		.avatar_url
695		.as_deref()
696		.or(userinfo.picture.as_deref())
697	{
698		set_avatar(services, provider, session, userinfo, user_id, avatar_url)
699			.await
700			.ok();
701	}
702
703	let idp_id = provider.id();
704	let idp_name = provider
705		.name
706		.as_deref()
707		.unwrap_or(provider.brand.as_str());
708
709	// log in conduit admin channel if a non-guest user registered
710	let notice =
711		format!("New user \"{user_id}\" registered on this server via {idp_name} ({idp_id})");
712
713	info!("{notice}");
714	if services.server.config.admin_room_notices {
715		services.admin.notice(&notice).await;
716	}
717
718	Ok(())
719}
720
721#[tracing::instrument(level = "debug", skip_all, fields(user_id, avatar_url))]
722async fn set_avatar(
723	services: &Services,
724	_provider: &Provider,
725	_session: &Session,
726	_userinfo: &UserInfo,
727	user_id: &UserId,
728	avatar_url: &str,
729) -> Result {
730	use reqwest::Response;
731
732	let response = services
733		.client
734		.default
735		.get(avatar_url)
736		.send()
737		.await
738		.and_then(Response::error_for_status)?;
739
740	let content_type = response
741		.headers()
742		.get(CONTENT_TYPE)
743		.map(HeaderValue::to_str)
744		.flat_ok()
745		.map(ToOwned::to_owned);
746
747	let mxc = Mxc {
748		server_name: services.globals.server_name(),
749		media_id: &utils::random_string(MXC_LENGTH),
750	};
751
752	let content_disposition = make_content_disposition(None, content_type.as_deref(), None);
753	let limit = services.server.config.max_response_size;
754	let bytes = read_response_capped(response, limit).await?;
755	services
756		.media
757		.create(&mxc, Some(user_id), Some(&content_disposition), content_type.as_deref(), &bytes)
758		.await?;
759
760	let mxc_uri: OwnedMxcUri = mxc.to_string().into();
761	services
762		.profile
763		.set_avatar_url(user_id, Some(&mxc_uri), None)
764		.await?;
765
766	Ok(())
767}
768
769#[tracing::instrument(
770	level = "debug",
771	ret(level = "debug")
772	skip_all,
773	fields(user),
774)]
775async fn decide_user_id(
776	services: &Services,
777	provider: &Provider,
778	userinfo: &UserInfo,
779	unique_id: &str,
780) -> Result<OwnedUserId> {
781	if let Some(user_id) = services
782		.oauth
783		.sessions
784		.find_user_association_pending(provider.id(), userinfo)
785	{
786		debug_info!(
787			provider = ?provider.id(),
788			?user_id,
789			?userinfo,
790			"Matched pending association"
791		);
792
793		return Ok(user_id);
794	}
795
796	let explicit = |claim: &str| provider.userid_claims.contains(claim);
797
798	let allowed = |claim: &str| provider.userid_claims.is_empty() || explicit(claim);
799
800	let choices = [
801		explicit("sub")
802			.then_some(userinfo.sub.as_str())
803			.map(str::to_lowercase),
804		userinfo
805			.preferred_username
806			.as_deref()
807			.map(str::to_lowercase)
808			.filter(|_| allowed("preferred_username")),
809		userinfo
810			.username
811			.as_deref()
812			.map(str::to_lowercase)
813			.filter(|_| allowed("username")),
814		userinfo
815			.nickname
816			.as_deref()
817			.map(str::to_lowercase)
818			.filter(|_| allowed("nickname")),
819		provider
820			.brand
821			.eq(&"github")
822			.then_some(userinfo.sub.as_str())
823			.map(str::to_lowercase)
824			.filter(|_| allowed("login")),
825		userinfo
826			.email
827			.as_deref()
828			.and_then(|email| email.split_once('@'))
829			.map(at!(0))
830			.map(str::to_lowercase)
831			.filter(|_| allowed("email")),
832	];
833
834	for choice in choices.into_iter().flatten() {
835		if let Some(user_id) = try_user_id(services, provider, &choice, false).await {
836			return Ok(user_id);
837		}
838	}
839
840	let length = Some(15..23);
841	let unique_id = truncate_deterministic(unique_id, length).to_lowercase();
842	if let Some(user_id) = try_user_id(services, provider, &unique_id, true).await {
843		return Ok(user_id);
844	}
845
846	Err!(Request(UserInUse("User ID is not available.")))
847}
848
849#[tracing::instrument(level = "debug", skip_all, fields(username))]
850async fn try_user_id(
851	services: &Services,
852	provider: &Provider,
853	username: &str,
854	unique_id: bool,
855) -> Option<OwnedUserId> {
856	let server_name = services.globals.server_name();
857	let user_id = parse_user_id(server_name, username)
858		.inspect_err(|e| warn!(?username, "Username invalid: {e}"))
859		.ok()?;
860
861	if services
862		.config
863		.forbidden_usernames
864		.is_match(username)
865	{
866		warn!(?username, "Username forbidden.");
867		return None;
868	}
869
870	if services.users.exists(&user_id).await {
871		if provider.trusted {
872			info!(
873				?username,
874				provider = ?provider.brand,
875				"Authorizing trusted provider access to existing account."
876			);
877
878			return Some(user_id);
879		}
880
881		if services
882			.users
883			.origin(&user_id)
884			.await
885			.ok()
886			.is_none_or(|origin| origin != "sso")
887		{
888			debug_warn!(?username, "Existing username has non-sso origin.");
889			return None;
890		}
891
892		if !unique_id {
893			debug_warn!(?username, "Username exists.");
894			return None;
895		}
896	} else if unique_id && !provider.unique_id_fallbacks {
897		debug_warn!(
898			?username,
899			provider = ?provider.brand,
900			"Unique ID fallbacks disabled.",
901		);
902
903		return None;
904	}
905
906	Some(user_id)
907}
908
909fn parse_user_id(server_name: &ServerName, username: &str) -> Result<OwnedUserId> {
910	match UserId::parse_with_server_name(username, server_name) {
911		| Err(e) => {
912			Err!(Request(InvalidUsername(debug_error!("Username {username} is not valid: {e}"))))
913		},
914		| Ok(user_id) => match user_id.validate_strict() {
915			| Ok(()) => Ok(user_id),
916			| Err(e) => Err!(Request(InvalidUsername(debug_error!(
917				"Username {username} contains disallowed characters or spaces: {e}"
918			)))),
919		},
920	}
921}
922
923#[cfg(test)]
924mod tests {
925	use serde_json::json;
926
927	use super::*;
928
929	fn apple_session_with_claims(claims: &serde_json::Value) -> Session {
930		let payload = b64.encode(serde_json::to_vec(claims).expect("serialize claims"));
931
932		Session {
933			id_token: Some(format!("header.{payload}.signature")),
934			..Default::default()
935		}
936	}
937
938	#[test]
939	fn decode_apple_userinfo_from_id_token_extracts_expected_claims() {
940		let session = apple_session_with_claims(&json!({
941			"sub": "apple-user-123",
942			"email": "alice@example.com",
943			"name": "Alice Example",
944			"given_name": "Alice",
945			"family_name": "Example"
946		}));
947
948		let userinfo =
949			decode_apple_userinfo_from_id_token(&session).expect("decode Apple id_token claims");
950
951		assert_eq!(userinfo.sub, "apple-user-123");
952		assert_eq!(userinfo.email.as_deref(), Some("alice@example.com"));
953		assert_eq!(userinfo.preferred_username.as_deref(), Some("alice"));
954		assert_eq!(userinfo.username.as_deref(), Some("alice"));
955		assert_eq!(userinfo.name.as_deref(), Some("Alice Example"));
956		assert_eq!(userinfo.given_name.as_deref(), Some("Alice"));
957		assert_eq!(userinfo.family_name.as_deref(), Some("Example"));
958	}
959
960	#[test]
961	fn decode_apple_userinfo_from_id_token_requires_sub_claim() {
962		let session = apple_session_with_claims(&json!({
963			"email": "alice@example.com"
964		}));
965
966		let error = decode_apple_userinfo_from_id_token(&session)
967			.expect_err("missing sub claim should fail");
968
969		let message = format!("{error}");
970		assert!(message.contains("sub claim"), "unexpected error: {message}");
971	}
972
973	#[test]
974	fn decode_apple_userinfo_from_id_token_requires_id_token() {
975		let session = Session::default();
976
977		let error = decode_apple_userinfo_from_id_token(&session)
978			.expect_err("missing id_token should fail");
979
980		let message = format!("{error}");
981		assert!(message.contains("Missing Apple id_token"), "unexpected error: {message}");
982	}
983
984	#[test]
985	fn decode_apple_userinfo_from_id_token_rejects_invalid_payload() {
986		let session = Session {
987			id_token: Some("header.!.signature".to_owned()),
988			..Default::default()
989		};
990
991		let error = decode_apple_userinfo_from_id_token(&session)
992			.expect_err("invalid id_token payload should fail");
993
994		let message = format!("{error}");
995		assert!(message.contains("invalid base64"), "unexpected error: {message}");
996	}
997}