Skip to main content

tuwunel_service/rooms/event_handler/
upgrade_outlier_pdu.rs

1use std::{borrow::Borrow, iter::once, sync::Arc, time::Instant};
2
3use futures::{FutureExt, StreamExt};
4use ruma::{
5	CanonicalJsonObject, EventId, OwnedEventId, RoomId, RoomVersionId, ServerName,
6	events::StateEventType,
7};
8use tuwunel_core::{
9	Err, Result, debug, debug_info, err, implement, is_equal_to,
10	matrix::{Event, EventTypeExt, PduEvent, StateKey, pdu::check_rules, room_version},
11	trace,
12	utils::stream::{BroadbandExt, ReadyExt},
13	warn,
14};
15
16use super::policy_server::PolicyCheck;
17use crate::rooms::{
18	state_compressor::{CompressedState, HashSetCompressStateEvent},
19	state_res,
20	timeline::RawPduId,
21};
22
23#[implement(super::Service)]
24#[tracing::instrument(
25	name = "upgrade",
26	level = "debug",
27	ret(level = "debug"),
28	skip_all,
29	fields(lev = %recursion_level)
30)]
31#[expect(clippy::too_many_arguments)]
32pub(super) async fn upgrade_outlier_to_timeline_pdu(
33	&self,
34	origin: &ServerName,
35	room_id: &RoomId,
36	incoming_pdu: PduEvent,
37	mut pdu_json: CanonicalJsonObject,
38	room_version: &RoomVersionId,
39	recursion_level: usize,
40	create_event_id: &EventId,
41) -> Result<Option<(RawPduId, bool)>> {
42	// Skip the PDU if we already have it as a timeline event
43	if let Ok(pdu_id) = self
44		.services
45		.timeline
46		.get_pdu_id(incoming_pdu.event_id())
47		.await
48	{
49		debug!(?pdu_id, "Exists.");
50		return Ok(Some((pdu_id, false)));
51	}
52
53	if self
54		.services
55		.pdu_metadata
56		.is_event_soft_failed(incoming_pdu.event_id())
57		.await
58	{
59		return Err!(Request(InvalidParam("Event has been soft failed")));
60	}
61
62	trace!("Upgrading to timeline pdu");
63
64	let timer = Instant::now();
65	let room_rules = room_version::rules(room_version)?;
66
67	trace!(format = ?room_rules.event_format, "Checking format");
68	check_rules(&pdu_json, &room_rules.event_format)?;
69
70	// 10. Fetch missing state and auth chain events by calling /state_ids at
71	//     backwards extremities doing all the checks in this list starting at 1.
72	//     These are not timeline events.
73	trace!("Resolving state at event");
74
75	let mut state_at_incoming_event = if incoming_pdu.prev_events().count() == 1 {
76		self.state_at_incoming_degree_one(&incoming_pdu)
77			.await?
78	} else {
79		self.state_at_incoming_resolved(&incoming_pdu, room_id, room_version)
80			.boxed()
81			.await?
82	};
83
84	if state_at_incoming_event.is_none() {
85		state_at_incoming_event = self
86			.fetch_state(
87				origin,
88				room_id,
89				incoming_pdu.event_id(),
90				room_version,
91				recursion_level,
92				create_event_id,
93			)
94			.boxed()
95			.await?;
96	}
97
98	let state_at_incoming_event =
99		state_at_incoming_event.expect("we always set this to some above");
100
101	// 11. Check the auth of the event passes based on the state of the event
102
103	let state_fetch = async |k: StateEventType, s: StateKey| {
104		let shortstatekey = self
105			.services
106			.short
107			.get_shortstatekey(&k, s.as_str())
108			.await?;
109
110		let event_id = state_at_incoming_event
111			.get(&shortstatekey)
112			.ok_or_else(|| {
113				err!(Request(NotFound(
114					"shortstatekey {shortstatekey:?} not found for ({k:?},{s:?})"
115				)))
116			})?;
117
118		self.services.timeline.get_pdu(event_id).await
119	};
120
121	let event_fetch = async |event_id: OwnedEventId| self.event_fetch(&event_id).await;
122
123	trace!("Performing auth check");
124	state_res::auth_check(&room_rules, &incoming_pdu, &event_fetch, &state_fetch).await?;
125
126	trace!("Gathering auth events");
127	let auth_events = self
128		.services
129		.state
130		.get_auth_events(
131			room_id,
132			incoming_pdu.kind(),
133			incoming_pdu.sender(),
134			incoming_pdu.state_key(),
135			incoming_pdu.content(),
136			&room_rules.authorization,
137			true,
138		)
139		.await?;
140
141	let state_fetch = async |k: StateEventType, s: StateKey| {
142		auth_events
143			.get(&k.with_state_key(s.as_str()))
144			.map(ToOwned::to_owned)
145			.ok_or_else(|| err!(Request(NotFound("state event not found"))))
146	};
147
148	trace!("Performing auth check");
149	state_res::auth_check(&room_rules, &incoming_pdu, &event_fetch, &state_fetch).await?;
150
151	// Soft fail check before doing state res
152	trace!("Performing soft-fail check");
153	let soft_fail_redact = match incoming_pdu.redacts_id(&room_rules) {
154		| None => false,
155		| Some(redact_id) =>
156			!self
157				.services
158				.state_accessor
159				.user_can_redact(&redact_id, incoming_pdu.sender(), incoming_pdu.room_id(), true)
160				.await?,
161	};
162
163	// MSC4284: soft-fail when the policy server rejects the event.
164	let soft_fail = soft_fail_redact
165		|| matches!(
166			self.verify_or_fetch_inbound_policy_signature(&mut pdu_json, &incoming_pdu)
167				.await,
168			PolicyCheck::Invalid,
169		);
170
171	// 13. Use state resolution to find new room state
172	// We start looking at current room state now, so lets lock the room
173	trace!("Locking the room");
174	let state_lock = self.services.state.mutex.lock(room_id).await;
175
176	// Now we calculate the set of extremities this room has after the incoming
177	// event has been applied. We start with the previous extremities (aka leaves)
178	trace!("Calculating extremities");
179	let extremities: Vec<_> = self
180		.services
181		.state
182		.get_forward_extremities(room_id)
183		.map(ToOwned::to_owned)
184		.ready_filter(|event_id| {
185			// Remove any that are referenced by this incoming event's prev_events
186			!incoming_pdu
187				.prev_events()
188				.any(is_equal_to!(event_id))
189		})
190		.broad_filter_map(async |event_id| {
191			// Only keep those extremities were not referenced yet
192			self.services
193				.pdu_metadata
194				.is_event_referenced(room_id, &event_id)
195				.await
196				.eq(&false)
197				.then_some(event_id)
198		})
199		.collect()
200		.await;
201
202	debug!(
203		retained = extremities.len(),
204		prev_events = incoming_pdu.prev_events().count(),
205		"Retained extremities checked against prev_events.",
206	);
207
208	trace!("Compressing state...");
209	let state_ids_compressed: Arc<CompressedState> = self
210		.services
211		.state_compressor
212		.compress_state_events(
213			state_at_incoming_event
214				.iter()
215				.map(|(ssk, eid)| (ssk, eid.borrow())),
216		)
217		.collect()
218		.map(Arc::new)
219		.await;
220
221	if incoming_pdu.state_key().is_some() {
222		// We also add state after incoming event to the fork states
223		let mut state_after = state_at_incoming_event.clone();
224		if let Some(state_key) = incoming_pdu.state_key() {
225			let event_id = incoming_pdu.event_id();
226			let event_type = incoming_pdu.kind();
227			let shortstatekey = self
228				.services
229				.short
230				.get_or_create_shortstatekey(&event_type.to_string().into(), state_key)
231				.await;
232
233			state_after.insert(shortstatekey, event_id.to_owned());
234			// Now it's the state after the event.
235			debug!(
236				?event_id,
237				?event_type,
238				?state_key,
239				?shortstatekey,
240				state_after = state_after.len(),
241				"Adding event to state."
242			);
243		}
244
245		trace!("Resolving new room state.");
246		let new_room_state = self
247			.resolve_state(room_id, room_version, state_after)
248			.boxed()
249			.await?;
250
251		// Set the new room state to the resolved state
252		trace!("Saving resolved state.");
253		let HashSetCompressStateEvent { shortstatehash, added, removed } = self
254			.services
255			.state_compressor
256			.save_state(room_id, new_room_state)
257			.await?;
258
259		debug!(
260			?shortstatehash,
261			added = added.len(),
262			removed = removed.len(),
263			"Forcing new room state."
264		);
265		self.services
266			.state
267			.force_state(room_id, shortstatehash, added, removed, &state_lock)
268			.await?;
269	}
270
271	// 14. Check if the event passes auth based on the "current state" of the room,
272	//     if not soft fail it
273	//
274	// Now that the event has passed all auth it is added into the timeline.
275	// We use the `state_at_event` instead of `state_after` so we accurately
276	// represent the state for this event.
277	trace!("Appending pdu to timeline");
278
279	// Incoming event will be referenced in prev_events unless soft-failed.
280	let incoming_extremity = once(incoming_pdu.event_id()).filter(|_| !soft_fail);
281
282	let extremities = extremities
283		.iter()
284		.map(Borrow::borrow)
285		.chain(incoming_extremity);
286
287	let pdu_id = self
288		.services
289		.timeline
290		.append_incoming_pdu(
291			&incoming_pdu,
292			pdu_json,
293			extremities,
294			state_ids_compressed,
295			soft_fail,
296			&state_lock,
297		)
298		.await?;
299
300	debug_assert!(
301		pdu_id.is_some() || soft_fail,
302		"Ok(None) returned by timeline for soft-failed PDU's"
303	);
304
305	if soft_fail {
306		self.services
307			.pdu_metadata
308			.mark_event_soft_failed(incoming_pdu.event_id());
309
310		drop(state_lock);
311		warn!(
312			elapsed = ?timer.elapsed(),
313			"Event was soft failed: {:?}",
314			incoming_pdu.event_id()
315		);
316
317		return Err!(Request(InvalidParam("Event has been soft failed")));
318	}
319
320	drop(state_lock);
321	debug_info!(
322		elapsed = ?timer.elapsed(),
323		"Accepted",
324	);
325
326	Ok(pdu_id.zip(Some(true)))
327}