Skip to main content

tuwunel_service/rooms/event_handler/
fetch_prev.rs

1use std::{collections::HashMap, iter::once};
2
3use futures::{FutureExt, StreamExt, stream::FuturesOrdered};
4use ruma::{
5	CanonicalJsonObject, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId,
6	RoomVersionId, ServerName, int, uint,
7};
8use tuwunel_core::{
9	Result, debug_warn, err, implement,
10	matrix::{
11		Event, PduEvent,
12		pdu::{MAX_PREV_EVENTS, check_room_id},
13	},
14	utils::stream::IterStream,
15};
16
17use crate::rooms::state_res;
18
19#[implement(super::Service)]
20#[tracing::instrument(
21	level = "debug",
22	skip_all,
23	fields(
24		%origin,
25		events = %initial_set.clone().count(),
26	),
27)]
28#[expect(clippy::type_complexity)]
29pub(super) async fn fetch_prev<'a, Events>(
30	&self,
31	origin: &ServerName,
32	room_id: &RoomId,
33	initial_set: Events,
34	room_version: &RoomVersionId,
35	recursion_level: usize,
36	first_ts_in_room: MilliSecondsSinceUnixEpoch,
37) -> Result<(Vec<OwnedEventId>, HashMap<OwnedEventId, (PduEvent, CanonicalJsonObject)>)>
38where
39	Events: Iterator<Item = &'a EventId> + Clone + Send,
40{
41	let mut todo_outlier_stack: FuturesOrdered<_> = initial_set
42		.stream()
43		.map(ToOwned::to_owned)
44		.filter_map(async |event_id| {
45			self.services
46				.timeline
47				.non_outlier_pdu_exists(&event_id)
48				.await
49				.is_err()
50				.then_some(event_id)
51		})
52		.map(async |event_id| {
53			let events = once(event_id.as_ref());
54			let auth = self
55				.fetch_auth(origin, room_id, events, room_version, recursion_level)
56				.await;
57
58			(event_id, auth)
59		})
60		.map(FutureExt::boxed)
61		.collect()
62		.await;
63
64	let mut amount = 0;
65	let mut eventid_info = HashMap::new();
66	let mut graph: HashMap<OwnedEventId, _> = HashMap::with_capacity(todo_outlier_stack.len());
67	while let Some((prev_event_id, mut outlier)) = todo_outlier_stack.next().await {
68		let Some((pdu, mut json_opt)) = outlier.pop() else {
69			// Fetch and handle failed
70			graph.insert(prev_event_id.clone(), Default::default());
71			continue;
72		};
73
74		check_room_id(&pdu, room_id)?;
75
76		let limit = self.services.server.config.max_fetch_prev_events;
77		if amount > limit {
78			debug_warn!(?limit, "Max prev event limit reached!");
79			graph.insert(prev_event_id.clone(), Default::default());
80			continue;
81		}
82
83		if json_opt.is_none() {
84			json_opt = self
85				.services
86				.timeline
87				.get_outlier_pdu_json(&prev_event_id)
88				.await
89				.ok();
90		}
91
92		let Some(json) = json_opt else {
93			// Get json failed, so this was not fetched over federation
94			graph.insert(prev_event_id.clone(), Default::default());
95			continue;
96		};
97
98		if pdu.origin_server_ts() > first_ts_in_room {
99			amount = amount.saturating_add(1);
100			debug_assert!(
101				pdu.prev_events().count() <= MAX_PREV_EVENTS,
102				"PduEvent {prev_event_id} has too many prev_events"
103			);
104
105			for prev_prev in pdu.prev_events() {
106				if graph.contains_key(prev_prev) {
107					continue;
108				}
109
110				let prev_prev = prev_prev.to_owned();
111				let fetch = async move {
112					let fetch = self
113						.fetch_auth(
114							origin,
115							room_id,
116							once(prev_prev.as_ref()),
117							room_version,
118							recursion_level,
119						)
120						.await;
121
122					(prev_prev, fetch)
123				};
124
125				todo_outlier_stack.push_back(fetch.boxed());
126			}
127
128			graph.insert(
129				prev_event_id.clone(),
130				pdu.prev_events().map(ToOwned::to_owned).collect(),
131			);
132		} else {
133			// Time based check failed
134			graph.insert(prev_event_id.clone(), Default::default());
135		}
136
137		eventid_info.insert(prev_event_id.clone(), (pdu, json));
138		self.services.server.check_running()?;
139	}
140
141	let event_fetch = async |event_id: OwnedEventId| {
142		let origin_server_ts = eventid_info
143			.get(&event_id)
144			.map_or_else(|| uint!(0), |info| info.0.origin_server_ts().get());
145
146		// This return value is the key used for sorting events,
147		// events are then sorted by power level, time,
148		// and lexically by event_id.
149		Ok((int!(0).into(), MilliSecondsSinceUnixEpoch(origin_server_ts)))
150	};
151
152	let sorted = state_res::topological_sort(&graph, &event_fetch)
153		.await
154		.map_err(|e| err!(Database(error!("Error sorting prev events: {e}"))))?;
155
156	debug_assert_eq!(
157		sorted.len(),
158		graph.len(),
159		"topological sort returned a different number of outputs than inputs"
160	);
161
162	debug_assert!(
163		sorted.len() >= eventid_info.len(),
164		"returned topologically sorted events differ from eventid_info"
165	);
166
167	Ok((sorted, eventid_info))
168}