1use axum::{
2 Json,
3 body::Body,
4 extract::{Form, State},
5 response::IntoResponse,
6};
7use http::{
8 Response, StatusCode,
9 header::{CACHE_CONTROL, PRAGMA},
10};
11use ruma::OwnedDeviceId;
12use serde::Deserialize;
13use serde_json::json;
14use tuwunel_core::{
15 Err, Error, Result, err, info,
16 utils::{
17 BoolExt,
18 future::OptionFutureExt,
19 time::{now, timepoint_has_passed},
20 },
21};
22use tuwunel_service::{
23 Services,
24 oauth::server::{IdTokenClaims, Server, extract_device_id},
25 users::device::generate_refresh_token,
26};
27
28use super::oauth_error;
29
30#[derive(Debug, Deserialize)]
31pub(crate) struct TokenRequest {
32 grant_type: String,
33 code: Option<String>,
34 redirect_uri: Option<String>,
35 client_id: Option<String>,
36 code_verifier: Option<String>,
37 refresh_token: Option<String>,
38 #[serde(rename = "scope")]
39 _scope: Option<String>,
40}
41
42pub(crate) async fn token_route(
43 State(services): State<crate::State>,
44 Form(body): Form<TokenRequest>,
45) -> impl IntoResponse {
46 let inner = match body.grant_type.as_str() {
49 | "authorization_code" => token_authorization_code(&services, &body)
50 .await
51 .unwrap_or_else(token_error_response),
52
53 | "refresh_token" => token_refresh(&services, &body)
54 .await
55 .unwrap_or_else(token_error_response),
56
57 | _ => oauth_error(
58 StatusCode::BAD_REQUEST,
59 "unsupported_grant_type",
60 "Unsupported grant_type",
61 ),
62 };
63 let mut response = inner.into_response();
64 let headers = response.headers_mut();
65 headers.insert(CACHE_CONTROL, http::HeaderValue::from_static("no-store"));
66 headers.insert(PRAGMA, http::HeaderValue::from_static("no-cache"));
67 response
68}
69
70async fn token_authorization_code(
71 services: &Services,
72 body: &TokenRequest,
73) -> Result<Response<Body>> {
74 let code = body
75 .code
76 .as_deref()
77 .ok_or_else(|| err!(Request(InvalidParam("code is required"))))?;
78
79 let redirect_uri = body
80 .redirect_uri
81 .as_deref()
82 .ok_or_else(|| err!(Request(InvalidParam("redirect_uri is required"))))?;
83
84 let client_id = body
85 .client_id
86 .as_deref()
87 .ok_or_else(|| err!(Request(InvalidParam("client_id is required"))))?;
88
89 let session = services
90 .oauth
91 .get_server()?
92 .exchange_auth_code(code, client_id, redirect_uri, body.code_verifier.as_deref())
93 .await?;
94
95 let user_id = &session.user_id;
96 let (access_token, expires_in) = services.users.generate_access_token(true);
97 let refresh_token = generate_refresh_token();
98 let client_name = services
99 .oauth
100 .get_server()?
101 .get_client(client_id)
102 .await
103 .ok()
104 .and_then(|c| c.client_name);
105
106 let device_display_name = client_name.as_deref().unwrap_or("OIDC Client");
107 let device_id: Option<OwnedDeviceId> =
108 extract_device_id(&session.scope).map(OwnedDeviceId::from);
109
110 let iss = services.oauth.get_server()?.issuer_url()?;
111 let id_token = session
112 .scope
113 .contains("openid")
114 .then(|| {
115 let now = now().as_secs();
116 let claims = IdTokenClaims {
117 iss,
118 sub: user_id.to_string(),
119 aud: client_id.to_owned(),
120 exp: now.saturating_add(3600),
121 iat: now,
122 nonce: session.nonce,
123 at_hash: Some(Server::at_hash(&access_token)),
124 };
125
126 services
127 .oauth
128 .get_server()?
129 .sign_id_token(&claims)
130 })
131 .transpose()?;
132
133 let device_id = services
134 .users
135 .create_device(
136 user_id,
137 device_id.as_deref(),
138 (Some(&access_token), expires_in),
139 Some(&refresh_token),
140 Some(device_display_name),
141 None,
142 )
143 .await?;
144
145 let idp_id = session.idp_id.as_deref().unwrap_or("");
146 services
147 .users
148 .mark_oidc_device(user_id, &device_id, idp_id);
149
150 info!("{user_id} logged in via OIDC on {device_id} ({device_display_name})");
151
152 let mut response = json!({
153 "access_token": access_token,
154 "refresh_token": refresh_token,
155 "scope": session.scope,
156 "token_type": "Bearer",
157 });
158
159 if let Some(id_token) = id_token {
160 response["id_token"] = json!(id_token);
161 }
162
163 if let Some(expires_in) = expires_in {
164 response["expires_in"] = json!(expires_in.as_secs());
165 }
166
167 Ok(Json(response).into_response())
168}
169
170async fn token_refresh(services: &Services, body: &TokenRequest) -> Result<Response<Body>> {
171 let refresh_token = body
172 .refresh_token
173 .as_deref()
174 .ok_or_else(|| err!(Request(InvalidParam("refresh_token is required"))))?;
175
176 let (user_id, device_id, expires_at) = services
177 .users
178 .find_from_token(refresh_token)
179 .await
180 .map_err(|_| err!(Request(Forbidden("Invalid refresh token"))))?;
181
182 if expires_at.is_some_and(timepoint_has_passed) {
183 services
184 .server
185 .config
186 .refresh_token_hard_logout
187 .then_async(|| services.users.remove_device(&user_id, &device_id))
188 .unwrap_or_else_async(async || {
189 services
190 .users
191 .remove_refresh_token(&user_id, &device_id)
192 .await
193 .ok();
194 })
195 .await;
196
197 return Err!(Request(Forbidden("Refresh token has expired")));
198 }
199
200 let (new_access_token, expires_in) = services.users.generate_access_token(true);
201 let new_refresh_token = generate_refresh_token();
202
203 services
204 .users
205 .set_access_token(
206 &user_id,
207 &device_id,
208 &new_access_token,
209 expires_in,
210 Some(&new_refresh_token),
211 )
212 .await?;
213
214 let mut response = json!({
215 "access_token": new_access_token,
216 "refresh_token": new_refresh_token,
217 "token_type": "Bearer",
218 });
219
220 if let Some(expires_in) = expires_in {
221 response["expires_in"] = json!(expires_in.as_secs());
222 }
223
224 Ok(Json(response).into_response())
225}
226
227#[expect(clippy::needless_pass_by_value)]
231fn token_error_response(e: Error) -> Response<Body> {
232 if !e.status_code().is_client_error() {
233 return oauth_error(
234 StatusCode::INTERNAL_SERVER_ERROR,
235 "server_error",
236 "An internal error occurred",
237 );
238 }
239
240 oauth_error(StatusCode::BAD_REQUEST, "invalid_grant", &e.sanitized_message())
241}