Skip to main content

tuwunel_service/rooms/event_handler/
fetch_prev.rs

1use std::{collections::HashMap, iter::once, time::Duration};
2
3use futures::{
4	FutureExt, StreamExt,
5	stream::{FuturesOrdered, FuturesUnordered},
6};
7use ruma::{
8	CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId,
9	RoomId, RoomVersionId, ServerName, int, uint,
10};
11use serde_json::value::RawValue as RawJsonValue;
12use tokio::time::{Instant, timeout_at};
13use tuwunel_core::{
14	Result, debug_warn, err, implement,
15	matrix::{
16		Event, PduEvent,
17		event::gen_event_id,
18		pdu::{MAX_PREV_EVENTS, check_room_id},
19	},
20	utils::{
21		BoolExt,
22		stream::{IterStream, automatic_width},
23	},
24};
25
26use crate::{
27	fetcher::{EventWindow, Op, Opts},
28	rooms::state_res::topological_sort,
29};
30
31#[implement(super::Service)]
32#[tracing::instrument(
33	level = "debug",
34	skip_all,
35	fields(
36		%origin,
37		events = %initial_set.clone().count(),
38	),
39)]
40#[expect(clippy::type_complexity, clippy::too_many_arguments)]
41pub(super) async fn fetch_prev<'a, Events>(
42	&self,
43	origin: &ServerName,
44	room_id: &RoomId,
45	incoming_event_id: &EventId,
46	initial_set: Events,
47	room_version: &RoomVersionId,
48	recursion_level: usize,
49	first_ts_in_room: MilliSecondsSinceUnixEpoch,
50) -> Result<(Vec<OwnedEventId>, HashMap<OwnedEventId, (PduEvent, CanonicalJsonObject)>)>
51where
52	Events: Iterator<Item = &'a EventId> + Clone + Send,
53{
54	let has_gap = initial_set
55		.clone()
56		.stream()
57		.any(async |event_id| !self.services.timeline.pdu_exists(event_id).await)
58		.await;
59
60	let wait_ms = self.services.server.config.fetch_prev_wait_ms;
61	let has_gap = (has_gap && wait_ms > 0)
62		.then_async(|| self.await_prev_gap(initial_set.clone(), Duration::from_millis(wait_ms)))
63		.await
64		.unwrap_or(has_gap);
65
66	has_gap
67		.then_async(|| {
68			self.prefetch_missing_events(
69				origin,
70				room_id,
71				incoming_event_id,
72				room_version,
73				recursion_level,
74			)
75		})
76		.await;
77
78	let mut todo_outlier_stack: FuturesOrdered<_> = initial_set
79		.stream()
80		.map(ToOwned::to_owned)
81		.filter_map(async |event_id| {
82			self.services
83				.timeline
84				.non_outlier_pdu_exists(&event_id)
85				.await
86				.is_err()
87				.then_some(event_id)
88		})
89		.map(async |event_id| {
90			let events = once(event_id.as_ref());
91			let auth = self
92				.fetch_auth(origin, room_id, events, room_version, recursion_level)
93				.await;
94
95			(event_id, auth)
96		})
97		.map(FutureExt::boxed)
98		.collect()
99		.await;
100
101	let mut amount = 0;
102	let mut eventid_info = HashMap::new();
103	let mut graph: HashMap<OwnedEventId, _> = HashMap::with_capacity(todo_outlier_stack.len());
104	while let Some((prev_event_id, mut outlier)) = todo_outlier_stack.next().await {
105		self.services.server.check_running()?;
106
107		let Some((pdu, mut json_opt)) = outlier.pop() else {
108			// Fetch and handle failed
109			graph.insert(prev_event_id.clone(), Default::default());
110			continue;
111		};
112
113		check_room_id(&pdu, room_id)?;
114
115		let limit = self.services.server.config.max_fetch_prev_events;
116		if amount > limit {
117			debug_warn!(?limit, "Max prev event limit reached!");
118			graph.insert(prev_event_id.clone(), Default::default());
119			continue;
120		}
121
122		if json_opt.is_none() {
123			json_opt = self
124				.services
125				.timeline
126				.get_outlier_pdu_json(&prev_event_id)
127				.await
128				.ok();
129		}
130
131		let Some(json) = json_opt else {
132			// Get json failed, so this was not fetched over federation
133			graph.insert(prev_event_id.clone(), Default::default());
134			continue;
135		};
136
137		if pdu.origin_server_ts() > first_ts_in_room {
138			amount = amount.saturating_add(1);
139			debug_assert!(
140				pdu.prev_events().count() <= MAX_PREV_EVENTS,
141				"PduEvent {prev_event_id} has too many prev_events"
142			);
143
144			for prev_prev in pdu.prev_events() {
145				if graph.contains_key(prev_prev) {
146					continue;
147				}
148
149				let prev_prev = prev_prev.to_owned();
150				let fetch = async move {
151					let fetch = self
152						.fetch_auth(
153							origin,
154							room_id,
155							once(prev_prev.as_ref()),
156							room_version,
157							recursion_level,
158						)
159						.await;
160
161					(prev_prev, fetch)
162				};
163
164				todo_outlier_stack.push_back(fetch.boxed());
165			}
166
167			graph.insert(
168				prev_event_id.clone(),
169				pdu.prev_events().map(ToOwned::to_owned).collect(),
170			);
171		} else {
172			// Time based check failed
173			graph.insert(prev_event_id.clone(), Default::default());
174		}
175
176		eventid_info.insert(prev_event_id.clone(), (pdu, json));
177	}
178
179	let event_fetch = async |event_id: OwnedEventId| {
180		let origin_server_ts = eventid_info
181			.get(&event_id)
182			.map_or_else(|| uint!(0), |info| info.0.origin_server_ts().get());
183
184		// This return value is the key used for sorting events,
185		// events are then sorted by power level, time,
186		// and lexically by event_id.
187		Ok((int!(0).into(), MilliSecondsSinceUnixEpoch(origin_server_ts)))
188	};
189
190	let graph_len = graph.len();
191	let sorted = topological_sort(graph, &event_fetch)
192		.await
193		.map_err(|e| err!(Database(error!("Error sorting prev events: {e}"))))?;
194
195	debug_assert_eq!(
196		sorted.len(),
197		graph_len,
198		"topological sort returned a different number of outputs than inputs"
199	);
200
201	debug_assert!(
202		sorted.len() >= eventid_info.len(),
203		"returned topologically sorted events differ from eventid_info"
204	);
205
206	Ok((sorted, eventid_info))
207}
208
209#[implement(super::Service)]
210async fn await_prev_gap<'a, Events>(&self, initial_set: Events, wait: Duration) -> bool
211where
212	Events: Iterator<Item = &'a EventId> + Send,
213{
214	let deadline = Instant::now()
215		.checked_add(wait)
216		.expect("wait deadline overflows");
217
218	// Each watcher registers before its existence recheck, so a prev that
219	// arrives during the recheck still wakes us.
220	let pending: FuturesUnordered<_> = initial_set
221		.map(|event_id| (event_id, self.services.timeline.watch_event(event_id)))
222		.stream()
223		.filter_map(async |(event_id, watcher)| {
224			(!self.services.timeline.pdu_exists(event_id).await).then_some(watcher)
225		})
226		.collect()
227		.await;
228
229	if pending.is_empty() {
230		return false;
231	}
232
233	timeout_at(deadline, pending.count())
234		.await
235		.is_err()
236}
237
238/// Fill the prev gap below `incoming_event_id` with one `/get_missing_events`
239/// batch, landing each returned event as a local outlier so the per-event walk
240/// resolves it without a federation fetch. `latest_events` is the held event
241/// the server walks back from, bounded by our forward extremities so it returns
242/// only the gap; best effort, so a failed batch or rejected event just leaves
243/// that id for the walk.
244#[implement(super::Service)]
245#[tracing::instrument(name = "missing", level = "debug", skip_all)]
246async fn prefetch_missing_events(
247	&self,
248	origin: &ServerName,
249	room_id: &RoomId,
250	incoming_event_id: &EventId,
251	room_version: &RoomVersionId,
252	recursion_level: usize,
253) {
254	let boundary: EventWindow = self
255		.services
256		.state
257		.get_forward_extremities(room_id)
258		.map(ToOwned::to_owned)
259		.collect()
260		.await;
261
262	let opts = Opts::new(Op::MissingEvents, room_id.to_owned())
263		.latest_events([incoming_event_id.to_owned()])
264		.earliest_events(boundary)
265		.hint(origin.to_owned())
266		.room_version(room_version.to_owned())
267		.attempt_limit(super::EVENT_FETCH_ATTEMPT_LIMIT)
268		.fanout_for_op();
269
270	let Ok(outcome) = self.services.fetcher.fetch(opts).await else {
271		return;
272	};
273
274	let Ok(events) = serde_json::from_slice::<Vec<Box<RawJsonValue>>>(&outcome.bytes) else {
275		return;
276	};
277
278	events
279		.into_iter()
280		.stream()
281		.for_each_concurrent(automatic_width(), async |pdu| {
282			self.land_missing_event(origin, room_id, &pdu, room_version, recursion_level)
283				.await
284				.ok();
285		})
286		.await;
287}
288
289/// Authenticate and persist one event from the missing-events batch as an
290/// outlier, deriving its id from content rather than trusting a requested id.
291#[implement(super::Service)]
292#[tracing::instrument(name = "land", level = "trace", skip_all)]
293async fn land_missing_event(
294	&self,
295	origin: &ServerName,
296	room_id: &RoomId,
297	pdu: &RawJsonValue,
298	room_version: &RoomVersionId,
299	recursion_level: usize,
300) -> Result {
301	let value: CanonicalJsonObject = serde_json::from_str(pdu.get())
302		.map_err(|e| err!(BadServerResponse("missing-events pdu is not canonical json: {e}")))?;
303
304	value
305		.get("room_id")
306		.and_then(CanonicalJsonValue::as_str)
307		.is_some_and(|id| id == room_id.as_str())
308		.then_some(())
309		.ok_or_else(|| {
310			err!(Request(InvalidParam("missing-events pdu is for a different room")))
311		})?;
312
313	let event_id = gen_event_id(&value, room_version)?;
314
315	Box::pin(self.handle_outlier_pdu(
316		origin,
317		room_id,
318		&event_id,
319		value,
320		room_version,
321		recursion_level,
322		false,
323	))
324	.await
325	.map(|_| ())
326}