tuwunel_service/oauth/server/
auth.rs1use std::time::{Duration, SystemTime};
2
3use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as b64};
4use ruma::OwnedUserId;
5use serde::{Deserialize, Serialize};
6use tuwunel_core::{Err, Result, err, implement, utils, utils::hash::sha256};
7use tuwunel_database::{Cbor, Deserialized};
8
9#[derive(Clone, Debug, Deserialize, Serialize)]
10pub struct AuthRequest {
11 pub client_id: String,
12
13 pub redirect_uri: String,
14
15 pub scope: String,
16
17 pub state: Option<String>,
18
19 pub nonce: Option<String>,
20
21 pub code_challenge: Option<String>,
22
23 pub code_challenge_method: Option<String>,
24
25 pub idp_id: Option<String>,
29
30 pub response_mode: Option<String>,
31
32 pub created_at: SystemTime,
33
34 pub expires_at: SystemTime,
35}
36
37#[derive(Clone, Debug, Deserialize, Serialize)]
38pub struct AuthCodeSession {
39 pub code: String,
40
41 pub client_id: String,
42
43 pub redirect_uri: String,
44
45 pub scope: String,
46
47 pub state: Option<String>,
48
49 pub nonce: Option<String>,
50
51 pub code_challenge: Option<String>,
52
53 pub code_challenge_method: Option<String>,
54
55 pub user_id: OwnedUserId,
56
57 pub idp_id: Option<String>,
60
61 pub created_at: SystemTime,
62
63 pub expires_at: SystemTime,
64}
65
66pub const AUTH_REQUEST_LIFETIME: Duration = Duration::from_mins(10);
67const AUTH_CODE_LIFETIME: Duration = Duration::from_mins(10);
68const AUTH_CODE_LENGTH: usize = 64;
69
70#[implement(super::Server)]
71#[must_use]
72pub fn create_auth_code(&self, auth_req: &AuthRequest, user_id: OwnedUserId) -> String {
73 let now = SystemTime::now();
74 let code = utils::random_string(AUTH_CODE_LENGTH);
75 let session = AuthCodeSession {
76 code: code.clone(),
77 client_id: auth_req.client_id.clone(),
78 redirect_uri: auth_req.redirect_uri.clone(),
79 scope: auth_req.scope.clone(),
80 state: auth_req.state.clone(),
81 nonce: auth_req.nonce.clone(),
82 code_challenge: auth_req.code_challenge.clone(),
83 code_challenge_method: auth_req.code_challenge_method.clone(),
84 user_id,
85 idp_id: auth_req.idp_id.clone(),
86 created_at: now,
87 expires_at: now.checked_add(AUTH_CODE_LIFETIME).unwrap_or(now),
88 };
89
90 self.db
91 .oidccode_authsession
92 .raw_put(&*code, Cbor(&session));
93
94 code
95}
96
97#[implement(super::Server)]
98pub fn store_auth_request(&self, req_id: &str, request: &AuthRequest) {
99 self.db
100 .oidcreqid_authrequest
101 .raw_put(req_id, Cbor(request));
102}
103
104#[implement(super::Server)]
105pub async fn take_auth_request(&self, req_id: &str) -> Result<AuthRequest> {
106 let request: AuthRequest = self
107 .db
108 .oidcreqid_authrequest
109 .get(req_id)
110 .await
111 .deserialized::<Cbor<_>>()
112 .map(|cbor: Cbor<AuthRequest>| cbor.0)
113 .map_err(|_| err!(Request(NotFound("Unknown or expired authorization request"))))?;
114
115 self.db.oidcreqid_authrequest.remove(req_id);
116
117 if SystemTime::now() > request.expires_at {
118 return Err!(Request(NotFound("Authorization request has expired")));
119 }
120
121 Ok(request)
122}
123
124#[implement(super::Server)]
125pub async fn exchange_auth_code(
126 &self,
127 code: &str,
128 client_id: &str,
129 redirect_uri: &str,
130 code_verifier: Option<&str>,
131 require_pkce: bool,
132) -> Result<AuthCodeSession> {
133 let session: AuthCodeSession = self
134 .db
135 .oidccode_authsession
136 .get(code)
137 .await
138 .deserialized::<Cbor<_>>()
139 .map(|cbor: Cbor<AuthCodeSession>| cbor.0)
140 .map_err(|_| err!(Request(Forbidden("Invalid or expired authorization code"))))?;
141
142 self.db.oidccode_authsession.remove(code);
143
144 if SystemTime::now() > session.expires_at {
145 return Err!(Request(Forbidden("Authorization code has expired")));
146 }
147 if session.client_id != client_id {
148 return Err!(Request(Forbidden("client_id mismatch")));
149 }
150 if session.redirect_uri != redirect_uri {
151 return Err!(Request(Forbidden("redirect_uri mismatch")));
152 }
153
154 let Some(challenge) = &session.code_challenge else {
155 if require_pkce {
158 return Err!(Request(Forbidden(
159 "the authorization request carried no PKCE code_challenge"
160 )));
161 }
162
163 return Ok(session);
164 };
165
166 let Some(verifier) = code_verifier else {
167 return Err!(Request(Forbidden("code_verifier required for PKCE")));
168 };
169
170 validate_code_verifier(verifier)?;
171
172 let method = session
173 .code_challenge_method
174 .as_deref()
175 .unwrap_or("S256");
176
177 let computed = match method {
180 | "S256" => b64.encode(sha256::hash(verifier.as_bytes())),
181 | _ => return Err!(Request(InvalidParam("Unsupported code_challenge_method"))),
182 };
183
184 if computed != *challenge {
185 return Err!(Request(Forbidden("PKCE verification failed")));
186 }
187
188 Ok(session)
189}
190
191fn validate_code_verifier(verifier: &str) -> Result {
195 if !(43..=128).contains(&verifier.len()) {
196 return Err!(Request(InvalidParam("code_verifier must be 43-128 characters")));
197 }
198
199 if !verifier
200 .bytes()
201 .all(|b| b.is_ascii_alphanumeric() || b == b'-' || b == b'.' || b == b'_' || b == b'~')
202 {
203 return Err!(Request(InvalidParam("code_verifier contains invalid characters")));
204 }
205
206 Ok(())
207}