Skip to main content

tuwunel_api/oidc/
complete.rs

1use std::iter::once;
2
3use axum::{
4	extract::State,
5	response::{IntoResponse, Redirect},
6};
7use serde::Deserialize;
8use tuwunel_core::{Result, err};
9use url::{Url, form_urlencoded};
10
11#[derive(Debug, Deserialize)]
12pub(crate) struct CompleteParams {
13	oidc_req_id: String,
14	#[serde(rename = "loginToken")]
15	login_token: String,
16}
17
18pub(crate) async fn complete_route(
19	State(services): State<crate::State>,
20	request: axum::extract::Request,
21) -> Result<impl IntoResponse> {
22	let query = request.uri().query().unwrap_or_default();
23	let params: CompleteParams = serde_html_form::from_str(query)?;
24
25	let oidc = services.oauth.get_server()?;
26
27	// Validate the auth request first (before consuming the login_token) so that
28	// a crafted request with an invalid oidc_req_id cannot burn a valid token.
29	let auth_req = oidc
30		.take_auth_request(&params.oidc_req_id)
31		.await?;
32
33	let user_id = services
34		.users
35		.find_from_login_token(&params.login_token)
36		.await
37		.map_err(|_| err!(Request(Forbidden("Invalid or expired login token"))))?;
38
39	let code = oidc.create_auth_code(&auth_req, user_id);
40	let redirect_url = Url::parse(&auth_req.redirect_uri)
41		.map_err(|_| err!(Request(InvalidParam("Invalid redirect_uri"))))
42		.map(|mut url| {
43			let pairs = once(("code", code.as_str()))
44				.chain(auth_req.state.as_deref().map(|s| ("state", s)));
45
46			match auth_req.response_mode.as_deref() {
47				| Some("fragment") => {
48					let body = form_urlencoded::Serializer::new(String::new())
49						.extend_pairs(pairs)
50						.finish();
51
52					url.set_fragment(Some(&body));
53				},
54				| _ => {
55					url.query_pairs_mut().extend_pairs(pairs);
56				},
57			}
58
59			url
60		})?;
61
62	Ok(Redirect::temporary(redirect_url.as_str()))
63}