1pub mod association;
2
3use std::{
4 iter::once,
5 pin::pin,
6 sync::{Arc, Mutex},
7 time::SystemTime,
8};
9
10use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
11use ruma::{OwnedUserId, UserId};
12use serde::{Deserialize, Serialize};
13use tuwunel_core::{
14 Err, Result, at, implement,
15 utils::stream::{IterStream, ReadyExt, TryExpect},
16};
17use tuwunel_database::{Cbor, Deserialized, Ignore, Map};
18use url::Url;
19
20use super::{Provider, Providers, UserInfo, unique_id};
21use crate::SelfServices;
22
23pub struct Sessions {
24 _services: SelfServices,
25 association_pending: Mutex<association::Pending>,
26 providers: Arc<Providers>,
27 db: Data,
28}
29
30struct Data {
31 oauthid_session: Arc<Map>,
32 oauthuniqid_oauthid: Arc<Map>,
33 userid_oauthid: Arc<Map>,
34}
35
36#[derive(Clone, Debug, Default, Deserialize, Serialize)]
41pub struct Session {
42 pub idp_id: Option<String>,
45
46 pub sess_id: Option<SessionId>,
48
49 pub token_type: Option<String>,
51
52 pub access_token: Option<String>,
54
55 pub expires_in: Option<u64>,
57
58 pub expires_at: Option<SystemTime>,
60
61 pub refresh_token: Option<String>,
63
64 pub refresh_token_expires_in: Option<u64>,
66
67 pub refresh_token_expires_at: Option<SystemTime>,
69
70 pub scope: Option<String>,
72
73 pub redirect_url: Option<Url>,
75
76 pub code_verifier: Option<String>,
78
79 pub cookie_nonce: Option<String>,
81
82 pub query_nonce: Option<String>,
84
85 pub authorize_expires_at: Option<SystemTime>,
87
88 pub user_id: Option<OwnedUserId>,
90
91 pub user_info: Option<UserInfo>,
93}
94
95pub type SessionId = String;
97
98pub const CODE_VERIFIER_LENGTH: usize = 64;
101
102pub const SESSION_ID_LENGTH: usize = 32;
104
105#[implement(Sessions)]
106pub(super) fn build(args: &crate::Args<'_>, providers: Arc<Providers>) -> Self {
107 Self {
108 _services: args.services.clone(),
109 association_pending: Default::default(),
110 providers,
111 db: Data {
112 oauthid_session: args.db["oauthid_session"].clone(),
113 oauthuniqid_oauthid: args.db["oauthuniqid_oauthid"].clone(),
114 userid_oauthid: args.db["userid_oauthid"].clone(),
115 },
116 }
117}
118
119#[implement(Sessions)]
121#[tracing::instrument(level = "debug", skip(self))]
122pub async fn delete(&self, sess_id: &str) {
123 let Ok(session) = self.get(sess_id).await else {
124 return;
125 };
126
127 if let Some(user_id) = session.user_id.as_deref() {
128 let sess_ids: Vec<_> = self
129 .get_sess_id_by_user(user_id)
130 .ready_filter_map(Result::ok)
131 .ready_filter(|assoc_id| assoc_id != sess_id)
132 .collect()
133 .await;
134
135 if !sess_ids.is_empty() {
136 self.db.userid_oauthid.raw_put(user_id, sess_ids);
137 } else {
138 self.db.userid_oauthid.remove(user_id);
139 }
140 }
141
142 if let Some(idp_id) = session.idp_id.as_ref()
145 && let Ok(provider) = self.providers.get(idp_id).await
146 && let Ok(unique_id) = unique_id((&provider, &session))
147 && let Ok(assoc_id) = self.get_sess_id_by_unique_id(&unique_id).await
148 && assoc_id == sess_id
149 {
150 self.db.oauthuniqid_oauthid.remove(&unique_id);
151 }
152
153 self.db.oauthid_session.remove(sess_id);
154}
155
156#[implement(Sessions)]
158#[tracing::instrument(level = "info", skip(self))]
159pub async fn put(&self, session: &Session) {
160 let sess_id = session
161 .sess_id
162 .as_deref()
163 .expect("Missing session.sess_id required for sessions.put()");
164
165 self.db
166 .oauthid_session
167 .raw_put(sess_id, Cbor(session));
168
169 if let Some(idp_id) = session.idp_id.as_ref()
170 && let Ok(provider) = self.providers.get(idp_id).await
171 && let Ok(unique_id) = unique_id((&provider, session))
172 {
173 self.db
174 .oauthuniqid_oauthid
175 .insert(&unique_id, sess_id);
176 }
177
178 if let Some(user_id) = session.user_id.as_deref() {
179 let sess_ids = self
180 .get_sess_id_by_user(user_id)
181 .ready_filter_map(Result::ok)
182 .chain(once(sess_id.to_owned()).stream())
183 .collect::<Vec<_>>()
184 .map(|mut ids| {
185 ids.sort_unstable();
186 ids.dedup();
187 ids
188 })
189 .await;
190
191 self.db.userid_oauthid.raw_put(user_id, sess_ids);
192 }
193}
194
195#[implement(Sessions)]
198#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
199pub async fn exists_for_user(&self, user_id: &UserId) -> bool {
200 pin!(self.get_by_user(user_id))
201 .next()
202 .await
203 .is_some()
204}
205
206#[implement(Sessions)]
209#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
210pub async fn get_by_unique_id(&self, unique_id: &str) -> Result<Session> {
211 self.get_sess_id_by_unique_id(unique_id)
212 .and_then(async |sess_id| self.get(&sess_id).await)
213 .await
214}
215
216#[implement(Sessions)]
219#[tracing::instrument(level = "debug", skip(self))]
220pub fn get_by_user(&self, user_id: &UserId) -> impl Stream<Item = Result<Session>> + Send {
221 self.get_sess_id_by_user(user_id)
222 .and_then(async |sess_id| self.get(&sess_id).await)
223}
224
225#[implement(Sessions)]
227#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
228pub async fn get(&self, sess_id: &str) -> Result<Session> {
229 self.db
230 .oauthid_session
231 .get(sess_id)
232 .await
233 .deserialized::<Cbor<_>>()
234 .map(at!(0))
235}
236
237#[implement(Sessions)]
239#[tracing::instrument(level = "debug", skip(self))]
240pub fn get_sess_id_by_user(&self, user_id: &UserId) -> impl Stream<Item = Result<String>> + Send {
241 self.db
242 .userid_oauthid
243 .get(user_id)
244 .map(Deserialized::deserialized)
245 .map_ok(Vec::into_iter)
246 .map_ok(IterStream::try_stream)
247 .try_flatten_stream()
248}
249
250#[implement(Sessions)]
252#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
253pub async fn get_sess_id_by_unique_id(&self, unique_id: &str) -> Result<String> {
254 self.db
255 .oauthuniqid_oauthid
256 .get(unique_id)
257 .await
258 .deserialized()
259}
260
261#[implement(Sessions)]
262pub fn users(&self) -> impl Stream<Item = OwnedUserId> + Send {
263 self.db
264 .userid_oauthid
265 .keys()
266 .expect_ok()
267 .map(UserId::to_owned)
268}
269
270#[implement(Sessions)]
271pub fn stream(&self) -> impl Stream<Item = Session> + Send {
272 self.db
273 .oauthid_session
274 .stream()
275 .expect_ok()
276 .map(|(_, session): (Ignore, Cbor<_>)| session.0)
277}
278
279#[implement(Sessions)]
280pub async fn provider(&self, session: &Session) -> Result<Provider> {
281 let Some(idp_id) = session.idp_id.as_deref() else {
282 return Err!(Request(NotFound("No provider for this session")));
283 };
284
285 self.providers.get(idp_id).await
286}