Skip to main content

tuwunel_service/oauth/
sessions.rs

1pub mod association;
2
3use std::{
4	iter::once,
5	pin::pin,
6	sync::{Arc, Mutex},
7	time::SystemTime,
8};
9
10use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
11use ruma::{OwnedUserId, UserId};
12use serde::{Deserialize, Serialize};
13use tuwunel_core::{
14	Err, Result, at, implement,
15	utils::stream::{IterStream, ReadyExt, TryExpect},
16};
17use tuwunel_database::{Cbor, Deserialized, Ignore, Map};
18use url::Url;
19
20use super::{Provider, Providers, UserInfo, unique_id};
21use crate::SelfServices;
22
23pub struct Sessions {
24	_services: SelfServices,
25	association_pending: Mutex<association::Pending>,
26	providers: Arc<Providers>,
27	db: Data,
28}
29
30struct Data {
31	oauthid_session: Arc<Map>,
32	oauthuniqid_oauthid: Arc<Map>,
33	userid_oauthid: Arc<Map>,
34}
35
36/// Session ultimately represents an OAuth authorization session yielding an
37/// associated matrix user registration. Maintains the state between
38/// authorization steps and the association to the matrix user until
39/// deactivation or revocation.
40#[derive(Clone, Debug, Default, Deserialize, Serialize)]
41pub struct Session {
42	/// Identity Provider ID (the `client_id` in the configuration) associated
43	/// with this session.
44	pub idp_id: Option<String>,
45
46	/// Session ID used as the index key for this session itself.
47	pub sess_id: Option<SessionId>,
48
49	/// Token type (bearer, mac, etc).
50	pub token_type: Option<String>,
51
52	/// Access token to the provider.
53	pub access_token: Option<String>,
54
55	/// OIDC ID token returned by the provider.
56	pub id_token: Option<String>,
57
58	/// Duration in seconds the access_token is valid for.
59	pub expires_in: Option<u64>,
60
61	/// Point in time that the access_token expires.
62	pub expires_at: Option<SystemTime>,
63
64	/// Token used to refresh the access_token.
65	pub refresh_token: Option<String>,
66
67	/// Duration in seconds the refresh_token is valid for
68	pub refresh_token_expires_in: Option<u64>,
69
70	/// Point in time that the refresh_token expires.
71	pub refresh_token_expires_at: Option<SystemTime>,
72
73	/// Access scope actually granted (if supported).
74	pub scope: Option<String>,
75
76	/// Redirect URL
77	pub redirect_url: Option<Url>,
78
79	/// Challenge preimage
80	pub code_verifier: Option<String>,
81
82	/// Random string passed exclusively in the grant session cookie.
83	pub cookie_nonce: Option<String>,
84
85	/// Random single-use string passed in the provider redirect.
86	pub query_nonce: Option<String>,
87
88	/// Point in time the authorization grant session expires.
89	pub authorize_expires_at: Option<SystemTime>,
90
91	/// Associated User Id registration.
92	pub user_id: Option<OwnedUserId>,
93
94	/// Last userinfo response persisted here.
95	pub user_info: Option<UserInfo>,
96}
97
98/// Session Identifier type.
99pub type SessionId = String;
100
101/// Number of characters generated for our code_verifier. The code_verifier is a
102/// random string which must be between 43 and 128 characters.
103pub const CODE_VERIFIER_LENGTH: usize = 64;
104
105/// Number of characters we will generate for the Session ID.
106pub const SESSION_ID_LENGTH: usize = 32;
107
108#[implement(Sessions)]
109pub(super) fn build(args: &crate::Args<'_>, providers: Arc<Providers>) -> Self {
110	Self {
111		_services: args.services.clone(),
112		association_pending: Default::default(),
113		providers,
114		db: Data {
115			oauthid_session: args.db["oauthid_session"].clone(),
116			oauthuniqid_oauthid: args.db["oauthuniqid_oauthid"].clone(),
117			userid_oauthid: args.db["userid_oauthid"].clone(),
118		},
119	}
120}
121
122/// Delete database state for the session.
123#[implement(Sessions)]
124#[tracing::instrument(level = "debug", skip(self))]
125pub async fn delete(&self, sess_id: &str) {
126	let Ok(session) = self.get(sess_id).await else {
127		return;
128	};
129
130	if let Some(user_id) = session.user_id.as_deref() {
131		let sess_ids: Vec<_> = self
132			.get_sess_id_by_user(user_id)
133			.ready_filter_map(Result::ok)
134			.ready_filter(|assoc_id| assoc_id != sess_id)
135			.collect()
136			.await;
137
138		if !sess_ids.is_empty() {
139			self.db.userid_oauthid.raw_put(user_id, sess_ids);
140		} else {
141			self.db.userid_oauthid.remove(user_id);
142		}
143	}
144
145	// Check the unique identity still points to this sess_id before deleting. If
146	// not, the association was updated to a newer session.
147	if let Some(idp_id) = session.idp_id.as_ref()
148		&& let Ok(provider) = self.providers.get(idp_id).await
149		&& let Ok(unique_id) = unique_id((&provider, &session))
150		&& let Ok(assoc_id) = self.get_sess_id_by_unique_id(&unique_id).await
151		&& assoc_id == sess_id
152	{
153		self.db.oauthuniqid_oauthid.remove(&unique_id);
154	}
155
156	self.db.oauthid_session.remove(sess_id);
157}
158
159/// Create or overwrite database state for the session.
160#[implement(Sessions)]
161#[tracing::instrument(level = "info", skip(self))]
162pub async fn put(&self, session: &Session) {
163	let sess_id = session
164		.sess_id
165		.as_deref()
166		.expect("Missing session.sess_id required for sessions.put()");
167
168	self.db
169		.oauthid_session
170		.raw_put(sess_id, Cbor(session));
171
172	if let Some(idp_id) = session.idp_id.as_ref()
173		&& let Ok(provider) = self.providers.get(idp_id).await
174		&& let Ok(unique_id) = unique_id((&provider, session))
175	{
176		self.db
177			.oauthuniqid_oauthid
178			.insert(&unique_id, sess_id);
179	}
180
181	if let Some(user_id) = session.user_id.as_deref() {
182		let sess_ids = self
183			.get_sess_id_by_user(user_id)
184			.ready_filter_map(Result::ok)
185			.chain(once(sess_id.to_owned()).stream())
186			.collect::<Vec<_>>()
187			.map(|mut ids| {
188				ids.sort_unstable();
189				ids.dedup();
190				ids
191			})
192			.await;
193
194		self.db.userid_oauthid.raw_put(user_id, sess_ids);
195	}
196}
197
198/// Check if database state exists for one or more sessions associated with
199/// `user_id`
200#[implement(Sessions)]
201#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
202pub async fn exists_for_user(&self, user_id: &UserId) -> bool {
203	pin!(self.get_by_user(user_id))
204		.next()
205		.await
206		.is_some()
207}
208
209/// Fetch database state for a session from its associated `(iss,sub)`, in case
210/// `sess_id` is not known.
211#[implement(Sessions)]
212#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
213pub async fn get_by_unique_id(&self, unique_id: &str) -> Result<Session> {
214	self.get_sess_id_by_unique_id(unique_id)
215		.and_then(async |sess_id| self.get(&sess_id).await)
216		.await
217}
218
219/// Fetch database state for one or more sessions from its associated `user_id`,
220/// in case `sess_id` is not known.
221#[implement(Sessions)]
222#[tracing::instrument(level = "debug", skip(self))]
223pub fn get_by_user(&self, user_id: &UserId) -> impl Stream<Item = Result<Session>> + Send {
224	self.get_sess_id_by_user(user_id)
225		.and_then(async |sess_id| self.get(&sess_id).await)
226}
227
228/// Fetch database state for a session from its `sess_id`.
229#[implement(Sessions)]
230#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
231pub async fn get(&self, sess_id: &str) -> Result<Session> {
232	self.db
233		.oauthid_session
234		.get(sess_id)
235		.await
236		.deserialized::<Cbor<_>>()
237		.map(at!(0))
238}
239
240/// Resolve the `sess_id` associations with a `user_id`.
241#[implement(Sessions)]
242#[tracing::instrument(level = "debug", skip(self))]
243pub fn get_sess_id_by_user(&self, user_id: &UserId) -> impl Stream<Item = Result<String>> + Send {
244	self.db
245		.userid_oauthid
246		.get(user_id)
247		.map(Deserialized::deserialized)
248		.map_ok(Vec::into_iter)
249		.map_ok(IterStream::try_stream)
250		.try_flatten_stream()
251}
252
253/// Resolve the `sess_id` from an associated provider issuer and subject hash.
254#[implement(Sessions)]
255#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
256pub async fn get_sess_id_by_unique_id(&self, unique_id: &str) -> Result<String> {
257	self.db
258		.oauthuniqid_oauthid
259		.get(unique_id)
260		.await
261		.deserialized()
262}
263
264#[implement(Sessions)]
265pub fn users(&self) -> impl Stream<Item = OwnedUserId> + Send {
266	self.db
267		.userid_oauthid
268		.keys()
269		.expect_ok()
270		.map(UserId::to_owned)
271}
272
273#[implement(Sessions)]
274pub fn stream(&self) -> impl Stream<Item = Session> + Send {
275	self.db
276		.oauthid_session
277		.stream()
278		.expect_ok()
279		.map(|(_, session): (Ignore, Cbor<_>)| session.0)
280}
281
282#[implement(Sessions)]
283pub async fn provider(&self, session: &Session) -> Result<Provider> {
284	let Some(idp_id) = session.idp_id.as_deref() else {
285		return Err!(Request(NotFound("No provider for this session")));
286	};
287
288	self.providers.get(idp_id).await
289}