Skip to main content

tuwunel_service/oauth/server/
auth.rs

1use std::time::{Duration, SystemTime};
2
3use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as b64};
4use ruma::OwnedUserId;
5use serde::{Deserialize, Serialize};
6use tuwunel_core::{Err, Result, err, implement, utils, utils::hash::sha256};
7use tuwunel_database::{Cbor, Deserialized};
8
9#[derive(Clone, Debug, Deserialize, Serialize)]
10pub struct AuthRequest {
11	pub client_id: String,
12
13	pub redirect_uri: String,
14
15	pub scope: String,
16
17	pub state: Option<String>,
18
19	pub nonce: Option<String>,
20
21	pub code_challenge: Option<String>,
22
23	pub code_challenge_method: Option<String>,
24
25	/// The identity provider ID used to authenticate the user for this
26	/// authorization request. Stored so it can be propagated to the device
27	/// at token exchange time and used for UIAA SSO provider binding.
28	pub idp_id: Option<String>,
29
30	pub response_mode: Option<String>,
31
32	pub created_at: SystemTime,
33
34	pub expires_at: SystemTime,
35}
36
37#[derive(Clone, Debug, Deserialize, Serialize)]
38pub struct AuthCodeSession {
39	pub code: String,
40
41	pub client_id: String,
42
43	pub redirect_uri: String,
44
45	pub scope: String,
46
47	pub state: Option<String>,
48
49	pub nonce: Option<String>,
50
51	pub code_challenge: Option<String>,
52
53	pub code_challenge_method: Option<String>,
54
55	pub user_id: OwnedUserId,
56
57	/// Propagated from the originating AuthRequest; identifies which IdP
58	/// authenticated the user so the device can be tagged at token exchange.
59	pub idp_id: Option<String>,
60
61	pub created_at: SystemTime,
62
63	pub expires_at: SystemTime,
64}
65
66pub const AUTH_REQUEST_LIFETIME: Duration = Duration::from_mins(10);
67const AUTH_CODE_LIFETIME: Duration = Duration::from_mins(10);
68const AUTH_CODE_LENGTH: usize = 64;
69
70#[implement(super::Server)]
71#[must_use]
72pub fn create_auth_code(&self, auth_req: &AuthRequest, user_id: OwnedUserId) -> String {
73	let now = SystemTime::now();
74	let code = utils::random_string(AUTH_CODE_LENGTH);
75	let session = AuthCodeSession {
76		code: code.clone(),
77		client_id: auth_req.client_id.clone(),
78		redirect_uri: auth_req.redirect_uri.clone(),
79		scope: auth_req.scope.clone(),
80		state: auth_req.state.clone(),
81		nonce: auth_req.nonce.clone(),
82		code_challenge: auth_req.code_challenge.clone(),
83		code_challenge_method: auth_req.code_challenge_method.clone(),
84		user_id,
85		idp_id: auth_req.idp_id.clone(),
86		created_at: now,
87		expires_at: now.checked_add(AUTH_CODE_LIFETIME).unwrap_or(now),
88	};
89
90	self.db
91		.oidccode_authsession
92		.raw_put(&*code, Cbor(&session));
93
94	code
95}
96
97#[implement(super::Server)]
98pub fn store_auth_request(&self, req_id: &str, request: &AuthRequest) {
99	self.db
100		.oidcreqid_authrequest
101		.raw_put(req_id, Cbor(request));
102}
103
104#[implement(super::Server)]
105pub async fn take_auth_request(&self, req_id: &str) -> Result<AuthRequest> {
106	let request: AuthRequest = self
107		.db
108		.oidcreqid_authrequest
109		.get(req_id)
110		.await
111		.deserialized::<Cbor<_>>()
112		.map(|cbor: Cbor<AuthRequest>| cbor.0)
113		.map_err(|_| err!(Request(NotFound("Unknown or expired authorization request"))))?;
114
115	self.db.oidcreqid_authrequest.remove(req_id);
116
117	if SystemTime::now() > request.expires_at {
118		return Err!(Request(NotFound("Authorization request has expired")));
119	}
120
121	Ok(request)
122}
123
124#[implement(super::Server)]
125pub async fn exchange_auth_code(
126	&self,
127	code: &str,
128	client_id: &str,
129	redirect_uri: &str,
130	code_verifier: Option<&str>,
131	require_pkce: bool,
132) -> Result<AuthCodeSession> {
133	let session: AuthCodeSession = self
134		.db
135		.oidccode_authsession
136		.get(code)
137		.await
138		.deserialized::<Cbor<_>>()
139		.map(|cbor: Cbor<AuthCodeSession>| cbor.0)
140		.map_err(|_| err!(Request(Forbidden("Invalid or expired authorization code"))))?;
141
142	self.db.oidccode_authsession.remove(code);
143
144	if SystemTime::now() > session.expires_at {
145		return Err!(Request(Forbidden("Authorization code has expired")));
146	}
147	if session.client_id != client_id {
148		return Err!(Request(Forbidden("client_id mismatch")));
149	}
150	if session.redirect_uri != redirect_uri {
151		return Err!(Request(Forbidden("redirect_uri mismatch")));
152	}
153
154	let Some(challenge) = &session.code_challenge else {
155		// Reject a challenge-less code when PKCE is required: the knob is
156		// reloadable and codes outlive an off->on flip of it.
157		if require_pkce {
158			return Err!(Request(Forbidden(
159				"the authorization request carried no PKCE code_challenge"
160			)));
161		}
162
163		return Ok(session);
164	};
165
166	let Some(verifier) = code_verifier else {
167		return Err!(Request(Forbidden("code_verifier required for PKCE")));
168	};
169
170	validate_code_verifier(verifier)?;
171
172	let method = session
173		.code_challenge_method
174		.as_deref()
175		.unwrap_or("S256");
176
177	// Only S256 is advertised in discovery metadata; reject plain to avoid
178	// downgrade attacks (plain challenge == verifier, trivially intercepted).
179	let computed = match method {
180		| "S256" => b64.encode(sha256::hash(verifier.as_bytes())),
181		| _ => return Err!(Request(InvalidParam("Unsupported code_challenge_method"))),
182	};
183
184	if computed != *challenge {
185		return Err!(Request(Forbidden("PKCE verification failed")));
186	}
187
188	Ok(session)
189}
190
191/// Validate code_verifier per RFC 7636 Section 4.1: must be 43-128
192/// characters using only unreserved characters [A-Z] / [a-z] / [0-9] /
193/// "-" / "." / "_" / "~".
194fn validate_code_verifier(verifier: &str) -> Result {
195	if !(43..=128).contains(&verifier.len()) {
196		return Err!(Request(InvalidParam("code_verifier must be 43-128 characters")));
197	}
198
199	if !verifier
200		.bytes()
201		.all(|b| b.is_ascii_alphanumeric() || b == b'-' || b == b'.' || b == b'_' || b == b'~')
202	{
203		return Err!(Request(InvalidParam("code_verifier contains invalid characters")));
204	}
205
206	Ok(())
207}