tuwunel_service/rooms/event_handler/
fetch_auth.rs1use std::{
2 collections::{HashSet, VecDeque},
3 time::Duration,
4};
5
6use futures::{FutureExt, StreamExt, TryFutureExt};
7use ruma::{
8 CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedEventId, RoomId, RoomVersionId,
9 ServerName,
10};
11use tuwunel_core::{
12 debug, debug_error, debug_warn, expected, implement,
13 matrix::{PduEvent, pdu::MAX_AUTH_EVENTS},
14 trace,
15 utils::stream::{BroadbandExt, IterStream},
16 warn,
17};
18
19use super::backoff::{Context, Disposition};
20use crate::fetcher::{Op, Opts};
21
22#[implement(super::Service)]
32#[tracing::instrument(
33 level = "debug",
34 skip_all,
35 fields(
36 %origin,
37 events = %events.clone().count(),
38 lev = %recursion_level,
39 ),
40)]
41pub(super) async fn fetch_auth<'a, Events>(
42 &self,
43 origin: &ServerName,
44 room_id: &RoomId,
45 events: Events,
46 room_version: &RoomVersionId,
47 recursion_level: usize,
48) -> Vec<(PduEvent, Option<CanonicalJsonObject>)>
49where
50 Events: Iterator<Item = &'a EventId> + Clone + Send,
51{
52 let events_with_auth_events: Vec<_> = events
53 .stream()
54 .broad_then(|event_id| self.fetch_auth_chain(origin, room_id, event_id, room_version))
55 .collect()
56 .boxed()
57 .await;
58
59 events_with_auth_events
60 .into_iter()
61 .stream()
62 .fold(Vec::new(), async |mut pdus, (id, local_pdu, events_in_reverse_order)| {
63 if self.services.server.check_running().is_err() {
64 return pdus;
65 }
66
67 if let Some(local_pdu) = local_pdu {
71 pdus.push((local_pdu, None));
72 }
73
74 events_in_reverse_order
75 .into_iter()
76 .rev()
77 .stream()
78 .fold(pdus, async |mut pdus, (next_id, value)| {
79 if self
80 .is_suppressed(
81 Context::Auth,
82 &next_id,
83 Duration::from_mins(5)..Duration::from_hours(24),
84 )
85 .await
86 .is_deny()
87 {
88 return pdus;
89 }
90
91 let outlier = Box::pin(self.handle_outlier_pdu(
92 origin,
93 room_id,
94 &next_id,
95 value.clone(),
96 room_version,
97 expected!(recursion_level + 1),
98 true,
99 ));
100
101 if let Ok((pdu, json)) = outlier
102 .await
103 .inspect_err(|e| warn!("Authentication of event {next_id} failed: {e:?}"))
104 {
105 if next_id == id {
106 pdus.push((pdu, Some(json)));
107 }
108 self.record_success(Context::Auth, &next_id).await;
109 } else {
110 self.record_outcome(Context::Auth, &next_id, Disposition::Transient);
111 }
112
113 pdus
114 })
115 .await
116 })
117 .await
118}
119
120#[implement(super::Service)]
121#[tracing::instrument(
122 name = "chain",
123 level = "trace",
124 skip_all,
125 fields(%event_id),
126)]
127async fn fetch_auth_chain(
128 &self,
129 origin: &ServerName,
130 room_id: &RoomId,
131 event_id: &EventId,
132 room_version: &RoomVersionId,
133) -> (OwnedEventId, Option<PduEvent>, Vec<(OwnedEventId, CanonicalJsonObject)>) {
134 if let Ok(local_pdu) = self.services.timeline.get_pdu(event_id).await {
138 trace!(?event_id, "Found in database");
139 return (event_id.to_owned(), Some(local_pdu), vec![]);
140 }
141
142 let mut events_all = HashSet::new();
146 let mut events_in_reverse_order = Vec::new();
147 let mut todo_auth_events: VecDeque<_> = [event_id.to_owned()].into();
148 while let Some(next_id) = todo_auth_events.pop_front() {
149 if events_all.contains(&next_id) {
150 continue;
151 }
152
153 if self
154 .is_suppressed(
155 Context::Fetch,
156 &next_id,
157 Duration::from_mins(2)..Duration::from_hours(8),
158 )
159 .await
160 .is_deny()
161 {
162 debug_warn!("Backed off from {next_id}");
163 continue;
164 }
165
166 if self.services.timeline.pdu_exists(&next_id).await {
167 trace!(?next_id, "Found in database");
168 continue;
169 }
170
171 if self.services.server.check_running().is_err() {
172 debug_warn!(?next_id, "Server shutting down");
173 break;
174 }
175
176 debug!("Fetching {next_id} over federation.");
177 let opts = Opts::new(Op::AuthEvent, room_id.to_owned())
178 .event_id(next_id.clone())
179 .hint(origin.to_owned())
180 .room_version(room_version.to_owned())
181 .attempt_limit(super::EVENT_FETCH_ATTEMPT_LIMIT)
182 .fanout_for_op();
183
184 let Ok(outcome) = self
185 .services
186 .fetcher
187 .fetch(opts)
188 .inspect_err(|e| debug_error!(?next_id, "Failed to fetch event: {e}"))
189 .await
190 else {
191 debug_warn!("Backing off from {next_id}");
192 self.record_outcome(Context::Fetch, &next_id, Disposition::Transient);
193 continue;
194 };
195
196 let Ok(value) = serde_json::from_slice::<CanonicalJsonObject>(&outcome.bytes) else {
197 self.record_outcome(Context::Fetch, &next_id, Disposition::Transient);
198 continue;
199 };
200
201 debug!("Got {next_id} over federation");
202 self.record_success(Context::Fetch, &next_id)
203 .await;
204 value
205 .get("auth_events")
206 .and_then(CanonicalJsonValue::as_array)
207 .into_iter()
208 .flatten()
209 .filter_map(|auth_event| auth_event.try_into().ok())
210 .take(MAX_AUTH_EVENTS)
211 .for_each(|auth_event: &EventId| {
212 todo_auth_events.push_back(auth_event.to_owned());
213 });
214
215 events_in_reverse_order.push((next_id.clone(), value));
216 events_all.insert(next_id);
217 }
218
219 (event_id.to_owned(), None, events_in_reverse_order)
220}