Skip to main content

tuwunel_api/oidc/
authorize.rs

1use std::{net::IpAddr, time::SystemTime};
2
3use axum::{
4	extract::State,
5	response::{IntoResponse, Redirect},
6};
7use serde::Deserialize;
8use tuwunel_core::{
9	Err, Result, err, utils,
10	utils::{BoolExt, result::FlatOk},
11};
12use tuwunel_service::{
13	Services,
14	oauth::server::{AUTH_REQUEST_LIFETIME, AuthRequest},
15};
16use url::Url;
17
18use super::{OIDC_REQ_ID_LENGTH, url_encode};
19
20#[derive(Debug, Deserialize)]
21pub(crate) struct AuthorizeParams {
22	client_id: String,
23	redirect_uri: String,
24	response_type: String,
25	response_mode: Option<String>,
26	scope: String,
27	state: Option<String>,
28	nonce: Option<String>,
29	code_challenge: Option<String>,
30	code_challenge_method: Option<String>,
31	#[serde(default, rename = "prompt")]
32	_prompt: Option<String>,
33}
34
35pub(crate) async fn authorize_route(
36	State(services): State<crate::State>,
37	request: axum::extract::Request,
38) -> Result<impl IntoResponse> {
39	let oidc = services.oauth.get_server()?;
40
41	let query = request.uri().query().unwrap_or_default();
42	let params: AuthorizeParams = serde_html_form::from_str(query)?;
43
44	if params.response_type != "code" {
45		return Err!(Request(InvalidParam("Only response_type=code is supported")));
46	}
47
48	let response_mode = params.response_mode.as_deref().unwrap_or("query");
49	if !matches!(response_mode, "query" | "fragment") {
50		return Err!(Request(InvalidParam(
51			"Only response_mode=query or response_mode=fragment is supported"
52		)));
53	}
54
55	validate_redirect_uri(&services, &params).await?;
56
57	let now = SystemTime::now();
58	let req_id = utils::random_string(OIDC_REQ_ID_LENGTH);
59	let idp_id = services
60		.oauth
61		.providers
62		.get_default_id()
63		.ok_or_else(|| err!(Config("identity_provider", "No identity provider configured")))?;
64
65	let auth_req = AuthRequest {
66		client_id: params.client_id,
67		redirect_uri: params.redirect_uri,
68		scope: params.scope,
69		state: params.state,
70		nonce: params.nonce,
71		code_challenge: params.code_challenge,
72		code_challenge_method: params.code_challenge_method,
73		// Record which IdP authenticated the user so it can be tagged on the
74		// device at token exchange time and used for UIAA SSO provider binding.
75		idp_id: Some(idp_id.clone()),
76		response_mode: params.response_mode,
77		created_at: now,
78		expires_at: now
79			.checked_add(AUTH_REQUEST_LIFETIME)
80			.unwrap_or(now),
81	};
82
83	let base = oidc.issuer_url()?;
84	let base = base.trim_end_matches('/');
85
86	let complete_url = Url::parse(&format!("{base}/_tuwunel/oidc/_complete"))
87		.map_err(|_| err!(error!("Failed to build complete URL")))
88		.map(|mut url| {
89			url.query_pairs_mut()
90				.append_pair("oidc_req_id", &req_id);
91			url
92		})?;
93
94	let idp_id_enc = url_encode(&idp_id);
95	let sso_url =
96		Url::parse(&format!("{base}/_matrix/client/v3/login/sso/redirect/{idp_id_enc}"))
97			.map_err(|_| err!(error!("Failed to build SSO URL")))
98			.map(|mut url| {
99				url.query_pairs_mut()
100					.append_pair("redirectUrl", complete_url.as_str());
101				url
102			})?;
103
104	oidc.store_auth_request(&req_id, &auth_req);
105
106	Ok(Redirect::temporary(sso_url.as_str()))
107}
108
109async fn validate_redirect_uri(services: &Services, params: &AuthorizeParams) -> Result {
110	services
111		.oauth
112		.get_server()
113		.expect("OIDC already configured")
114		.get_client(&params.client_id)
115		.await?
116		.redirect_uris
117		.iter()
118		.any(|uri| redirect_uri_matches(uri, &params.redirect_uri))
119		.ok_or_else(|| err!(Request(InvalidParam("redirect_uri not registered for this client"))))
120}
121
122fn redirect_uri_matches(registered: &str, requested: &str) -> bool {
123	match (Url::parse(registered), Url::parse(requested)) {
124		| (..) if registered == requested => true,
125		| (Ok(reg), Ok(req)) if is_loopback_redirect(&reg) && is_loopback_redirect(&req) =>
126			reg.scheme() == req.scheme()
127				&& reg.host_str() == req.host_str()
128				&& reg.path() == req.path()
129				&& reg.query() == req.query()
130				&& reg.fragment() == req.fragment(),
131
132		| _ => false,
133	}
134}
135
136fn is_loopback_redirect(uri: &Url) -> bool {
137	let addr = || uri.host_str().map(str::parse::<IpAddr>).flat_ok();
138
139	uri.scheme() == "http" && matches!(addr(), Some(ip) if ip.is_loopback())
140}