Skip to main content

tuwunel_service/rooms/event_handler/
upgrade_outlier_pdu.rs

1use std::{borrow::Borrow, collections::HashMap, iter::once, sync::Arc, time::Instant};
2
3use futures::{FutureExt, StreamExt};
4use ruma::{
5	CanonicalJsonObject, EventId, OwnedEventId, RoomId, RoomVersionId, ServerName,
6	events::StateEventType, room_version_rules::RoomVersionRules,
7};
8use tuwunel_core::{
9	Err, Result, debug, debug_info, err, implement, is_equal_to,
10	matrix::{Event, EventTypeExt, PduEvent, StateKey, pdu::check_rules, room_version},
11	trace,
12	utils::stream::{BroadbandExt, ReadyExt},
13	warn,
14};
15
16use super::policy_server::PolicyCheck;
17use crate::rooms::{
18	state::RoomMutexGuard,
19	state_compressor::{CompressedState, HashSetCompressStateEvent},
20	state_res,
21	timeline::RawPduId,
22};
23
24#[implement(super::Service)]
25#[tracing::instrument(
26	name = "upgrade",
27	level = "debug",
28	ret(level = "debug"),
29	skip_all,
30	fields(lev = %recursion_level)
31)]
32#[expect(clippy::too_many_arguments)]
33pub(super) async fn upgrade_outlier_to_timeline_pdu(
34	&self,
35	origin: &ServerName,
36	room_id: &RoomId,
37	incoming_pdu: PduEvent,
38	mut pdu_json: CanonicalJsonObject,
39	room_version: &RoomVersionId,
40	recursion_level: usize,
41	create_event_id: &EventId,
42) -> Result<Option<(RawPduId, bool)>> {
43	// Skip the PDU if we already have it as a timeline event
44	if let Ok(pdu_id) = self
45		.services
46		.timeline
47		.get_pdu_id(incoming_pdu.event_id())
48		.await
49	{
50		debug!(?pdu_id, "Exists.");
51		return Ok(Some((pdu_id, false)));
52	}
53
54	if self
55		.services
56		.pdu_metadata
57		.is_event_soft_failed(incoming_pdu.event_id())
58		.await
59	{
60		return Err!(Request(InvalidParam("Event has been soft failed")));
61	}
62
63	trace!("Upgrading to timeline pdu");
64
65	let timer = Instant::now();
66	let room_rules = room_version::rules(room_version)?;
67
68	trace!(format = ?room_rules.event_format, "Checking format");
69	check_rules(&pdu_json, &room_rules.event_format)?;
70
71	let (state_at_incoming_event, via_fetch) = self
72		.resolve_state_at_incoming_event(
73			origin,
74			room_id,
75			&incoming_pdu,
76			room_version,
77			recursion_level,
78			create_event_id,
79		)
80		.await?;
81
82	self.auth_check_outlier_pdu(room_id, &incoming_pdu, &room_rules, &state_at_incoming_event)
83		.await?;
84
85	let soft_fail = self
86		.compute_soft_fail(&incoming_pdu, &room_rules, &mut pdu_json)
87		.await?;
88
89	// 13. Use state resolution to find new room state
90	// We start looking at current room state now, so lets lock the room
91	trace!("Locking the room");
92	let state_lock = self.services.state.mutex.lock(room_id).await;
93
94	let extremities = self
95		.compute_remaining_extremities(room_id, &incoming_pdu)
96		.await;
97
98	trace!("Compressing state...");
99	let state_ids_compressed: Arc<CompressedState> = self
100		.services
101		.state_compressor
102		.compress_state_events(
103			state_at_incoming_event
104				.iter()
105				.map(|(ssk, eid)| (ssk, eid.borrow())),
106		)
107		.collect()
108		.map(Arc::new)
109		.await;
110
111	if via_fetch && !soft_fail {
112		self.cache_resolved_state(room_id, incoming_pdu.event_id(), state_ids_compressed.clone())
113			.await;
114	}
115
116	if incoming_pdu.state_key().is_some() {
117		self.resolve_and_force_state_after(
118			room_id,
119			room_version,
120			&incoming_pdu,
121			&state_at_incoming_event,
122			&state_lock,
123		)
124		.boxed()
125		.await?;
126	}
127
128	// 14. Check if the event passes auth based on the "current state" of the room,
129	//     if not soft fail it
130	//
131	// Now that the event has passed all auth it is added into the timeline.
132	// We use the `state_at_event` instead of `state_after` so we accurately
133	// represent the state for this event.
134	trace!("Appending pdu to timeline");
135
136	// Incoming event will be referenced in prev_events unless soft-failed.
137	let incoming_extremity = once(incoming_pdu.event_id()).filter(|_| !soft_fail);
138
139	let extremities = extremities
140		.iter()
141		.map(Borrow::borrow)
142		.chain(incoming_extremity);
143
144	let pdu_id = self
145		.services
146		.timeline
147		.append_incoming_pdu(
148			&incoming_pdu,
149			pdu_json,
150			extremities,
151			state_ids_compressed,
152			soft_fail,
153			&state_lock,
154		)
155		.await?;
156
157	debug_assert!(
158		pdu_id.is_some() || soft_fail,
159		"Ok(None) returned by timeline for soft-failed PDU's"
160	);
161
162	if soft_fail {
163		self.services
164			.pdu_metadata
165			.mark_event_soft_failed(incoming_pdu.event_id());
166
167		drop(state_lock);
168		warn!(
169			elapsed = ?timer.elapsed(),
170			"Event was soft failed: {:?}",
171			incoming_pdu.event_id()
172		);
173
174		return Err!(Request(InvalidParam("Event has been soft failed")));
175	}
176
177	drop(state_lock);
178	debug_info!(
179		elapsed = ?timer.elapsed(),
180		"Accepted",
181	);
182
183	Ok(pdu_id.zip(Some(true)))
184}
185
186#[implement(super::Service)]
187async fn resolve_state_at_incoming_event(
188	&self,
189	origin: &ServerName,
190	room_id: &RoomId,
191	incoming_pdu: &PduEvent,
192	room_version: &RoomVersionId,
193	recursion_level: usize,
194	create_event_id: &EventId,
195) -> Result<(HashMap<u64, OwnedEventId>, bool)> {
196	// 10. Fetch missing state and auth chain events by calling /state_ids at
197	//     backwards extremities doing all the checks in this list starting at 1.
198	//     These are not timeline events.
199	trace!("Resolving state at event");
200
201	let mut state_at_incoming_event = if incoming_pdu.prev_events().count() == 1 {
202		self.state_at_incoming_degree_one(incoming_pdu)
203			.await?
204	} else {
205		self.state_at_incoming_resolved(incoming_pdu, room_id, room_version)
206			.boxed()
207			.await?
208	};
209
210	if state_at_incoming_event.is_none() {
211		state_at_incoming_event = self
212			.cached_resolved_state(incoming_pdu.event_id())
213			.await;
214	}
215
216	let (state_at_incoming_event, via_fetch) = match state_at_incoming_event {
217		| Some(state) => (state, false),
218		| None => {
219			let state = self
220				.fetch_state(
221					origin,
222					room_id,
223					incoming_pdu.event_id(),
224					room_version,
225					recursion_level,
226					create_event_id,
227				)
228				.boxed()
229				.await?
230				.expect("fetch_state always resolves state to some");
231
232			(state, true)
233		},
234	};
235
236	Ok((state_at_incoming_event, via_fetch))
237}
238
239#[implement(super::Service)]
240async fn auth_check_outlier_pdu(
241	&self,
242	room_id: &RoomId,
243	incoming_pdu: &PduEvent,
244	room_rules: &RoomVersionRules,
245	state_at_incoming_event: &HashMap<u64, OwnedEventId>,
246) -> Result {
247	// 11. Check the auth of the event passes based on the state of the event
248
249	let state_fetch = async |k: StateEventType, s: StateKey| {
250		let shortstatekey = self
251			.services
252			.short
253			.get_shortstatekey(&k, s.as_str())
254			.await?;
255
256		let event_id = state_at_incoming_event
257			.get(&shortstatekey)
258			.ok_or_else(|| {
259				err!(Request(NotFound(
260					"shortstatekey {shortstatekey:?} not found for ({k:?},{s:?})"
261				)))
262			})?;
263
264		self.services.timeline.get_pdu(event_id).await
265	};
266
267	let event_fetch = async |event_id: OwnedEventId| self.event_fetch(&event_id).await;
268
269	trace!("Performing auth check");
270	state_res::auth_check(room_rules, incoming_pdu, &event_fetch, &state_fetch).await?;
271
272	trace!("Gathering auth events");
273	let auth_events = self
274		.services
275		.state
276		.get_auth_events(
277			room_id,
278			incoming_pdu.kind(),
279			incoming_pdu.sender(),
280			incoming_pdu.state_key(),
281			incoming_pdu.content(),
282			&room_rules.authorization,
283			true,
284		)
285		.await?;
286
287	let state_fetch = async |k: StateEventType, s: StateKey| {
288		auth_events
289			.get(&k.with_state_key(s.as_str()))
290			.map(ToOwned::to_owned)
291			.ok_or_else(|| err!(Request(NotFound("state event not found"))))
292	};
293
294	trace!("Performing auth check");
295	state_res::auth_check(room_rules, incoming_pdu, &event_fetch, &state_fetch).await?;
296
297	Ok(())
298}
299
300#[implement(super::Service)]
301async fn compute_soft_fail(
302	&self,
303	incoming_pdu: &PduEvent,
304	room_rules: &RoomVersionRules,
305	pdu_json: &mut CanonicalJsonObject,
306) -> Result<bool> {
307	// Soft fail check before doing state res
308	trace!("Performing soft-fail check");
309	let soft_fail_redact = match incoming_pdu.redacts_id(room_rules) {
310		| None => false,
311		| Some(redact_id) =>
312			!self
313				.services
314				.state_accessor
315				.user_can_redact(&redact_id, incoming_pdu.sender(), incoming_pdu.room_id(), true)
316				.await?,
317	};
318
319	// MSC4284: soft-fail when the policy server rejects the event.
320	Ok(soft_fail_redact
321		|| matches!(
322			self.verify_or_fetch_inbound_policy_signature(pdu_json, incoming_pdu)
323				.await,
324			PolicyCheck::Invalid,
325		))
326}
327
328#[implement(super::Service)]
329async fn compute_remaining_extremities(
330	&self,
331	room_id: &RoomId,
332	incoming_pdu: &PduEvent,
333) -> Vec<OwnedEventId> {
334	// Now we calculate the set of extremities this room has after the incoming
335	// event has been applied. We start with the previous extremities (aka leaves)
336	trace!("Calculating extremities");
337	let extremities: Vec<_> = self
338		.services
339		.state
340		.get_forward_extremities(room_id)
341		.map(ToOwned::to_owned)
342		.ready_filter(|event_id| {
343			// Remove any that are referenced by this incoming event's prev_events
344			!incoming_pdu
345				.prev_events()
346				.any(is_equal_to!(event_id))
347		})
348		.broad_filter_map(async |event_id| {
349			// Only keep those extremities were not referenced yet
350			self.services
351				.pdu_metadata
352				.is_event_referenced(room_id, &event_id)
353				.await
354				.eq(&false)
355				.then_some(event_id)
356		})
357		.collect()
358		.await;
359
360	debug!(
361		retained = extremities.len(),
362		prev_events = incoming_pdu.prev_events().count(),
363		"Retained extremities checked against prev_events.",
364	);
365
366	extremities
367}
368
369#[implement(super::Service)]
370async fn resolve_and_force_state_after(
371	&self,
372	room_id: &RoomId,
373	room_version: &RoomVersionId,
374	incoming_pdu: &PduEvent,
375	state_at_incoming_event: &HashMap<u64, OwnedEventId>,
376	state_lock: &RoomMutexGuard,
377) -> Result {
378	// We also add state after incoming event to the fork states
379	let mut state_after = state_at_incoming_event.clone();
380	if let Some(state_key) = incoming_pdu.state_key() {
381		let event_id = incoming_pdu.event_id();
382		let event_type = incoming_pdu.kind();
383		let shortstatekey = self
384			.services
385			.short
386			.get_or_create_shortstatekey(&event_type.to_string().into(), state_key)
387			.await;
388
389		state_after.insert(shortstatekey, event_id.to_owned());
390		// Now it's the state after the event.
391		debug!(
392			?event_id,
393			?event_type,
394			?state_key,
395			?shortstatekey,
396			state_after = state_after.len(),
397			"Adding event to state."
398		);
399	}
400
401	trace!("Resolving new room state.");
402	let new_room_state = self
403		.resolve_state(room_id, room_version, state_after)
404		.boxed()
405		.await?;
406
407	// Set the new room state to the resolved state
408	trace!("Saving resolved state.");
409	let HashSetCompressStateEvent { shortstatehash, added, removed } = self
410		.services
411		.state_compressor
412		.save_state(room_id, new_room_state)
413		.await?;
414
415	debug!(
416		?shortstatehash,
417		added = added.len(),
418		removed = removed.len(),
419		"Forcing new room state."
420	);
421	self.services
422		.state
423		.force_state(room_id, shortstatehash, added, removed, state_lock)
424		.await?;
425
426	Ok(())
427}