Skip to main content

tuwunel_service/oauth/
mod.rs

1pub mod providers;
2pub mod server;
3pub mod sessions;
4pub mod token_response;
5pub mod user_info;
6
7use std::sync::Arc;
8
9use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as b64encode};
10use futures::{Stream, StreamExt, TryStreamExt};
11use reqwest::{
12	Method,
13	header::{ACCEPT, CONTENT_TYPE},
14};
15use ruma::UserId;
16use serde::Serialize;
17use serde_json::Value as JsonValue;
18use tuwunel_core::{
19	Err, Result, err, implement,
20	utils::{hash::sha256, result::LogErr, stream::ReadyExt},
21	warn,
22};
23use url::Url;
24
25use self::{providers::Providers, sessions::Sessions};
26pub use self::{
27	providers::{Provider, ProviderId},
28	server::Server,
29	sessions::{CODE_VERIFIER_LENGTH, SESSION_ID_LENGTH, Session, SessionId},
30	token_response::TokenResponse,
31	user_info::UserInfo,
32};
33use crate::SelfServices;
34
35pub struct Service {
36	services: SelfServices,
37	pub providers: Arc<Providers>,
38	pub sessions: Arc<Sessions>,
39	pub server: Option<Arc<Server>>,
40}
41
42impl crate::Service for Service {
43	fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
44		let providers = Arc::new(Providers::build(args));
45		let sessions = Arc::new(Sessions::build(args, providers.clone()));
46		let server = Server::build(args)?.map(Arc::new);
47
48		Ok(Arc::new(Self {
49			services: args.services.clone(),
50			sessions,
51			providers,
52			server,
53		}))
54	}
55
56	fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
57}
58
59#[implement(Service)]
60#[inline]
61pub fn get_server(&self) -> Result<&Server> {
62	self.server
63		.as_deref()
64		.ok_or_else(|| err!(Request(Unrecognized("OIDC server not configured"))))
65}
66
67/// Remove all session state for a user. For debug and developer use only;
68/// deleting state can cause registration conflicts and unintended
69/// re-registrations.
70#[implement(Service)]
71#[tracing::instrument(level = "debug", skip(self))]
72pub async fn delete_user_sessions(&self, user_id: &UserId) {
73	self.user_sessions(user_id)
74		.ready_filter_map(Result::ok)
75		.ready_filter_map(|(_, session)| session.sess_id)
76		.for_each(async |sess_id| {
77			self.sessions.delete(&sess_id).await;
78		})
79		.await;
80}
81
82/// Revoke all session tokens for a user.
83#[implement(Service)]
84#[tracing::instrument(level = "debug", skip(self))]
85pub async fn revoke_user_tokens(&self, user_id: &UserId) {
86	self.user_sessions(user_id)
87		.ready_filter_map(Result::ok)
88		.for_each(async |(provider, session)| {
89			self.revoke_token((&provider, &session))
90				.await
91				.log_err()
92				.ok();
93		})
94		.await;
95}
96
97/// Get user's authorizations. Lists pairs of `(Provider, Session)` for a user.
98#[implement(Service)]
99#[tracing::instrument(level = "debug", skip(self))]
100pub fn user_sessions(
101	&self,
102	user_id: &UserId,
103) -> impl Stream<Item = Result<(Provider, Session)>> + Send {
104	self.sessions
105		.get_by_user(user_id)
106		.and_then(async |session| Ok((self.sessions.provider(&session).await?, session)))
107}
108
109/// Network request to a Provider returning userinfo for a Session. The session
110/// must have a valid access token.
111#[implement(Service)]
112#[tracing::instrument(level = "debug", skip_all, ret)]
113pub async fn request_userinfo(
114	&self,
115	(provider, session): (&Provider, &Session),
116) -> Result<UserInfo> {
117	#[derive(Debug, Serialize)]
118	struct Query;
119
120	let url = provider
121		.userinfo_url
122		.clone()
123		.ok_or_else(|| err!(Config("userinfo_url", "Missing userinfo URL in config")))?;
124
125	self.request((Some(provider), Some(session)), Method::GET, url, Option::<Query>::None)
126		.await
127		.and_then(|value| serde_json::from_value(value).map_err(Into::into))
128		.log_err()
129}
130
131/// Network request to a Provider returning information for a Session based on
132/// its access token.
133#[implement(Service)]
134#[tracing::instrument(level = "debug", skip_all, ret)]
135pub async fn request_tokeninfo(
136	&self,
137	(provider, session): (&Provider, &Session),
138) -> Result<UserInfo> {
139	#[derive(Debug, Serialize)]
140	struct Query;
141
142	let url = provider
143		.introspection_url
144		.clone()
145		.ok_or_else(|| {
146			err!(Config("introspection_url", "Missing introspection URL in config"))
147		})?;
148
149	self.request((Some(provider), Some(session)), Method::GET, url, Option::<Query>::None)
150		.await
151		.and_then(|value| serde_json::from_value(value).map_err(Into::into))
152		.log_err()
153}
154
155/// Network request to a Provider revoking a Session's token.
156#[implement(Service)]
157#[tracing::instrument(level = "debug", skip_all, ret)]
158pub async fn revoke_token(&self, (provider, session): (&Provider, &Session)) -> Result {
159	#[derive(Debug, Serialize)]
160	struct RevokeQuery<'a> {
161		client_id: &'a str,
162		client_secret: &'a str,
163	}
164
165	let client_secret = provider.get_client_secret().await?;
166
167	let query = RevokeQuery {
168		client_id: &provider.client_id,
169		client_secret: &client_secret,
170	};
171
172	let url = provider
173		.revocation_url
174		.clone()
175		.ok_or_else(|| err!(Config("revocation_url", "Missing revocation URL in config")))?;
176
177	self.request((Some(provider), Some(session)), Method::POST, url, Some(query))
178		.await
179		.log_err()
180		.map(|_| ())
181}
182
183/// Network request to a Provider to obtain an access token for a Session using
184/// a provided code.
185#[implement(Service)]
186#[tracing::instrument(level = "debug", skip_all, ret)]
187pub async fn request_token(
188	&self,
189	(provider, session): (&Provider, &Session),
190	code: &str,
191) -> Result<TokenResponse> {
192	#[derive(Debug, Serialize)]
193	struct TokenQuery<'a> {
194		client_id: &'a str,
195		client_secret: &'a str,
196		grant_type: &'a str,
197		code: &'a str,
198		code_verifier: Option<&'a str>,
199		redirect_uri: Option<&'a str>,
200	}
201
202	let client_secret = provider.get_client_secret().await?;
203
204	let query = TokenQuery {
205		client_id: &provider.client_id,
206		client_secret: &client_secret,
207		grant_type: "authorization_code",
208		code,
209		code_verifier: session.code_verifier.as_deref(),
210		redirect_uri: provider.callback_url.as_ref().map(Url::as_str),
211	};
212
213	let url = provider
214		.token_url
215		.clone()
216		.ok_or_else(|| err!(Config("token_url", "Missing token URL in config")))?;
217
218	self.request((Some(provider), Some(session)), Method::POST, url, Some(query))
219		.await
220		.and_then(|value| serde_json::from_value(value).map_err(Into::into))
221		.log_err()
222}
223
224/// Send a request to a provider; this is somewhat abstract since URL's are
225/// formed prior to this call and could point at anything, however this function
226/// uses the oauth-specific http client and is configured for JSON with special
227/// casing for an `error` property in the response.
228#[implement(Service)]
229#[tracing::instrument(
230	name = "request",
231	level = "debug",
232	ret(level = "trace"),
233	skip(self, body)
234)]
235pub async fn request<Body>(
236	&self,
237	(provider, session): (Option<&Provider>, Option<&Session>),
238	method: Method,
239	url: Url,
240	body: Option<Body>,
241) -> Result<JsonValue>
242where
243	Body: Serialize,
244{
245	let mut request = self
246		.services
247		.client
248		.oauth
249		.request(method, url)
250		.header(ACCEPT, "application/json");
251
252	if let Some(body) = body.map(serde_html_form::to_string).transpose()? {
253		request = request
254			.header(CONTENT_TYPE, "application/x-www-form-urlencoded")
255			.body(body);
256	}
257
258	if let Some(session) = session
259		&& let Some(access_token) = session.access_token.clone()
260	{
261		request = request.bearer_auth(access_token);
262	}
263
264	let response: JsonValue = request
265		.send()
266		.await?
267		.error_for_status()?
268		.json()
269		.await?;
270
271	if let Some(response) = response.as_object().as_ref()
272		&& let Some(error) = response.get("error").and_then(JsonValue::as_str)
273	{
274		let description = response
275			.get("error_description")
276			.and_then(JsonValue::as_str)
277			.unwrap_or("(no description)");
278
279		return Err!(Request(Forbidden("Error from provider: {error}: {description}",)));
280	}
281
282	Ok(response)
283}
284
285/// Generate a unique-id string determined by the combination of `Provider` and
286/// `Session` instances.
287#[inline]
288pub fn unique_id((provider, session): (&Provider, &Session)) -> Result<String> {
289	unique_id_parts((provider, session)).and_then(unique_id_iss_sub)
290}
291
292/// Generate a unique-id string determined by the combination of `Provider`
293/// instance and `sub` string.
294#[inline]
295pub fn unique_id_sub((provider, sub): (&Provider, &str)) -> Result<String> {
296	unique_id_sub_parts((provider, sub)).and_then(unique_id_iss_sub)
297}
298
299/// Generate a unique-id string determined by the combination of `issuer_url`
300/// and `Session` instance.
301#[inline]
302pub fn unique_id_iss((iss, session): (&str, &Session)) -> Result<String> {
303	unique_id_iss_parts((iss, session)).and_then(unique_id_iss_sub)
304}
305
306/// Generate a unique-id string determined by the `issuer_url` and the `sub`
307/// strings directly.
308pub fn unique_id_iss_sub((iss, sub): (&str, &str)) -> Result<String> {
309	let hash = sha256::delimited([iss, sub].iter());
310	let b64 = b64encode.encode(hash);
311
312	Ok(b64)
313}
314
315fn unique_id_parts<'a>(
316	(provider, session): (&'a Provider, &'a Session),
317) -> Result<(&'a str, &'a str)> {
318	provider
319		.issuer_url
320		.as_ref()
321		.map(Url::as_str)
322		.ok_or_else(|| err!(Config("issuer_url", "issuer_url not found for this provider.")))
323		.and_then(|iss| unique_id_iss_parts((iss, session)))
324}
325
326fn unique_id_sub_parts<'a>(
327	(provider, sub): (&'a Provider, &'a str),
328) -> Result<(&'a str, &'a str)> {
329	provider
330		.issuer_url
331		.as_ref()
332		.map(Url::as_str)
333		.ok_or_else(|| err!(Config("issuer_url", "issuer_url not found for this provider.")))
334		.map(|iss| (iss, sub))
335}
336
337fn unique_id_iss_parts<'a>((iss, session): (&'a str, &'a Session)) -> Result<(&'a str, &'a str)> {
338	session
339		.user_info
340		.as_ref()
341		.map(|user_info| user_info.sub.as_str())
342		.ok_or_else(|| err!(Request(NotFound("user_info not found for this session."))))
343		.map(|sub| (iss, sub))
344}