1mod data;
2mod dest;
3mod sender;
4
5use std::{
6 fmt::Debug,
7 hash::{DefaultHasher, Hash, Hasher},
8 io::Write,
9 iter::once,
10 pin::pin,
11 sync::Arc,
12};
13
14use async_trait::async_trait;
15use futures::{FutureExt, Stream, StreamExt};
16use ruma::{RoomId, ServerName, UserId};
17use tokio::{task, task::JoinSet};
18use tuwunel_core::{
19 Result, Server, debug, debug_warn, err, error,
20 smallvec::SmallVec,
21 utils::{
22 IterStream, ReadyExt, TryReadyExt, available_parallelism, future::BoolExt,
23 math::usize_from_u64_truncated, result::LogErr,
24 },
25 warn,
26};
27
28use self::data::Data;
29pub use self::{
30 dest::Destination,
31 sender::{EDU_LIMIT, PDU_LIMIT},
32};
33use crate::rooms::timeline::RawPduId;
34
35pub struct Service {
36 pub db: Data,
37 server: Arc<Server>,
38 services: Arc<crate::services::OnceServices>,
39 channels: Vec<(loole::Sender<Msg>, loole::Receiver<Msg>)>,
40}
41
42#[derive(Clone, Debug, PartialEq, Eq)]
43struct Msg {
44 dest: Destination,
45 event: SendingEvent,
46 queue_id: Vec<u8>,
47}
48
49#[expect(clippy::module_name_repetitions)]
50#[derive(Clone, Debug, PartialEq, Eq, Hash)]
51pub enum SendingEvent {
52 Pdu(RawPduId), Edu(EduBuf), Flush, }
56
57pub type EduBuf = SmallVec<[u8; EDU_BUF_CAP]>;
58pub type EduVec = SmallVec<[EduBuf; EDU_VEC_CAP]>;
59
60const EDU_BUF_CAP: usize = 128 - 16;
61const EDU_VEC_CAP: usize = 1;
62
63#[async_trait]
64impl crate::Service for Service {
65 fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
66 let num_senders = num_senders(args);
67 Ok(Arc::new(Self {
68 db: Data::new(args),
69 server: args.server.clone(),
70 services: args.services.clone(),
71 channels: (0..num_senders)
72 .map(|_| loole::unbounded())
73 .collect(),
74 }))
75 }
76
77 async fn worker(self: Arc<Self>) -> Result {
78 let mut senders =
79 self.channels
80 .iter()
81 .enumerate()
82 .fold(JoinSet::new(), |mut joinset, (id, _)| {
83 let self_ = self.clone();
84 let worker = self_.sender(id);
85 let worker = if self.unconstrained() {
86 task::unconstrained(worker).boxed()
87 } else {
88 worker.boxed()
89 };
90
91 let runtime = self.server.runtime();
92 let _abort = joinset.spawn_on(worker, runtime);
93 joinset
94 });
95
96 while let Some(ret) = senders.join_next_with_id().await {
97 match ret {
98 | Ok((id, _)) => {
99 debug!(?id, "sender worker finished");
100 },
101 | Err(error) => {
102 error!(id = ?error.id(), ?error, "sender worker finished");
103 },
104 }
105 }
106
107 Ok(())
108 }
109
110 async fn interrupt(&self) {
111 for (sender, _) in &self.channels {
112 if !sender.is_closed() {
113 sender.close();
114 }
115 }
116 }
117
118 fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
119
120 fn unconstrained(&self) -> bool { true }
121}
122
123impl Service {
124 #[tracing::instrument(skip(self, pdu_id, user, pushkey), level = "debug")]
125 pub fn send_pdu_push(&self, pdu_id: &RawPduId, user: &UserId, pushkey: String) -> Result {
126 let dest = Destination::Push(user.to_owned(), pushkey);
127 let event = SendingEvent::Pdu(*pdu_id);
128 let _cork = self.db.db.cork();
129 let keys = self.db.queue_requests(once((&event, &dest)));
130
131 self.dispatch(Msg {
132 dest,
133 event,
134 queue_id: keys
135 .into_iter()
136 .next()
137 .expect("request queue key"),
138 })
139 }
140
141 #[tracing::instrument(skip(self), level = "debug")]
142 pub fn send_pdu_appservice(&self, appservice_id: String, pdu_id: RawPduId) -> Result {
143 let dest = Destination::Appservice(appservice_id);
144 let event = SendingEvent::Pdu(pdu_id);
145 let _cork = self.db.db.cork();
146 let keys = self.db.queue_requests(once((&event, &dest)));
147
148 self.dispatch(Msg {
149 dest,
150 event,
151 queue_id: keys
152 .into_iter()
153 .next()
154 .expect("request queue key"),
155 })
156 }
157
158 #[tracing::instrument(skip(self, room_id, pdu_id), level = "debug")]
159 pub async fn send_pdu_room(&self, room_id: &RoomId, pdu_id: &RawPduId) -> Result {
160 let servers = self
161 .services
162 .state_cache
163 .room_servers(room_id)
164 .ready_filter(|server_name| !self.services.globals.server_is_ours(server_name));
165
166 self.send_pdu_servers(servers, pdu_id).await
167 }
168
169 #[tracing::instrument(skip(self, servers, pdu_id), level = "debug")]
170 pub async fn send_pdu_servers<'a, S>(&self, servers: S, pdu_id: &RawPduId) -> Result
171 where
172 S: Stream<Item = &'a ServerName> + Send + 'a,
173 {
174 let requests = servers
175 .map(|server| {
176 (Destination::Federation(server.into()), SendingEvent::Pdu(pdu_id.to_owned()))
177 })
178 .collect::<Vec<_>>()
179 .await;
180
181 let _cork = self.db.db.cork();
182 let keys = self
183 .db
184 .queue_requests(requests.iter().map(|(o, e)| (e, o)));
185
186 for ((dest, event), queue_id) in requests.into_iter().zip(keys) {
187 self.dispatch(Msg { dest, event, queue_id })?;
188 }
189
190 Ok(())
191 }
192
193 #[tracing::instrument(skip(self, server, serialized), level = "debug")]
194 pub fn send_edu_server(&self, server: &ServerName, serialized: EduBuf) -> Result {
195 let dest = Destination::Federation(server.to_owned());
196 let event = SendingEvent::Edu(serialized);
197 let _cork = self.db.db.cork();
198 let keys = self.db.queue_requests(once((&event, &dest)));
199
200 self.dispatch(Msg {
201 dest,
202 event,
203 queue_id: keys
204 .into_iter()
205 .next()
206 .expect("request queue key"),
207 })
208 }
209
210 #[tracing::instrument(skip(self, room_id, serialized), level = "debug")]
211 pub async fn send_edu_room(&self, room_id: &RoomId, serialized: EduBuf) -> Result {
212 let servers = self
213 .services
214 .state_cache
215 .room_servers(room_id)
216 .ready_filter(|server_name| !self.services.globals.server_is_ours(server_name));
217
218 self.send_edu_servers(servers, serialized).await
219 }
220
221 #[tracing::instrument(skip(self, serialized), level = "debug")]
223 pub fn send_edu_appservice(&self, appservice_id: String, serialized: EduBuf) -> Result {
224 let dest = Destination::Appservice(appservice_id);
225 let event = SendingEvent::Edu(serialized);
226 let _cork = self.db.db.cork();
227 let keys = self.db.queue_requests(once((&event, &dest)));
228
229 self.dispatch(Msg {
230 dest,
231 event,
232 queue_id: keys
233 .into_iter()
234 .next()
235 .expect("request queue key"),
236 })
237 }
238
239 #[tracing::instrument(skip(self, serializer), level = "debug")]
243 pub async fn send_edu_room_appservices<'a, F>(
244 &self,
245 room_id: &RoomId,
246 serializer: F,
247 ) -> Result
248 where
249 F: Fn(&mut dyn Write) -> Result + Send + 'a,
250 &'a F: Send + Sync,
251 {
252 self.services
253 .appservice
254 .read()
255 .await
256 .values()
257 .stream()
258 .filter(|&appservice| async {
259 if !appservice.registration.receive_ephemeral {
260 return false;
261 }
262
263 if appservice.rooms.is_match(room_id.as_str()) {
264 return true;
265 }
266
267 let appservice_in_room = self
268 .services
269 .state_cache
270 .appservice_in_room(room_id, appservice);
271
272 let matching_aliases = self
273 .services
274 .alias
275 .local_aliases_for_room(room_id)
276 .ready_any(|room_alias| appservice.aliases.is_match(room_alias.as_str()));
277
278 pin!(appservice_in_room)
279 .or(pin!(matching_aliases))
280 .await
281 })
282 .map(Ok)
283 .ready_try_for_each(|appservice| {
284 let mut buf = EduBuf::new();
285
286 serializer(&mut buf)?;
287 self.send_edu_appservice(appservice.registration.id.clone(), buf)
288 .log_err()
289 .ok();
290
291 Ok(())
292 })
293 .await
294 }
295
296 #[tracing::instrument(skip(self, servers, serialized), level = "debug")]
297 pub async fn send_edu_servers<'a, S>(&self, servers: S, serialized: EduBuf) -> Result
298 where
299 S: Stream<Item = &'a ServerName> + Send + 'a,
300 {
301 let requests = servers
302 .map(|server| {
303 (
304 Destination::Federation(server.to_owned()),
305 SendingEvent::Edu(serialized.clone()),
306 )
307 })
308 .collect::<Vec<_>>()
309 .await;
310
311 let _cork = self.db.db.cork();
312 let keys = self
313 .db
314 .queue_requests(requests.iter().map(|(o, e)| (e, o)));
315
316 for ((dest, event), queue_id) in requests.into_iter().zip(keys) {
317 self.dispatch(Msg { dest, event, queue_id })?;
318 }
319
320 Ok(())
321 }
322
323 #[tracing::instrument(skip(self, room_id), level = "debug")]
324 pub async fn flush_room(&self, room_id: &RoomId) -> Result {
325 let servers = self
326 .services
327 .state_cache
328 .room_servers(room_id)
329 .ready_filter(|server_name| !self.services.globals.server_is_ours(server_name));
330
331 self.flush_servers(servers).await
332 }
333
334 #[tracing::instrument(skip(self, servers), level = "debug")]
335 pub async fn flush_servers<'a, S>(&self, servers: S) -> Result
336 where
337 S: Stream<Item = &'a ServerName> + Send + 'a,
338 {
339 servers
340 .map(ToOwned::to_owned)
341 .map(Destination::Federation)
342 .map(Ok)
343 .ready_try_for_each(|dest| {
344 self.dispatch(Msg {
345 dest,
346 event: SendingEvent::Flush,
347 queue_id: Vec::<u8>::new(),
348 })
349 })
350 .await
351 }
352
353 #[tracing::instrument(skip(self), level = "debug")]
358 pub async fn cleanup_events(
359 &self,
360 appservice_id: Option<&str>,
361 user_id: Option<&UserId>,
362 push_key: Option<&str>,
363 ) -> Result {
364 match (appservice_id, user_id, push_key) {
365 | (None, Some(user_id), Some(push_key)) => {
366 self.db
367 .delete_all_requests_for(&Destination::Push(
368 user_id.to_owned(),
369 push_key.to_owned(),
370 ))
371 .await;
372
373 Ok(())
374 },
375 | (Some(appservice_id), None, None) => {
376 self.db
377 .delete_all_requests_for(&Destination::Appservice(appservice_id.to_owned()))
378 .await;
379
380 Ok(())
381 },
382 | _ => {
383 debug_warn!("cleanup_events called with too many or too few arguments");
384 Ok(())
385 },
386 }
387 }
388
389 fn dispatch(&self, msg: Msg) -> Result {
390 let shard = self.shard_id(&msg.dest);
391 let sender = &self
392 .channels
393 .get(shard)
394 .expect("missing sender worker channels")
395 .0;
396
397 debug_assert!(!sender.is_full(), "channel full");
398 debug_assert!(!sender.is_closed(), "channel closed");
399 sender.send(msg).map_err(|e| err!("{e}"))
400 }
401
402 pub(super) fn shard_id(&self, dest: &Destination) -> usize {
403 if self.channels.len() <= 1 {
404 return 0;
405 }
406
407 let mut hash = DefaultHasher::default();
408 dest.hash(&mut hash);
409
410 let hash: u64 = hash.finish();
411 let hash = usize_from_u64_truncated(hash);
412
413 let chans = self.channels.len().max(1);
414 hash.overflowing_rem(chans).0
415 }
416}
417
418fn num_senders(args: &crate::Args<'_>) -> usize {
419 const MIN_SENDERS: usize = 1;
420 let max_senders = args
423 .server
424 .metrics
425 .num_workers()
426 .min(available_parallelism());
427
428 args.server
431 .config
432 .sender_workers
433 .clamp(MIN_SENDERS, max_senders)
434}