1pub mod association;
2
3use std::{
4 iter::once,
5 pin::pin,
6 sync::{Arc, Mutex},
7 time::SystemTime,
8};
9
10use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
11use ruma::{OwnedUserId, UserId};
12use serde::{Deserialize, Serialize};
13use tuwunel_core::{
14 Err, Result, at, implement,
15 utils::stream::{IterStream, ReadyExt, TryExpect},
16};
17use tuwunel_database::{Cbor, Deserialized, Ignore, Map};
18use url::Url;
19
20use super::{Provider, Providers, UserInfo, unique_id};
21use crate::SelfServices;
22
23pub struct Sessions {
24 _services: SelfServices,
25 association_pending: Mutex<association::Pending>,
26 providers: Arc<Providers>,
27 db: Data,
28}
29
30struct Data {
31 oauthid_session: Arc<Map>,
32 oauthuniqid_oauthid: Arc<Map>,
33 userid_oauthid: Arc<Map>,
34}
35
36#[derive(Clone, Debug, Default, Deserialize, Serialize)]
41pub struct Session {
42 pub idp_id: Option<String>,
45
46 pub sess_id: Option<SessionId>,
48
49 pub token_type: Option<String>,
51
52 pub access_token: Option<String>,
54
55 pub id_token: Option<String>,
57
58 pub expires_in: Option<u64>,
60
61 pub expires_at: Option<SystemTime>,
63
64 pub refresh_token: Option<String>,
66
67 pub refresh_token_expires_in: Option<u64>,
69
70 pub refresh_token_expires_at: Option<SystemTime>,
72
73 pub scope: Option<String>,
75
76 pub redirect_url: Option<Url>,
78
79 pub code_verifier: Option<String>,
81
82 pub cookie_nonce: Option<String>,
84
85 pub query_nonce: Option<String>,
87
88 pub authorize_expires_at: Option<SystemTime>,
90
91 pub user_id: Option<OwnedUserId>,
93
94 pub user_info: Option<UserInfo>,
96}
97
98pub type SessionId = String;
100
101pub const CODE_VERIFIER_LENGTH: usize = 64;
104
105pub const SESSION_ID_LENGTH: usize = 32;
107
108#[implement(Sessions)]
109pub(super) fn build(args: &crate::Args<'_>, providers: Arc<Providers>) -> Self {
110 Self {
111 _services: args.services.clone(),
112 association_pending: Default::default(),
113 providers,
114 db: Data {
115 oauthid_session: args.db["oauthid_session"].clone(),
116 oauthuniqid_oauthid: args.db["oauthuniqid_oauthid"].clone(),
117 userid_oauthid: args.db["userid_oauthid"].clone(),
118 },
119 }
120}
121
122#[implement(Sessions)]
124#[tracing::instrument(level = "debug", skip(self))]
125pub async fn delete(&self, sess_id: &str) {
126 let Ok(session) = self.get(sess_id).await else {
127 return;
128 };
129
130 if let Some(user_id) = session.user_id.as_deref() {
131 let sess_ids: Vec<_> = self
132 .get_sess_id_by_user(user_id)
133 .ready_filter_map(Result::ok)
134 .ready_filter(|assoc_id| assoc_id != sess_id)
135 .collect()
136 .await;
137
138 if !sess_ids.is_empty() {
139 self.db.userid_oauthid.raw_put(user_id, sess_ids);
140 } else {
141 self.db.userid_oauthid.remove(user_id);
142 }
143 }
144
145 if let Some(idp_id) = session.idp_id.as_ref()
148 && let Ok(provider) = self.providers.get(idp_id).await
149 && let Ok(unique_id) = unique_id((&provider, &session))
150 && let Ok(assoc_id) = self.get_sess_id_by_unique_id(&unique_id).await
151 && assoc_id == sess_id
152 {
153 self.db.oauthuniqid_oauthid.remove(&unique_id);
154 }
155
156 self.db.oauthid_session.remove(sess_id);
157}
158
159#[implement(Sessions)]
161#[tracing::instrument(level = "info", skip(self))]
162pub async fn put(&self, session: &Session) {
163 let sess_id = session
164 .sess_id
165 .as_deref()
166 .expect("Missing session.sess_id required for sessions.put()");
167
168 self.db
169 .oauthid_session
170 .raw_put(sess_id, Cbor(session));
171
172 if let Some(idp_id) = session.idp_id.as_ref()
173 && let Ok(provider) = self.providers.get(idp_id).await
174 && let Ok(unique_id) = unique_id((&provider, session))
175 {
176 self.db
177 .oauthuniqid_oauthid
178 .insert(&unique_id, sess_id);
179 }
180
181 if let Some(user_id) = session.user_id.as_deref() {
182 let sess_ids = self
183 .get_sess_id_by_user(user_id)
184 .ready_filter_map(Result::ok)
185 .chain(once(sess_id.to_owned()).stream())
186 .collect::<Vec<_>>()
187 .map(|mut ids| {
188 ids.sort_unstable();
189 ids.dedup();
190 ids
191 })
192 .await;
193
194 self.db.userid_oauthid.raw_put(user_id, sess_ids);
195 }
196}
197
198#[implement(Sessions)]
201#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
202pub async fn exists_for_user(&self, user_id: &UserId) -> bool {
203 pin!(self.get_by_user(user_id))
204 .next()
205 .await
206 .is_some()
207}
208
209#[implement(Sessions)]
212#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
213pub async fn get_by_unique_id(&self, unique_id: &str) -> Result<Session> {
214 self.get_sess_id_by_unique_id(unique_id)
215 .and_then(async |sess_id| self.get(&sess_id).await)
216 .await
217}
218
219#[implement(Sessions)]
222#[tracing::instrument(level = "debug", skip(self))]
223pub fn get_by_user(&self, user_id: &UserId) -> impl Stream<Item = Result<Session>> + Send {
224 self.get_sess_id_by_user(user_id)
225 .and_then(async |sess_id| self.get(&sess_id).await)
226}
227
228#[implement(Sessions)]
230#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
231pub async fn get(&self, sess_id: &str) -> Result<Session> {
232 self.db
233 .oauthid_session
234 .get(sess_id)
235 .await
236 .deserialized::<Cbor<_>>()
237 .map(at!(0))
238}
239
240#[implement(Sessions)]
242#[tracing::instrument(level = "debug", skip(self))]
243pub fn get_sess_id_by_user(&self, user_id: &UserId) -> impl Stream<Item = Result<String>> + Send {
244 self.db
245 .userid_oauthid
246 .get(user_id)
247 .map(Deserialized::deserialized)
248 .map_ok(Vec::into_iter)
249 .map_ok(IterStream::try_stream)
250 .try_flatten_stream()
251}
252
253#[implement(Sessions)]
255#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
256pub async fn get_sess_id_by_unique_id(&self, unique_id: &str) -> Result<String> {
257 self.db
258 .oauthuniqid_oauthid
259 .get(unique_id)
260 .await
261 .deserialized()
262}
263
264#[implement(Sessions)]
265pub fn users(&self) -> impl Stream<Item = OwnedUserId> + Send {
266 self.db
267 .userid_oauthid
268 .keys()
269 .expect_ok()
270 .map(UserId::to_owned)
271}
272
273#[implement(Sessions)]
274pub fn stream(&self) -> impl Stream<Item = Session> + Send {
275 self.db
276 .oauthid_session
277 .stream()
278 .expect_ok()
279 .map(|(_, session): (Ignore, Cbor<_>)| session.0)
280}
281
282#[implement(Sessions)]
283pub async fn provider(&self, session: &Session) -> Result<Provider> {
284 let Some(idp_id) = session.idp_id.as_deref() else {
285 return Err!(Request(NotFound("No provider for this session")));
286 };
287
288 self.providers.get(idp_id).await
289}