1use std::{borrow::Borrow, collections::HashMap, iter::once, sync::Arc, time::Instant};
2
3use futures::{FutureExt, StreamExt};
4use ruma::{
5 CanonicalJsonObject, EventId, OwnedEventId, RoomId, RoomVersionId, ServerName,
6 events::StateEventType, room_version_rules::RoomVersionRules,
7};
8use tuwunel_core::{
9 Err, Result, debug, debug_info, err, implement, is_equal_to,
10 matrix::{Event, EventTypeExt, PduEvent, StateKey, pdu::check_rules, room_version},
11 trace,
12 utils::stream::{BroadbandExt, ReadyExt},
13 warn,
14};
15
16use super::policy_server::PolicyCheck;
17use crate::rooms::{
18 state::RoomMutexGuard,
19 state_compressor::{CompressedState, HashSetCompressStateEvent},
20 state_res,
21 timeline::RawPduId,
22};
23
24#[implement(super::Service)]
25#[tracing::instrument(
26 name = "upgrade",
27 level = "debug",
28 ret(level = "debug"),
29 skip_all,
30 fields(lev = %recursion_level)
31)]
32#[expect(clippy::too_many_arguments)]
33pub(super) async fn upgrade_outlier_to_timeline_pdu(
34 &self,
35 origin: &ServerName,
36 room_id: &RoomId,
37 incoming_pdu: PduEvent,
38 mut pdu_json: CanonicalJsonObject,
39 room_version: &RoomVersionId,
40 recursion_level: usize,
41 create_event_id: &EventId,
42) -> Result<Option<(RawPduId, bool)>> {
43 if let Ok(pdu_id) = self
45 .services
46 .timeline
47 .get_pdu_id(incoming_pdu.event_id())
48 .await
49 {
50 debug!(?pdu_id, "Exists.");
51 return Ok(Some((pdu_id, false)));
52 }
53
54 if self
55 .services
56 .pdu_metadata
57 .is_event_soft_failed(incoming_pdu.event_id())
58 .await
59 {
60 return Err!(Request(InvalidParam("Event has been soft failed")));
61 }
62
63 trace!("Upgrading to timeline pdu");
64
65 let timer = Instant::now();
66 let room_rules = room_version::rules(room_version)?;
67
68 trace!(format = ?room_rules.event_format, "Checking format");
69 check_rules(&pdu_json, &room_rules.event_format)?;
70
71 let (state_at_incoming_event, via_fetch) = self
72 .resolve_state_at_incoming_event(
73 origin,
74 room_id,
75 &incoming_pdu,
76 room_version,
77 recursion_level,
78 create_event_id,
79 )
80 .await?;
81
82 self.auth_check_outlier_pdu(room_id, &incoming_pdu, &room_rules, &state_at_incoming_event)
83 .await?;
84
85 let soft_fail = self
86 .compute_soft_fail(&incoming_pdu, &room_rules, &mut pdu_json)
87 .await?;
88
89 trace!("Locking the room");
92 let state_lock = self.services.state.mutex.lock(room_id).await;
93
94 let extremities = self
95 .compute_remaining_extremities(room_id, &incoming_pdu)
96 .await;
97
98 trace!("Compressing state...");
99 let state_ids_compressed: Arc<CompressedState> = self
100 .services
101 .state_compressor
102 .compress_state_events(
103 state_at_incoming_event
104 .iter()
105 .map(|(ssk, eid)| (ssk, eid.borrow())),
106 )
107 .collect()
108 .map(Arc::new)
109 .await;
110
111 if via_fetch && !soft_fail {
112 self.cache_resolved_state(room_id, incoming_pdu.event_id(), state_ids_compressed.clone())
113 .await;
114 }
115
116 if incoming_pdu.state_key().is_some() {
117 self.resolve_and_force_state_after(
118 room_id,
119 room_version,
120 &incoming_pdu,
121 &state_at_incoming_event,
122 &state_lock,
123 )
124 .boxed()
125 .await?;
126 }
127
128 trace!("Appending pdu to timeline");
135
136 let incoming_extremity = once(incoming_pdu.event_id()).filter(|_| !soft_fail);
138
139 let extremities = extremities
140 .iter()
141 .map(Borrow::borrow)
142 .chain(incoming_extremity);
143
144 let pdu_id = self
145 .services
146 .timeline
147 .append_incoming_pdu(
148 &incoming_pdu,
149 pdu_json,
150 extremities,
151 state_ids_compressed,
152 soft_fail,
153 &state_lock,
154 )
155 .await?;
156
157 debug_assert!(
158 pdu_id.is_some() || soft_fail,
159 "Ok(None) returned by timeline for soft-failed PDU's"
160 );
161
162 if soft_fail {
163 self.services
164 .pdu_metadata
165 .mark_event_soft_failed(incoming_pdu.event_id());
166
167 drop(state_lock);
168 warn!(
169 elapsed = ?timer.elapsed(),
170 "Event was soft failed: {:?}",
171 incoming_pdu.event_id()
172 );
173
174 return Err!(Request(InvalidParam("Event has been soft failed")));
175 }
176
177 drop(state_lock);
178 debug_info!(
179 elapsed = ?timer.elapsed(),
180 "Accepted",
181 );
182
183 Ok(pdu_id.zip(Some(true)))
184}
185
186#[implement(super::Service)]
187async fn resolve_state_at_incoming_event(
188 &self,
189 origin: &ServerName,
190 room_id: &RoomId,
191 incoming_pdu: &PduEvent,
192 room_version: &RoomVersionId,
193 recursion_level: usize,
194 create_event_id: &EventId,
195) -> Result<(HashMap<u64, OwnedEventId>, bool)> {
196 trace!("Resolving state at event");
200
201 let mut state_at_incoming_event = if incoming_pdu.prev_events().count() == 1 {
202 self.state_at_incoming_degree_one(incoming_pdu)
203 .await?
204 } else {
205 self.state_at_incoming_resolved(incoming_pdu, room_id, room_version)
206 .boxed()
207 .await?
208 };
209
210 if state_at_incoming_event.is_none() {
211 state_at_incoming_event = self
212 .cached_resolved_state(incoming_pdu.event_id())
213 .await;
214 }
215
216 let (state_at_incoming_event, via_fetch) = match state_at_incoming_event {
217 | Some(state) => (state, false),
218 | None => {
219 let state = self
220 .fetch_state(
221 origin,
222 room_id,
223 incoming_pdu.event_id(),
224 room_version,
225 recursion_level,
226 create_event_id,
227 )
228 .boxed()
229 .await?
230 .expect("fetch_state always resolves state to some");
231
232 (state, true)
233 },
234 };
235
236 Ok((state_at_incoming_event, via_fetch))
237}
238
239#[implement(super::Service)]
240async fn auth_check_outlier_pdu(
241 &self,
242 room_id: &RoomId,
243 incoming_pdu: &PduEvent,
244 room_rules: &RoomVersionRules,
245 state_at_incoming_event: &HashMap<u64, OwnedEventId>,
246) -> Result {
247 let state_fetch = async |k: StateEventType, s: StateKey| {
250 let shortstatekey = self
251 .services
252 .short
253 .get_shortstatekey(&k, s.as_str())
254 .await?;
255
256 let event_id = state_at_incoming_event
257 .get(&shortstatekey)
258 .ok_or_else(|| {
259 err!(Request(NotFound(
260 "shortstatekey {shortstatekey:?} not found for ({k:?},{s:?})"
261 )))
262 })?;
263
264 self.services.timeline.get_pdu(event_id).await
265 };
266
267 let event_fetch = async |event_id: OwnedEventId| self.event_fetch(&event_id).await;
268
269 trace!("Performing auth check");
270 state_res::auth_check(room_rules, incoming_pdu, &event_fetch, &state_fetch).await?;
271
272 trace!("Gathering auth events");
273 let auth_events = self
274 .services
275 .state
276 .get_auth_events(
277 room_id,
278 incoming_pdu.kind(),
279 incoming_pdu.sender(),
280 incoming_pdu.state_key(),
281 incoming_pdu.content(),
282 &room_rules.authorization,
283 true,
284 )
285 .await?;
286
287 let state_fetch = async |k: StateEventType, s: StateKey| {
288 auth_events
289 .get(&k.with_state_key(s.as_str()))
290 .map(ToOwned::to_owned)
291 .ok_or_else(|| err!(Request(NotFound("state event not found"))))
292 };
293
294 trace!("Performing auth check");
295 state_res::auth_check(room_rules, incoming_pdu, &event_fetch, &state_fetch).await?;
296
297 Ok(())
298}
299
300#[implement(super::Service)]
301async fn compute_soft_fail(
302 &self,
303 incoming_pdu: &PduEvent,
304 room_rules: &RoomVersionRules,
305 pdu_json: &mut CanonicalJsonObject,
306) -> Result<bool> {
307 trace!("Performing soft-fail check");
309 let soft_fail_redact = match incoming_pdu.redacts_id(room_rules) {
310 | None => false,
311 | Some(redact_id) =>
312 !self
313 .services
314 .state_accessor
315 .user_can_redact(&redact_id, incoming_pdu.sender(), incoming_pdu.room_id(), true)
316 .await?,
317 };
318
319 Ok(soft_fail_redact
321 || matches!(
322 self.verify_or_fetch_inbound_policy_signature(pdu_json, incoming_pdu)
323 .await,
324 PolicyCheck::Invalid,
325 ))
326}
327
328#[implement(super::Service)]
329async fn compute_remaining_extremities(
330 &self,
331 room_id: &RoomId,
332 incoming_pdu: &PduEvent,
333) -> Vec<OwnedEventId> {
334 trace!("Calculating extremities");
337 let extremities: Vec<_> = self
338 .services
339 .state
340 .get_forward_extremities(room_id)
341 .map(ToOwned::to_owned)
342 .ready_filter(|event_id| {
343 !incoming_pdu
345 .prev_events()
346 .any(is_equal_to!(event_id))
347 })
348 .broad_filter_map(async |event_id| {
349 self.services
351 .pdu_metadata
352 .is_event_referenced(room_id, &event_id)
353 .await
354 .eq(&false)
355 .then_some(event_id)
356 })
357 .collect()
358 .await;
359
360 debug!(
361 retained = extremities.len(),
362 prev_events = incoming_pdu.prev_events().count(),
363 "Retained extremities checked against prev_events.",
364 );
365
366 extremities
367}
368
369#[implement(super::Service)]
370async fn resolve_and_force_state_after(
371 &self,
372 room_id: &RoomId,
373 room_version: &RoomVersionId,
374 incoming_pdu: &PduEvent,
375 state_at_incoming_event: &HashMap<u64, OwnedEventId>,
376 state_lock: &RoomMutexGuard,
377) -> Result {
378 let mut state_after = state_at_incoming_event.clone();
380 if let Some(state_key) = incoming_pdu.state_key() {
381 let event_id = incoming_pdu.event_id();
382 let event_type = incoming_pdu.kind();
383 let shortstatekey = self
384 .services
385 .short
386 .get_or_create_shortstatekey(&event_type.to_string().into(), state_key)
387 .await;
388
389 state_after.insert(shortstatekey, event_id.to_owned());
390 debug!(
392 ?event_id,
393 ?event_type,
394 ?state_key,
395 ?shortstatekey,
396 state_after = state_after.len(),
397 "Adding event to state."
398 );
399 }
400
401 trace!("Resolving new room state.");
402 let new_room_state = self
403 .resolve_state(room_id, room_version, state_after)
404 .boxed()
405 .await?;
406
407 trace!("Saving resolved state.");
409 let HashSetCompressStateEvent { shortstatehash, added, removed } = self
410 .services
411 .state_compressor
412 .save_state(room_id, new_room_state)
413 .await?;
414
415 debug!(
416 ?shortstatehash,
417 added = added.len(),
418 removed = removed.len(),
419 "Forcing new room state."
420 );
421 self.services
422 .state
423 .force_state(room_id, shortstatehash, added, removed, state_lock)
424 .await?;
425
426 Ok(())
427}