tuwunel_service/rooms/event_handler/
fetch_prev.rs1use std::{collections::HashMap, iter::once, time::Duration};
2
3use futures::{
4 FutureExt, StreamExt,
5 stream::{FuturesOrdered, FuturesUnordered},
6};
7use ruma::{
8 CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId,
9 RoomId, RoomVersionId, ServerName, int, uint,
10};
11use serde_json::value::RawValue as RawJsonValue;
12use tokio::time::{Instant, timeout_at};
13use tuwunel_core::{
14 Result, debug_warn, err, implement,
15 matrix::{
16 Event, PduEvent,
17 event::gen_event_id,
18 pdu::{MAX_PREV_EVENTS, check_room_id},
19 },
20 utils::{
21 BoolExt,
22 stream::{IterStream, automatic_width},
23 },
24};
25
26use crate::{
27 fetcher::{EventWindow, Op, Opts},
28 rooms::state_res::topological_sort,
29};
30
31#[implement(super::Service)]
32#[tracing::instrument(
33 level = "debug",
34 skip_all,
35 fields(
36 %origin,
37 events = %initial_set.clone().count(),
38 ),
39)]
40#[expect(clippy::type_complexity, clippy::too_many_arguments)]
41pub(super) async fn fetch_prev<'a, Events>(
42 &self,
43 origin: &ServerName,
44 room_id: &RoomId,
45 incoming_event_id: &EventId,
46 initial_set: Events,
47 room_version: &RoomVersionId,
48 recursion_level: usize,
49 first_ts_in_room: MilliSecondsSinceUnixEpoch,
50) -> Result<(Vec<OwnedEventId>, HashMap<OwnedEventId, (PduEvent, CanonicalJsonObject)>)>
51where
52 Events: Iterator<Item = &'a EventId> + Clone + Send,
53{
54 let has_gap = initial_set
55 .clone()
56 .stream()
57 .any(async |event_id| !self.services.timeline.pdu_exists(event_id).await)
58 .await;
59
60 let wait_ms = self.services.server.config.fetch_prev_wait_ms;
61 let has_gap = (has_gap && wait_ms > 0)
62 .then_async(|| self.await_prev_gap(initial_set.clone(), Duration::from_millis(wait_ms)))
63 .await
64 .unwrap_or(has_gap);
65
66 has_gap
67 .then_async(|| {
68 self.prefetch_missing_events(
69 origin,
70 room_id,
71 incoming_event_id,
72 room_version,
73 recursion_level,
74 )
75 })
76 .await;
77
78 let mut todo_outlier_stack: FuturesOrdered<_> = initial_set
79 .stream()
80 .map(ToOwned::to_owned)
81 .filter_map(async |event_id| {
82 self.services
83 .timeline
84 .non_outlier_pdu_exists(&event_id)
85 .await
86 .is_err()
87 .then_some(event_id)
88 })
89 .map(async |event_id| {
90 let events = once(event_id.as_ref());
91 let auth = self
92 .fetch_auth(origin, room_id, events, room_version, recursion_level)
93 .await;
94
95 (event_id, auth)
96 })
97 .map(FutureExt::boxed)
98 .collect()
99 .await;
100
101 let mut amount = 0;
102 let mut eventid_info = HashMap::new();
103 let mut graph: HashMap<OwnedEventId, _> = HashMap::with_capacity(todo_outlier_stack.len());
104 while let Some((prev_event_id, mut outlier)) = todo_outlier_stack.next().await {
105 self.services.server.check_running()?;
106
107 let Some((pdu, mut json_opt)) = outlier.pop() else {
108 graph.insert(prev_event_id.clone(), Default::default());
110 continue;
111 };
112
113 check_room_id(&pdu, room_id)?;
114
115 let limit = self.services.server.config.max_fetch_prev_events;
116 if amount > limit {
117 debug_warn!(?limit, "Max prev event limit reached!");
118 graph.insert(prev_event_id.clone(), Default::default());
119 continue;
120 }
121
122 if json_opt.is_none() {
123 json_opt = self
124 .services
125 .timeline
126 .get_outlier_pdu_json(&prev_event_id)
127 .await
128 .ok();
129 }
130
131 let Some(json) = json_opt else {
132 graph.insert(prev_event_id.clone(), Default::default());
134 continue;
135 };
136
137 if pdu.origin_server_ts() > first_ts_in_room {
138 amount = amount.saturating_add(1);
139 debug_assert!(
140 pdu.prev_events().count() <= MAX_PREV_EVENTS,
141 "PduEvent {prev_event_id} has too many prev_events"
142 );
143
144 for prev_prev in pdu.prev_events() {
145 if graph.contains_key(prev_prev) {
146 continue;
147 }
148
149 let prev_prev = prev_prev.to_owned();
150 let fetch = async move {
151 let fetch = self
152 .fetch_auth(
153 origin,
154 room_id,
155 once(prev_prev.as_ref()),
156 room_version,
157 recursion_level,
158 )
159 .await;
160
161 (prev_prev, fetch)
162 };
163
164 todo_outlier_stack.push_back(fetch.boxed());
165 }
166
167 graph.insert(
168 prev_event_id.clone(),
169 pdu.prev_events().map(ToOwned::to_owned).collect(),
170 );
171 } else {
172 graph.insert(prev_event_id.clone(), Default::default());
174 }
175
176 eventid_info.insert(prev_event_id.clone(), (pdu, json));
177 }
178
179 let event_fetch = async |event_id: OwnedEventId| {
180 let origin_server_ts = eventid_info
181 .get(&event_id)
182 .map_or_else(|| uint!(0), |info| info.0.origin_server_ts().get());
183
184 Ok((int!(0).into(), MilliSecondsSinceUnixEpoch(origin_server_ts)))
188 };
189
190 let graph_len = graph.len();
191 let sorted = topological_sort(graph, &event_fetch)
192 .await
193 .map_err(|e| err!(Database(error!("Error sorting prev events: {e}"))))?;
194
195 debug_assert_eq!(
196 sorted.len(),
197 graph_len,
198 "topological sort returned a different number of outputs than inputs"
199 );
200
201 debug_assert!(
202 sorted.len() >= eventid_info.len(),
203 "returned topologically sorted events differ from eventid_info"
204 );
205
206 Ok((sorted, eventid_info))
207}
208
209#[implement(super::Service)]
210async fn await_prev_gap<'a, Events>(&self, initial_set: Events, wait: Duration) -> bool
211where
212 Events: Iterator<Item = &'a EventId> + Send,
213{
214 let deadline = Instant::now()
215 .checked_add(wait)
216 .expect("wait deadline overflows");
217
218 let pending: FuturesUnordered<_> = initial_set
221 .map(|event_id| (event_id, self.services.timeline.watch_event(event_id)))
222 .stream()
223 .filter_map(async |(event_id, watcher)| {
224 (!self.services.timeline.pdu_exists(event_id).await).then_some(watcher)
225 })
226 .collect()
227 .await;
228
229 if pending.is_empty() {
230 return false;
231 }
232
233 timeout_at(deadline, pending.count())
234 .await
235 .is_err()
236}
237
238#[implement(super::Service)]
245#[tracing::instrument(name = "missing", level = "debug", skip_all)]
246async fn prefetch_missing_events(
247 &self,
248 origin: &ServerName,
249 room_id: &RoomId,
250 incoming_event_id: &EventId,
251 room_version: &RoomVersionId,
252 recursion_level: usize,
253) {
254 let boundary: EventWindow = self
255 .services
256 .state
257 .get_forward_extremities(room_id)
258 .map(ToOwned::to_owned)
259 .collect()
260 .await;
261
262 let opts = Opts::new(Op::MissingEvents, room_id.to_owned())
263 .latest_events([incoming_event_id.to_owned()])
264 .earliest_events(boundary)
265 .hint(origin.to_owned())
266 .room_version(room_version.to_owned())
267 .attempt_limit(super::EVENT_FETCH_ATTEMPT_LIMIT)
268 .fanout_for_op();
269
270 let Ok(outcome) = self.services.fetcher.fetch(opts).await else {
271 return;
272 };
273
274 let Ok(events) = serde_json::from_slice::<Vec<Box<RawJsonValue>>>(&outcome.bytes) else {
275 return;
276 };
277
278 events
279 .into_iter()
280 .stream()
281 .for_each_concurrent(automatic_width(), async |pdu| {
282 self.land_missing_event(origin, room_id, &pdu, room_version, recursion_level)
283 .await
284 .ok();
285 })
286 .await;
287}
288
289#[implement(super::Service)]
292#[tracing::instrument(name = "land", level = "trace", skip_all)]
293async fn land_missing_event(
294 &self,
295 origin: &ServerName,
296 room_id: &RoomId,
297 pdu: &RawJsonValue,
298 room_version: &RoomVersionId,
299 recursion_level: usize,
300) -> Result {
301 let value: CanonicalJsonObject = serde_json::from_str(pdu.get())
302 .map_err(|e| err!(BadServerResponse("missing-events pdu is not canonical json: {e}")))?;
303
304 value
305 .get("room_id")
306 .and_then(CanonicalJsonValue::as_str)
307 .is_some_and(|id| id == room_id.as_str())
308 .then_some(())
309 .ok_or_else(|| {
310 err!(Request(InvalidParam("missing-events pdu is for a different room")))
311 })?;
312
313 let event_id = gen_event_id(&value, room_version)?;
314
315 Box::pin(self.handle_outlier_pdu(
316 origin,
317 room_id,
318 &event_id,
319 value,
320 room_version,
321 recursion_level,
322 false,
323 ))
324 .await
325 .map(|_| ())
326}