1use std::{net::IpAddr, time::SystemTime};
2
3use axum::{
4 extract::State,
5 response::{IntoResponse, Redirect},
6};
7use serde::Deserialize;
8use tuwunel_core::{
9 Err, Result, err, utils,
10 utils::{BoolExt, result::FlatOk},
11};
12use tuwunel_service::{
13 Services,
14 oauth::server::{AUTH_REQUEST_LIFETIME, AuthRequest},
15};
16use url::Url;
17
18use super::{OIDC_REQ_ID_LENGTH, url_encode};
19use crate::ClientIp;
20
21#[derive(Debug, Deserialize)]
22pub(crate) struct AuthorizeParams {
23 client_id: String,
24 redirect_uri: String,
25 response_type: String,
26 response_mode: Option<String>,
27 scope: String,
28 state: Option<String>,
29 nonce: Option<String>,
30 code_challenge: Option<String>,
31 code_challenge_method: Option<String>,
32 #[serde(default)]
33 idp_id: Option<String>,
34 #[serde(default)]
35 prompt: Option<String>,
36}
37
38pub(crate) async fn authorize_route(
39 State(services): State<crate::State>,
40 ClientIp(client): ClientIp,
41 request: axum::extract::Request,
42) -> Result<impl IntoResponse> {
43 let oidc = services.oauth.get_server()?;
44 services.oauth.check_rate_limit(client)?;
45
46 let query = request.uri().query().unwrap_or_default();
47 let params: AuthorizeParams = serde_html_form::from_str(query)?;
48
49 if params.response_type != "code" {
50 return Err!(Request(InvalidParam("Only response_type=code is supported")));
51 }
52
53 let response_mode = params.response_mode.as_deref().unwrap_or("query");
54 if !matches!(response_mode, "query" | "fragment") {
55 return Err!(Request(InvalidParam(
56 "Only response_mode=query or response_mode=fragment is supported"
57 )));
58 }
59
60 match (¶ms.code_challenge, params.code_challenge_method.as_deref()) {
63 | (None, _) if services.config.oidc_require_pkce =>
64 return Err!(Request(InvalidParam("code_challenge is required (PKCE with S256)"))),
65
66 | (Some(_), method) if method != Some("S256") =>
67 return Err!(Request(InvalidParam("Only code_challenge_method=S256 is supported"))),
68
69 | _ => {},
70 }
71
72 validate_redirect_uri(&services, ¶ms).await?;
73
74 let now = SystemTime::now();
75 let req_id = utils::random_string(OIDC_REQ_ID_LENGTH);
76 let base = oidc.issuer_url()?;
77 let base = base.trim_end_matches('/');
78
79 let resolved_idp: Option<String> = match params.idp_id.as_deref() {
80 | Some(requested) => services
81 .oauth
82 .providers
83 .get_config(requested)
84 .map(|provider| Some(provider.id().to_owned()))
85 .map_err(|_| err!(Request(InvalidParam("Unrecognized identity provider"))))?,
86
87 | None => services.oauth.providers.get_default_id(),
88 };
89
90 let serve_native = params.idp_id.is_none()
93 && should_serve_native(
94 services.config.oidc_native_auth,
95 resolved_idp.is_some(),
96 params.prompt.as_deref() == Some("create"),
97 );
98
99 let idp_id = match (serve_native, resolved_idp) {
100 | (true, _) => None,
101 | (false, Some(idp_id)) => Some(idp_id),
102 | (false, None) =>
103 return Err!(Config("identity_provider", "No identity provider configured")),
104 };
105
106 let auth_req = AuthRequest {
107 client_id: params.client_id,
108 redirect_uri: params.redirect_uri,
109 scope: params.scope,
110 state: params.state,
111 nonce: params.nonce,
112 code_challenge: params.code_challenge,
113 code_challenge_method: params.code_challenge_method,
114 idp_id: idp_id.clone(),
117 response_mode: params.response_mode,
118 created_at: now,
119 expires_at: now
120 .checked_add(AUTH_REQUEST_LIFETIME)
121 .unwrap_or(now),
122 };
123
124 oidc.store_auth_request(&req_id, &auth_req);
125
126 let Some(idp_id) = idp_id else {
127 let view = match params.prompt.as_deref() {
128 | Some("create") => "register",
129 | _ => "login",
130 };
131
132 let native_url = Url::parse(&format!("{base}/_tuwunel/oidc/native"))
133 .map_err(|_| err!(error!("Failed to build native auth URL")))
134 .map(|mut url| {
135 url.query_pairs_mut()
136 .append_pair("oidc_req_id", &req_id)
137 .append_pair("view", view);
138
139 url
140 })?;
141
142 return Ok(Redirect::temporary(native_url.as_str()));
143 };
144
145 let complete_url = Url::parse(&format!("{base}/_tuwunel/oidc/_complete"))
146 .map_err(|_| err!(error!("Failed to build complete URL")))
147 .map(|mut url| {
148 url.query_pairs_mut()
149 .append_pair("oidc_req_id", &req_id);
150
151 url
152 })?;
153
154 let idp_id_enc = url_encode(&idp_id);
155 let sso_url =
156 Url::parse(&format!("{base}/_matrix/client/v3/login/sso/redirect/{idp_id_enc}"))
157 .map_err(|_| err!(error!("Failed to build SSO URL")))
158 .map(|mut url| {
159 url.query_pairs_mut()
160 .append_pair("redirectUrl", complete_url.as_str());
161
162 url
163 })?;
164
165 Ok(Redirect::temporary(sso_url.as_str()))
166}
167
168fn should_serve_native(native_enabled: bool, has_default_idp: bool, wants_create: bool) -> bool {
173 native_enabled && (!has_default_idp || wants_create)
174}
175
176async fn validate_redirect_uri(services: &Services, params: &AuthorizeParams) -> Result {
177 services
178 .oauth
179 .get_server()
180 .expect("OIDC already configured")
181 .get_client(¶ms.client_id)
182 .await?
183 .redirect_uris
184 .iter()
185 .any(|uri| redirect_uri_matches(uri, ¶ms.redirect_uri))
186 .into_option()
187 .ok_or_else(|| err!(Request(InvalidParam("redirect_uri not registered for this client"))))
188}
189
190fn redirect_uri_matches(registered: &str, requested: &str) -> bool {
191 match (Url::parse(registered), Url::parse(requested)) {
192 | (..) if registered == requested => true,
193 | (Ok(reg), Ok(req)) if is_loopback_redirect(®) && is_loopback_redirect(&req) =>
194 reg.scheme() == req.scheme()
195 && reg.host_str() == req.host_str()
196 && reg.path() == req.path()
197 && reg.query() == req.query()
198 && reg.fragment() == req.fragment(),
199
200 | _ => false,
201 }
202}
203
204fn is_loopback_redirect(uri: &Url) -> bool {
205 let addr = || uri.host_str().map(str::parse::<IpAddr>).flat_ok();
206
207 uri.scheme() == "http" && matches!(addr(), Some(ip) if ip.is_loopback())
208}
209
210#[cfg(test)]
211mod tests {
212 use super::should_serve_native;
213
214 #[test]
215 fn native_decision_truth_table() {
216 assert!(!should_serve_native(false, false, false));
218 assert!(!should_serve_native(false, true, true));
219
220 assert!(should_serve_native(true, false, false));
222
223 assert!(!should_serve_native(true, true, false));
225
226 assert!(should_serve_native(true, true, true));
228 }
229}