Skip to main content

tuwunel_service/oauth/server/
auth.rs

1use std::time::{Duration, SystemTime};
2
3use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as b64};
4use ruma::OwnedUserId;
5use serde::{Deserialize, Serialize};
6use tuwunel_core::{Err, Result, err, implement, utils, utils::hash::sha256};
7use tuwunel_database::{Cbor, Deserialized};
8
9#[derive(Clone, Debug, Deserialize, Serialize)]
10pub struct AuthRequest {
11	pub client_id: String,
12
13	pub redirect_uri: String,
14
15	pub scope: String,
16
17	pub state: Option<String>,
18
19	pub nonce: Option<String>,
20
21	pub code_challenge: Option<String>,
22
23	pub code_challenge_method: Option<String>,
24
25	/// The identity provider ID used to authenticate the user for this
26	/// authorization request. Stored so it can be propagated to the device
27	/// at token exchange time and used for UIAA SSO provider binding.
28	pub idp_id: Option<String>,
29
30	pub response_mode: Option<String>,
31
32	pub created_at: SystemTime,
33
34	pub expires_at: SystemTime,
35}
36
37#[derive(Clone, Debug, Deserialize, Serialize)]
38pub struct AuthCodeSession {
39	pub code: String,
40
41	pub client_id: String,
42
43	pub redirect_uri: String,
44
45	pub scope: String,
46
47	pub state: Option<String>,
48
49	pub nonce: Option<String>,
50
51	pub code_challenge: Option<String>,
52
53	pub code_challenge_method: Option<String>,
54
55	pub user_id: OwnedUserId,
56
57	/// Propagated from the originating AuthRequest; identifies which IdP
58	/// authenticated the user so the device can be tagged at token exchange.
59	pub idp_id: Option<String>,
60
61	pub created_at: SystemTime,
62
63	pub expires_at: SystemTime,
64}
65
66pub const AUTH_REQUEST_LIFETIME: Duration = Duration::from_mins(10);
67const AUTH_CODE_LIFETIME: Duration = Duration::from_mins(10);
68const AUTH_CODE_LENGTH: usize = 64;
69
70#[implement(super::Server)]
71#[must_use]
72pub fn create_auth_code(&self, auth_req: &AuthRequest, user_id: OwnedUserId) -> String {
73	let now = SystemTime::now();
74	let code = utils::random_string(AUTH_CODE_LENGTH);
75	let session = AuthCodeSession {
76		code: code.clone(),
77		client_id: auth_req.client_id.clone(),
78		redirect_uri: auth_req.redirect_uri.clone(),
79		scope: auth_req.scope.clone(),
80		state: auth_req.state.clone(),
81		nonce: auth_req.nonce.clone(),
82		code_challenge: auth_req.code_challenge.clone(),
83		code_challenge_method: auth_req.code_challenge_method.clone(),
84		user_id,
85		idp_id: auth_req.idp_id.clone(),
86		created_at: now,
87		expires_at: now.checked_add(AUTH_CODE_LIFETIME).unwrap_or(now),
88	};
89
90	self.db
91		.oidccode_authsession
92		.raw_put(&*code, Cbor(&session));
93
94	code
95}
96
97#[implement(super::Server)]
98pub fn store_auth_request(&self, req_id: &str, request: &AuthRequest) {
99	self.db
100		.oidcreqid_authrequest
101		.raw_put(req_id, Cbor(request));
102}
103
104#[implement(super::Server)]
105pub async fn take_auth_request(&self, req_id: &str) -> Result<AuthRequest> {
106	let request: AuthRequest = self
107		.db
108		.oidcreqid_authrequest
109		.get(req_id)
110		.await
111		.deserialized::<Cbor<_>>()
112		.map(|cbor: Cbor<AuthRequest>| cbor.0)
113		.map_err(|_| err!(Request(NotFound("Unknown or expired authorization request"))))?;
114
115	self.db.oidcreqid_authrequest.remove(req_id);
116
117	if SystemTime::now() > request.expires_at {
118		return Err!(Request(NotFound("Authorization request has expired")));
119	}
120
121	Ok(request)
122}
123
124#[implement(super::Server)]
125pub async fn exchange_auth_code(
126	&self,
127	code: &str,
128	client_id: &str,
129	redirect_uri: &str,
130	code_verifier: Option<&str>,
131) -> Result<AuthCodeSession> {
132	let session: AuthCodeSession = self
133		.db
134		.oidccode_authsession
135		.get(code)
136		.await
137		.deserialized::<Cbor<_>>()
138		.map(|cbor: Cbor<AuthCodeSession>| cbor.0)
139		.map_err(|_| err!(Request(Forbidden("Invalid or expired authorization code"))))?;
140
141	self.db.oidccode_authsession.remove(code);
142
143	if SystemTime::now() > session.expires_at {
144		return Err!(Request(Forbidden("Authorization code has expired")));
145	}
146	if session.client_id != client_id {
147		return Err!(Request(Forbidden("client_id mismatch")));
148	}
149	if session.redirect_uri != redirect_uri {
150		return Err!(Request(Forbidden("redirect_uri mismatch")));
151	}
152
153	let Some(challenge) = &session.code_challenge else {
154		return Ok(session);
155	};
156
157	let Some(verifier) = code_verifier else {
158		return Err!(Request(Forbidden("code_verifier required for PKCE")));
159	};
160
161	validate_code_verifier(verifier)?;
162
163	let method = session
164		.code_challenge_method
165		.as_deref()
166		.unwrap_or("S256");
167
168	// Only S256 is advertised in discovery metadata; reject plain to avoid
169	// downgrade attacks (plain challenge == verifier, trivially intercepted).
170	let computed = match method {
171		| "S256" => b64.encode(sha256::hash(verifier.as_bytes())),
172		| _ => return Err!(Request(InvalidParam("Unsupported code_challenge_method"))),
173	};
174
175	if computed != *challenge {
176		return Err!(Request(Forbidden("PKCE verification failed")));
177	}
178
179	Ok(session)
180}
181
182/// Validate code_verifier per RFC 7636 Section 4.1: must be 43-128
183/// characters using only unreserved characters [A-Z] / [a-z] / [0-9] /
184/// "-" / "." / "_" / "~".
185fn validate_code_verifier(verifier: &str) -> Result {
186	if !(43..=128).contains(&verifier.len()) {
187		return Err!(Request(InvalidParam("code_verifier must be 43-128 characters")));
188	}
189
190	if !verifier
191		.bytes()
192		.all(|b| b.is_ascii_alphanumeric() || b == b'-' || b == b'.' || b == b'_' || b == b'~')
193	{
194		return Err!(Request(InvalidParam("code_verifier contains invalid characters")));
195	}
196
197	Ok(())
198}