1pub mod providers;
2pub mod server;
3pub mod sessions;
4pub mod token_response;
5pub mod user_info;
6
7use std::sync::Arc;
8
9use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as b64encode};
10use futures::{Stream, StreamExt, TryStreamExt};
11use reqwest::{
12 Method,
13 header::{ACCEPT, CONTENT_TYPE},
14};
15use ruma::UserId;
16use serde::Serialize;
17use serde_json::Value as JsonValue;
18use tuwunel_core::{
19 Err, Result, err, implement,
20 utils::{hash::sha256, result::LogErr, stream::ReadyExt},
21 warn,
22};
23use url::Url;
24
25use self::{providers::Providers, sessions::Sessions};
26pub use self::{
27 providers::{Provider, ProviderId},
28 server::Server,
29 sessions::{CODE_VERIFIER_LENGTH, SESSION_ID_LENGTH, Session, SessionId},
30 token_response::TokenResponse,
31 user_info::UserInfo,
32};
33use crate::SelfServices;
34
35pub struct Service {
36 services: SelfServices,
37 pub providers: Arc<Providers>,
38 pub sessions: Arc<Sessions>,
39 pub server: Option<Arc<Server>>,
40}
41
42impl crate::Service for Service {
43 fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
44 let providers = Arc::new(Providers::build(args));
45 let sessions = Arc::new(Sessions::build(args, providers.clone()));
46 let server = Server::build(args)?.map(Arc::new);
47
48 Ok(Arc::new(Self {
49 services: args.services.clone(),
50 sessions,
51 providers,
52 server,
53 }))
54 }
55
56 fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
57}
58
59#[implement(Service)]
60#[inline]
61pub fn get_server(&self) -> Result<&Server> {
62 self.server
63 .as_deref()
64 .ok_or_else(|| err!(Request(Unrecognized("OIDC server not configured"))))
65}
66
67#[implement(Service)]
71#[tracing::instrument(level = "debug", skip(self))]
72pub async fn delete_user_sessions(&self, user_id: &UserId) {
73 self.user_sessions(user_id)
74 .ready_filter_map(Result::ok)
75 .ready_filter_map(|(_, session)| session.sess_id)
76 .for_each(async |sess_id| {
77 self.sessions.delete(&sess_id).await;
78 })
79 .await;
80}
81
82#[implement(Service)]
84#[tracing::instrument(level = "debug", skip(self))]
85pub async fn revoke_user_tokens(&self, user_id: &UserId) {
86 self.user_sessions(user_id)
87 .ready_filter_map(Result::ok)
88 .for_each(async |(provider, session)| {
89 self.revoke_token((&provider, &session))
90 .await
91 .log_err()
92 .ok();
93 })
94 .await;
95}
96
97#[implement(Service)]
99#[tracing::instrument(level = "debug", skip(self))]
100pub fn user_sessions(
101 &self,
102 user_id: &UserId,
103) -> impl Stream<Item = Result<(Provider, Session)>> + Send {
104 self.sessions
105 .get_by_user(user_id)
106 .and_then(async |session| Ok((self.sessions.provider(&session).await?, session)))
107}
108
109#[implement(Service)]
112#[tracing::instrument(level = "debug", skip_all, ret)]
113pub async fn request_userinfo(
114 &self,
115 (provider, session): (&Provider, &Session),
116) -> Result<UserInfo> {
117 #[derive(Debug, Serialize)]
118 struct Query;
119
120 let url = provider
121 .userinfo_url
122 .clone()
123 .ok_or_else(|| err!(Config("userinfo_url", "Missing userinfo URL in config")))?;
124
125 self.request((Some(provider), Some(session)), Method::GET, url, Option::<Query>::None)
126 .await
127 .and_then(|value| serde_json::from_value(value).map_err(Into::into))
128 .log_err()
129}
130
131#[implement(Service)]
134#[tracing::instrument(level = "debug", skip_all, ret)]
135pub async fn request_tokeninfo(
136 &self,
137 (provider, session): (&Provider, &Session),
138) -> Result<UserInfo> {
139 #[derive(Debug, Serialize)]
140 struct Query;
141
142 let url = provider
143 .introspection_url
144 .clone()
145 .ok_or_else(|| {
146 err!(Config("introspection_url", "Missing introspection URL in config"))
147 })?;
148
149 self.request((Some(provider), Some(session)), Method::GET, url, Option::<Query>::None)
150 .await
151 .and_then(|value| serde_json::from_value(value).map_err(Into::into))
152 .log_err()
153}
154
155#[implement(Service)]
157#[tracing::instrument(level = "debug", skip_all, ret)]
158pub async fn revoke_token(&self, (provider, session): (&Provider, &Session)) -> Result {
159 #[derive(Debug, Serialize)]
160 struct RevokeQuery<'a> {
161 client_id: &'a str,
162 client_secret: &'a str,
163 }
164
165 let client_secret = provider.get_client_secret().await?;
166
167 let query = RevokeQuery {
168 client_id: &provider.client_id,
169 client_secret: &client_secret,
170 };
171
172 let url = provider
173 .revocation_url
174 .clone()
175 .ok_or_else(|| err!(Config("revocation_url", "Missing revocation URL in config")))?;
176
177 self.request((Some(provider), Some(session)), Method::POST, url, Some(query))
178 .await
179 .log_err()
180 .map(|_| ())
181}
182
183#[implement(Service)]
186#[tracing::instrument(level = "debug", skip_all, ret)]
187pub async fn request_token(
188 &self,
189 (provider, session): (&Provider, &Session),
190 code: &str,
191) -> Result<TokenResponse> {
192 #[derive(Debug, Serialize)]
193 struct TokenQuery<'a> {
194 client_id: &'a str,
195 client_secret: &'a str,
196 grant_type: &'a str,
197 code: &'a str,
198 code_verifier: Option<&'a str>,
199 redirect_uri: Option<&'a str>,
200 }
201
202 let client_secret = provider.get_client_secret().await?;
203
204 let query = TokenQuery {
205 client_id: &provider.client_id,
206 client_secret: &client_secret,
207 grant_type: "authorization_code",
208 code,
209 code_verifier: session.code_verifier.as_deref(),
210 redirect_uri: provider.callback_url.as_ref().map(Url::as_str),
211 };
212
213 let url = provider
214 .token_url
215 .clone()
216 .ok_or_else(|| err!(Config("token_url", "Missing token URL in config")))?;
217
218 self.request((Some(provider), Some(session)), Method::POST, url, Some(query))
219 .await
220 .and_then(|value| serde_json::from_value(value).map_err(Into::into))
221 .log_err()
222}
223
224#[implement(Service)]
229#[tracing::instrument(
230 name = "request",
231 level = "debug",
232 ret(level = "trace"),
233 skip(self, body)
234)]
235pub async fn request<Body>(
236 &self,
237 (provider, session): (Option<&Provider>, Option<&Session>),
238 method: Method,
239 url: Url,
240 body: Option<Body>,
241) -> Result<JsonValue>
242where
243 Body: Serialize,
244{
245 let mut request = self
246 .services
247 .client
248 .oauth
249 .request(method, url)
250 .header(ACCEPT, "application/json");
251
252 if let Some(body) = body.map(serde_html_form::to_string).transpose()? {
253 request = request
254 .header(CONTENT_TYPE, "application/x-www-form-urlencoded")
255 .body(body);
256 }
257
258 if let Some(session) = session
259 && let Some(access_token) = session.access_token.clone()
260 {
261 request = request.bearer_auth(access_token);
262 }
263
264 let response: JsonValue = request
265 .send()
266 .await?
267 .error_for_status()?
268 .json()
269 .await?;
270
271 if let Some(response) = response.as_object().as_ref()
272 && let Some(error) = response.get("error").and_then(JsonValue::as_str)
273 {
274 let description = response
275 .get("error_description")
276 .and_then(JsonValue::as_str)
277 .unwrap_or("(no description)");
278
279 return Err!(Request(Forbidden("Error from provider: {error}: {description}",)));
280 }
281
282 Ok(response)
283}
284
285#[inline]
288pub fn unique_id((provider, session): (&Provider, &Session)) -> Result<String> {
289 unique_id_parts((provider, session)).and_then(unique_id_iss_sub)
290}
291
292#[inline]
295pub fn unique_id_sub((provider, sub): (&Provider, &str)) -> Result<String> {
296 unique_id_sub_parts((provider, sub)).and_then(unique_id_iss_sub)
297}
298
299#[inline]
302pub fn unique_id_iss((iss, session): (&str, &Session)) -> Result<String> {
303 unique_id_iss_parts((iss, session)).and_then(unique_id_iss_sub)
304}
305
306pub fn unique_id_iss_sub((iss, sub): (&str, &str)) -> Result<String> {
309 let hash = sha256::delimited([iss, sub].iter());
310 let b64 = b64encode.encode(hash);
311
312 Ok(b64)
313}
314
315fn unique_id_parts<'a>(
316 (provider, session): (&'a Provider, &'a Session),
317) -> Result<(&'a str, &'a str)> {
318 provider
319 .issuer_url
320 .as_ref()
321 .map(Url::as_str)
322 .ok_or_else(|| err!(Config("issuer_url", "issuer_url not found for this provider.")))
323 .and_then(|iss| unique_id_iss_parts((iss, session)))
324}
325
326fn unique_id_sub_parts<'a>(
327 (provider, sub): (&'a Provider, &'a str),
328) -> Result<(&'a str, &'a str)> {
329 provider
330 .issuer_url
331 .as_ref()
332 .map(Url::as_str)
333 .ok_or_else(|| err!(Config("issuer_url", "issuer_url not found for this provider.")))
334 .map(|iss| (iss, sub))
335}
336
337fn unique_id_iss_parts<'a>((iss, session): (&'a str, &'a Session)) -> Result<(&'a str, &'a str)> {
338 session
339 .user_info
340 .as_ref()
341 .map(|user_info| user_info.sub.as_str())
342 .ok_or_else(|| err!(Request(NotFound("user_info not found for this session."))))
343 .map(|sub| (iss, sub))
344}