Skip to main content

tuwunel_service/sync/
mod.rs

1mod watch;
2
3#[cfg(test)]
4mod tests;
5
6use std::{
7	collections::{BTreeMap, btree_map::Entry},
8	sync::Arc,
9};
10
11use futures::{FutureExt, Stream};
12use ruma::{
13	DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId,
14	api::client::sync::sync_events::v5::{
15		ConnId as ConnectionId, ListId, Request, request,
16		request::{AccountData, E2EE, Receipts, ToDevice, Typing},
17	},
18};
19use serde::{Deserialize, Serialize};
20use tokio::sync::Mutex as TokioMutex;
21use tuwunel_core::{Result, at, debug, err, implement, is_equal_to, utils::stream::TryIgnore};
22use tuwunel_database::{Cbor, Deserialized, Map};
23
24pub struct Service {
25	services: Arc<crate::services::OnceServices>,
26	connections: Connections,
27	db: Data,
28}
29
30struct Data {
31	userdeviceconnid_conn: Arc<Map>,
32	todeviceid_events: Arc<Map>,
33	userroomid_joined: Arc<Map>,
34	userroomid_invitestate: Arc<Map>,
35	userroomid_leftstate: Arc<Map>,
36	userroomid_knockedstate: Arc<Map>,
37	userroomid_notificationcount: Arc<Map>,
38	userroomid_highlightcount: Arc<Map>,
39	pduid_pdu: Arc<Map>,
40	keychangeid_userid: Arc<Map>,
41	roomuserdataid_accountdata: Arc<Map>,
42	roomusertype_roomuserdataid: Arc<Map>,
43	readreceiptid_readreceipt: Arc<Map>,
44	userid_lastonetimekeyupdate: Arc<Map>,
45	roomuserid_lastnotificationread: Arc<Map>,
46}
47
48#[derive(Debug, Default, Deserialize, Serialize)]
49pub struct Connection {
50	pub globalsince: u64,
51	pub next_batch: u64,
52	pub lists: Lists,
53	pub extensions: request::Extensions,
54	pub subscriptions: Subscriptions,
55	pub rooms: Rooms,
56}
57
58#[derive(Clone, Copy, Debug, Default, Deserialize, Serialize)]
59pub struct Room {
60	pub roomsince: u64,
61}
62
63type Connections = TokioMutex<BTreeMap<ConnectionKey, ConnectionVal>>;
64pub type ConnectionVal = Arc<TokioMutex<Connection>>;
65pub type ConnectionKey = (OwnedUserId, Option<OwnedDeviceId>, Option<ConnectionId>);
66
67pub type Subscriptions = BTreeMap<OwnedRoomId, request::ListConfig>;
68pub type Lists = BTreeMap<ListId, request::List>;
69pub type Rooms = BTreeMap<OwnedRoomId, Room>;
70
71impl crate::Service for Service {
72	fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
73		Ok(Arc::new(Self {
74			db: Data {
75				userdeviceconnid_conn: args.db["userdeviceconnid_conn"].clone(),
76				todeviceid_events: args.db["todeviceid_events"].clone(),
77				userroomid_joined: args.db["userroomid_joined"].clone(),
78				userroomid_invitestate: args.db["userroomid_invitestate"].clone(),
79				userroomid_leftstate: args.db["userroomid_leftstate"].clone(),
80				userroomid_knockedstate: args.db["userroomid_knockedstate"].clone(),
81				userroomid_notificationcount: args.db["userroomid_notificationcount"].clone(),
82				userroomid_highlightcount: args.db["userroomid_highlightcount"].clone(),
83				pduid_pdu: args.db["pduid_pdu"].clone(),
84				keychangeid_userid: args.db["keychangeid_userid"].clone(),
85				roomuserdataid_accountdata: args.db["roomuserdataid_accountdata"].clone(),
86				roomusertype_roomuserdataid: args.db["roomusertype_roomuserdataid"].clone(),
87				readreceiptid_readreceipt: args.db["readreceiptid_readreceipt"].clone(),
88				userid_lastonetimekeyupdate: args.db["userid_lastonetimekeyupdate"].clone(),
89				roomuserid_lastnotificationread: args.db["roomuserid_lastnotificationread"]
90					.clone(),
91			},
92			services: args.services.clone(),
93			connections: Default::default(),
94		}))
95	}
96
97	fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
98}
99
100#[implement(Service)]
101#[tracing::instrument(level = "debug", skip(self))]
102pub async fn clear_connections(
103	&self,
104	user_id: Option<&UserId>,
105	device_id: Option<&DeviceId>,
106	conn_id: Option<&ConnectionId>,
107) {
108	self.connections
109		.lock()
110		.await
111		.retain(|(conn_user_id, conn_device_id, conn_conn_id), _| {
112			let retain = user_id.is_none_or(is_equal_to!(conn_user_id))
113				&& (device_id.is_none() || device_id == conn_device_id.as_deref())
114				&& (conn_id.is_none() || conn_id == conn_conn_id.as_ref());
115
116			if !retain {
117				self.db
118					.userdeviceconnid_conn
119					.del((conn_user_id, conn_device_id, conn_conn_id));
120			}
121
122			retain
123		});
124}
125
126#[implement(Service)]
127#[tracing::instrument(level = "debug", skip(self))]
128pub async fn drop_connection(&self, key: &ConnectionKey) {
129	let mut cache = self.connections.lock().await;
130
131	self.db.userdeviceconnid_conn.del(key);
132	cache.remove(key);
133}
134
135#[implement(Service)]
136#[tracing::instrument(level = "debug", skip(self))]
137pub async fn load_or_init_connection(&self, key: &ConnectionKey) -> ConnectionVal {
138	let mut cache = self.connections.lock().await;
139
140	match cache.entry(key.clone()) {
141		| Entry::Occupied(val) => val.get().clone(),
142		| Entry::Vacant(val) => {
143			let conn = self
144				.db
145				.userdeviceconnid_conn
146				.qry(key)
147				.boxed()
148				.await
149				.deserialized::<Cbor<_>>()
150				.map(at!(0))
151				.map(TokioMutex::new)
152				.map(Arc::new)
153				.unwrap_or_default();
154
155			val.insert(conn).clone()
156		},
157	}
158}
159
160#[implement(Service)]
161#[tracing::instrument(level = "debug", skip(self))]
162pub async fn load_connection(&self, key: &ConnectionKey) -> Result<ConnectionVal> {
163	let mut cache = self.connections.lock().await;
164
165	match cache.entry(key.clone()) {
166		| Entry::Occupied(val) => Ok(val.get().clone()),
167		| Entry::Vacant(val) => self
168			.db
169			.userdeviceconnid_conn
170			.qry(key)
171			.await
172			.deserialized::<Cbor<_>>()
173			.map(at!(0))
174			.map(TokioMutex::new)
175			.map(Arc::new)
176			.map(|conn| val.insert(conn).clone()),
177	}
178}
179
180#[implement(Service)]
181#[tracing::instrument(level = "debug", skip(self))]
182pub async fn get_loaded_connection(&self, key: &ConnectionKey) -> Result<ConnectionVal> {
183	self.connections
184		.lock()
185		.await
186		.get(key)
187		.cloned()
188		.ok_or_else(|| err!(Request(NotFound("Connection not found."))))
189}
190
191#[implement(Service)]
192#[tracing::instrument(level = "trace", skip(self))]
193pub async fn list_loaded_connections(&self) -> Vec<ConnectionKey> {
194	self.connections
195		.lock()
196		.await
197		.keys()
198		.cloned()
199		.collect()
200}
201
202#[implement(Service)]
203#[tracing::instrument(level = "trace", skip(self))]
204pub fn list_stored_connections(&self) -> impl Stream<Item = ConnectionKey> {
205	self.db.userdeviceconnid_conn.keys().ignore_err()
206}
207
208#[implement(Service)]
209#[tracing::instrument(level = "trace", skip(self))]
210pub async fn is_connection_loaded(&self, key: &ConnectionKey) -> bool {
211	self.connections.lock().await.contains_key(key)
212}
213
214#[implement(Service)]
215#[tracing::instrument(level = "trace", skip(self))]
216pub async fn is_connection_stored(&self, key: &ConnectionKey) -> bool {
217	self.db.userdeviceconnid_conn.contains(key).await
218}
219
220#[inline]
221pub fn into_connection_key<U, D, C>(
222	user_id: U,
223	device_id: Option<D>,
224	conn_id: Option<C>,
225) -> ConnectionKey
226where
227	U: Into<OwnedUserId>,
228	D: Into<OwnedDeviceId>,
229	C: Into<ConnectionId>,
230{
231	(user_id.into(), device_id.map(Into::into), conn_id.map(Into::into))
232}
233
234#[implement(Connection)]
235#[tracing::instrument(level = "debug", skip(self, service))]
236pub fn store(&self, service: &Service, key: &ConnectionKey) {
237	service
238		.db
239		.userdeviceconnid_conn
240		.put(key, Cbor(self));
241
242	debug!(
243		since = %self.globalsince,
244		next_batch = %self.next_batch,
245		"Persisted connection state"
246	);
247}
248
249#[implement(Connection)]
250#[tracing::instrument(level = "debug", skip(self))]
251pub fn update_rooms_prologue(&mut self, retard_since: Option<u64>) {
252	self.rooms.values_mut().for_each(|room| {
253		if let Some(retard_since) = retard_since
254			&& room.roomsince > retard_since
255		{
256			room.roomsince = retard_since;
257		}
258	});
259}
260
261#[implement(Connection)]
262#[tracing::instrument(level = "debug", skip_all)]
263pub fn update_rooms_epilogue<'a, Rooms>(&mut self, window: Rooms)
264where
265	Rooms: Iterator<Item = &'a RoomId> + Send + 'a,
266{
267	window.for_each(|room_id| {
268		let room = self.rooms.entry(room_id.into()).or_default();
269
270		room.roomsince = self.next_batch;
271	});
272}
273
274#[implement(Connection)]
275#[tracing::instrument(level = "debug", skip_all)]
276pub fn update_cache(&mut self, request: &Request) {
277	Self::update_cache_lists(request, self);
278	Self::update_cache_subscriptions(request, self);
279	Self::update_cache_extensions(request, self);
280}
281
282#[implement(Connection)]
283fn update_cache_lists(request: &Request, cached: &mut Self) {
284	for (list_id, request_list) in &request.lists {
285		cached
286			.lists
287			.entry(list_id.clone())
288			.and_modify(|cached_list| {
289				Self::update_cache_list(request_list, cached_list);
290			})
291			.or_insert_with(|| request_list.clone());
292	}
293}
294
295#[implement(Connection)]
296fn update_cache_list(request: &request::List, cached: &mut request::List) {
297	cached.ranges.clone_from(&request.ranges);
298	list_or_sticky(&request.room_details.required_state, &mut cached.room_details.required_state);
299
300	match (&request.filters, &mut cached.filters) {
301		| (None, None) => {},
302		| (None, Some(_cached)) => {},
303		| (Some(request), None) => cached.filters = Some(request.clone()),
304		| (Some(request), Some(cached)) => {
305			some_or_sticky(request.is_dm.as_ref(), &mut cached.is_dm);
306			some_or_sticky(request.is_encrypted.as_ref(), &mut cached.is_encrypted);
307			some_or_sticky(request.is_invite.as_ref(), &mut cached.is_invite);
308			list_or_sticky(&request.room_types, &mut cached.room_types);
309			list_or_sticky(&request.not_room_types, &mut cached.not_room_types);
310			list_or_sticky(&request.tags, &mut cached.not_tags);
311			list_or_sticky(&request.spaces, &mut cached.spaces);
312		},
313	}
314}
315
316#[implement(Connection)]
317fn update_cache_subscriptions(request: &Request, cached: &mut Self) {
318	cached.subscriptions = request.room_subscriptions.clone();
319}
320
321#[implement(Connection)]
322fn update_cache_extensions(request: &Request, cached: &mut Self) {
323	let request = &request.extensions;
324	let cached = &mut cached.extensions;
325
326	Self::update_cache_account_data(&request.account_data, &mut cached.account_data);
327	Self::update_cache_receipts(&request.receipts, &mut cached.receipts);
328	Self::update_cache_typing(&request.typing, &mut cached.typing);
329	Self::update_cache_to_device(&request.to_device, &mut cached.to_device);
330	Self::update_cache_e2ee(&request.e2ee, &mut cached.e2ee);
331}
332
333#[implement(Connection)]
334fn update_cache_account_data(request: &AccountData, cached: &mut AccountData) {
335	some_or_sticky(request.enabled.as_ref(), &mut cached.enabled);
336	some_or_sticky(request.lists.as_ref(), &mut cached.lists);
337	some_or_sticky(request.rooms.as_ref(), &mut cached.rooms);
338}
339
340#[implement(Connection)]
341fn update_cache_receipts(request: &Receipts, cached: &mut Receipts) {
342	some_or_sticky(request.enabled.as_ref(), &mut cached.enabled);
343	some_or_sticky(request.rooms.as_ref(), &mut cached.rooms);
344	some_or_sticky(request.lists.as_ref(), &mut cached.lists);
345}
346
347#[implement(Connection)]
348fn update_cache_typing(request: &Typing, cached: &mut Typing) {
349	some_or_sticky(request.enabled.as_ref(), &mut cached.enabled);
350	some_or_sticky(request.rooms.as_ref(), &mut cached.rooms);
351	some_or_sticky(request.lists.as_ref(), &mut cached.lists);
352}
353
354#[implement(Connection)]
355fn update_cache_to_device(request: &ToDevice, cached: &mut ToDevice) {
356	some_or_sticky(request.enabled.as_ref(), &mut cached.enabled);
357	cached.since.clone_from(&request.since);
358}
359
360#[implement(Connection)]
361fn update_cache_e2ee(request: &E2EE, cached: &mut E2EE) {
362	some_or_sticky(request.enabled.as_ref(), &mut cached.enabled);
363}
364
365fn list_or_sticky<T: Clone>(target: &Vec<T>, cached: &mut Vec<T>) {
366	if !target.is_empty() {
367		cached.clone_from(target);
368	}
369}
370
371fn some_or_sticky<T: Clone>(target: Option<&T>, cached: &mut Option<T>) {
372	if let Some(target) = target {
373		cached.replace(target.clone());
374	}
375}