Skip to main content

tuwunel_service/oauth/
mod.rs

1pub mod providers;
2pub mod server;
3pub mod sessions;
4pub mod token_response;
5pub mod user_info;
6
7use std::{
8	collections::HashMap,
9	net::IpAddr,
10	sync::{Arc, Mutex},
11	time::Instant,
12};
13
14use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as b64encode};
15use futures::{Stream, StreamExt, TryStreamExt};
16use http::StatusCode;
17use reqwest::{
18	Method,
19	header::{ACCEPT, CONTENT_TYPE},
20};
21use ruma::{
22	UserId,
23	api::error::{ErrorKind, LimitExceededErrorData},
24};
25use serde::Serialize;
26use serde_json::Value as JsonValue;
27use tuwunel_core::{
28	Err, Error, Result, err, implement,
29	utils::{hash::sha256, result::LogErr, stream::ReadyExt},
30	warn,
31};
32use url::Url;
33
34use self::{providers::Providers, sessions::Sessions};
35pub use self::{
36	providers::{Provider, ProviderId},
37	server::Server,
38	sessions::{CODE_VERIFIER_LENGTH, SESSION_ID_LENGTH, Session, SessionId},
39	token_response::TokenResponse,
40	user_info::UserInfo,
41};
42use crate::{SelfServices, client::read_response_capped};
43
44/// Per-client-IP token-bucket table: last-refill instant and remaining tokens.
45type Ratelimiter = Mutex<HashMap<IpAddr, (Instant, f64)>>;
46
47pub struct Service {
48	services: SelfServices,
49	pub providers: Arc<Providers>,
50	pub sessions: Arc<Sessions>,
51	pub server: Option<Arc<Server>>,
52	ratelimiter: Ratelimiter,
53	device_ratelimiter: Ratelimiter,
54}
55
56impl crate::Service for Service {
57	fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
58		let providers = Arc::new(Providers::build(args));
59		let sessions = Arc::new(Sessions::build(args, providers.clone()));
60		let server = Server::build(args)?.map(Arc::new);
61
62		Ok(Arc::new(Self {
63			services: args.services.clone(),
64			sessions,
65			providers,
66			server,
67			ratelimiter: Mutex::new(HashMap::new()),
68			device_ratelimiter: Mutex::new(HashMap::new()),
69		}))
70	}
71
72	fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
73}
74
75#[implement(Service)]
76#[inline]
77pub fn get_server(&self) -> Result<&Server> {
78	self.server
79		.as_deref()
80		.ok_or_else(|| err!(Request(Unrecognized("OIDC server not configured"))))
81}
82
83/// Cap on the rate-limit table; fully refilled buckets are pruned past it.
84const RATELIMIT_MAP_CAP: usize = 1 << 16;
85
86/// Always-on throttle for the RFC 8628 device user-code endpoints. The
87/// `user_code` is low-entropy by design (§6.1), so §5.1 requires bounding
88/// guesses regardless of the optional `oidc_rc_*` knobs; the burst stays
89/// generous for the one code a real user enters.
90const DEVICE_RC_PER_SECOND: f64 = 1.0;
91const DEVICE_RC_BURST: f64 = 60.0;
92
93/// Shared per-client-IP token-bucket throttle for the OIDC endpoints. A no-op
94/// unless both `oidc_rc_per_second` and `oidc_rc_burst_count` are configured.
95#[implement(Service)]
96pub fn check_rate_limit(&self, client: IpAddr) -> Result {
97	let config = &self.services.config;
98	let rate = f64::from(config.oidc_rc_per_second);
99	let burst = f64::from(config.oidc_rc_burst_count);
100
101	if rate <= 0.0 || burst <= 0.0 {
102		return Ok(());
103	}
104
105	check_bucket(&self.ratelimiter, client, rate, burst)
106}
107
108/// Always-on anti-brute-force throttle for the device user-code endpoints
109/// (RFC 8628 §5.1), independent of the optional `oidc_rc_*` knobs.
110#[implement(Service)]
111pub fn check_device_rate_limit(&self, client: IpAddr) -> Result {
112	check_bucket(&self.device_ratelimiter, client, DEVICE_RC_PER_SECOND, DEVICE_RC_BURST)
113}
114
115fn check_bucket(table: &Ratelimiter, client: IpAddr, rate: f64, burst: f64) -> Result {
116	let now = Instant::now();
117	let mut buckets = table.lock()?;
118
119	// A fully refilled bucket equals an absent one; prune those past the cap so
120	// a source-address spray cannot grow the table without bound.
121	if buckets.len() >= RATELIMIT_MAP_CAP {
122		buckets.retain(|_, bucket| {
123			let (last, toks) = *bucket;
124			now.duration_since(last)
125				.as_secs_f64()
126				.mul_add(rate, toks)
127				< burst
128		});
129	}
130
131	let (last_time, tokens) = buckets
132		.entry(client)
133		.or_insert_with(|| (now, burst));
134
135	let new_tokens = now
136		.duration_since(*last_time)
137		.as_secs_f64()
138		.mul_add(rate, *tokens)
139		.min(burst);
140
141	if new_tokens < 1.0 {
142		return Err(Error::Request(
143			ErrorKind::LimitExceeded(LimitExceededErrorData { retry_after: None }),
144			"Too many OIDC requests.".into(),
145			StatusCode::TOO_MANY_REQUESTS,
146		));
147	}
148
149	*last_time = now;
150	*tokens = new_tokens - 1.0;
151
152	Ok(())
153}
154
155/// Remove all session state for a user. For debug and developer use only;
156/// deleting state can cause registration conflicts and unintended
157/// re-registrations.
158#[implement(Service)]
159#[tracing::instrument(level = "debug", skip(self))]
160pub async fn delete_user_sessions(&self, user_id: &UserId) {
161	self.user_sessions(user_id)
162		.ready_filter_map(Result::ok)
163		.ready_filter_map(|(_, session)| session.sess_id)
164		.for_each(async |sess_id| {
165			self.sessions.delete(&sess_id).await;
166		})
167		.await;
168}
169
170/// Revoke all session tokens for a user.
171#[implement(Service)]
172#[tracing::instrument(level = "debug", skip(self))]
173pub async fn revoke_user_tokens(&self, user_id: &UserId) {
174	self.user_sessions(user_id)
175		.ready_filter_map(Result::ok)
176		.for_each(async |(provider, session)| {
177			self.revoke_token((&provider, &session))
178				.await
179				.log_err()
180				.ok();
181		})
182		.await;
183}
184
185/// Get user's authorizations. Lists pairs of `(Provider, Session)` for a user.
186#[implement(Service)]
187#[tracing::instrument(level = "debug", skip(self))]
188pub fn user_sessions(
189	&self,
190	user_id: &UserId,
191) -> impl Stream<Item = Result<(Provider, Session)>> + Send {
192	self.sessions
193		.get_by_user(user_id)
194		.and_then(async |session| Ok((self.sessions.provider(&session).await?, session)))
195}
196
197/// Network request to a Provider returning userinfo for a Session. The session
198/// must have a valid access token.
199#[implement(Service)]
200#[tracing::instrument(level = "debug", skip_all, ret)]
201pub async fn request_userinfo(
202	&self,
203	(provider, session): (&Provider, &Session),
204) -> Result<UserInfo> {
205	#[derive(Debug, Serialize)]
206	struct Query;
207
208	let url = provider
209		.userinfo_url
210		.clone()
211		.ok_or_else(|| err!(Config("userinfo_url", "Missing userinfo URL in config")))?;
212
213	self.request((Some(provider), Some(session)), Method::GET, url, Option::<Query>::None)
214		.await
215		.and_then(|value| serde_json::from_value(value).map_err(Into::into))
216		.log_err()
217}
218
219/// Network request to a Provider returning information for a Session based on
220/// its access token.
221#[implement(Service)]
222#[tracing::instrument(level = "debug", skip_all, ret)]
223pub async fn request_tokeninfo(
224	&self,
225	(provider, session): (&Provider, &Session),
226) -> Result<UserInfo> {
227	#[derive(Debug, Serialize)]
228	struct Query;
229
230	let url = provider
231		.introspection_url
232		.clone()
233		.ok_or_else(|| {
234			err!(Config("introspection_url", "Missing introspection URL in config"))
235		})?;
236
237	self.request((Some(provider), Some(session)), Method::GET, url, Option::<Query>::None)
238		.await
239		.and_then(|value| serde_json::from_value(value).map_err(Into::into))
240		.log_err()
241}
242
243/// Network request to a Provider revoking a Session's token.
244#[implement(Service)]
245#[tracing::instrument(level = "debug", skip_all, ret)]
246pub async fn revoke_token(&self, (provider, session): (&Provider, &Session)) -> Result {
247	#[derive(Debug, Serialize)]
248	struct RevokeQuery<'a> {
249		client_id: &'a str,
250		client_secret: &'a str,
251	}
252
253	let client_secret = provider.get_client_secret().await?;
254
255	let query = RevokeQuery {
256		client_id: &provider.client_id,
257		client_secret: &client_secret,
258	};
259
260	let url = provider
261		.revocation_url
262		.clone()
263		.ok_or_else(|| err!(Config("revocation_url", "Missing revocation URL in config")))?;
264
265	self.request((Some(provider), Some(session)), Method::POST, url, Some(query))
266		.await
267		.log_err()
268		.map(|_| ())
269}
270
271/// Network request to a Provider to obtain an access token for a Session using
272/// a provided code.
273#[implement(Service)]
274#[tracing::instrument(level = "debug", skip_all, ret)]
275pub async fn request_token(
276	&self,
277	(provider, session): (&Provider, &Session),
278	code: &str,
279) -> Result<TokenResponse> {
280	#[derive(Debug, Serialize)]
281	struct TokenQuery<'a> {
282		client_id: &'a str,
283		client_secret: &'a str,
284		grant_type: &'a str,
285		code: &'a str,
286		code_verifier: Option<&'a str>,
287		redirect_uri: Option<&'a str>,
288	}
289
290	let client_secret = provider.get_client_secret().await?;
291
292	let query = TokenQuery {
293		client_id: &provider.client_id,
294		client_secret: &client_secret,
295		grant_type: "authorization_code",
296		code,
297		code_verifier: session.code_verifier.as_deref(),
298		redirect_uri: provider.callback_url.as_ref().map(Url::as_str),
299	};
300
301	let url = provider
302		.token_url
303		.clone()
304		.ok_or_else(|| err!(Config("token_url", "Missing token URL in config")))?;
305
306	self.request((Some(provider), Some(session)), Method::POST, url, Some(query))
307		.await
308		.and_then(|value| serde_json::from_value(value).map_err(Into::into))
309		.log_err()
310}
311
312/// Send a request to a provider; this is somewhat abstract since URL's are
313/// formed prior to this call and could point at anything, however this function
314/// uses the oauth-specific http client and is configured for JSON with special
315/// casing for an `error` property in the response.
316#[implement(Service)]
317#[tracing::instrument(
318	name = "request",
319	level = "debug",
320	ret(level = "trace"),
321	skip(self, body)
322)]
323pub async fn request<Body>(
324	&self,
325	(provider, session): (Option<&Provider>, Option<&Session>),
326	method: Method,
327	url: Url,
328	body: Option<Body>,
329) -> Result<JsonValue>
330where
331	Body: Serialize,
332{
333	let mut request = self
334		.services
335		.client
336		.oauth
337		.request(method, url)
338		.header(ACCEPT, "application/json");
339
340	if let Some(body) = body.map(serde_html_form::to_string).transpose()? {
341		request = request
342			.header(CONTENT_TYPE, "application/x-www-form-urlencoded")
343			.body(body);
344	}
345
346	if let Some(session) = session
347		&& let Some(access_token) = session.access_token.clone()
348	{
349		request = request.bearer_auth(access_token);
350	}
351
352	let limit = self.services.config.max_response_size;
353	let http_response = request.send().await?.error_for_status()?;
354
355	let body = read_response_capped(http_response, limit).await?;
356	let response: JsonValue = serde_json::from_slice(&body)?;
357
358	if let Some(response) = response.as_object().as_ref()
359		&& let Some(error) = response.get("error").and_then(JsonValue::as_str)
360	{
361		let description = response
362			.get("error_description")
363			.and_then(JsonValue::as_str)
364			.unwrap_or("(no description)");
365
366		return Err!(Request(Forbidden("Error from provider: {error}: {description}",)));
367	}
368
369	Ok(response)
370}
371
372/// Generate a unique-id string determined by the combination of `Provider` and
373/// `Session` instances.
374#[inline]
375pub fn unique_id((provider, session): (&Provider, &Session)) -> Result<String> {
376	unique_id_parts((provider, session)).and_then(unique_id_iss_sub)
377}
378
379/// Generate a unique-id string determined by the combination of `Provider`
380/// instance and `sub` string.
381#[inline]
382pub fn unique_id_sub((provider, sub): (&Provider, &str)) -> Result<String> {
383	unique_id_sub_parts((provider, sub)).and_then(unique_id_iss_sub)
384}
385
386/// Generate a unique-id string determined by the combination of `issuer_url`
387/// and `Session` instance.
388#[inline]
389pub fn unique_id_iss((iss, session): (&str, &Session)) -> Result<String> {
390	unique_id_iss_parts((iss, session)).and_then(unique_id_iss_sub)
391}
392
393/// Generate a unique-id string determined by the `issuer_url` and the `sub`
394/// strings directly.
395pub fn unique_id_iss_sub((iss, sub): (&str, &str)) -> Result<String> {
396	let hash = sha256::delimited([iss, sub].iter());
397	let b64 = b64encode.encode(hash);
398
399	Ok(b64)
400}
401
402fn unique_id_parts<'a>(
403	(provider, session): (&'a Provider, &'a Session),
404) -> Result<(&'a str, &'a str)> {
405	identity_issuer(provider)
406		.ok_or_else(|| err!(Config("issuer_url", "issuer_url not found for this provider.")))
407		.and_then(|iss| unique_id_iss_parts((iss, session)))
408}
409
410fn unique_id_sub_parts<'a>(
411	(provider, sub): (&'a Provider, &'a str),
412) -> Result<(&'a str, &'a str)> {
413	identity_issuer(provider)
414		.ok_or_else(|| err!(Config("issuer_url", "issuer_url not found for this provider.")))
415		.map(|iss| (iss, sub))
416}
417
418/// Issuer string used as input to the identity hash. Pinned per-brand for
419/// providers whose published issuer has changed under us, so existing account
420/// associations survive the change.
421fn identity_issuer(provider: &Provider) -> Option<&str> {
422	match provider.brand.as_str() {
423		| "github" => Some("https://github.com/"),
424		| _ => provider.issuer_url.as_ref().map(Url::as_str),
425	}
426}
427
428fn unique_id_iss_parts<'a>((iss, session): (&'a str, &'a Session)) -> Result<(&'a str, &'a str)> {
429	session
430		.user_info
431		.as_ref()
432		.map(|user_info| user_info.sub.as_str())
433		.ok_or_else(|| err!(Request(NotFound("user_info not found for this session."))))
434		.map(|sub| (iss, sub))
435}