1pub mod providers;
2pub mod server;
3pub mod sessions;
4pub mod token_response;
5pub mod user_info;
6
7use std::{
8 collections::HashMap,
9 net::IpAddr,
10 sync::{Arc, Mutex},
11 time::Instant,
12};
13
14use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as b64encode};
15use futures::{Stream, StreamExt, TryStreamExt};
16use http::StatusCode;
17use reqwest::{
18 Method,
19 header::{ACCEPT, CONTENT_TYPE},
20};
21use ruma::{
22 UserId,
23 api::error::{ErrorKind, LimitExceededErrorData},
24};
25use serde::Serialize;
26use serde_json::Value as JsonValue;
27use tuwunel_core::{
28 Err, Error, Result, err, implement,
29 utils::{hash::sha256, result::LogErr, stream::ReadyExt},
30 warn,
31};
32use url::Url;
33
34use self::{providers::Providers, sessions::Sessions};
35pub use self::{
36 providers::{Provider, ProviderId},
37 server::Server,
38 sessions::{CODE_VERIFIER_LENGTH, SESSION_ID_LENGTH, Session, SessionId},
39 token_response::TokenResponse,
40 user_info::UserInfo,
41};
42use crate::{SelfServices, client::read_response_capped};
43
44type Ratelimiter = Mutex<HashMap<IpAddr, (Instant, f64)>>;
46
47pub struct Service {
48 services: SelfServices,
49 pub providers: Arc<Providers>,
50 pub sessions: Arc<Sessions>,
51 pub server: Option<Arc<Server>>,
52 ratelimiter: Ratelimiter,
53 device_ratelimiter: Ratelimiter,
54}
55
56impl crate::Service for Service {
57 fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
58 let providers = Arc::new(Providers::build(args));
59 let sessions = Arc::new(Sessions::build(args, providers.clone()));
60 let server = Server::build(args)?.map(Arc::new);
61
62 Ok(Arc::new(Self {
63 services: args.services.clone(),
64 sessions,
65 providers,
66 server,
67 ratelimiter: Mutex::new(HashMap::new()),
68 device_ratelimiter: Mutex::new(HashMap::new()),
69 }))
70 }
71
72 fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
73}
74
75#[implement(Service)]
76#[inline]
77pub fn get_server(&self) -> Result<&Server> {
78 self.server
79 .as_deref()
80 .ok_or_else(|| err!(Request(Unrecognized("OIDC server not configured"))))
81}
82
83const RATELIMIT_MAP_CAP: usize = 1 << 16;
85
86const DEVICE_RC_PER_SECOND: f64 = 1.0;
91const DEVICE_RC_BURST: f64 = 60.0;
92
93#[implement(Service)]
96pub fn check_rate_limit(&self, client: IpAddr) -> Result {
97 let config = &self.services.config;
98 let rate = f64::from(config.oidc_rc_per_second);
99 let burst = f64::from(config.oidc_rc_burst_count);
100
101 if rate <= 0.0 || burst <= 0.0 {
102 return Ok(());
103 }
104
105 check_bucket(&self.ratelimiter, client, rate, burst)
106}
107
108#[implement(Service)]
111pub fn check_device_rate_limit(&self, client: IpAddr) -> Result {
112 check_bucket(&self.device_ratelimiter, client, DEVICE_RC_PER_SECOND, DEVICE_RC_BURST)
113}
114
115fn check_bucket(table: &Ratelimiter, client: IpAddr, rate: f64, burst: f64) -> Result {
116 let now = Instant::now();
117 let mut buckets = table.lock()?;
118
119 if buckets.len() >= RATELIMIT_MAP_CAP {
122 buckets.retain(|_, bucket| {
123 let (last, toks) = *bucket;
124 now.duration_since(last)
125 .as_secs_f64()
126 .mul_add(rate, toks)
127 < burst
128 });
129 }
130
131 let (last_time, tokens) = buckets
132 .entry(client)
133 .or_insert_with(|| (now, burst));
134
135 let new_tokens = now
136 .duration_since(*last_time)
137 .as_secs_f64()
138 .mul_add(rate, *tokens)
139 .min(burst);
140
141 if new_tokens < 1.0 {
142 return Err(Error::Request(
143 ErrorKind::LimitExceeded(LimitExceededErrorData { retry_after: None }),
144 "Too many OIDC requests.".into(),
145 StatusCode::TOO_MANY_REQUESTS,
146 ));
147 }
148
149 *last_time = now;
150 *tokens = new_tokens - 1.0;
151
152 Ok(())
153}
154
155#[implement(Service)]
159#[tracing::instrument(level = "debug", skip(self))]
160pub async fn delete_user_sessions(&self, user_id: &UserId) {
161 self.user_sessions(user_id)
162 .ready_filter_map(Result::ok)
163 .ready_filter_map(|(_, session)| session.sess_id)
164 .for_each(async |sess_id| {
165 self.sessions.delete(&sess_id).await;
166 })
167 .await;
168}
169
170#[implement(Service)]
172#[tracing::instrument(level = "debug", skip(self))]
173pub async fn revoke_user_tokens(&self, user_id: &UserId) {
174 self.user_sessions(user_id)
175 .ready_filter_map(Result::ok)
176 .for_each(async |(provider, session)| {
177 self.revoke_token((&provider, &session))
178 .await
179 .log_err()
180 .ok();
181 })
182 .await;
183}
184
185#[implement(Service)]
187#[tracing::instrument(level = "debug", skip(self))]
188pub fn user_sessions(
189 &self,
190 user_id: &UserId,
191) -> impl Stream<Item = Result<(Provider, Session)>> + Send {
192 self.sessions
193 .get_by_user(user_id)
194 .and_then(async |session| Ok((self.sessions.provider(&session).await?, session)))
195}
196
197#[implement(Service)]
200#[tracing::instrument(level = "debug", skip_all, ret)]
201pub async fn request_userinfo(
202 &self,
203 (provider, session): (&Provider, &Session),
204) -> Result<UserInfo> {
205 #[derive(Debug, Serialize)]
206 struct Query;
207
208 let url = provider
209 .userinfo_url
210 .clone()
211 .ok_or_else(|| err!(Config("userinfo_url", "Missing userinfo URL in config")))?;
212
213 self.request((Some(provider), Some(session)), Method::GET, url, Option::<Query>::None)
214 .await
215 .and_then(|value| serde_json::from_value(value).map_err(Into::into))
216 .log_err()
217}
218
219#[implement(Service)]
222#[tracing::instrument(level = "debug", skip_all, ret)]
223pub async fn request_tokeninfo(
224 &self,
225 (provider, session): (&Provider, &Session),
226) -> Result<UserInfo> {
227 #[derive(Debug, Serialize)]
228 struct Query;
229
230 let url = provider
231 .introspection_url
232 .clone()
233 .ok_or_else(|| {
234 err!(Config("introspection_url", "Missing introspection URL in config"))
235 })?;
236
237 self.request((Some(provider), Some(session)), Method::GET, url, Option::<Query>::None)
238 .await
239 .and_then(|value| serde_json::from_value(value).map_err(Into::into))
240 .log_err()
241}
242
243#[implement(Service)]
245#[tracing::instrument(level = "debug", skip_all, ret)]
246pub async fn revoke_token(&self, (provider, session): (&Provider, &Session)) -> Result {
247 #[derive(Debug, Serialize)]
248 struct RevokeQuery<'a> {
249 client_id: &'a str,
250 client_secret: &'a str,
251 }
252
253 let client_secret = provider.get_client_secret().await?;
254
255 let query = RevokeQuery {
256 client_id: &provider.client_id,
257 client_secret: &client_secret,
258 };
259
260 let url = provider
261 .revocation_url
262 .clone()
263 .ok_or_else(|| err!(Config("revocation_url", "Missing revocation URL in config")))?;
264
265 self.request((Some(provider), Some(session)), Method::POST, url, Some(query))
266 .await
267 .log_err()
268 .map(|_| ())
269}
270
271#[implement(Service)]
274#[tracing::instrument(level = "debug", skip_all, ret)]
275pub async fn request_token(
276 &self,
277 (provider, session): (&Provider, &Session),
278 code: &str,
279) -> Result<TokenResponse> {
280 #[derive(Debug, Serialize)]
281 struct TokenQuery<'a> {
282 client_id: &'a str,
283 client_secret: &'a str,
284 grant_type: &'a str,
285 code: &'a str,
286 code_verifier: Option<&'a str>,
287 redirect_uri: Option<&'a str>,
288 }
289
290 let client_secret = provider.get_client_secret().await?;
291
292 let query = TokenQuery {
293 client_id: &provider.client_id,
294 client_secret: &client_secret,
295 grant_type: "authorization_code",
296 code,
297 code_verifier: session.code_verifier.as_deref(),
298 redirect_uri: provider.callback_url.as_ref().map(Url::as_str),
299 };
300
301 let url = provider
302 .token_url
303 .clone()
304 .ok_or_else(|| err!(Config("token_url", "Missing token URL in config")))?;
305
306 self.request((Some(provider), Some(session)), Method::POST, url, Some(query))
307 .await
308 .and_then(|value| serde_json::from_value(value).map_err(Into::into))
309 .log_err()
310}
311
312#[implement(Service)]
317#[tracing::instrument(
318 name = "request",
319 level = "debug",
320 ret(level = "trace"),
321 skip(self, body)
322)]
323pub async fn request<Body>(
324 &self,
325 (provider, session): (Option<&Provider>, Option<&Session>),
326 method: Method,
327 url: Url,
328 body: Option<Body>,
329) -> Result<JsonValue>
330where
331 Body: Serialize,
332{
333 let mut request = self
334 .services
335 .client
336 .oauth
337 .request(method, url)
338 .header(ACCEPT, "application/json");
339
340 if let Some(body) = body.map(serde_html_form::to_string).transpose()? {
341 request = request
342 .header(CONTENT_TYPE, "application/x-www-form-urlencoded")
343 .body(body);
344 }
345
346 if let Some(session) = session
347 && let Some(access_token) = session.access_token.clone()
348 {
349 request = request.bearer_auth(access_token);
350 }
351
352 let limit = self.services.config.max_response_size;
353 let http_response = request.send().await?.error_for_status()?;
354
355 let body = read_response_capped(http_response, limit).await?;
356 let response: JsonValue = serde_json::from_slice(&body)?;
357
358 if let Some(response) = response.as_object().as_ref()
359 && let Some(error) = response.get("error").and_then(JsonValue::as_str)
360 {
361 let description = response
362 .get("error_description")
363 .and_then(JsonValue::as_str)
364 .unwrap_or("(no description)");
365
366 return Err!(Request(Forbidden("Error from provider: {error}: {description}",)));
367 }
368
369 Ok(response)
370}
371
372#[inline]
375pub fn unique_id((provider, session): (&Provider, &Session)) -> Result<String> {
376 unique_id_parts((provider, session)).and_then(unique_id_iss_sub)
377}
378
379#[inline]
382pub fn unique_id_sub((provider, sub): (&Provider, &str)) -> Result<String> {
383 unique_id_sub_parts((provider, sub)).and_then(unique_id_iss_sub)
384}
385
386#[inline]
389pub fn unique_id_iss((iss, session): (&str, &Session)) -> Result<String> {
390 unique_id_iss_parts((iss, session)).and_then(unique_id_iss_sub)
391}
392
393pub fn unique_id_iss_sub((iss, sub): (&str, &str)) -> Result<String> {
396 let hash = sha256::delimited([iss, sub].iter());
397 let b64 = b64encode.encode(hash);
398
399 Ok(b64)
400}
401
402fn unique_id_parts<'a>(
403 (provider, session): (&'a Provider, &'a Session),
404) -> Result<(&'a str, &'a str)> {
405 identity_issuer(provider)
406 .ok_or_else(|| err!(Config("issuer_url", "issuer_url not found for this provider.")))
407 .and_then(|iss| unique_id_iss_parts((iss, session)))
408}
409
410fn unique_id_sub_parts<'a>(
411 (provider, sub): (&'a Provider, &'a str),
412) -> Result<(&'a str, &'a str)> {
413 identity_issuer(provider)
414 .ok_or_else(|| err!(Config("issuer_url", "issuer_url not found for this provider.")))
415 .map(|iss| (iss, sub))
416}
417
418fn identity_issuer(provider: &Provider) -> Option<&str> {
422 match provider.brand.as_str() {
423 | "github" => Some("https://github.com/"),
424 | _ => provider.issuer_url.as_ref().map(Url::as_str),
425 }
426}
427
428fn unique_id_iss_parts<'a>((iss, session): (&'a str, &'a Session)) -> Result<(&'a str, &'a str)> {
429 session
430 .user_info
431 .as_ref()
432 .map(|user_info| user_info.sub.as_str())
433 .ok_or_else(|| err!(Request(NotFound("user_info not found for this session."))))
434 .map(|sub| (iss, sub))
435}