Skip to main content

tuwunel_service/rooms/event_handler/
mod.rs

1mod acl_check;
2mod fetch_auth;
3mod fetch_prev;
4mod fetch_state;
5mod handle_incoming_pdu;
6mod handle_outlier_pdu;
7mod handle_prev_pdu;
8mod parse_incoming_pdu;
9mod policy_server;
10mod resolve_state;
11mod state_at_incoming;
12mod upgrade_outlier_pdu;
13
14use std::{
15	collections::{HashMap, hash_map},
16	fmt::Write,
17	ops::Range,
18	sync::{Arc, RwLock},
19	time::{Duration, Instant},
20};
21
22use async_trait::async_trait;
23use ruma::{EventId, OwnedEventId, OwnedRoomId};
24use tuwunel_core::{
25	Result, implement,
26	matrix::PduEvent,
27	utils::{MutexMap, bytes::pretty, continue_exponential_backoff},
28};
29use tuwunel_database::Map;
30
31pub struct Service {
32	pub mutex_federation: RoomMutexMap,
33	services: Arc<crate::services::OnceServices>,
34	bad_event_ratelimiter: Arc<RwLock<HashMap<OwnedEventId, RateLimitState>>>,
35	db: Data,
36}
37
38struct Data {
39	eventid_policysigstate: Arc<Map>,
40}
41
42type RoomMutexMap = MutexMap<OwnedRoomId, ()>;
43type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries
44
45#[async_trait]
46impl crate::Service for Service {
47	fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
48		Ok(Arc::new(Self {
49			mutex_federation: RoomMutexMap::new(),
50			services: args.services.clone(),
51			bad_event_ratelimiter: Arc::new(RwLock::new(HashMap::new())),
52			db: Data {
53				eventid_policysigstate: args.db["eventid_policysigstate"].clone(),
54			},
55		}))
56	}
57
58	async fn memory_usage(&self, out: &mut (dyn Write + Send)) -> Result {
59		let mutex_federation = self.mutex_federation.len();
60		writeln!(out, "federation_mutex: {mutex_federation}")?;
61
62		let (ber_count, ber_bytes) = self.bad_event_ratelimiter.read()?.iter().fold(
63			(0_usize, 0_usize),
64			|(mut count, mut bytes), (event_id, _)| {
65				bytes = bytes.saturating_add(event_id.capacity());
66				bytes = bytes.saturating_add(size_of::<RateLimitState>());
67				count = count.saturating_add(1);
68				(count, bytes)
69			},
70		);
71
72		writeln!(out, "bad_event_ratelimiter: {ber_count} ({})", pretty(ber_bytes))?;
73
74		Ok(())
75	}
76
77	async fn clear_cache(&self) {
78		self.bad_event_ratelimiter
79			.write()
80			.expect("locked for writing")
81			.clear();
82	}
83
84	fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
85}
86
87#[implement(Service)]
88fn cancel_back_off(&self, event_id: &EventId) -> bool {
89	self.bad_event_ratelimiter
90		.write()
91		.expect("locked")
92		.remove(event_id)
93		.is_some()
94}
95
96#[implement(Service)]
97fn back_off(&self, event_id: &EventId) -> bool {
98	use hash_map::Entry::{Occupied, Vacant};
99
100	match self
101		.bad_event_ratelimiter
102		.write()
103		.expect("locked")
104		.entry(event_id.into())
105	{
106		| Vacant(e) => {
107			e.insert((Instant::now(), 1));
108			true
109		},
110		| Occupied(mut e) => {
111			*e.get_mut() = (Instant::now(), e.get().1.saturating_add(1));
112			false
113		},
114	}
115}
116
117#[implement(Service)]
118fn is_backed_off(&self, event_id: &EventId, range: Range<Duration>) -> bool {
119	let Some((time, tries)) = self
120		.bad_event_ratelimiter
121		.read()
122		.expect("locked")
123		.get(event_id)
124		.copied()
125	else {
126		return false;
127	};
128
129	if !continue_exponential_backoff(range.start, range.end, time.elapsed(), tries) {
130		return false;
131	}
132
133	true
134}
135
136#[implement(Service)]
137#[tracing::instrument(
138	name = "exists",
139	level = "trace",
140	ret(level = "trace"),
141	skip_all,
142	fields(%event_id)
143)]
144async fn event_exists(&self, event_id: &EventId) -> bool {
145	self.services.timeline.pdu_exists(event_id).await
146}
147
148#[implement(Service)]
149#[tracing::instrument(
150	name = "fetch",
151	level = "trace",
152	skip_all,
153	fields(%event_id)
154)]
155async fn event_fetch(&self, event_id: &EventId) -> Result<PduEvent> {
156	self.services.timeline.get_pdu(event_id).await
157}