Skip to main content

tuwunel_service/sending/
data.rs

1use std::{fmt::Debug, sync::Arc};
2
3use futures::{Stream, StreamExt};
4use ruma::{OwnedServerName, ServerName, UserId};
5use tuwunel_core::{
6	Error, Result, at, utils,
7	utils::{ReadyExt, stream::TryIgnore},
8};
9use tuwunel_database::{Database, Deserialized, Map};
10
11use super::{Destination, SendingEvent};
12
13pub(super) type OutgoingItem = (Key, SendingEvent, Destination);
14pub(super) type SendingItem = (Key, SendingEvent);
15pub(super) type QueueItem = (Key, SendingEvent);
16pub(super) type Key = Vec<u8>;
17
18pub struct Data {
19	servercurrentevent_data: Arc<Map>,
20	servernameevent_data: Arc<Map>,
21	servername_educount: Arc<Map>,
22	pub(super) db: Arc<Database>,
23	services: Arc<crate::services::OnceServices>,
24}
25
26impl Data {
27	pub(super) fn new(args: &crate::Args<'_>) -> Self {
28		let db = &args.db;
29		Self {
30			servercurrentevent_data: db["servercurrentevent_data"].clone(),
31			servernameevent_data: db["servernameevent_data"].clone(),
32			servername_educount: db["servername_educount"].clone(),
33			db: args.db.clone(),
34			services: args.services.clone(),
35		}
36	}
37
38	#[inline]
39	pub(super) fn delete_active_request(&self, key: &[u8]) {
40		self.servercurrentevent_data.remove(key);
41	}
42
43	pub(super) async fn delete_all_active_requests_for(&self, destination: &Destination) {
44		let prefix = destination.get_prefix();
45		self.servercurrentevent_data
46			.raw_keys_prefix(&prefix)
47			.ignore_err()
48			.ready_for_each(|key| self.servercurrentevent_data.remove(key))
49			.await;
50	}
51
52	pub(super) async fn delete_all_requests_for(&self, destination: &Destination) {
53		let prefix = destination.get_prefix();
54		self.servercurrentevent_data
55			.raw_keys_prefix(&prefix)
56			.ignore_err()
57			.ready_for_each(|key| self.servercurrentevent_data.remove(key))
58			.await;
59
60		self.servernameevent_data
61			.raw_keys_prefix(&prefix)
62			.ignore_err()
63			.ready_for_each(|key| self.servernameevent_data.remove(key))
64			.await;
65	}
66
67	pub(super) fn mark_as_active<'a, I>(&self, events: I)
68	where
69		I: Iterator<Item = &'a QueueItem>,
70	{
71		events
72			.filter(|(key, _)| !key.is_empty())
73			.for_each(|(key, val)| {
74				let val = if let SendingEvent::Edu(val) = &val { &**val } else { &[] };
75
76				self.servercurrentevent_data.insert(key, val);
77				self.servernameevent_data.remove(key);
78			});
79	}
80
81	#[inline]
82	pub fn active_requests(&self) -> impl Stream<Item = OutgoingItem> + Send + '_ {
83		self.servercurrentevent_data
84			.raw_stream()
85			.ignore_err()
86			.map(|(key, val)| {
87				let (dest, event) =
88					parse_servercurrentevent(key, val).expect("invalid servercurrentevent");
89
90				(key.to_vec(), event, dest)
91			})
92	}
93
94	#[inline]
95	pub fn active_requests_for(
96		&self,
97		destination: &Destination,
98	) -> impl Stream<Item = SendingItem> + Send + '_ + use<'_> {
99		let prefix = destination.get_prefix();
100		self.servercurrentevent_data
101			.raw_stream_from(&prefix)
102			.ignore_err()
103			.ready_take_while(move |(key, _)| key.starts_with(&prefix))
104			.map(|(key, val)| {
105				let (_, event) =
106					parse_servercurrentevent(key, val).expect("invalid servercurrentevent");
107
108				(key.to_vec(), event)
109			})
110	}
111
112	pub(super) fn queue_requests<'a, I>(&self, requests: I) -> Vec<Vec<u8>>
113	where
114		I: Iterator<Item = (&'a SendingEvent, &'a Destination)> + Clone + Debug + Send,
115	{
116		let keys: Vec<_> = requests
117			.clone()
118			.map(|(event, dest)| {
119				let mut key = dest.get_prefix();
120				if let SendingEvent::Pdu(value) = event {
121					key.extend(value.as_ref());
122				} else {
123					let count = self.services.globals.next_count();
124					let count = count.to_be_bytes();
125					key.extend(&count);
126				}
127
128				key
129			})
130			.collect();
131
132		self.servernameevent_data.insert_batch(
133			keys.iter()
134				.map(Vec::as_slice)
135				.zip(requests.map(at!(0)))
136				.map(|(key, event)| {
137					let value = if let SendingEvent::Edu(value) = &event {
138						&**value
139					} else {
140						&[]
141					};
142
143					(key, value)
144				}),
145		);
146
147		keys
148	}
149
150	pub fn queued_requests(
151		&self,
152		destination: &Destination,
153	) -> impl Stream<Item = QueueItem> + Send + '_ + use<'_> {
154		let prefix = destination.get_prefix();
155		self.servernameevent_data
156			.raw_stream_from(&prefix)
157			.ignore_err()
158			.ready_take_while(move |(key, _)| key.starts_with(&prefix))
159			.map(|(key, val)| {
160				let (_, event) =
161					parse_servercurrentevent(key, val).expect("invalid servercurrentevent");
162
163				(key.to_vec(), event)
164			})
165	}
166
167	pub(super) fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) {
168		self.servername_educount
169			.raw_put(server_name, last_count);
170	}
171
172	pub async fn get_latest_educount(&self, server_name: &ServerName) -> u64 {
173		self.servername_educount
174			.get(server_name)
175			.await
176			.deserialized()
177			.unwrap_or(0)
178	}
179}
180
181fn parse_servercurrentevent(key: &[u8], value: &[u8]) -> Result<(Destination, SendingEvent)> {
182	// Appservices start with a plus
183	Ok::<_, Error>(if key.starts_with(b"+") {
184		let mut parts = key[1..].splitn(2, |&b| b == 0xFF);
185
186		let server = parts
187			.next()
188			.expect("splitn always returns one element");
189		let event = parts
190			.next()
191			.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
192
193		let server = utils::string_from_bytes(server).map_err(|_| {
194			Error::bad_database("Invalid server bytes in server_currenttransaction")
195		})?;
196
197		(
198			Destination::Appservice(server),
199			if value.is_empty() {
200				SendingEvent::Pdu(event.into())
201			} else {
202				SendingEvent::Edu(value.into())
203			},
204		)
205	} else if key.starts_with(b"$") {
206		let mut parts = key[1..].splitn(3, |&b| b == 0xFF);
207
208		let user = parts
209			.next()
210			.expect("splitn always returns one element");
211		let user_string = utils::str_from_bytes(user)
212			.map_err(|_| Error::bad_database("Invalid user string in servercurrentevent"))?;
213		let user_id = UserId::parse(user_string)
214			.map_err(|_| Error::bad_database("Invalid user id in servercurrentevent"))?;
215
216		let pushkey = parts
217			.next()
218			.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
219		let pushkey_string = utils::string_from_bytes(pushkey)
220			.map_err(|_| Error::bad_database("Invalid pushkey in servercurrentevent"))?;
221
222		let event = parts
223			.next()
224			.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
225
226		(
227			Destination::Push(user_id, pushkey_string),
228			if value.is_empty() {
229				SendingEvent::Pdu(event.into())
230			} else {
231				// I'm pretty sure this should never be called
232				SendingEvent::Edu(value.into())
233			},
234		)
235	} else {
236		let mut parts = key.splitn(2, |&b| b == 0xFF);
237
238		let server = parts
239			.next()
240			.expect("splitn always returns one element");
241		let event = parts
242			.next()
243			.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
244
245		let server = utils::string_from_bytes(server).map_err(|_| {
246			Error::bad_database("Invalid server bytes in server_currenttransaction")
247		})?;
248
249		(
250			Destination::Federation(OwnedServerName::parse(&server).map_err(|_| {
251				Error::bad_database("Invalid server string in server_currenttransaction")
252			})?),
253			if value.is_empty() {
254				SendingEvent::Pdu(event.into())
255			} else {
256				SendingEvent::Edu(value.into())
257			},
258		)
259	})
260}