Skip to main content

tuwunel_service/rooms/event_handler/
fetch_auth.rs

1use std::{
2	collections::{HashSet, VecDeque},
3	time::Duration,
4};
5
6use futures::{FutureExt, StreamExt, TryFutureExt};
7use ruma::{
8	CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedEventId, RoomId, RoomVersionId,
9	ServerName,
10};
11use tuwunel_core::{
12	debug, debug_error, debug_warn, expected, implement,
13	matrix::{PduEvent, pdu::MAX_AUTH_EVENTS},
14	trace,
15	utils::stream::{BroadbandExt, IterStream},
16	warn,
17};
18
19use super::backoff::{Context, Disposition};
20use crate::fetcher::{Op, Opts};
21
22/// Find the event and auth it. Once the event is validated (steps 1 - 8)
23/// it is appended to the outliers Tree.
24///
25/// Returns pdu and if we fetched it over federation the raw json.
26///
27/// a. Look in the main timeline (pduid_pdu tree)
28/// b. Look at outlier pdu tree
29/// c. Ask origin server over federation
30/// d. TODO: Ask other servers over federation?
31#[implement(super::Service)]
32#[tracing::instrument(
33	level = "debug",
34	skip_all,
35	fields(
36		%origin,
37		events = %events.clone().count(),
38		lev = %recursion_level,
39	),
40)]
41pub(super) async fn fetch_auth<'a, Events>(
42	&self,
43	origin: &ServerName,
44	room_id: &RoomId,
45	events: Events,
46	room_version: &RoomVersionId,
47	recursion_level: usize,
48) -> Vec<(PduEvent, Option<CanonicalJsonObject>)>
49where
50	Events: Iterator<Item = &'a EventId> + Clone + Send,
51{
52	let events_with_auth_events: Vec<_> = events
53		.stream()
54		.broad_then(|event_id| self.fetch_auth_chain(origin, room_id, event_id, room_version))
55		.collect()
56		.boxed()
57		.await;
58
59	events_with_auth_events
60		.into_iter()
61		.stream()
62		.fold(Vec::new(), async |mut pdus, (id, local_pdu, events_in_reverse_order)| {
63			if self.services.server.check_running().is_err() {
64				return pdus;
65			}
66
67			// a. Look in the main timeline (pduid_pdu tree)
68			// b. Look at outlier pdu tree
69			// (get_pdu_json checks both)
70			if let Some(local_pdu) = local_pdu {
71				pdus.push((local_pdu, None));
72			}
73
74			events_in_reverse_order
75				.into_iter()
76				.rev()
77				.stream()
78				.fold(pdus, async |mut pdus, (next_id, value)| {
79					if self
80						.is_suppressed(
81							Context::Auth,
82							&next_id,
83							Duration::from_mins(5)..Duration::from_hours(24),
84						)
85						.await
86						.is_deny()
87					{
88						return pdus;
89					}
90
91					let outlier = Box::pin(self.handle_outlier_pdu(
92						origin,
93						room_id,
94						&next_id,
95						value.clone(),
96						room_version,
97						expected!(recursion_level + 1),
98						true,
99					));
100
101					if let Ok((pdu, json)) = outlier
102						.await
103						.inspect_err(|e| warn!("Authentication of event {next_id} failed: {e:?}"))
104					{
105						if next_id == id {
106							pdus.push((pdu, Some(json)));
107						}
108						self.record_success(Context::Auth, &next_id).await;
109					} else {
110						self.record_outcome(Context::Auth, &next_id, Disposition::Transient);
111					}
112
113					pdus
114				})
115				.await
116		})
117		.await
118}
119
120#[implement(super::Service)]
121#[tracing::instrument(
122	name = "chain",
123	level = "trace",
124	skip_all,
125	fields(%event_id),
126)]
127async fn fetch_auth_chain(
128	&self,
129	origin: &ServerName,
130	room_id: &RoomId,
131	event_id: &EventId,
132	room_version: &RoomVersionId,
133) -> (OwnedEventId, Option<PduEvent>, Vec<(OwnedEventId, CanonicalJsonObject)>) {
134	// a. Look in the main timeline (pduid_pdu tree)
135	// b. Look at outlier pdu tree
136	// (get_pdu_json checks both)
137	if let Ok(local_pdu) = self.services.timeline.get_pdu(event_id).await {
138		trace!(?event_id, "Found in database");
139		return (event_id.to_owned(), Some(local_pdu), vec![]);
140	}
141
142	// c. Ask origin server over federation
143	// We also handle its auth chain here so we don't get a stack overflow in
144	// handle_outlier_pdu.
145	let mut events_all = HashSet::new();
146	let mut events_in_reverse_order = Vec::new();
147	let mut todo_auth_events: VecDeque<_> = [event_id.to_owned()].into();
148	while let Some(next_id) = todo_auth_events.pop_front() {
149		if events_all.contains(&next_id) {
150			continue;
151		}
152
153		if self
154			.is_suppressed(
155				Context::Fetch,
156				&next_id,
157				Duration::from_mins(2)..Duration::from_hours(8),
158			)
159			.await
160			.is_deny()
161		{
162			debug_warn!("Backed off from {next_id}");
163			continue;
164		}
165
166		if self.services.timeline.pdu_exists(&next_id).await {
167			trace!(?next_id, "Found in database");
168			continue;
169		}
170
171		if self.services.server.check_running().is_err() {
172			debug_warn!(?next_id, "Server shutting down");
173			break;
174		}
175
176		debug!("Fetching {next_id} over federation.");
177		let opts = Opts::new(Op::AuthEvent, room_id.to_owned())
178			.event_id(next_id.clone())
179			.hint(origin.to_owned())
180			.room_version(room_version.to_owned())
181			.attempt_limit(super::EVENT_FETCH_ATTEMPT_LIMIT)
182			.fanout_for_op();
183
184		let Ok(outcome) = self
185			.services
186			.fetcher
187			.fetch(opts)
188			.inspect_err(|e| debug_error!(?next_id, "Failed to fetch event: {e}"))
189			.await
190		else {
191			debug_warn!("Backing off from {next_id}");
192			self.record_outcome(Context::Fetch, &next_id, Disposition::Transient);
193			continue;
194		};
195
196		let Ok(value) = serde_json::from_slice::<CanonicalJsonObject>(&outcome.bytes) else {
197			self.record_outcome(Context::Fetch, &next_id, Disposition::Transient);
198			continue;
199		};
200
201		debug!("Got {next_id} over federation");
202		self.record_success(Context::Fetch, &next_id)
203			.await;
204		value
205			.get("auth_events")
206			.and_then(CanonicalJsonValue::as_array)
207			.into_iter()
208			.flatten()
209			.filter_map(|auth_event| auth_event.try_into().ok())
210			.take(MAX_AUTH_EVENTS)
211			.for_each(|auth_event: &EventId| {
212				todo_auth_events.push_back(auth_event.to_owned());
213			});
214
215		events_in_reverse_order.push((next_id.clone(), value));
216		events_all.insert(next_id);
217	}
218
219	(event_id.to_owned(), None, events_in_reverse_order)
220}