1mod watch;
2
3#[cfg(test)]
4mod tests;
5
6use std::{
7 collections::{BTreeMap, btree_map::Entry},
8 sync::Arc,
9};
10
11use futures::{FutureExt, Stream};
12use ruma::{
13 DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId,
14 api::client::sync::sync_events::v5::{
15 ConnId as ConnectionId, ListId, Request, request,
16 request::{AccountData, E2EE, Receipts, ToDevice, Typing},
17 },
18};
19use serde::{Deserialize, Serialize};
20use tokio::sync::Mutex as TokioMutex;
21use tuwunel_core::{Result, at, debug, err, implement, is_equal_to, utils::stream::TryIgnore};
22use tuwunel_database::{Cbor, Deserialized, Map};
23
24pub struct Service {
25 services: Arc<crate::services::OnceServices>,
26 connections: Connections,
27 db: Data,
28}
29
30struct Data {
31 userdeviceconnid_conn: Arc<Map>,
32 todeviceid_events: Arc<Map>,
33 userroomid_joined: Arc<Map>,
34 userroomid_invitestate: Arc<Map>,
35 userroomid_leftstate: Arc<Map>,
36 userroomid_knockedstate: Arc<Map>,
37 userroomid_notificationcount: Arc<Map>,
38 userroomid_highlightcount: Arc<Map>,
39 pduid_pdu: Arc<Map>,
40 keychangeid_userid: Arc<Map>,
41 roomuserdataid_accountdata: Arc<Map>,
42 roomusertype_roomuserdataid: Arc<Map>,
43 readreceiptid_readreceipt: Arc<Map>,
44 userid_lastonetimekeyupdate: Arc<Map>,
45 roomuserid_lastnotificationread: Arc<Map>,
46}
47
48#[derive(Debug, Default, Deserialize, Serialize)]
49pub struct Connection {
50 pub globalsince: u64,
51 pub next_batch: u64,
52 pub lists: Lists,
53 pub extensions: request::Extensions,
54 pub subscriptions: Subscriptions,
55 pub rooms: Rooms,
56}
57
58#[derive(Clone, Copy, Debug, Default, Deserialize, Serialize)]
59pub struct Room {
60 pub roomsince: u64,
61}
62
63type Connections = TokioMutex<BTreeMap<ConnectionKey, ConnectionVal>>;
64pub type ConnectionVal = Arc<TokioMutex<Connection>>;
65pub type ConnectionKey = (OwnedUserId, Option<OwnedDeviceId>, Option<ConnectionId>);
66
67pub type Subscriptions = BTreeMap<OwnedRoomId, request::ListConfig>;
68pub type Lists = BTreeMap<ListId, request::List>;
69pub type Rooms = BTreeMap<OwnedRoomId, Room>;
70
71impl crate::Service for Service {
72 fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
73 Ok(Arc::new(Self {
74 db: Data {
75 userdeviceconnid_conn: args.db["userdeviceconnid_conn"].clone(),
76 todeviceid_events: args.db["todeviceid_events"].clone(),
77 userroomid_joined: args.db["userroomid_joined"].clone(),
78 userroomid_invitestate: args.db["userroomid_invitestate"].clone(),
79 userroomid_leftstate: args.db["userroomid_leftstate"].clone(),
80 userroomid_knockedstate: args.db["userroomid_knockedstate"].clone(),
81 userroomid_notificationcount: args.db["userroomid_notificationcount"].clone(),
82 userroomid_highlightcount: args.db["userroomid_highlightcount"].clone(),
83 pduid_pdu: args.db["pduid_pdu"].clone(),
84 keychangeid_userid: args.db["keychangeid_userid"].clone(),
85 roomuserdataid_accountdata: args.db["roomuserdataid_accountdata"].clone(),
86 roomusertype_roomuserdataid: args.db["roomusertype_roomuserdataid"].clone(),
87 readreceiptid_readreceipt: args.db["readreceiptid_readreceipt"].clone(),
88 userid_lastonetimekeyupdate: args.db["userid_lastonetimekeyupdate"].clone(),
89 roomuserid_lastnotificationread: args.db["roomuserid_lastnotificationread"]
90 .clone(),
91 },
92 services: args.services.clone(),
93 connections: Default::default(),
94 }))
95 }
96
97 fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
98}
99
100#[implement(Service)]
101#[tracing::instrument(level = "debug", skip(self))]
102pub async fn clear_connections(
103 &self,
104 user_id: Option<&UserId>,
105 device_id: Option<&DeviceId>,
106 conn_id: Option<&ConnectionId>,
107) {
108 self.connections
109 .lock()
110 .await
111 .retain(|(conn_user_id, conn_device_id, conn_conn_id), _| {
112 let retain = user_id.is_none_or(is_equal_to!(conn_user_id))
113 && (device_id.is_none() || device_id == conn_device_id.as_deref())
114 && (conn_id.is_none() || conn_id == conn_conn_id.as_ref());
115
116 if !retain {
117 self.db
118 .userdeviceconnid_conn
119 .del((conn_user_id, conn_device_id, conn_conn_id));
120 }
121
122 retain
123 });
124}
125
126#[implement(Service)]
127#[tracing::instrument(level = "debug", skip(self))]
128pub async fn drop_connection(&self, key: &ConnectionKey) {
129 let mut cache = self.connections.lock().await;
130
131 self.db.userdeviceconnid_conn.del(key);
132 cache.remove(key);
133}
134
135#[implement(Service)]
136#[tracing::instrument(level = "debug", skip(self))]
137pub async fn load_or_init_connection(&self, key: &ConnectionKey) -> ConnectionVal {
138 let mut cache = self.connections.lock().await;
139
140 match cache.entry(key.clone()) {
141 | Entry::Occupied(val) => val.get().clone(),
142 | Entry::Vacant(val) => {
143 let conn = self
144 .db
145 .userdeviceconnid_conn
146 .qry(key)
147 .boxed()
148 .await
149 .deserialized::<Cbor<_>>()
150 .map(at!(0))
151 .map(TokioMutex::new)
152 .map(Arc::new)
153 .unwrap_or_default();
154
155 val.insert(conn).clone()
156 },
157 }
158}
159
160#[implement(Service)]
161#[tracing::instrument(level = "debug", skip(self))]
162pub async fn load_connection(&self, key: &ConnectionKey) -> Result<ConnectionVal> {
163 let mut cache = self.connections.lock().await;
164
165 match cache.entry(key.clone()) {
166 | Entry::Occupied(val) => Ok(val.get().clone()),
167 | Entry::Vacant(val) => self
168 .db
169 .userdeviceconnid_conn
170 .qry(key)
171 .await
172 .deserialized::<Cbor<_>>()
173 .map(at!(0))
174 .map(TokioMutex::new)
175 .map(Arc::new)
176 .map(|conn| val.insert(conn).clone()),
177 }
178}
179
180#[implement(Service)]
181#[tracing::instrument(level = "debug", skip(self))]
182pub async fn get_loaded_connection(&self, key: &ConnectionKey) -> Result<ConnectionVal> {
183 self.connections
184 .lock()
185 .await
186 .get(key)
187 .cloned()
188 .ok_or_else(|| err!(Request(NotFound("Connection not found."))))
189}
190
191#[implement(Service)]
192#[tracing::instrument(level = "trace", skip(self))]
193pub async fn list_loaded_connections(&self) -> Vec<ConnectionKey> {
194 self.connections
195 .lock()
196 .await
197 .keys()
198 .cloned()
199 .collect()
200}
201
202#[implement(Service)]
203#[tracing::instrument(level = "trace", skip(self))]
204pub fn list_stored_connections(&self) -> impl Stream<Item = ConnectionKey> {
205 self.db.userdeviceconnid_conn.keys().ignore_err()
206}
207
208#[implement(Service)]
209#[tracing::instrument(level = "trace", skip(self))]
210pub async fn is_connection_loaded(&self, key: &ConnectionKey) -> bool {
211 self.connections.lock().await.contains_key(key)
212}
213
214#[implement(Service)]
215#[tracing::instrument(level = "trace", skip(self))]
216pub async fn is_connection_stored(&self, key: &ConnectionKey) -> bool {
217 self.db.userdeviceconnid_conn.contains(key).await
218}
219
220#[inline]
221pub fn into_connection_key<U, D, C>(
222 user_id: U,
223 device_id: Option<D>,
224 conn_id: Option<C>,
225) -> ConnectionKey
226where
227 U: Into<OwnedUserId>,
228 D: Into<OwnedDeviceId>,
229 C: Into<ConnectionId>,
230{
231 (user_id.into(), device_id.map(Into::into), conn_id.map(Into::into))
232}
233
234#[implement(Connection)]
235#[tracing::instrument(level = "debug", skip(self, service))]
236pub fn store(&self, service: &Service, key: &ConnectionKey) {
237 service
238 .db
239 .userdeviceconnid_conn
240 .put(key, Cbor(self));
241
242 debug!(
243 since = %self.globalsince,
244 next_batch = %self.next_batch,
245 "Persisted connection state"
246 );
247}
248
249#[implement(Connection)]
250#[tracing::instrument(level = "debug", skip(self))]
251pub fn update_rooms_prologue(&mut self, retard_since: Option<u64>) {
252 self.rooms.values_mut().for_each(|room| {
253 if let Some(retard_since) = retard_since
254 && room.roomsince > retard_since
255 {
256 room.roomsince = retard_since;
257 }
258 });
259}
260
261#[implement(Connection)]
262#[tracing::instrument(level = "debug", skip_all)]
263pub fn update_rooms_epilogue<'a, Rooms>(&mut self, window: Rooms)
264where
265 Rooms: Iterator<Item = &'a RoomId> + Send + 'a,
266{
267 window.for_each(|room_id| {
268 let room = self.rooms.entry(room_id.into()).or_default();
269
270 room.roomsince = self.next_batch;
271 });
272}
273
274#[implement(Connection)]
275#[tracing::instrument(level = "debug", skip_all)]
276pub fn update_cache(&mut self, request: &Request) {
277 Self::update_cache_lists(request, self);
278 Self::update_cache_subscriptions(request, self);
279 Self::update_cache_extensions(request, self);
280}
281
282#[implement(Connection)]
283fn update_cache_lists(request: &Request, cached: &mut Self) {
284 for (list_id, request_list) in &request.lists {
285 cached
286 .lists
287 .entry(list_id.clone())
288 .and_modify(|cached_list| {
289 Self::update_cache_list(request_list, cached_list);
290 })
291 .or_insert_with(|| request_list.clone());
292 }
293}
294
295#[implement(Connection)]
296fn update_cache_list(request: &request::List, cached: &mut request::List) {
297 cached.ranges.clone_from(&request.ranges);
298 list_or_sticky(&request.room_details.required_state, &mut cached.room_details.required_state);
299
300 match (&request.filters, &mut cached.filters) {
301 | (None, None) => {},
302 | (None, Some(_cached)) => {},
303 | (Some(request), None) => cached.filters = Some(request.clone()),
304 | (Some(request), Some(cached)) => {
305 some_or_sticky(request.is_dm.as_ref(), &mut cached.is_dm);
306 some_or_sticky(request.is_encrypted.as_ref(), &mut cached.is_encrypted);
307 some_or_sticky(request.is_invite.as_ref(), &mut cached.is_invite);
308 list_or_sticky(&request.room_types, &mut cached.room_types);
309 list_or_sticky(&request.not_room_types, &mut cached.not_room_types);
310 list_or_sticky(&request.tags, &mut cached.not_tags);
311 list_or_sticky(&request.spaces, &mut cached.spaces);
312 },
313 }
314}
315
316#[implement(Connection)]
317fn update_cache_subscriptions(request: &Request, cached: &mut Self) {
318 cached.subscriptions = request.room_subscriptions.clone();
319}
320
321#[implement(Connection)]
322fn update_cache_extensions(request: &Request, cached: &mut Self) {
323 let request = &request.extensions;
324 let cached = &mut cached.extensions;
325
326 Self::update_cache_account_data(&request.account_data, &mut cached.account_data);
327 Self::update_cache_receipts(&request.receipts, &mut cached.receipts);
328 Self::update_cache_typing(&request.typing, &mut cached.typing);
329 Self::update_cache_to_device(&request.to_device, &mut cached.to_device);
330 Self::update_cache_e2ee(&request.e2ee, &mut cached.e2ee);
331}
332
333#[implement(Connection)]
334fn update_cache_account_data(request: &AccountData, cached: &mut AccountData) {
335 some_or_sticky(request.enabled.as_ref(), &mut cached.enabled);
336 some_or_sticky(request.lists.as_ref(), &mut cached.lists);
337 some_or_sticky(request.rooms.as_ref(), &mut cached.rooms);
338}
339
340#[implement(Connection)]
341fn update_cache_receipts(request: &Receipts, cached: &mut Receipts) {
342 some_or_sticky(request.enabled.as_ref(), &mut cached.enabled);
343 some_or_sticky(request.rooms.as_ref(), &mut cached.rooms);
344 some_or_sticky(request.lists.as_ref(), &mut cached.lists);
345}
346
347#[implement(Connection)]
348fn update_cache_typing(request: &Typing, cached: &mut Typing) {
349 some_or_sticky(request.enabled.as_ref(), &mut cached.enabled);
350 some_or_sticky(request.rooms.as_ref(), &mut cached.rooms);
351 some_or_sticky(request.lists.as_ref(), &mut cached.lists);
352}
353
354#[implement(Connection)]
355fn update_cache_to_device(request: &ToDevice, cached: &mut ToDevice) {
356 some_or_sticky(request.enabled.as_ref(), &mut cached.enabled);
357 cached.since.clone_from(&request.since);
358}
359
360#[implement(Connection)]
361fn update_cache_e2ee(request: &E2EE, cached: &mut E2EE) {
362 some_or_sticky(request.enabled.as_ref(), &mut cached.enabled);
363}
364
365fn list_or_sticky<T: Clone>(target: &Vec<T>, cached: &mut Vec<T>) {
366 if !target.is_empty() {
367 cached.clone_from(target);
368 }
369}
370
371fn some_or_sticky<T: Clone>(target: Option<&T>, cached: &mut Option<T>) {
372 if let Some(target) = target {
373 cached.replace(target.clone());
374 }
375}