Skip to main content

tuwunel_service/oauth/
sessions.rs

1pub mod association;
2
3use std::{
4	iter::once,
5	pin::pin,
6	sync::{Arc, Mutex},
7	time::SystemTime,
8};
9
10use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
11use ruma::{OwnedUserId, UserId};
12use serde::{Deserialize, Serialize};
13use tuwunel_core::{
14	Err, Result, at, implement,
15	utils::stream::{IterStream, ReadyExt, TryExpect},
16};
17use tuwunel_database::{Cbor, Deserialized, Ignore, Map};
18use url::Url;
19
20use super::{Provider, Providers, UserInfo, unique_id};
21use crate::SelfServices;
22
23pub struct Sessions {
24	_services: SelfServices,
25	association_pending: Mutex<association::Pending>,
26	providers: Arc<Providers>,
27	db: Data,
28}
29
30struct Data {
31	oauthid_session: Arc<Map>,
32	oauthuniqid_oauthid: Arc<Map>,
33	userid_oauthid: Arc<Map>,
34}
35
36/// Session ultimately represents an OAuth authorization session yielding an
37/// associated matrix user registration. Maintains the state between
38/// authorization steps and the association to the matrix user until
39/// deactivation or revocation.
40#[derive(Clone, Debug, Default, Deserialize, Serialize)]
41pub struct Session {
42	/// Identity Provider ID (the `client_id` in the configuration) associated
43	/// with this session.
44	pub idp_id: Option<String>,
45
46	/// Session ID used as the index key for this session itself.
47	pub sess_id: Option<SessionId>,
48
49	/// Token type (bearer, mac, etc).
50	pub token_type: Option<String>,
51
52	/// Access token to the provider.
53	pub access_token: Option<String>,
54
55	/// Duration in seconds the access_token is valid for.
56	pub expires_in: Option<u64>,
57
58	/// Point in time that the access_token expires.
59	pub expires_at: Option<SystemTime>,
60
61	/// Token used to refresh the access_token.
62	pub refresh_token: Option<String>,
63
64	/// Duration in seconds the refresh_token is valid for
65	pub refresh_token_expires_in: Option<u64>,
66
67	/// Point in time that the refresh_token expires.
68	pub refresh_token_expires_at: Option<SystemTime>,
69
70	/// Access scope actually granted (if supported).
71	pub scope: Option<String>,
72
73	/// Redirect URL
74	pub redirect_url: Option<Url>,
75
76	/// Challenge preimage
77	pub code_verifier: Option<String>,
78
79	/// Random string passed exclusively in the grant session cookie.
80	pub cookie_nonce: Option<String>,
81
82	/// Random single-use string passed in the provider redirect.
83	pub query_nonce: Option<String>,
84
85	/// Point in time the authorization grant session expires.
86	pub authorize_expires_at: Option<SystemTime>,
87
88	/// Associated User Id registration.
89	pub user_id: Option<OwnedUserId>,
90
91	/// Last userinfo response persisted here.
92	pub user_info: Option<UserInfo>,
93}
94
95/// Session Identifier type.
96pub type SessionId = String;
97
98/// Number of characters generated for our code_verifier. The code_verifier is a
99/// random string which must be between 43 and 128 characters.
100pub const CODE_VERIFIER_LENGTH: usize = 64;
101
102/// Number of characters we will generate for the Session ID.
103pub const SESSION_ID_LENGTH: usize = 32;
104
105#[implement(Sessions)]
106pub(super) fn build(args: &crate::Args<'_>, providers: Arc<Providers>) -> Self {
107	Self {
108		_services: args.services.clone(),
109		association_pending: Default::default(),
110		providers,
111		db: Data {
112			oauthid_session: args.db["oauthid_session"].clone(),
113			oauthuniqid_oauthid: args.db["oauthuniqid_oauthid"].clone(),
114			userid_oauthid: args.db["userid_oauthid"].clone(),
115		},
116	}
117}
118
119/// Delete database state for the session.
120#[implement(Sessions)]
121#[tracing::instrument(level = "debug", skip(self))]
122pub async fn delete(&self, sess_id: &str) {
123	let Ok(session) = self.get(sess_id).await else {
124		return;
125	};
126
127	if let Some(user_id) = session.user_id.as_deref() {
128		let sess_ids: Vec<_> = self
129			.get_sess_id_by_user(user_id)
130			.ready_filter_map(Result::ok)
131			.ready_filter(|assoc_id| assoc_id != sess_id)
132			.collect()
133			.await;
134
135		if !sess_ids.is_empty() {
136			self.db.userid_oauthid.raw_put(user_id, sess_ids);
137		} else {
138			self.db.userid_oauthid.remove(user_id);
139		}
140	}
141
142	// Check the unique identity still points to this sess_id before deleting. If
143	// not, the association was updated to a newer session.
144	if let Some(idp_id) = session.idp_id.as_ref()
145		&& let Ok(provider) = self.providers.get(idp_id).await
146		&& let Ok(unique_id) = unique_id((&provider, &session))
147		&& let Ok(assoc_id) = self.get_sess_id_by_unique_id(&unique_id).await
148		&& assoc_id == sess_id
149	{
150		self.db.oauthuniqid_oauthid.remove(&unique_id);
151	}
152
153	self.db.oauthid_session.remove(sess_id);
154}
155
156/// Create or overwrite database state for the session.
157#[implement(Sessions)]
158#[tracing::instrument(level = "info", skip(self))]
159pub async fn put(&self, session: &Session) {
160	let sess_id = session
161		.sess_id
162		.as_deref()
163		.expect("Missing session.sess_id required for sessions.put()");
164
165	self.db
166		.oauthid_session
167		.raw_put(sess_id, Cbor(session));
168
169	if let Some(idp_id) = session.idp_id.as_ref()
170		&& let Ok(provider) = self.providers.get(idp_id).await
171		&& let Ok(unique_id) = unique_id((&provider, session))
172	{
173		self.db
174			.oauthuniqid_oauthid
175			.insert(&unique_id, sess_id);
176	}
177
178	if let Some(user_id) = session.user_id.as_deref() {
179		let sess_ids = self
180			.get_sess_id_by_user(user_id)
181			.ready_filter_map(Result::ok)
182			.chain(once(sess_id.to_owned()).stream())
183			.collect::<Vec<_>>()
184			.map(|mut ids| {
185				ids.sort_unstable();
186				ids.dedup();
187				ids
188			})
189			.await;
190
191		self.db.userid_oauthid.raw_put(user_id, sess_ids);
192	}
193}
194
195/// Check if database state exists for one or more sessions associated with
196/// `user_id`
197#[implement(Sessions)]
198#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
199pub async fn exists_for_user(&self, user_id: &UserId) -> bool {
200	pin!(self.get_by_user(user_id))
201		.next()
202		.await
203		.is_some()
204}
205
206/// Fetch database state for a session from its associated `(iss,sub)`, in case
207/// `sess_id` is not known.
208#[implement(Sessions)]
209#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
210pub async fn get_by_unique_id(&self, unique_id: &str) -> Result<Session> {
211	self.get_sess_id_by_unique_id(unique_id)
212		.and_then(async |sess_id| self.get(&sess_id).await)
213		.await
214}
215
216/// Fetch database state for one or more sessions from its associated `user_id`,
217/// in case `sess_id` is not known.
218#[implement(Sessions)]
219#[tracing::instrument(level = "debug", skip(self))]
220pub fn get_by_user(&self, user_id: &UserId) -> impl Stream<Item = Result<Session>> + Send {
221	self.get_sess_id_by_user(user_id)
222		.and_then(async |sess_id| self.get(&sess_id).await)
223}
224
225/// Fetch database state for a session from its `sess_id`.
226#[implement(Sessions)]
227#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
228pub async fn get(&self, sess_id: &str) -> Result<Session> {
229	self.db
230		.oauthid_session
231		.get(sess_id)
232		.await
233		.deserialized::<Cbor<_>>()
234		.map(at!(0))
235}
236
237/// Resolve the `sess_id` associations with a `user_id`.
238#[implement(Sessions)]
239#[tracing::instrument(level = "debug", skip(self))]
240pub fn get_sess_id_by_user(&self, user_id: &UserId) -> impl Stream<Item = Result<String>> + Send {
241	self.db
242		.userid_oauthid
243		.get(user_id)
244		.map(Deserialized::deserialized)
245		.map_ok(Vec::into_iter)
246		.map_ok(IterStream::try_stream)
247		.try_flatten_stream()
248}
249
250/// Resolve the `sess_id` from an associated provider issuer and subject hash.
251#[implement(Sessions)]
252#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
253pub async fn get_sess_id_by_unique_id(&self, unique_id: &str) -> Result<String> {
254	self.db
255		.oauthuniqid_oauthid
256		.get(unique_id)
257		.await
258		.deserialized()
259}
260
261#[implement(Sessions)]
262pub fn users(&self) -> impl Stream<Item = OwnedUserId> + Send {
263	self.db
264		.userid_oauthid
265		.keys()
266		.expect_ok()
267		.map(UserId::to_owned)
268}
269
270#[implement(Sessions)]
271pub fn stream(&self) -> impl Stream<Item = Session> + Send {
272	self.db
273		.oauthid_session
274		.stream()
275		.expect_ok()
276		.map(|(_, session): (Ignore, Cbor<_>)| session.0)
277}
278
279#[implement(Sessions)]
280pub async fn provider(&self, session: &Session) -> Result<Provider> {
281	let Some(idp_id) = session.idp_id.as_deref() else {
282		return Err!(Request(NotFound("No provider for this session")));
283	};
284
285	self.providers.get(idp_id).await
286}