1use std::{
2 collections::{BTreeMap, HashMap, HashSet, btree_map::Entry},
3 fmt::Debug,
4 sync::{
5 Arc,
6 atomic::{AtomicU64, AtomicUsize, Ordering},
7 },
8 time::{Duration, Instant},
9};
10
11use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
12use futures::{
13 FutureExt, StreamExt, TryFutureExt,
14 future::{BoxFuture, join3, try_join3},
15 pin_mut,
16 stream::FuturesUnordered,
17};
18use ruma::{
19 MilliSecondsSinceUnixEpoch, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName,
20 UserId,
21 api::{
22 appservice::event::push_events::v1::EphemeralData,
23 federation::transactions::{
24 edu::{
25 DeviceListUpdateContent, Edu, PresenceContent, PresenceUpdate, ReceiptContent,
26 ReceiptData, ReceiptMap,
27 },
28 send_transaction_message,
29 },
30 },
31 device_id,
32 events::{
33 AnySyncEphemeralRoomEvent, GlobalAccountDataEventType, push_rules::PushRulesEvent,
34 receipt::ReceiptType,
35 },
36 presence::PresenceState,
37 push,
38 serde::Raw,
39 uint,
40};
41use tuwunel_core::{
42 Error, Event, Result, debug, err, error, extract_variant,
43 result::LogErr,
44 smallvec::SmallVec,
45 trace,
46 utils::{
47 BoolExt, ReadyExt, calculate_hash, continue_exponential_backoff_secs,
48 future::TryExtExt,
49 stream::{BroadbandExt, IterStream, WidebandExt},
50 },
51 warn,
52};
53
54use super::{Destination, EduBuf, EduVec, Msg, SendingEvent, Service, data::QueueItem};
55use crate::rooms::timeline::RawPduId;
56
57#[derive(Debug)]
58enum TransactionStatus {
59 Running,
60 Failed(u32, Instant), Retrying(u32), }
63
64type SendingError = (Destination, Error);
65type SendingResult = Result<Destination, SendingError>;
66type SendingFuture<'a> = BoxFuture<'a, SendingResult>;
67type SendingFutures<'a> = FuturesUnordered<SendingFuture<'a>>;
68type CurTransactionStatus = HashMap<Destination, TransactionStatus>;
69
70type UserReceipts = SmallVec<[ReceiptData; 1]>;
74
75type RankedReceipts = SmallVec<[ReceiptMap; 1]>;
80
81const SELECT_PRESENCE_LIMIT: usize = 256;
82const SELECT_RECEIPT_LIMIT: usize = 256;
83const SELECT_EDU_LIMIT: usize = EDU_LIMIT - 2;
84const DEQUEUE_LIMIT: usize = 48;
85
86pub const PDU_LIMIT: usize = 50;
87pub const EDU_LIMIT: usize = 100;
88
89impl Service {
90 #[tracing::instrument(skip(self), level = "debug")]
91 pub(super) async fn sender(self: Arc<Self>, id: usize) -> Result {
92 let mut statuses: CurTransactionStatus = CurTransactionStatus::new();
93 let mut futures: SendingFutures<'_> = FuturesUnordered::new();
94
95 self.startup_netburst(id, &mut futures, &mut statuses)
96 .boxed()
97 .await;
98
99 self.work_loop(id, &mut futures, &mut statuses)
100 .await;
101
102 if !futures.is_empty() {
103 self.finish_responses(&mut futures).boxed().await;
104 }
105
106 Ok(())
107 }
108
109 #[tracing::instrument(
110 name = "work",
111 level = "trace"
112 skip_all,
113 fields(
114 futures = %futures.len(),
115 statuses = %statuses.len(),
116 ),
117 )]
118 async fn work_loop<'a>(
119 &'a self,
120 id: usize,
121 futures: &mut SendingFutures<'a>,
122 statuses: &mut CurTransactionStatus,
123 ) {
124 let receiver = self
125 .channels
126 .get(id)
127 .map(|(_, receiver)| receiver.clone())
128 .expect("Missing channel for sender worker");
129
130 while !receiver.is_closed() {
131 tokio::select! {
132 Some(response) = futures.next() => {
133 self.handle_response(response, futures, statuses).await;
134 },
135 request = receiver.recv_async() => match request {
136 Ok(request) => self.handle_request(request, futures, statuses).await,
137 Err(_) => return,
138 },
139 }
140 }
141 }
142
143 #[tracing::instrument(name = "response", level = "debug", skip_all)]
144 async fn handle_response<'a>(
145 &'a self,
146 response: SendingResult,
147 futures: &mut SendingFutures<'a>,
148 statuses: &mut CurTransactionStatus,
149 ) {
150 match response {
151 | Err((dest, e)) => Self::handle_response_err(dest, statuses, &e),
152 | Ok(dest) =>
153 self.handle_response_ok(&dest, futures, statuses)
154 .await,
155 }
156 }
157
158 fn handle_response_err(dest: Destination, statuses: &mut CurTransactionStatus, e: &Error) {
159 debug!(dest = ?dest, "{e:?}");
160 statuses.entry(dest).and_modify(|e| {
161 *e = match e {
162 | TransactionStatus::Running => TransactionStatus::Failed(1, Instant::now()),
163
164 | &mut TransactionStatus::Retrying(ref n) =>
165 TransactionStatus::Failed(n.saturating_add(1), Instant::now()),
166
167 | TransactionStatus::Failed(..) => {
168 panic!("Request that was not even running failed?!")
169 },
170 }
171 });
172 }
173
174 #[expect(clippy::needless_pass_by_ref_mut)]
175 async fn handle_response_ok<'a>(
176 &'a self,
177 dest: &Destination,
178 futures: &mut SendingFutures<'a>,
179 statuses: &mut CurTransactionStatus,
180 ) {
181 let _cork = self.db.db.cork();
182 self.db.delete_all_active_requests_for(dest).await;
183
184 let new_events = self
186 .db
187 .queued_requests(dest)
188 .take(DEQUEUE_LIMIT)
189 .collect::<Vec<_>>()
190 .await;
191
192 if !new_events.is_empty() {
194 self.db.mark_as_active(new_events.iter());
195
196 let new_events_vec = new_events
197 .into_iter()
198 .map(|(_, event)| event)
199 .collect();
200
201 futures.push(self.send_events(dest.clone(), new_events_vec));
202 } else {
203 statuses.remove(dest);
204 }
205 }
206
207 #[expect(clippy::needless_pass_by_ref_mut)]
208 #[tracing::instrument(name = "request", level = "debug", skip_all)]
209 async fn handle_request<'a>(
210 &'a self,
211 msg: Msg,
212 futures: &mut SendingFutures<'a>,
213 statuses: &mut CurTransactionStatus,
214 ) {
215 let iv = vec![(msg.queue_id, msg.event)];
216 if let Ok(Some(events)) = self.select_events(&msg.dest, iv, statuses).await {
217 if !events.is_empty() {
218 futures.push(self.send_events(msg.dest, events));
219 } else {
220 statuses.remove(&msg.dest);
221 }
222 }
223 }
224
225 #[tracing::instrument(
226 name = "finish",
227 level = "info",
228 skip_all,
229 fields(futures = %futures.len()),
230 )]
231 async fn finish_responses<'a>(&'a self, futures: &mut SendingFutures<'a>) {
232 use tokio::{
233 select,
234 time::{Instant, sleep_until},
235 };
236
237 let timeout = self.server.config.sender_shutdown_timeout;
238 let timeout = Duration::from_secs(timeout);
239 let now = Instant::now();
240 let deadline = now.checked_add(timeout).unwrap_or(now);
241 loop {
242 trace!("Waiting for {} requests to complete...", futures.len());
243 select! {
244 () = sleep_until(deadline) => return,
245 response = futures.next() => match response {
246 Some(Ok(dest)) => self.db.delete_all_active_requests_for(&dest).await,
247 Some(_) => continue,
248 None => return,
249 },
250 }
251 }
252 }
253
254 #[tracing::instrument(
255 name = "netburst",
256 level = "debug",
257 skip_all,
258 fields(futures = %futures.len()),
259 )]
260 #[expect(clippy::needless_pass_by_ref_mut)]
261 async fn startup_netburst<'a>(
262 &'a self,
263 id: usize,
264 futures: &mut SendingFutures<'a>,
265 statuses: &mut CurTransactionStatus,
266 ) {
267 let keep =
268 usize::try_from(self.server.config.startup_netburst_keep).unwrap_or(usize::MAX);
269
270 let mut txns = HashMap::<Destination, Vec<SendingEvent>>::new();
271 let active = self.db.active_requests();
272
273 pin_mut!(active);
274 while let Some((key, event, dest)) = active.next().await {
275 if self.shard_id(&dest) != id {
276 continue;
277 }
278
279 let entry = txns.entry(dest.clone()).or_default();
280 if self.server.config.startup_netburst_keep >= 0 && entry.len() >= keep {
281 warn!("Dropping unsent event {dest:?} {:?}", String::from_utf8_lossy(&key));
282 self.db.delete_active_request(&key);
283 } else {
284 entry.push(event);
285 }
286 }
287
288 for (dest, events) in txns {
289 if self.server.config.startup_netburst && !events.is_empty() {
290 statuses.insert(dest.clone(), TransactionStatus::Running);
291 futures.push(self.send_events(dest.clone(), events));
292 }
293 }
294 }
295
296 #[tracing::instrument(
297 name = "select",,
298 level = "debug",
299 skip_all,
300 fields(
301 ?dest,
302 new_events = %new_events.len(),
303 )
304 )]
305 async fn select_events(
306 &self,
307 dest: &Destination,
308 new_events: Vec<QueueItem>, statuses: &mut CurTransactionStatus,
310 ) -> Result<Option<Vec<SendingEvent>>> {
311 let (allow, retry) = self.select_events_current(dest, statuses)?;
312
313 if !allow {
315 return Ok(None);
316 }
317
318 let _cork = self.db.db.cork();
319 let mut events = Vec::new();
320
321 if retry {
323 self.db
324 .active_requests_for(dest)
325 .ready_for_each(|(_, e)| events.push(e))
326 .await;
327
328 return Ok(Some(events));
329 }
330
331 let _cork = self.db.db.cork();
333 if !new_events.is_empty() {
334 self.db.mark_as_active(new_events.iter());
335 for (_, e) in new_events {
336 events.push(e);
337 }
338 }
339
340 if let Destination::Federation(server_name) = dest
342 && let Ok((select_edus, last_count)) = self.select_edus(server_name).await
343 {
344 debug_assert!(select_edus.len() <= EDU_LIMIT, "exceeded edus limit");
345 let select_edus = select_edus.into_iter().map(SendingEvent::Edu);
346
347 events.extend(select_edus);
348 self.db
349 .set_latest_educount(server_name, last_count);
350 }
351
352 Ok(Some(events))
353 }
354
355 fn select_events_current(
356 &self,
357 dest: &Destination,
358 statuses: &mut CurTransactionStatus,
359 ) -> Result<(bool, bool)> {
360 let (mut allow, mut retry) = (true, false);
361 statuses
362 .entry(dest.clone()) .and_modify(|e| match e {
364 TransactionStatus::Failed(tries, time) => {
365 let min = self.server.config.sender_timeout;
367 let max = self.server.config.sender_retry_backoff_limit;
368 if continue_exponential_backoff_secs(min, max, time.elapsed(), *tries)
369 && !matches!(dest, Destination::Appservice(_))
370 {
371 allow = false;
372 } else {
373 retry = true;
374 *e = TransactionStatus::Retrying(*tries);
375 }
376 },
377 TransactionStatus::Running | TransactionStatus::Retrying(_) => {
378 allow = false; },
380 })
381 .or_insert(TransactionStatus::Running);
382
383 Ok((allow, retry))
384 }
385
386 #[tracing::instrument(
387 name = "edus",,
388 level = "debug",
389 skip_all,
390 )]
391 async fn select_edus(&self, server_name: &ServerName) -> Result<(EduVec, u64)> {
392 let since = self.db.get_latest_educount(server_name).await;
394 let since_upper = self.services.globals.current_count();
395 let batch = (since, since_upper);
396 debug_assert!(batch.0 <= batch.1, "since range must not be negative");
397
398 let events_len = AtomicUsize::default();
399 let max_edu_count = AtomicU64::new(since);
400
401 let device_changes =
402 self.select_edus_device_changes(server_name, batch, &max_edu_count, &events_len);
403
404 let receipts = self
405 .server
406 .config
407 .allow_outgoing_read_receipts
408 .then_async(|| self.select_edus_receipts(server_name, batch, &max_edu_count));
409
410 let presence = self
411 .server
412 .config
413 .allow_outgoing_presence
414 .then_async(|| self.select_edus_presence(server_name, batch, &max_edu_count));
415
416 let (device_changes, receipts, presence) =
417 join3(device_changes, receipts, presence).await;
418
419 let mut events = device_changes;
420 events.extend(presence.into_iter().flatten());
421 events.extend(receipts.into_iter().flatten());
422
423 Ok((events, max_edu_count.load(Ordering::Acquire)))
424 }
425
426 #[tracing::instrument(
428 name = "device_changes",
429 level = "trace",
430 skip(self, server_name, max_edu_count)
431 )]
432 async fn select_edus_device_changes(
433 &self,
434 server_name: &ServerName,
435 since: (u64, u64),
436 max_edu_count: &AtomicU64,
437 events_len: &AtomicUsize,
438 ) -> EduVec {
439 let mut events = EduVec::new();
440 let server_rooms = self
441 .services
442 .state_cache
443 .server_rooms(server_name);
444
445 pin_mut!(server_rooms);
446 let mut device_list_changes = HashSet::<OwnedUserId>::new();
447 while let Some(room_id) = server_rooms.next().await {
448 let keys_changed = self
449 .services
450 .users
451 .room_keys_changed(room_id, since.0, Some(since.1))
452 .ready_filter(|(user_id, _)| self.services.globals.user_is_local(user_id));
453
454 pin_mut!(keys_changed);
455 while let Some((user_id, count)) = keys_changed.next().await {
456 debug_assert!(count <= since.1, "exceeds upper-bound");
457
458 max_edu_count.fetch_max(count, Ordering::Relaxed);
459 if !device_list_changes.insert(user_id.into()) {
460 continue;
461 }
462
463 let edu = Edu::DeviceListUpdate(DeviceListUpdateContent {
466 user_id: user_id.into(),
467 device_id: device_id!("placeholder").to_owned(),
468 device_display_name: Some("Placeholder".to_owned()),
469 stream_id: uint!(1),
470 prev_id: Vec::new(),
471 deleted: None,
472 keys: None,
473 });
474
475 let mut buf = EduBuf::new();
476 serde_json::to_writer(&mut buf, &edu)
477 .expect("failed to serialize device list update to JSON");
478
479 events.push(buf);
480 if events_len.fetch_add(1, Ordering::Relaxed) >= SELECT_EDU_LIMIT - 1 {
481 return events;
482 }
483 }
484 }
485
486 events
487 }
488
489 #[tracing::instrument(
499 name = "receipts",
500 level = "trace",
501 skip(self, server_name, max_edu_count)
502 )]
503 async fn select_edus_receipts(
504 &self,
505 server_name: &ServerName,
506 since: (u64, u64),
507 max_edu_count: &AtomicU64,
508 ) -> EduVec {
509 let num = AtomicUsize::new(0);
510 let by_room: SmallVec<[(OwnedRoomId, RankedReceipts); 1]> = self
511 .services
512 .state_cache
513 .server_rooms(server_name)
514 .map(ToOwned::to_owned)
515 .broad_filter_map(async |room_id| {
516 let ranked = self
517 .select_edus_receipts_room(&room_id, since, max_edu_count, &num)
518 .await;
519
520 ranked
521 .is_empty()
522 .is_false()
523 .then_some((room_id, ranked))
524 })
525 .collect()
526 .boxed()
527 .await;
528
529 let max_rank = by_room
530 .iter()
531 .map(|(_, maps)| maps.len())
532 .max()
533 .unwrap_or(0);
534
535 let pivot_rank = |rank: usize| -> Option<BTreeMap<OwnedRoomId, ReceiptMap>> {
536 let receipts: BTreeMap<_, _> = by_room
537 .iter()
538 .filter_map(|(room_id, maps)| {
539 maps.get(rank)
540 .cloned()
541 .map(|map| (room_id.clone(), map))
542 })
543 .collect();
544
545 receipts.is_empty().is_false().then_some(receipts)
546 };
547
548 let serialize_edu = |receipts: BTreeMap<OwnedRoomId, ReceiptMap>| -> EduBuf {
549 let mut buf = EduBuf::new();
550 serde_json::to_writer(&mut buf, &Edu::Receipt(ReceiptContent { receipts }))
551 .expect("Failed to serialize Receipt EDU to JSON vec");
552 buf
553 };
554
555 (0..max_rank)
556 .filter_map(pivot_rank)
557 .map(serialize_edu)
558 .collect()
559 }
560
561 #[tracing::instrument(
570 name = "receipts",
571 level = "trace",
572 skip(self, since, max_edu_count)
573 )]
574 async fn select_edus_receipts_room(
575 &self,
576 room_id: &RoomId,
577 since: (u64, u64),
578 max_edu_count: &AtomicU64,
579 num: &AtomicUsize,
580 ) -> RankedReceipts {
581 let receipts =
582 self.services
583 .read_receipt
584 .readreceipts_since(room_id, since.0, Some(since.1));
585
586 pin_mut!(receipts);
587 let mut by_user = BTreeMap::<OwnedUserId, UserReceipts>::new();
588 while let Some((user_id, count, read_receipt)) = receipts.next().await {
589 debug_assert!(count <= since.1, "exceeds upper-bound");
590
591 max_edu_count.fetch_max(count, Ordering::Relaxed);
592 if !self.services.globals.user_is_local(user_id) {
593 continue;
594 }
595
596 let Ok(event) = serde_json::from_str(read_receipt.json().get()) else {
597 error!(?user_id, ?count, ?read_receipt, "Invalid edu event in read_receipts.");
598 continue;
599 };
600
601 let AnySyncEphemeralRoomEvent::Receipt(r) = event else {
602 error!(?user_id, ?count, ?event, "Invalid event type in read_receipts");
603 continue;
604 };
605
606 let (event_id, mut receipt) = r
607 .content
608 .0
609 .into_iter()
610 .next()
611 .expect("we only use one event per read receipt");
612
613 let receipt = receipt
614 .remove(&ReceiptType::Read)
615 .expect("our read receipts always set this")
616 .remove(user_id)
617 .expect("our read receipts always have the user here");
618
619 let receipt_data = ReceiptData { data: receipt, event_ids: vec![event_id] };
620
621 match by_user.entry(user_id.to_owned()) {
622 | Entry::Vacant(slot) => {
623 slot.insert(SmallVec::from_buf([receipt_data]));
624 let num = num.fetch_add(1, Ordering::Relaxed);
625 if num >= SELECT_RECEIPT_LIMIT {
626 break;
627 }
628 },
629 | Entry::Occupied(mut slot) => {
630 slot.get_mut().push(receipt_data);
631 },
632 }
633 }
634
635 by_user
639 .into_iter()
640 .fold(RankedReceipts::new(), |mut acc, (user_id, receipts)| {
641 for (rank, receipt_data) in receipts.into_iter().enumerate() {
642 if rank >= acc.len() {
643 acc.push(ReceiptMap { read: BTreeMap::new() });
644 }
645
646 acc[rank]
647 .read
648 .insert(user_id.clone(), receipt_data);
649 }
650
651 acc
652 })
653 }
654
655 #[tracing::instrument(
657 name = "presence",
658 level = "trace",
659 skip(self, server_name, max_edu_count)
660 )]
661 async fn select_edus_presence(
662 &self,
663 server_name: &ServerName,
664 since: (u64, u64),
665 max_edu_count: &AtomicU64,
666 ) -> Option<EduBuf> {
667 let presence_since = self
668 .services
669 .presence
670 .presence_since(since.0, Some(since.1));
671
672 pin_mut!(presence_since);
673 let mut presence_updates = HashMap::<OwnedUserId, PresenceUpdate>::new();
674 while let Some((user_id, count, presence_bytes)) = presence_since.next().await {
675 debug_assert!(count <= since.1, "exceeded upper-bound");
676
677 max_edu_count.fetch_max(count, Ordering::Relaxed);
678 if !self.services.globals.user_is_local(user_id) {
679 continue;
680 }
681
682 if !self
683 .services
684 .state_cache
685 .server_sees_user(server_name, user_id)
686 .await
687 {
688 continue;
689 }
690
691 let Ok(presence_event) = self
692 .services
693 .presence
694 .from_json_bytes_to_event(presence_bytes, user_id)
695 .await
696 .log_err()
697 else {
698 continue;
699 };
700
701 let update = PresenceUpdate {
702 user_id: user_id.into(),
703 presence: presence_event.content.presence,
704 currently_active: presence_event
705 .content
706 .currently_active
707 .unwrap_or(false),
708 status_msg: presence_event.content.status_msg,
709 last_active_ago: presence_event
710 .content
711 .last_active_ago
712 .unwrap_or_else(|| uint!(0)),
713 };
714
715 presence_updates.insert(user_id.into(), update);
716 if presence_updates.len() >= SELECT_PRESENCE_LIMIT {
717 break;
718 }
719 }
720
721 if presence_updates.is_empty() {
722 return None;
723 }
724
725 let presence_content = Edu::Presence(PresenceContent {
726 push: presence_updates.into_values().collect(),
727 });
728
729 let mut buf = EduBuf::new();
730 serde_json::to_writer(&mut buf, &presence_content)
731 .expect("failed to serialize Presence EDU to JSON");
732
733 Some(buf)
734 }
735
736 fn send_events(&self, dest: Destination, events: Vec<SendingEvent>) -> SendingFuture<'_> {
737 debug_assert!(!events.is_empty(), "sending empty transaction");
738 match dest {
739 | Destination::Federation(server) => self
740 .send_events_dest_federation(server, events)
741 .boxed(),
742 | Destination::Appservice(id) => self
743 .send_events_dest_appservice(id, events)
744 .boxed(),
745 | Destination::Push(user_id, pushkey) => self
746 .send_events_dest_push(user_id, pushkey, events)
747 .boxed(),
748 }
749 }
750
751 #[tracing::instrument(
752 name = "appservice",
753 level = "debug",
754 skip(self, events),
755 fields(
756 events = %events.len(),
757 ),
758 )]
759 async fn send_events_dest_appservice(
760 &self,
761 id: String,
762 events: Vec<SendingEvent>,
763 ) -> SendingResult {
764 let Some(appservice) = self
765 .services
766 .appservice
767 .get_registration(&id)
768 .await
769 else {
770 return Err((
771 Destination::Appservice(id.clone()),
772 err!(Database(warn!(?id, "Missing appservice registration"))),
773 ));
774 };
775
776 let mut pdu_jsons = Vec::with_capacity(
777 events
778 .iter()
779 .filter(|event| matches!(event, SendingEvent::Pdu(_)))
780 .count(),
781 );
782 let mut edu_jsons: Vec<Raw<EphemeralData>> = Vec::with_capacity(
783 events
784 .iter()
785 .filter(|event| matches!(event, SendingEvent::Edu(_)))
786 .count(),
787 );
788 for event in &events {
789 match event {
790 | SendingEvent::Pdu(pdu_id) => {
791 if let Ok(pdu) = self
792 .services
793 .timeline
794 .get_pdu_from_id(pdu_id)
795 .await
796 {
797 pdu_jsons.push(pdu.to_format());
798 }
799 },
800 | SendingEvent::Edu(edu) => {
801 if appservice.receive_ephemeral
802 && let Ok(edu) =
803 serde_json::from_slice(edu).and_then(|edu| Raw::new(&edu))
804 {
805 edu_jsons.push(edu);
806 }
807 },
808 | SendingEvent::Flush => {}, }
810 }
811
812 let txn_hash = calculate_hash(events.iter().filter_map(|e| match e {
813 | SendingEvent::Edu(b) => Some(b.as_ref()),
814 | SendingEvent::Pdu(b) => Some(b.as_ref()),
815 | SendingEvent::Flush => None,
816 }));
817
818 let txn_id = &*URL_SAFE_NO_PAD.encode(txn_hash);
819
820 match self
823 .services
824 .appservice
825 .send_request(appservice, ruma::api::appservice::event::push_events::v1::Request {
826 txn_id: txn_id.into(),
827 events: pdu_jsons,
828 ephemeral: edu_jsons,
829 to_device: Vec::new(), })
831 .await
832 {
833 | Ok(_) => Ok(Destination::Appservice(id)),
834 | Err(e) => Err((Destination::Appservice(id), e)),
835 }
836 }
837
838 #[tracing::instrument(
839 name = "push",
840 level = "info",
841 skip(self, events),
842 fields(
843 events = %events.len(),
844 ),
845 )]
846 async fn send_events_dest_push(
847 &self,
848 user_id: OwnedUserId,
849 pushkey: String,
850 events: Vec<SendingEvent>,
851 ) -> SendingResult {
852 let suppressed = self.pushing_suppressed(&user_id).map(Ok);
853
854 let pusher = self
855 .services
856 .pusher
857 .get_pusher(&user_id, &pushkey)
858 .map_err(|_| {
859 (
860 Destination::Push(user_id.clone(), pushkey.clone()),
861 err!(Database(error!(?user_id, ?pushkey, "Missing pusher"))),
862 )
863 });
864
865 let rules_for_user = self
866 .services
867 .account_data
868 .get_global::<PushRulesEvent>(&user_id, GlobalAccountDataEventType::PushRules)
869 .map(|ev| {
870 ev.map_or_else(
871 |_| push::Ruleset::server_default(&user_id),
872 |ev| ev.content.global,
873 )
874 })
875 .map(Ok);
876
877 let (pusher, rules_for_user, suppressed) =
878 try_join3(pusher, rules_for_user, suppressed).await?;
879
880 if suppressed {
881 let queued = self
882 .enqueue_suppressed_push_events(&user_id, &pushkey, &events)
883 .await;
884 debug!(
885 ?user_id,
886 pushkey,
887 queued,
888 events = events.len(),
889 "Push suppressed; queued events"
890 );
891 return Ok(Destination::Push(user_id, pushkey));
892 }
893
894 self.schedule_flush_suppressed_for_pushkey(
895 user_id.clone(),
896 pushkey.clone(),
897 "non-suppressed push",
898 );
899
900 let _sent = events
901 .iter()
902 .stream()
903 .ready_filter_map(|event| extract_variant!(event, SendingEvent::Pdu))
904 .wide_filter_map(|pdu_id| {
905 self.services
906 .timeline
907 .get_pdu_from_id(pdu_id)
908 .ok()
909 })
910 .ready_filter(|pdu| !pdu.is_redacted())
911 .wide_filter_map(async |pdu| {
912 self.services
913 .pusher
914 .send_push_notice(&user_id, &pusher, &rules_for_user, &pdu)
915 .await
916 .map_err(|e| (Destination::Push(user_id.clone(), pushkey.clone()), e))
917 .ok()
918 })
919 .count()
920 .await;
921
922 Ok(Destination::Push(user_id, pushkey))
923 }
924
925 pub fn schedule_flush_suppressed_for_pushkey(
926 &self,
927 user_id: OwnedUserId,
928 pushkey: String,
929 reason: &'static str,
930 ) {
931 let sending = self.services.sending.clone();
932 let runtime = self.server.runtime();
933 runtime.spawn(async move {
934 sending
935 .flush_suppressed_for_pushkey(user_id, pushkey, reason)
936 .await;
937 });
938 }
939
940 pub fn schedule_flush_suppressed_for_user(&self, user_id: OwnedUserId, reason: &'static str) {
941 let sending = self.services.sending.clone();
942 let runtime = self.server.runtime();
943 runtime.spawn(async move {
944 sending
945 .flush_suppressed_for_user(user_id, reason)
946 .await;
947 });
948 }
949
950 async fn enqueue_suppressed_push_events(
951 &self,
952 user_id: &UserId,
953 pushkey: &str,
954 events: &[SendingEvent],
955 ) -> usize {
956 let mut queued = 0_usize;
957 for event in events {
958 let SendingEvent::Pdu(pdu_id) = event else {
959 continue;
960 };
961
962 let Ok(pdu) = self
963 .services
964 .timeline
965 .get_pdu_from_id(pdu_id)
966 .await
967 else {
968 debug!(?user_id, ?pdu_id, "Suppressing push but PDU is missing");
969 continue;
970 };
971
972 if pdu.is_redacted() {
973 trace!(?user_id, ?pdu_id, "Suppressing push for redacted PDU");
974 continue;
975 }
976
977 if self.services.pusher.queue_suppressed_push(
978 user_id,
979 pushkey,
980 pdu.room_id(),
981 *pdu_id,
982 ) {
983 queued = queued.saturating_add(1);
984 }
985 }
986
987 queued
988 }
989
990 async fn flush_suppressed_rooms(
991 &self,
992 user_id: &UserId,
993 pushkey: &str,
994 pusher: &ruma::api::client::push::Pusher,
995 rules_for_user: &push::Ruleset,
996 rooms: Vec<(OwnedRoomId, Vec<RawPduId>)>,
997 reason: &'static str,
998 ) {
999 if rooms.is_empty() {
1000 return;
1001 }
1002
1003 let mut sent = 0_usize;
1004 debug!(?user_id, pushkey, rooms = rooms.len(), "Flushing suppressed pushes ({reason})");
1005
1006 for (room_id, pdu_ids) in rooms {
1007 let unread = self
1008 .services
1009 .pusher
1010 .notification_count(user_id, &room_id)
1011 .await;
1012 if unread == 0 {
1013 trace!(?user_id, ?room_id, "Skipping suppressed push flush: no unread");
1014 continue;
1015 }
1016
1017 for pdu_id in pdu_ids {
1018 let Ok(pdu) = self
1019 .services
1020 .timeline
1021 .get_pdu_from_id(&pdu_id)
1022 .await
1023 else {
1024 debug!(?user_id, ?pdu_id, "Suppressed PDU missing during flush");
1025 continue;
1026 };
1027
1028 if pdu.is_redacted() {
1029 trace!(?user_id, ?pdu_id, "Suppressed PDU redacted during flush");
1030 continue;
1031 }
1032
1033 if let Err(error) = self
1034 .services
1035 .pusher
1036 .send_push_notice(user_id, pusher, rules_for_user, &pdu)
1037 .await
1038 {
1039 let requeued = self
1040 .services
1041 .pusher
1042 .queue_suppressed_push(user_id, pushkey, &room_id, pdu_id);
1043 warn!(
1044 ?user_id,
1045 ?room_id,
1046 ?error,
1047 requeued,
1048 "Failed to send suppressed push notification"
1049 );
1050 } else {
1051 sent = sent.saturating_add(1);
1052 }
1053 }
1054 }
1055
1056 debug!(?user_id, pushkey, sent, "Flushed suppressed push notifications");
1057 }
1058
1059 async fn flush_suppressed_for_pushkey(
1060 &self,
1061 user_id: OwnedUserId,
1062 pushkey: String,
1063 reason: &'static str,
1064 ) {
1065 let suppressed = self
1066 .services
1067 .pusher
1068 .take_suppressed_for_pushkey(&user_id, &pushkey);
1069 if suppressed.is_empty() {
1070 return;
1071 }
1072
1073 let pusher = match self
1074 .services
1075 .pusher
1076 .get_pusher(&user_id, &pushkey)
1077 .await
1078 {
1079 | Ok(pusher) => pusher,
1080 | Err(error) => {
1081 warn!(?user_id, pushkey, ?error, "Missing pusher for suppressed flush");
1082 return;
1083 },
1084 };
1085
1086 let rules_for_user = match self
1087 .services
1088 .account_data
1089 .get_global::<PushRulesEvent>(&user_id, GlobalAccountDataEventType::PushRules)
1090 .await
1091 {
1092 | Ok(ev) => ev.content.global,
1093 | Err(_) => push::Ruleset::server_default(&user_id),
1094 };
1095
1096 self.flush_suppressed_rooms(
1097 &user_id,
1098 &pushkey,
1099 &pusher,
1100 &rules_for_user,
1101 suppressed,
1102 reason,
1103 )
1104 .await;
1105 }
1106
1107 pub async fn flush_suppressed_for_user(&self, user_id: OwnedUserId, reason: &'static str) {
1108 let suppressed = self
1109 .services
1110 .pusher
1111 .take_suppressed_for_user(&user_id);
1112 if suppressed.is_empty() {
1113 return;
1114 }
1115
1116 let rules_for_user = match self
1117 .services
1118 .account_data
1119 .get_global::<PushRulesEvent>(&user_id, GlobalAccountDataEventType::PushRules)
1120 .await
1121 {
1122 | Ok(ev) => ev.content.global,
1123 | Err(_) => push::Ruleset::server_default(&user_id),
1124 };
1125
1126 for (pushkey, rooms) in suppressed {
1127 let pusher = match self
1128 .services
1129 .pusher
1130 .get_pusher(&user_id, &pushkey)
1131 .await
1132 {
1133 | Ok(pusher) => pusher,
1134 | Err(error) => {
1135 warn!(?user_id, pushkey, ?error, "Missing pusher for suppressed flush");
1136 continue;
1137 },
1138 };
1139
1140 self.flush_suppressed_rooms(
1141 &user_id,
1142 &pushkey,
1143 &pusher,
1144 &rules_for_user,
1145 rooms,
1146 reason,
1147 )
1148 .await;
1149 }
1150 }
1151
1152 async fn pushing_suppressed(&self, user_id: &UserId) -> bool {
1155 if !self.services.config.suppress_push_when_active {
1156 debug!(?user_id, "push not suppressed: suppress_push_when_active disabled");
1157 return false;
1158 }
1159
1160 let Ok(presence) = self.services.presence.get_presence(user_id).await else {
1161 debug!(?user_id, "push not suppressed: presence unavailable");
1162 return false;
1163 };
1164
1165 if presence.content.presence != PresenceState::Online {
1166 debug!(
1167 ?user_id,
1168 presence = ?presence.content.presence,
1169 "push not suppressed: presence not online"
1170 );
1171 return false;
1172 }
1173
1174 let presence_age_ms = presence
1175 .content
1176 .last_active_ago
1177 .map(u64::from)
1178 .unwrap_or(u64::MAX);
1179
1180 if presence_age_ms >= 65_000 {
1181 debug!(?user_id, presence_age_ms, "push not suppressed: presence too old");
1182 return false;
1183 }
1184
1185 let sync_gap_ms = self
1186 .services
1187 .presence
1188 .last_sync_gap_ms(user_id)
1189 .await;
1190
1191 let considered_active = sync_gap_ms.is_some_and(|gap| gap < 32_000);
1192
1193 match sync_gap_ms {
1194 | Some(gap) if gap < 32_000 => debug!(
1195 ?user_id,
1196 presence_age_ms,
1197 sync_gap_ms = gap,
1198 "suppressing push: active heuristic"
1199 ),
1200 | Some(gap) => debug!(
1201 ?user_id,
1202 presence_age_ms,
1203 sync_gap_ms = gap,
1204 "push not suppressed: sync gap too large"
1205 ),
1206 | None => debug!(?user_id, presence_age_ms, "push not suppressed: no recent sync"),
1207 }
1208
1209 considered_active
1210 }
1211
1212 async fn send_events_dest_federation(
1213 &self,
1214 server: OwnedServerName,
1215 events: Vec<SendingEvent>,
1216 ) -> SendingResult {
1217 let pdus: Vec<_> = events
1218 .iter()
1219 .filter_map(|pdu| match pdu {
1220 | SendingEvent::Pdu(pdu) => Some(pdu),
1221 | _ => None,
1222 })
1223 .stream()
1224 .wide_filter_map(|pdu_id| {
1225 self.services
1226 .timeline
1227 .get_pdu_json_from_id(pdu_id)
1228 .ok()
1229 })
1230 .wide_then(|pdu| {
1231 self.services
1232 .federation
1233 .format_pdu_into(pdu, None)
1234 })
1235 .collect()
1236 .await;
1237
1238 let edus: Vec<Raw<Edu>> = events
1239 .iter()
1240 .filter_map(|edu| match edu {
1241 | SendingEvent::Edu(edu) => Some(edu.as_ref()),
1242 | _ => None,
1243 })
1244 .map(serde_json::from_slice)
1245 .filter_map(Result::ok)
1246 .collect();
1247
1248 if pdus.is_empty() && edus.is_empty() {
1249 return Ok(Destination::Federation(server));
1250 }
1251
1252 let preimage = pdus
1253 .iter()
1254 .map(|raw| raw.get().as_bytes())
1255 .chain(edus.iter().map(|raw| raw.json().get().as_bytes()));
1256
1257 let txn_hash = calculate_hash(preimage);
1258 let txn_id = &*URL_SAFE_NO_PAD.encode(txn_hash);
1259 let request = send_transaction_message::v1::Request {
1260 transaction_id: txn_id.into(),
1261 origin: self.server.name.clone(),
1262 origin_server_ts: MilliSecondsSinceUnixEpoch::now(),
1263 pdus,
1264 edus,
1265 };
1266
1267 let result = self
1268 .services
1269 .federation
1270 .execute_on(&self.services.client.sender, &server, request)
1271 .await;
1272
1273 for (event_id, result) in result.iter().flat_map(|resp| resp.pdus.iter()) {
1274 if let Err(e) = result {
1275 warn!(
1276 %txn_id, %server,
1277 "error sending PDU {event_id} to remote server: {e:?}"
1278 );
1279 }
1280 }
1281
1282 match result {
1283 | Err(error) => Err((Destination::Federation(server), error)),
1284 | Ok(_) => Ok(Destination::Federation(server)),
1285 }
1286 }
1287}