tuwunel_service/rooms/event_handler/
mod.rs1mod acl_check;
2mod fetch_auth;
3mod fetch_prev;
4mod fetch_state;
5mod handle_incoming_pdu;
6mod handle_outlier_pdu;
7mod handle_prev_pdu;
8mod parse_incoming_pdu;
9mod policy_server;
10mod resolve_state;
11mod state_at_incoming;
12mod upgrade_outlier_pdu;
13
14use std::{
15 collections::{HashMap, hash_map},
16 fmt::Write,
17 ops::Range,
18 sync::{Arc, RwLock},
19 time::{Duration, Instant},
20};
21
22use async_trait::async_trait;
23use ruma::{EventId, OwnedEventId, OwnedRoomId};
24use tuwunel_core::{
25 Result, implement,
26 matrix::PduEvent,
27 utils::{MutexMap, bytes::pretty, continue_exponential_backoff},
28};
29use tuwunel_database::Map;
30
31pub struct Service {
32 pub mutex_federation: RoomMutexMap,
33 services: Arc<crate::services::OnceServices>,
34 bad_event_ratelimiter: Arc<RwLock<HashMap<OwnedEventId, RateLimitState>>>,
35 db: Data,
36}
37
38struct Data {
39 eventid_policysigstate: Arc<Map>,
40}
41
42type RoomMutexMap = MutexMap<OwnedRoomId, ()>;
43type RateLimitState = (Instant, u32); #[async_trait]
46impl crate::Service for Service {
47 fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
48 Ok(Arc::new(Self {
49 mutex_federation: RoomMutexMap::new(),
50 services: args.services.clone(),
51 bad_event_ratelimiter: Arc::new(RwLock::new(HashMap::new())),
52 db: Data {
53 eventid_policysigstate: args.db["eventid_policysigstate"].clone(),
54 },
55 }))
56 }
57
58 async fn memory_usage(&self, out: &mut (dyn Write + Send)) -> Result {
59 let mutex_federation = self.mutex_federation.len();
60 writeln!(out, "federation_mutex: {mutex_federation}")?;
61
62 let (ber_count, ber_bytes) = self.bad_event_ratelimiter.read()?.iter().fold(
63 (0_usize, 0_usize),
64 |(mut count, mut bytes), (event_id, _)| {
65 bytes = bytes.saturating_add(event_id.capacity());
66 bytes = bytes.saturating_add(size_of::<RateLimitState>());
67 count = count.saturating_add(1);
68 (count, bytes)
69 },
70 );
71
72 writeln!(out, "bad_event_ratelimiter: {ber_count} ({})", pretty(ber_bytes))?;
73
74 Ok(())
75 }
76
77 async fn clear_cache(&self) {
78 self.bad_event_ratelimiter
79 .write()
80 .expect("locked for writing")
81 .clear();
82 }
83
84 fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
85}
86
87#[implement(Service)]
88fn cancel_back_off(&self, event_id: &EventId) -> bool {
89 self.bad_event_ratelimiter
90 .write()
91 .expect("locked")
92 .remove(event_id)
93 .is_some()
94}
95
96#[implement(Service)]
97fn back_off(&self, event_id: &EventId) -> bool {
98 use hash_map::Entry::{Occupied, Vacant};
99
100 match self
101 .bad_event_ratelimiter
102 .write()
103 .expect("locked")
104 .entry(event_id.into())
105 {
106 | Vacant(e) => {
107 e.insert((Instant::now(), 1));
108 true
109 },
110 | Occupied(mut e) => {
111 *e.get_mut() = (Instant::now(), e.get().1.saturating_add(1));
112 false
113 },
114 }
115}
116
117#[implement(Service)]
118fn is_backed_off(&self, event_id: &EventId, range: Range<Duration>) -> bool {
119 let Some((time, tries)) = self
120 .bad_event_ratelimiter
121 .read()
122 .expect("locked")
123 .get(event_id)
124 .copied()
125 else {
126 return false;
127 };
128
129 if !continue_exponential_backoff(range.start, range.end, time.elapsed(), tries) {
130 return false;
131 }
132
133 true
134}
135
136#[implement(Service)]
137#[tracing::instrument(
138 name = "exists",
139 level = "trace",
140 ret(level = "trace"),
141 skip_all,
142 fields(%event_id)
143)]
144async fn event_exists(&self, event_id: &EventId) -> bool {
145 self.services.timeline.pdu_exists(event_id).await
146}
147
148#[implement(Service)]
149#[tracing::instrument(
150 name = "fetch",
151 level = "trace",
152 skip_all,
153 fields(%event_id)
154)]
155async fn event_fetch(&self, event_id: &EventId) -> Result<PduEvent> {
156 self.services.timeline.get_pdu(event_id).await
157}