tuwunel_service/rooms/event_handler/
fetch_auth.rs1use std::{
2 collections::{HashSet, VecDeque},
3 ops::Range,
4 time::Duration,
5};
6
7use futures::{FutureExt, StreamExt, TryFutureExt};
8use ruma::{
9 CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedEventId, RoomId, RoomVersionId,
10 ServerName, api::federation::event::get_event,
11};
12use tuwunel_core::{
13 debug, debug_error, debug_warn, expected, implement,
14 matrix::{PduEvent, event::gen_event_id_canonical_json, pdu::MAX_AUTH_EVENTS},
15 trace,
16 utils::stream::{BroadbandExt, IterStream, ReadyExt},
17 warn,
18};
19
20#[implement(super::Service)]
30#[tracing::instrument(
31 level = "debug",
32 skip_all,
33 fields(
34 %origin,
35 events = %events.clone().count(),
36 lev = %recursion_level,
37 ),
38)]
39pub(super) async fn fetch_auth<'a, Events>(
40 &self,
41 origin: &ServerName,
42 room_id: &RoomId,
43 events: Events,
44 room_version: &RoomVersionId,
45 recursion_level: usize,
46) -> Vec<(PduEvent, Option<CanonicalJsonObject>)>
47where
48 Events: Iterator<Item = &'a EventId> + Clone + Send,
49{
50 let events_with_auth_events: Vec<_> = events
51 .stream()
52 .broad_then(|event_id| self.fetch_auth_chain(origin, room_id, event_id, room_version))
53 .collect()
54 .boxed()
55 .await;
56
57 events_with_auth_events
58 .into_iter()
59 .stream()
60 .fold(Vec::new(), async |mut pdus, (id, local_pdu, events_in_reverse_order)| {
61 if let Some(local_pdu) = local_pdu {
65 pdus.push((local_pdu, None));
66 }
67
68 events_in_reverse_order
69 .into_iter()
70 .rev()
71 .stream()
72 .ready_filter(|(next_id, _)| {
73 let backed_off = self.is_backed_off(next_id, Range {
74 start: Duration::from_mins(5),
75 end: Duration::from_hours(24),
76 });
77
78 !backed_off
79 })
80 .fold(pdus, async |mut pdus, (next_id, value)| {
81 let outlier = Box::pin(self.handle_outlier_pdu(
82 origin,
83 room_id,
84 &next_id,
85 value.clone(),
86 room_version,
87 expected!(recursion_level + 1),
88 true,
89 ));
90
91 if let Ok((pdu, json)) = outlier
92 .await
93 .inspect_err(|e| warn!("Authentication of event {next_id} failed: {e:?}"))
94 {
95 if next_id == id {
96 pdus.push((pdu, Some(json)));
97 }
98 self.cancel_back_off(&next_id);
99 } else {
100 self.back_off(&next_id);
101 }
102
103 pdus
104 })
105 .await
106 })
107 .await
108}
109
110#[implement(super::Service)]
111#[tracing::instrument(
112 name = "chain",
113 level = "trace",
114 skip_all,
115 fields(%event_id),
116)]
117async fn fetch_auth_chain(
118 &self,
119 origin: &ServerName,
120 _room_id: &RoomId,
121 event_id: &EventId,
122 room_version: &RoomVersionId,
123) -> (OwnedEventId, Option<PduEvent>, Vec<(OwnedEventId, CanonicalJsonObject)>) {
124 if let Ok(local_pdu) = self.services.timeline.get_pdu(event_id).await {
128 trace!(?event_id, "Found in database");
129 return (event_id.to_owned(), Some(local_pdu), vec![]);
130 }
131
132 let mut events_all = HashSet::new();
136 let mut events_in_reverse_order = Vec::new();
137 let mut todo_auth_events: VecDeque<_> = [event_id.to_owned()].into();
138 while let Some(next_id) = todo_auth_events.pop_front() {
139 if events_all.contains(&next_id) {
140 continue;
141 }
142
143 if self.is_backed_off(&next_id, Range {
144 start: Duration::from_mins(2),
145 end: Duration::from_hours(8),
146 }) {
147 debug_warn!("Backed off from {next_id}");
148 continue;
149 }
150
151 if self.services.timeline.pdu_exists(&next_id).await {
152 trace!(?next_id, "Found in database");
153 continue;
154 }
155
156 debug!("Fetching {next_id} over federation.");
157 let Ok(res) = self
158 .services
159 .federation
160 .execute(origin, get_event::v1::Request { event_id: next_id.clone() })
161 .inspect_err(|e| debug_error!(?next_id, "Failed to fetch event: {e}"))
162 .inspect_ok(|_| {
163 self.cancel_back_off(&next_id);
164 })
165 .await
166 else {
167 debug_warn!("Backing off from {next_id}");
168 self.back_off(&next_id);
169 continue;
170 };
171
172 debug!("Got {next_id} over federation");
173 let Ok((calculated_event_id, value)) =
174 gen_event_id_canonical_json(&res.pdu, room_version)
175 else {
176 self.back_off(&next_id);
177 continue;
178 };
179
180 if calculated_event_id != next_id {
181 warn!(
182 "Server returned wrong event id: requested {next_id}, got \
183 {calculated_event_id}. Event: {:?}",
184 &res.pdu
185 );
186
187 self.back_off(&next_id);
188 continue;
189 }
190
191 self.cancel_back_off(&next_id);
192 value
193 .get("auth_events")
194 .and_then(CanonicalJsonValue::as_array)
195 .into_iter()
196 .flatten()
197 .filter_map(|auth_event| auth_event.try_into().ok())
198 .take(MAX_AUTH_EVENTS)
199 .for_each(|auth_event: &EventId| {
200 todo_auth_events.push_back(auth_event.to_owned());
201 });
202
203 events_in_reverse_order.push((next_id.clone(), value));
204 events_all.insert(next_id);
205 }
206
207 (event_id.to_owned(), None, events_in_reverse_order)
208}