Skip to main content

tuwunel_api/oidc/
token.rs

1use axum::{
2	Json,
3	body::Body,
4	extract::{Form, State},
5	response::IntoResponse,
6};
7use http::{
8	Response, StatusCode,
9	header::{CACHE_CONTROL, PRAGMA},
10};
11use ruma::OwnedDeviceId;
12use serde::Deserialize;
13use serde_json::json;
14use tuwunel_core::{
15	Err, Error, Result, err, info,
16	utils::{
17		BoolExt,
18		future::OptionFutureExt,
19		time::{now, timepoint_has_passed},
20	},
21};
22use tuwunel_service::{
23	Services,
24	oauth::server::{IdTokenClaims, Server, extract_device_id},
25	users::device::generate_refresh_token,
26};
27
28use super::oauth_error;
29
30#[derive(Debug, Deserialize)]
31pub(crate) struct TokenRequest {
32	grant_type: String,
33	code: Option<String>,
34	redirect_uri: Option<String>,
35	client_id: Option<String>,
36	code_verifier: Option<String>,
37	refresh_token: Option<String>,
38	#[serde(rename = "scope")]
39	_scope: Option<String>,
40}
41
42pub(crate) async fn token_route(
43	State(services): State<crate::State>,
44	Form(body): Form<TokenRequest>,
45) -> impl IntoResponse {
46	// RFC 6749 §5.1 and §5.2 require Cache-Control: no-store and Pragma: no-cache
47	// on all token endpoint responses (success and error).
48	let inner = match body.grant_type.as_str() {
49		| "authorization_code" => token_authorization_code(&services, &body)
50			.await
51			.unwrap_or_else(token_error_response),
52
53		| "refresh_token" => token_refresh(&services, &body)
54			.await
55			.unwrap_or_else(token_error_response),
56
57		| _ => oauth_error(
58			StatusCode::BAD_REQUEST,
59			"unsupported_grant_type",
60			"Unsupported grant_type",
61		),
62	};
63	let mut response = inner.into_response();
64	let headers = response.headers_mut();
65	headers.insert(CACHE_CONTROL, http::HeaderValue::from_static("no-store"));
66	headers.insert(PRAGMA, http::HeaderValue::from_static("no-cache"));
67	response
68}
69
70async fn token_authorization_code(
71	services: &Services,
72	body: &TokenRequest,
73) -> Result<Response<Body>> {
74	let code = body
75		.code
76		.as_deref()
77		.ok_or_else(|| err!(Request(InvalidParam("code is required"))))?;
78
79	let redirect_uri = body
80		.redirect_uri
81		.as_deref()
82		.ok_or_else(|| err!(Request(InvalidParam("redirect_uri is required"))))?;
83
84	let client_id = body
85		.client_id
86		.as_deref()
87		.ok_or_else(|| err!(Request(InvalidParam("client_id is required"))))?;
88
89	let session = services
90		.oauth
91		.get_server()?
92		.exchange_auth_code(code, client_id, redirect_uri, body.code_verifier.as_deref())
93		.await?;
94
95	let user_id = &session.user_id;
96	let (access_token, expires_in) = services.users.generate_access_token(true);
97	let refresh_token = generate_refresh_token();
98	let client_name = services
99		.oauth
100		.get_server()?
101		.get_client(client_id)
102		.await
103		.ok()
104		.and_then(|c| c.client_name);
105
106	let device_display_name = client_name.as_deref().unwrap_or("OIDC Client");
107	let device_id: Option<OwnedDeviceId> =
108		extract_device_id(&session.scope).map(OwnedDeviceId::from);
109
110	let iss = services.oauth.get_server()?.issuer_url()?;
111	let id_token = session
112		.scope
113		.contains("openid")
114		.then(|| {
115			let now = now().as_secs();
116			let claims = IdTokenClaims {
117				iss,
118				sub: user_id.to_string(),
119				aud: client_id.to_owned(),
120				exp: now.saturating_add(3600),
121				iat: now,
122				nonce: session.nonce,
123				at_hash: Some(Server::at_hash(&access_token)),
124			};
125
126			services
127				.oauth
128				.get_server()?
129				.sign_id_token(&claims)
130		})
131		.transpose()?;
132
133	let device_id = services
134		.users
135		.create_device(
136			user_id,
137			device_id.as_deref(),
138			(Some(&access_token), expires_in),
139			Some(&refresh_token),
140			Some(device_display_name),
141			None,
142		)
143		.await?;
144
145	let idp_id = session.idp_id.as_deref().unwrap_or("");
146	services
147		.users
148		.mark_oidc_device(user_id, &device_id, idp_id);
149
150	info!("{user_id} logged in via OIDC on {device_id} ({device_display_name})");
151
152	let mut response = json!({
153		"access_token": access_token,
154		"refresh_token": refresh_token,
155		"scope": session.scope,
156		"token_type": "Bearer",
157	});
158
159	if let Some(id_token) = id_token {
160		response["id_token"] = json!(id_token);
161	}
162
163	if let Some(expires_in) = expires_in {
164		response["expires_in"] = json!(expires_in.as_secs());
165	}
166
167	Ok(Json(response).into_response())
168}
169
170async fn token_refresh(services: &Services, body: &TokenRequest) -> Result<Response<Body>> {
171	let refresh_token = body
172		.refresh_token
173		.as_deref()
174		.ok_or_else(|| err!(Request(InvalidParam("refresh_token is required"))))?;
175
176	let (user_id, device_id, expires_at) = services
177		.users
178		.find_from_token(refresh_token)
179		.await
180		.map_err(|_| err!(Request(Forbidden("Invalid refresh token"))))?;
181
182	if expires_at.is_some_and(timepoint_has_passed) {
183		services
184			.server
185			.config
186			.refresh_token_hard_logout
187			.then_async(|| services.users.remove_device(&user_id, &device_id))
188			.unwrap_or_else_async(async || {
189				services
190					.users
191					.remove_refresh_token(&user_id, &device_id)
192					.await
193					.ok();
194			})
195			.await;
196
197		return Err!(Request(Forbidden("Refresh token has expired")));
198	}
199
200	let (new_access_token, expires_in) = services.users.generate_access_token(true);
201	let new_refresh_token = generate_refresh_token();
202
203	services
204		.users
205		.set_access_token(
206			&user_id,
207			&device_id,
208			&new_access_token,
209			expires_in,
210			Some(&new_refresh_token),
211		)
212		.await?;
213
214	let mut response = json!({
215		"access_token": new_access_token,
216		"refresh_token": new_refresh_token,
217		"token_type": "Bearer",
218	});
219
220	if let Some(expires_in) = expires_in {
221		response["expires_in"] = json!(expires_in.as_secs());
222	}
223
224	Ok(Json(response).into_response())
225}
226
227/// RFC 6749 §5.2: map error to correct HTTP status and OAuth2 error code.
228/// Client-side errors (invalid grant, bad params) → 400 invalid_grant.
229/// Server-side errors → 500 server_error with sanitized message.
230#[expect(clippy::needless_pass_by_value)]
231fn token_error_response(e: Error) -> Response<Body> {
232	if !e.status_code().is_client_error() {
233		return oauth_error(
234			StatusCode::INTERNAL_SERVER_ERROR,
235			"server_error",
236			"An internal error occurred",
237		);
238	}
239
240	oauth_error(StatusCode::BAD_REQUEST, "invalid_grant", &e.sanitized_message())
241}