tuwunel_service/rooms/event_handler/
upgrade_outlier_pdu.rs1use std::{borrow::Borrow, iter::once, sync::Arc, time::Instant};
2
3use futures::{FutureExt, StreamExt};
4use ruma::{
5 CanonicalJsonObject, EventId, OwnedEventId, RoomId, RoomVersionId, ServerName,
6 events::StateEventType,
7};
8use tuwunel_core::{
9 Err, Result, debug, debug_info, err, implement, is_equal_to,
10 matrix::{Event, EventTypeExt, PduEvent, StateKey, pdu::check_rules, room_version},
11 trace,
12 utils::stream::{BroadbandExt, ReadyExt},
13 warn,
14};
15
16use super::policy_server::PolicyCheck;
17use crate::rooms::{
18 state_compressor::{CompressedState, HashSetCompressStateEvent},
19 state_res,
20 timeline::RawPduId,
21};
22
23#[implement(super::Service)]
24#[tracing::instrument(
25 name = "upgrade",
26 level = "debug",
27 ret(level = "debug"),
28 skip_all,
29 fields(lev = %recursion_level)
30)]
31#[expect(clippy::too_many_arguments)]
32pub(super) async fn upgrade_outlier_to_timeline_pdu(
33 &self,
34 origin: &ServerName,
35 room_id: &RoomId,
36 incoming_pdu: PduEvent,
37 mut pdu_json: CanonicalJsonObject,
38 room_version: &RoomVersionId,
39 recursion_level: usize,
40 create_event_id: &EventId,
41) -> Result<Option<(RawPduId, bool)>> {
42 if let Ok(pdu_id) = self
44 .services
45 .timeline
46 .get_pdu_id(incoming_pdu.event_id())
47 .await
48 {
49 debug!(?pdu_id, "Exists.");
50 return Ok(Some((pdu_id, false)));
51 }
52
53 if self
54 .services
55 .pdu_metadata
56 .is_event_soft_failed(incoming_pdu.event_id())
57 .await
58 {
59 return Err!(Request(InvalidParam("Event has been soft failed")));
60 }
61
62 trace!("Upgrading to timeline pdu");
63
64 let timer = Instant::now();
65 let room_rules = room_version::rules(room_version)?;
66
67 trace!(format = ?room_rules.event_format, "Checking format");
68 check_rules(&pdu_json, &room_rules.event_format)?;
69
70 trace!("Resolving state at event");
74
75 let mut state_at_incoming_event = if incoming_pdu.prev_events().count() == 1 {
76 self.state_at_incoming_degree_one(&incoming_pdu)
77 .await?
78 } else {
79 self.state_at_incoming_resolved(&incoming_pdu, room_id, room_version)
80 .boxed()
81 .await?
82 };
83
84 if state_at_incoming_event.is_none() {
85 state_at_incoming_event = self
86 .fetch_state(
87 origin,
88 room_id,
89 incoming_pdu.event_id(),
90 room_version,
91 recursion_level,
92 create_event_id,
93 )
94 .boxed()
95 .await?;
96 }
97
98 let state_at_incoming_event =
99 state_at_incoming_event.expect("we always set this to some above");
100
101 let state_fetch = async |k: StateEventType, s: StateKey| {
104 let shortstatekey = self
105 .services
106 .short
107 .get_shortstatekey(&k, s.as_str())
108 .await?;
109
110 let event_id = state_at_incoming_event
111 .get(&shortstatekey)
112 .ok_or_else(|| {
113 err!(Request(NotFound(
114 "shortstatekey {shortstatekey:?} not found for ({k:?},{s:?})"
115 )))
116 })?;
117
118 self.services.timeline.get_pdu(event_id).await
119 };
120
121 let event_fetch = async |event_id: OwnedEventId| self.event_fetch(&event_id).await;
122
123 trace!("Performing auth check");
124 state_res::auth_check(&room_rules, &incoming_pdu, &event_fetch, &state_fetch).await?;
125
126 trace!("Gathering auth events");
127 let auth_events = self
128 .services
129 .state
130 .get_auth_events(
131 room_id,
132 incoming_pdu.kind(),
133 incoming_pdu.sender(),
134 incoming_pdu.state_key(),
135 incoming_pdu.content(),
136 &room_rules.authorization,
137 true,
138 )
139 .await?;
140
141 let state_fetch = async |k: StateEventType, s: StateKey| {
142 auth_events
143 .get(&k.with_state_key(s.as_str()))
144 .map(ToOwned::to_owned)
145 .ok_or_else(|| err!(Request(NotFound("state event not found"))))
146 };
147
148 trace!("Performing auth check");
149 state_res::auth_check(&room_rules, &incoming_pdu, &event_fetch, &state_fetch).await?;
150
151 trace!("Performing soft-fail check");
153 let soft_fail_redact = match incoming_pdu.redacts_id(&room_rules) {
154 | None => false,
155 | Some(redact_id) =>
156 !self
157 .services
158 .state_accessor
159 .user_can_redact(&redact_id, incoming_pdu.sender(), incoming_pdu.room_id(), true)
160 .await?,
161 };
162
163 let soft_fail = soft_fail_redact
165 || matches!(
166 self.verify_or_fetch_inbound_policy_signature(&mut pdu_json, &incoming_pdu)
167 .await,
168 PolicyCheck::Invalid,
169 );
170
171 trace!("Locking the room");
174 let state_lock = self.services.state.mutex.lock(room_id).await;
175
176 trace!("Calculating extremities");
179 let extremities: Vec<_> = self
180 .services
181 .state
182 .get_forward_extremities(room_id)
183 .map(ToOwned::to_owned)
184 .ready_filter(|event_id| {
185 !incoming_pdu
187 .prev_events()
188 .any(is_equal_to!(event_id))
189 })
190 .broad_filter_map(async |event_id| {
191 self.services
193 .pdu_metadata
194 .is_event_referenced(room_id, &event_id)
195 .await
196 .eq(&false)
197 .then_some(event_id)
198 })
199 .collect()
200 .await;
201
202 debug!(
203 retained = extremities.len(),
204 prev_events = incoming_pdu.prev_events().count(),
205 "Retained extremities checked against prev_events.",
206 );
207
208 trace!("Compressing state...");
209 let state_ids_compressed: Arc<CompressedState> = self
210 .services
211 .state_compressor
212 .compress_state_events(
213 state_at_incoming_event
214 .iter()
215 .map(|(ssk, eid)| (ssk, eid.borrow())),
216 )
217 .collect()
218 .map(Arc::new)
219 .await;
220
221 if incoming_pdu.state_key().is_some() {
222 let mut state_after = state_at_incoming_event.clone();
224 if let Some(state_key) = incoming_pdu.state_key() {
225 let event_id = incoming_pdu.event_id();
226 let event_type = incoming_pdu.kind();
227 let shortstatekey = self
228 .services
229 .short
230 .get_or_create_shortstatekey(&event_type.to_string().into(), state_key)
231 .await;
232
233 state_after.insert(shortstatekey, event_id.to_owned());
234 debug!(
236 ?event_id,
237 ?event_type,
238 ?state_key,
239 ?shortstatekey,
240 state_after = state_after.len(),
241 "Adding event to state."
242 );
243 }
244
245 trace!("Resolving new room state.");
246 let new_room_state = self
247 .resolve_state(room_id, room_version, state_after)
248 .boxed()
249 .await?;
250
251 trace!("Saving resolved state.");
253 let HashSetCompressStateEvent { shortstatehash, added, removed } = self
254 .services
255 .state_compressor
256 .save_state(room_id, new_room_state)
257 .await?;
258
259 debug!(
260 ?shortstatehash,
261 added = added.len(),
262 removed = removed.len(),
263 "Forcing new room state."
264 );
265 self.services
266 .state
267 .force_state(room_id, shortstatehash, added, removed, &state_lock)
268 .await?;
269 }
270
271 trace!("Appending pdu to timeline");
278
279 let incoming_extremity = once(incoming_pdu.event_id()).filter(|_| !soft_fail);
281
282 let extremities = extremities
283 .iter()
284 .map(Borrow::borrow)
285 .chain(incoming_extremity);
286
287 let pdu_id = self
288 .services
289 .timeline
290 .append_incoming_pdu(
291 &incoming_pdu,
292 pdu_json,
293 extremities,
294 state_ids_compressed,
295 soft_fail,
296 &state_lock,
297 )
298 .await?;
299
300 debug_assert!(
301 pdu_id.is_some() || soft_fail,
302 "Ok(None) returned by timeline for soft-failed PDU's"
303 );
304
305 if soft_fail {
306 self.services
307 .pdu_metadata
308 .mark_event_soft_failed(incoming_pdu.event_id());
309
310 drop(state_lock);
311 warn!(
312 elapsed = ?timer.elapsed(),
313 "Event was soft failed: {:?}",
314 incoming_pdu.event_id()
315 );
316
317 return Err!(Request(InvalidParam("Event has been soft failed")));
318 }
319
320 drop(state_lock);
321 debug_info!(
322 elapsed = ?timer.elapsed(),
323 "Accepted",
324 );
325
326 Ok(pdu_id.zip(Some(true)))
327}