Skip to main content

tuwunel_service/rooms/event_handler/
fetch_auth.rs

1use std::{
2	collections::{HashSet, VecDeque},
3	ops::Range,
4	time::Duration,
5};
6
7use futures::{FutureExt, StreamExt, TryFutureExt};
8use ruma::{
9	CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedEventId, RoomId, RoomVersionId,
10	ServerName, api::federation::event::get_event,
11};
12use tuwunel_core::{
13	debug, debug_error, debug_warn, expected, implement,
14	matrix::{PduEvent, event::gen_event_id_canonical_json, pdu::MAX_AUTH_EVENTS},
15	trace,
16	utils::stream::{BroadbandExt, IterStream, ReadyExt},
17	warn,
18};
19
20/// Find the event and auth it. Once the event is validated (steps 1 - 8)
21/// it is appended to the outliers Tree.
22///
23/// Returns pdu and if we fetched it over federation the raw json.
24///
25/// a. Look in the main timeline (pduid_pdu tree)
26/// b. Look at outlier pdu tree
27/// c. Ask origin server over federation
28/// d. TODO: Ask other servers over federation?
29#[implement(super::Service)]
30#[tracing::instrument(
31	level = "debug",
32	skip_all,
33	fields(
34		%origin,
35		events = %events.clone().count(),
36		lev = %recursion_level,
37	),
38)]
39pub(super) async fn fetch_auth<'a, Events>(
40	&self,
41	origin: &ServerName,
42	room_id: &RoomId,
43	events: Events,
44	room_version: &RoomVersionId,
45	recursion_level: usize,
46) -> Vec<(PduEvent, Option<CanonicalJsonObject>)>
47where
48	Events: Iterator<Item = &'a EventId> + Clone + Send,
49{
50	let events_with_auth_events: Vec<_> = events
51		.stream()
52		.broad_then(|event_id| self.fetch_auth_chain(origin, room_id, event_id, room_version))
53		.collect()
54		.boxed()
55		.await;
56
57	events_with_auth_events
58		.into_iter()
59		.stream()
60		.fold(Vec::new(), async |mut pdus, (id, local_pdu, events_in_reverse_order)| {
61			// a. Look in the main timeline (pduid_pdu tree)
62			// b. Look at outlier pdu tree
63			// (get_pdu_json checks both)
64			if let Some(local_pdu) = local_pdu {
65				pdus.push((local_pdu, None));
66			}
67
68			events_in_reverse_order
69				.into_iter()
70				.rev()
71				.stream()
72				.ready_filter(|(next_id, _)| {
73					let backed_off = self.is_backed_off(next_id, Range {
74						start: Duration::from_mins(5),
75						end: Duration::from_hours(24),
76					});
77
78					!backed_off
79				})
80				.fold(pdus, async |mut pdus, (next_id, value)| {
81					let outlier = Box::pin(self.handle_outlier_pdu(
82						origin,
83						room_id,
84						&next_id,
85						value.clone(),
86						room_version,
87						expected!(recursion_level + 1),
88						true,
89					));
90
91					if let Ok((pdu, json)) = outlier
92						.await
93						.inspect_err(|e| warn!("Authentication of event {next_id} failed: {e:?}"))
94					{
95						if next_id == id {
96							pdus.push((pdu, Some(json)));
97						}
98						self.cancel_back_off(&next_id);
99					} else {
100						self.back_off(&next_id);
101					}
102
103					pdus
104				})
105				.await
106		})
107		.await
108}
109
110#[implement(super::Service)]
111#[tracing::instrument(
112	name = "chain",
113	level = "trace",
114	skip_all,
115	fields(%event_id),
116)]
117async fn fetch_auth_chain(
118	&self,
119	origin: &ServerName,
120	_room_id: &RoomId,
121	event_id: &EventId,
122	room_version: &RoomVersionId,
123) -> (OwnedEventId, Option<PduEvent>, Vec<(OwnedEventId, CanonicalJsonObject)>) {
124	// a. Look in the main timeline (pduid_pdu tree)
125	// b. Look at outlier pdu tree
126	// (get_pdu_json checks both)
127	if let Ok(local_pdu) = self.services.timeline.get_pdu(event_id).await {
128		trace!(?event_id, "Found in database");
129		return (event_id.to_owned(), Some(local_pdu), vec![]);
130	}
131
132	// c. Ask origin server over federation
133	// We also handle its auth chain here so we don't get a stack overflow in
134	// handle_outlier_pdu.
135	let mut events_all = HashSet::new();
136	let mut events_in_reverse_order = Vec::new();
137	let mut todo_auth_events: VecDeque<_> = [event_id.to_owned()].into();
138	while let Some(next_id) = todo_auth_events.pop_front() {
139		if events_all.contains(&next_id) {
140			continue;
141		}
142
143		if self.is_backed_off(&next_id, Range {
144			start: Duration::from_mins(2),
145			end: Duration::from_hours(8),
146		}) {
147			debug_warn!("Backed off from {next_id}");
148			continue;
149		}
150
151		if self.services.timeline.pdu_exists(&next_id).await {
152			trace!(?next_id, "Found in database");
153			continue;
154		}
155
156		debug!("Fetching {next_id} over federation.");
157		let Ok(res) = self
158			.services
159			.federation
160			.execute(origin, get_event::v1::Request { event_id: next_id.clone() })
161			.inspect_err(|e| debug_error!(?next_id, "Failed to fetch event: {e}"))
162			.inspect_ok(|_| {
163				self.cancel_back_off(&next_id);
164			})
165			.await
166		else {
167			debug_warn!("Backing off from {next_id}");
168			self.back_off(&next_id);
169			continue;
170		};
171
172		debug!("Got {next_id} over federation");
173		let Ok((calculated_event_id, value)) =
174			gen_event_id_canonical_json(&res.pdu, room_version)
175		else {
176			self.back_off(&next_id);
177			continue;
178		};
179
180		if calculated_event_id != next_id {
181			warn!(
182				"Server returned wrong event id: requested {next_id}, got \
183				 {calculated_event_id}. Event: {:?}",
184				&res.pdu
185			);
186
187			self.back_off(&next_id);
188			continue;
189		}
190
191		self.cancel_back_off(&next_id);
192		value
193			.get("auth_events")
194			.and_then(CanonicalJsonValue::as_array)
195			.into_iter()
196			.flatten()
197			.filter_map(|auth_event| auth_event.try_into().ok())
198			.take(MAX_AUTH_EVENTS)
199			.for_each(|auth_event: &EventId| {
200				todo_auth_events.push_back(auth_event.to_owned());
201			});
202
203		events_in_reverse_order.push((next_id.clone(), value));
204		events_all.insert(next_id);
205	}
206
207	(event_id.to_owned(), None, events_in_reverse_order)
208}