tuwunel_service/rooms/event_handler/
fetch_prev.rs1use std::{collections::HashMap, iter::once};
2
3use futures::{FutureExt, StreamExt, stream::FuturesOrdered};
4use ruma::{
5 CanonicalJsonObject, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId,
6 RoomVersionId, ServerName, int, uint,
7};
8use tuwunel_core::{
9 Result, debug_warn, err, implement,
10 matrix::{
11 Event, PduEvent,
12 pdu::{MAX_PREV_EVENTS, check_room_id},
13 },
14 utils::stream::IterStream,
15};
16
17use crate::rooms::state_res;
18
19#[implement(super::Service)]
20#[tracing::instrument(
21 level = "debug",
22 skip_all,
23 fields(
24 %origin,
25 events = %initial_set.clone().count(),
26 ),
27)]
28#[expect(clippy::type_complexity)]
29pub(super) async fn fetch_prev<'a, Events>(
30 &self,
31 origin: &ServerName,
32 room_id: &RoomId,
33 initial_set: Events,
34 room_version: &RoomVersionId,
35 recursion_level: usize,
36 first_ts_in_room: MilliSecondsSinceUnixEpoch,
37) -> Result<(Vec<OwnedEventId>, HashMap<OwnedEventId, (PduEvent, CanonicalJsonObject)>)>
38where
39 Events: Iterator<Item = &'a EventId> + Clone + Send,
40{
41 let mut todo_outlier_stack: FuturesOrdered<_> = initial_set
42 .stream()
43 .map(ToOwned::to_owned)
44 .filter_map(async |event_id| {
45 self.services
46 .timeline
47 .non_outlier_pdu_exists(&event_id)
48 .await
49 .is_err()
50 .then_some(event_id)
51 })
52 .map(async |event_id| {
53 let events = once(event_id.as_ref());
54 let auth = self
55 .fetch_auth(origin, room_id, events, room_version, recursion_level)
56 .await;
57
58 (event_id, auth)
59 })
60 .map(FutureExt::boxed)
61 .collect()
62 .await;
63
64 let mut amount = 0;
65 let mut eventid_info = HashMap::new();
66 let mut graph: HashMap<OwnedEventId, _> = HashMap::with_capacity(todo_outlier_stack.len());
67 while let Some((prev_event_id, mut outlier)) = todo_outlier_stack.next().await {
68 let Some((pdu, mut json_opt)) = outlier.pop() else {
69 graph.insert(prev_event_id.clone(), Default::default());
71 continue;
72 };
73
74 check_room_id(&pdu, room_id)?;
75
76 let limit = self.services.server.config.max_fetch_prev_events;
77 if amount > limit {
78 debug_warn!(?limit, "Max prev event limit reached!");
79 graph.insert(prev_event_id.clone(), Default::default());
80 continue;
81 }
82
83 if json_opt.is_none() {
84 json_opt = self
85 .services
86 .timeline
87 .get_outlier_pdu_json(&prev_event_id)
88 .await
89 .ok();
90 }
91
92 let Some(json) = json_opt else {
93 graph.insert(prev_event_id.clone(), Default::default());
95 continue;
96 };
97
98 if pdu.origin_server_ts() > first_ts_in_room {
99 amount = amount.saturating_add(1);
100 debug_assert!(
101 pdu.prev_events().count() <= MAX_PREV_EVENTS,
102 "PduEvent {prev_event_id} has too many prev_events"
103 );
104
105 for prev_prev in pdu.prev_events() {
106 if graph.contains_key(prev_prev) {
107 continue;
108 }
109
110 let prev_prev = prev_prev.to_owned();
111 let fetch = async move {
112 let fetch = self
113 .fetch_auth(
114 origin,
115 room_id,
116 once(prev_prev.as_ref()),
117 room_version,
118 recursion_level,
119 )
120 .await;
121
122 (prev_prev, fetch)
123 };
124
125 todo_outlier_stack.push_back(fetch.boxed());
126 }
127
128 graph.insert(
129 prev_event_id.clone(),
130 pdu.prev_events().map(ToOwned::to_owned).collect(),
131 );
132 } else {
133 graph.insert(prev_event_id.clone(), Default::default());
135 }
136
137 eventid_info.insert(prev_event_id.clone(), (pdu, json));
138 self.services.server.check_running()?;
139 }
140
141 let event_fetch = async |event_id: OwnedEventId| {
142 let origin_server_ts = eventid_info
143 .get(&event_id)
144 .map_or_else(|| uint!(0), |info| info.0.origin_server_ts().get());
145
146 Ok((int!(0).into(), MilliSecondsSinceUnixEpoch(origin_server_ts)))
150 };
151
152 let sorted = state_res::topological_sort(&graph, &event_fetch)
153 .await
154 .map_err(|e| err!(Database(error!("Error sorting prev events: {e}"))))?;
155
156 debug_assert_eq!(
157 sorted.len(),
158 graph.len(),
159 "topological sort returned a different number of outputs than inputs"
160 );
161
162 debug_assert!(
163 sorted.len() >= eventid_info.len(),
164 "returned topologically sorted events differ from eventid_info"
165 );
166
167 Ok((sorted, eventid_info))
168}