tuwunel_api/oidc/
authorize.rs1use std::{net::IpAddr, time::SystemTime};
2
3use axum::{
4 extract::State,
5 response::{IntoResponse, Redirect},
6};
7use serde::Deserialize;
8use tuwunel_core::{
9 Err, Result, err, utils,
10 utils::{BoolExt, result::FlatOk},
11};
12use tuwunel_service::{
13 Services,
14 oauth::server::{AUTH_REQUEST_LIFETIME, AuthRequest},
15};
16use url::Url;
17
18use super::{OIDC_REQ_ID_LENGTH, url_encode};
19
20#[derive(Debug, Deserialize)]
21pub(crate) struct AuthorizeParams {
22 client_id: String,
23 redirect_uri: String,
24 response_type: String,
25 response_mode: Option<String>,
26 scope: String,
27 state: Option<String>,
28 nonce: Option<String>,
29 code_challenge: Option<String>,
30 code_challenge_method: Option<String>,
31 #[serde(default, rename = "prompt")]
32 _prompt: Option<String>,
33}
34
35pub(crate) async fn authorize_route(
36 State(services): State<crate::State>,
37 request: axum::extract::Request,
38) -> Result<impl IntoResponse> {
39 let oidc = services.oauth.get_server()?;
40
41 let query = request.uri().query().unwrap_or_default();
42 let params: AuthorizeParams = serde_html_form::from_str(query)?;
43
44 if params.response_type != "code" {
45 return Err!(Request(InvalidParam("Only response_type=code is supported")));
46 }
47
48 let response_mode = params.response_mode.as_deref().unwrap_or("query");
49 if !matches!(response_mode, "query" | "fragment") {
50 return Err!(Request(InvalidParam(
51 "Only response_mode=query or response_mode=fragment is supported"
52 )));
53 }
54
55 validate_redirect_uri(&services, ¶ms).await?;
56
57 let now = SystemTime::now();
58 let req_id = utils::random_string(OIDC_REQ_ID_LENGTH);
59 let idp_id = services
60 .oauth
61 .providers
62 .get_default_id()
63 .ok_or_else(|| err!(Config("identity_provider", "No identity provider configured")))?;
64
65 let auth_req = AuthRequest {
66 client_id: params.client_id,
67 redirect_uri: params.redirect_uri,
68 scope: params.scope,
69 state: params.state,
70 nonce: params.nonce,
71 code_challenge: params.code_challenge,
72 code_challenge_method: params.code_challenge_method,
73 idp_id: Some(idp_id.clone()),
76 response_mode: params.response_mode,
77 created_at: now,
78 expires_at: now
79 .checked_add(AUTH_REQUEST_LIFETIME)
80 .unwrap_or(now),
81 };
82
83 let base = oidc.issuer_url()?;
84 let base = base.trim_end_matches('/');
85
86 let complete_url = Url::parse(&format!("{base}/_tuwunel/oidc/_complete"))
87 .map_err(|_| err!(error!("Failed to build complete URL")))
88 .map(|mut url| {
89 url.query_pairs_mut()
90 .append_pair("oidc_req_id", &req_id);
91 url
92 })?;
93
94 let idp_id_enc = url_encode(&idp_id);
95 let sso_url =
96 Url::parse(&format!("{base}/_matrix/client/v3/login/sso/redirect/{idp_id_enc}"))
97 .map_err(|_| err!(error!("Failed to build SSO URL")))
98 .map(|mut url| {
99 url.query_pairs_mut()
100 .append_pair("redirectUrl", complete_url.as_str());
101 url
102 })?;
103
104 oidc.store_auth_request(&req_id, &auth_req);
105
106 Ok(Redirect::temporary(sso_url.as_str()))
107}
108
109async fn validate_redirect_uri(services: &Services, params: &AuthorizeParams) -> Result {
110 services
111 .oauth
112 .get_server()
113 .expect("OIDC already configured")
114 .get_client(¶ms.client_id)
115 .await?
116 .redirect_uris
117 .iter()
118 .any(|uri| redirect_uri_matches(uri, ¶ms.redirect_uri))
119 .ok_or_else(|| err!(Request(InvalidParam("redirect_uri not registered for this client"))))
120}
121
122fn redirect_uri_matches(registered: &str, requested: &str) -> bool {
123 match (Url::parse(registered), Url::parse(requested)) {
124 | (..) if registered == requested => true,
125 | (Ok(reg), Ok(req)) if is_loopback_redirect(®) && is_loopback_redirect(&req) =>
126 reg.scheme() == req.scheme()
127 && reg.host_str() == req.host_str()
128 && reg.path() == req.path()
129 && reg.query() == req.query()
130 && reg.fragment() == req.fragment(),
131
132 | _ => false,
133 }
134}
135
136fn is_loopback_redirect(uri: &Url) -> bool {
137 let addr = || uri.host_str().map(str::parse::<IpAddr>).flat_ok();
138
139 uri.scheme() == "http" && matches!(addr(), Some(ip) if ip.is_loopback())
140}