Skip to main content

tuwunel_service/sending/
mod.rs

1mod data;
2mod dest;
3mod sender;
4
5use std::{
6	fmt::Debug,
7	hash::{DefaultHasher, Hash, Hasher},
8	io::Write,
9	iter::once,
10	pin::pin,
11	sync::Arc,
12};
13
14use async_trait::async_trait;
15use futures::{FutureExt, Stream, StreamExt};
16use ruma::{RoomId, ServerName, UserId};
17use tokio::{task, task::JoinSet};
18use tuwunel_core::{
19	Result, Server, debug, debug_warn, err, error,
20	smallvec::SmallVec,
21	utils::{
22		IterStream, ReadyExt, TryReadyExt, available_parallelism, future::BoolExt,
23		math::usize_from_u64_truncated, result::LogErr,
24	},
25	warn,
26};
27
28use self::data::Data;
29pub use self::{
30	dest::Destination,
31	sender::{EDU_LIMIT, PDU_LIMIT},
32};
33use crate::rooms::timeline::RawPduId;
34
35pub struct Service {
36	pub db: Data,
37	server: Arc<Server>,
38	services: Arc<crate::services::OnceServices>,
39	channels: Vec<(loole::Sender<Msg>, loole::Receiver<Msg>)>,
40}
41
42#[derive(Clone, Debug, PartialEq, Eq)]
43struct Msg {
44	dest: Destination,
45	event: SendingEvent,
46	queue_id: Vec<u8>,
47}
48
49#[expect(clippy::module_name_repetitions)]
50#[derive(Clone, Debug, PartialEq, Eq, Hash)]
51pub enum SendingEvent {
52	Pdu(RawPduId), // pduid
53	Edu(EduBuf),   // edu json
54	Flush,         // none
55}
56
57pub type EduBuf = SmallVec<[u8; EDU_BUF_CAP]>;
58pub type EduVec = SmallVec<[EduBuf; EDU_VEC_CAP]>;
59
60const EDU_BUF_CAP: usize = 128 - 16;
61const EDU_VEC_CAP: usize = 1;
62
63#[async_trait]
64impl crate::Service for Service {
65	fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
66		let num_senders = num_senders(args);
67		Ok(Arc::new(Self {
68			db: Data::new(args),
69			server: args.server.clone(),
70			services: args.services.clone(),
71			channels: (0..num_senders)
72				.map(|_| loole::unbounded())
73				.collect(),
74		}))
75	}
76
77	async fn worker(self: Arc<Self>) -> Result {
78		let mut senders =
79			self.channels
80				.iter()
81				.enumerate()
82				.fold(JoinSet::new(), |mut joinset, (id, _)| {
83					let self_ = self.clone();
84					let worker = self_.sender(id);
85					let worker = if self.unconstrained() {
86						task::unconstrained(worker).boxed()
87					} else {
88						worker.boxed()
89					};
90
91					let runtime = self.server.runtime();
92					let _abort = joinset.spawn_on(worker, runtime);
93					joinset
94				});
95
96		while let Some(ret) = senders.join_next_with_id().await {
97			match ret {
98				| Ok((id, _)) => {
99					debug!(?id, "sender worker finished");
100				},
101				| Err(error) => {
102					error!(id = ?error.id(), ?error, "sender worker finished");
103				},
104			}
105		}
106
107		Ok(())
108	}
109
110	async fn interrupt(&self) {
111		for (sender, _) in &self.channels {
112			if !sender.is_closed() {
113				sender.close();
114			}
115		}
116	}
117
118	fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
119
120	fn unconstrained(&self) -> bool { true }
121}
122
123impl Service {
124	#[tracing::instrument(skip(self, pdu_id, user, pushkey), level = "debug")]
125	pub fn send_pdu_push(&self, pdu_id: &RawPduId, user: &UserId, pushkey: String) -> Result {
126		let dest = Destination::Push(user.to_owned(), pushkey);
127		let event = SendingEvent::Pdu(*pdu_id);
128		let _cork = self.db.db.cork();
129		let keys = self.db.queue_requests(once((&event, &dest)));
130
131		self.dispatch(Msg {
132			dest,
133			event,
134			queue_id: keys
135				.into_iter()
136				.next()
137				.expect("request queue key"),
138		})
139	}
140
141	#[tracing::instrument(skip(self), level = "debug")]
142	pub fn send_pdu_appservice(&self, appservice_id: String, pdu_id: RawPduId) -> Result {
143		let dest = Destination::Appservice(appservice_id);
144		let event = SendingEvent::Pdu(pdu_id);
145		let _cork = self.db.db.cork();
146		let keys = self.db.queue_requests(once((&event, &dest)));
147
148		self.dispatch(Msg {
149			dest,
150			event,
151			queue_id: keys
152				.into_iter()
153				.next()
154				.expect("request queue key"),
155		})
156	}
157
158	#[tracing::instrument(skip(self, room_id, pdu_id), level = "debug")]
159	pub async fn send_pdu_room(&self, room_id: &RoomId, pdu_id: &RawPduId) -> Result {
160		let servers = self
161			.services
162			.state_cache
163			.room_servers(room_id)
164			.ready_filter(|server_name| !self.services.globals.server_is_ours(server_name));
165
166		self.send_pdu_servers(servers, pdu_id).await
167	}
168
169	#[tracing::instrument(skip(self, servers, pdu_id), level = "debug")]
170	pub async fn send_pdu_servers<'a, S>(&self, servers: S, pdu_id: &RawPduId) -> Result
171	where
172		S: Stream<Item = &'a ServerName> + Send + 'a,
173	{
174		let requests = servers
175			.map(|server| {
176				(Destination::Federation(server.into()), SendingEvent::Pdu(pdu_id.to_owned()))
177			})
178			.collect::<Vec<_>>()
179			.await;
180
181		let _cork = self.db.db.cork();
182		let keys = self
183			.db
184			.queue_requests(requests.iter().map(|(o, e)| (e, o)));
185
186		for ((dest, event), queue_id) in requests.into_iter().zip(keys) {
187			self.dispatch(Msg { dest, event, queue_id })?;
188		}
189
190		Ok(())
191	}
192
193	#[tracing::instrument(skip(self, server, serialized), level = "debug")]
194	pub fn send_edu_server(&self, server: &ServerName, serialized: EduBuf) -> Result {
195		let dest = Destination::Federation(server.to_owned());
196		let event = SendingEvent::Edu(serialized);
197		let _cork = self.db.db.cork();
198		let keys = self.db.queue_requests(once((&event, &dest)));
199
200		self.dispatch(Msg {
201			dest,
202			event,
203			queue_id: keys
204				.into_iter()
205				.next()
206				.expect("request queue key"),
207		})
208	}
209
210	#[tracing::instrument(skip(self, room_id, serialized), level = "debug")]
211	pub async fn send_edu_room(&self, room_id: &RoomId, serialized: EduBuf) -> Result {
212		let servers = self
213			.services
214			.state_cache
215			.room_servers(room_id)
216			.ready_filter(|server_name| !self.services.globals.server_is_ours(server_name));
217
218		self.send_edu_servers(servers, serialized).await
219	}
220
221	/// Queue an EDU for delivery to a specific appservice.
222	#[tracing::instrument(skip(self, serialized), level = "debug")]
223	pub fn send_edu_appservice(&self, appservice_id: String, serialized: EduBuf) -> Result {
224		let dest = Destination::Appservice(appservice_id);
225		let event = SendingEvent::Edu(serialized);
226		let _cork = self.db.db.cork();
227		let keys = self.db.queue_requests(once((&event, &dest)));
228
229		self.dispatch(Msg {
230			dest,
231			event,
232			queue_id: keys
233				.into_iter()
234				.next()
235				.expect("request queue key"),
236		})
237	}
238
239	/// Sends an EDU to all appservices interested in a room.
240	/// The `serialized` data must be in `EphemeralData` format, not federation
241	/// `Edu`.
242	#[tracing::instrument(skip(self, serializer), level = "debug")]
243	pub async fn send_edu_room_appservices<'a, F>(
244		&self,
245		room_id: &RoomId,
246		serializer: F,
247	) -> Result
248	where
249		F: Fn(&mut dyn Write) -> Result + Send + 'a,
250		&'a F: Send + Sync,
251	{
252		self.services
253			.appservice
254			.read()
255			.await
256			.values()
257			.stream()
258			.filter(|&appservice| async {
259				if !appservice.registration.receive_ephemeral {
260					return false;
261				}
262
263				if appservice.rooms.is_match(room_id.as_str()) {
264					return true;
265				}
266
267				let appservice_in_room = self
268					.services
269					.state_cache
270					.appservice_in_room(room_id, appservice);
271
272				let matching_aliases = self
273					.services
274					.alias
275					.local_aliases_for_room(room_id)
276					.ready_any(|room_alias| appservice.aliases.is_match(room_alias.as_str()));
277
278				pin!(appservice_in_room)
279					.or(pin!(matching_aliases))
280					.await
281			})
282			.map(Ok)
283			.ready_try_for_each(|appservice| {
284				let mut buf = EduBuf::new();
285
286				serializer(&mut buf)?;
287				self.send_edu_appservice(appservice.registration.id.clone(), buf)
288					.log_err()
289					.ok();
290
291				Ok(())
292			})
293			.await
294	}
295
296	#[tracing::instrument(skip(self, servers, serialized), level = "debug")]
297	pub async fn send_edu_servers<'a, S>(&self, servers: S, serialized: EduBuf) -> Result
298	where
299		S: Stream<Item = &'a ServerName> + Send + 'a,
300	{
301		let requests = servers
302			.map(|server| {
303				(
304					Destination::Federation(server.to_owned()),
305					SendingEvent::Edu(serialized.clone()),
306				)
307			})
308			.collect::<Vec<_>>()
309			.await;
310
311		let _cork = self.db.db.cork();
312		let keys = self
313			.db
314			.queue_requests(requests.iter().map(|(o, e)| (e, o)));
315
316		for ((dest, event), queue_id) in requests.into_iter().zip(keys) {
317			self.dispatch(Msg { dest, event, queue_id })?;
318		}
319
320		Ok(())
321	}
322
323	#[tracing::instrument(skip(self, room_id), level = "debug")]
324	pub async fn flush_room(&self, room_id: &RoomId) -> Result {
325		let servers = self
326			.services
327			.state_cache
328			.room_servers(room_id)
329			.ready_filter(|server_name| !self.services.globals.server_is_ours(server_name));
330
331		self.flush_servers(servers).await
332	}
333
334	#[tracing::instrument(skip(self, servers), level = "debug")]
335	pub async fn flush_servers<'a, S>(&self, servers: S) -> Result
336	where
337		S: Stream<Item = &'a ServerName> + Send + 'a,
338	{
339		servers
340			.map(ToOwned::to_owned)
341			.map(Destination::Federation)
342			.map(Ok)
343			.ready_try_for_each(|dest| {
344				self.dispatch(Msg {
345					dest,
346					event: SendingEvent::Flush,
347					queue_id: Vec::<u8>::new(),
348				})
349			})
350			.await
351	}
352
353	/// Clean up queued sending event data
354	///
355	/// Used after we remove an appservice registration or a user deletes a push
356	/// key
357	#[tracing::instrument(skip(self), level = "debug")]
358	pub async fn cleanup_events(
359		&self,
360		appservice_id: Option<&str>,
361		user_id: Option<&UserId>,
362		push_key: Option<&str>,
363	) -> Result {
364		match (appservice_id, user_id, push_key) {
365			| (None, Some(user_id), Some(push_key)) => {
366				self.db
367					.delete_all_requests_for(&Destination::Push(
368						user_id.to_owned(),
369						push_key.to_owned(),
370					))
371					.await;
372
373				Ok(())
374			},
375			| (Some(appservice_id), None, None) => {
376				self.db
377					.delete_all_requests_for(&Destination::Appservice(appservice_id.to_owned()))
378					.await;
379
380				Ok(())
381			},
382			| _ => {
383				debug_warn!("cleanup_events called with too many or too few arguments");
384				Ok(())
385			},
386		}
387	}
388
389	fn dispatch(&self, msg: Msg) -> Result {
390		let shard = self.shard_id(&msg.dest);
391		let sender = &self
392			.channels
393			.get(shard)
394			.expect("missing sender worker channels")
395			.0;
396
397		debug_assert!(!sender.is_full(), "channel full");
398		debug_assert!(!sender.is_closed(), "channel closed");
399		sender.send(msg).map_err(|e| err!("{e}"))
400	}
401
402	pub(super) fn shard_id(&self, dest: &Destination) -> usize {
403		if self.channels.len() <= 1 {
404			return 0;
405		}
406
407		let mut hash = DefaultHasher::default();
408		dest.hash(&mut hash);
409
410		let hash: u64 = hash.finish();
411		let hash = usize_from_u64_truncated(hash);
412
413		let chans = self.channels.len().max(1);
414		hash.overflowing_rem(chans).0
415	}
416}
417
418fn num_senders(args: &crate::Args<'_>) -> usize {
419	const MIN_SENDERS: usize = 1;
420	// Limit the number of senders to the number of workers threads or number of
421	// cores, conservatively.
422	let max_senders = args
423		.server
424		.metrics
425		.num_workers()
426		.min(available_parallelism());
427
428	// If the user doesn't override the default 0, this is intended to then default
429	// to 1 for now as multiple senders is experimental.
430	args.server
431		.config
432		.sender_workers
433		.clamp(MIN_SENDERS, max_senders)
434}