tuwunel_service/oauth/server/
auth.rs1use std::time::{Duration, SystemTime};
2
3use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as b64};
4use ruma::OwnedUserId;
5use serde::{Deserialize, Serialize};
6use tuwunel_core::{Err, Result, err, implement, utils, utils::hash::sha256};
7use tuwunel_database::{Cbor, Deserialized};
8
9#[derive(Clone, Debug, Deserialize, Serialize)]
10pub struct AuthRequest {
11 pub client_id: String,
12
13 pub redirect_uri: String,
14
15 pub scope: String,
16
17 pub state: Option<String>,
18
19 pub nonce: Option<String>,
20
21 pub code_challenge: Option<String>,
22
23 pub code_challenge_method: Option<String>,
24
25 pub idp_id: Option<String>,
29
30 pub response_mode: Option<String>,
31
32 pub created_at: SystemTime,
33
34 pub expires_at: SystemTime,
35}
36
37#[derive(Clone, Debug, Deserialize, Serialize)]
38pub struct AuthCodeSession {
39 pub code: String,
40
41 pub client_id: String,
42
43 pub redirect_uri: String,
44
45 pub scope: String,
46
47 pub state: Option<String>,
48
49 pub nonce: Option<String>,
50
51 pub code_challenge: Option<String>,
52
53 pub code_challenge_method: Option<String>,
54
55 pub user_id: OwnedUserId,
56
57 pub idp_id: Option<String>,
60
61 pub created_at: SystemTime,
62
63 pub expires_at: SystemTime,
64}
65
66pub const AUTH_REQUEST_LIFETIME: Duration = Duration::from_mins(10);
67const AUTH_CODE_LIFETIME: Duration = Duration::from_mins(10);
68const AUTH_CODE_LENGTH: usize = 64;
69
70#[implement(super::Server)]
71#[must_use]
72pub fn create_auth_code(&self, auth_req: &AuthRequest, user_id: OwnedUserId) -> String {
73 let now = SystemTime::now();
74 let code = utils::random_string(AUTH_CODE_LENGTH);
75 let session = AuthCodeSession {
76 code: code.clone(),
77 client_id: auth_req.client_id.clone(),
78 redirect_uri: auth_req.redirect_uri.clone(),
79 scope: auth_req.scope.clone(),
80 state: auth_req.state.clone(),
81 nonce: auth_req.nonce.clone(),
82 code_challenge: auth_req.code_challenge.clone(),
83 code_challenge_method: auth_req.code_challenge_method.clone(),
84 user_id,
85 idp_id: auth_req.idp_id.clone(),
86 created_at: now,
87 expires_at: now.checked_add(AUTH_CODE_LIFETIME).unwrap_or(now),
88 };
89
90 self.db
91 .oidccode_authsession
92 .raw_put(&*code, Cbor(&session));
93
94 code
95}
96
97#[implement(super::Server)]
98pub fn store_auth_request(&self, req_id: &str, request: &AuthRequest) {
99 self.db
100 .oidcreqid_authrequest
101 .raw_put(req_id, Cbor(request));
102}
103
104#[implement(super::Server)]
105pub async fn take_auth_request(&self, req_id: &str) -> Result<AuthRequest> {
106 let request: AuthRequest = self
107 .db
108 .oidcreqid_authrequest
109 .get(req_id)
110 .await
111 .deserialized::<Cbor<_>>()
112 .map(|cbor: Cbor<AuthRequest>| cbor.0)
113 .map_err(|_| err!(Request(NotFound("Unknown or expired authorization request"))))?;
114
115 self.db.oidcreqid_authrequest.remove(req_id);
116
117 if SystemTime::now() > request.expires_at {
118 return Err!(Request(NotFound("Authorization request has expired")));
119 }
120
121 Ok(request)
122}
123
124#[implement(super::Server)]
125pub async fn exchange_auth_code(
126 &self,
127 code: &str,
128 client_id: &str,
129 redirect_uri: &str,
130 code_verifier: Option<&str>,
131) -> Result<AuthCodeSession> {
132 let session: AuthCodeSession = self
133 .db
134 .oidccode_authsession
135 .get(code)
136 .await
137 .deserialized::<Cbor<_>>()
138 .map(|cbor: Cbor<AuthCodeSession>| cbor.0)
139 .map_err(|_| err!(Request(Forbidden("Invalid or expired authorization code"))))?;
140
141 self.db.oidccode_authsession.remove(code);
142
143 if SystemTime::now() > session.expires_at {
144 return Err!(Request(Forbidden("Authorization code has expired")));
145 }
146 if session.client_id != client_id {
147 return Err!(Request(Forbidden("client_id mismatch")));
148 }
149 if session.redirect_uri != redirect_uri {
150 return Err!(Request(Forbidden("redirect_uri mismatch")));
151 }
152
153 let Some(challenge) = &session.code_challenge else {
154 return Ok(session);
155 };
156
157 let Some(verifier) = code_verifier else {
158 return Err!(Request(Forbidden("code_verifier required for PKCE")));
159 };
160
161 validate_code_verifier(verifier)?;
162
163 let method = session
164 .code_challenge_method
165 .as_deref()
166 .unwrap_or("S256");
167
168 let computed = match method {
171 | "S256" => b64.encode(sha256::hash(verifier.as_bytes())),
172 | _ => return Err!(Request(InvalidParam("Unsupported code_challenge_method"))),
173 };
174
175 if computed != *challenge {
176 return Err!(Request(Forbidden("PKCE verification failed")));
177 }
178
179 Ok(session)
180}
181
182fn validate_code_verifier(verifier: &str) -> Result {
186 if !(43..=128).contains(&verifier.len()) {
187 return Err!(Request(InvalidParam("code_verifier must be 43-128 characters")));
188 }
189
190 if !verifier
191 .bytes()
192 .all(|b| b.is_ascii_alphanumeric() || b == b'-' || b == b'.' || b == b'_' || b == b'~')
193 {
194 return Err!(Request(InvalidParam("code_verifier contains invalid characters")));
195 }
196
197 Ok(())
198}