Skip to main content

tuwunel_service/rooms/state_res/resolve/
mainline_sort.rs

1use futures::{
2	FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, pin_mut, stream::try_unfold,
3};
4use ruma::{EventId, OwnedEventId, events::TimelineEventType};
5use tuwunel_core::{
6	Error, Result, at, is_equal_to,
7	matrix::Event,
8	trace,
9	utils::stream::{BroadbandExt, IterStream, TryReadyExt},
10};
11
12/// Perform mainline ordering of the given events.
13///
14/// Definition in the spec:
15/// Given mainline positions calculated from P, the mainline ordering based on P
16/// of a set of events is the ordering, from smallest to largest, using the
17/// following comparison relation on events: for events x and y, x < y if
18///
19/// 1. the mainline position of x is greater than the mainline position of y
20///    (i.e. the auth chain of x is based on an earlier event in the mainline
21///    than y); or
22/// 2. the mainline positions of the events are the same, but x’s
23///    origin_server_ts is less than y’s origin_server_ts; or
24/// 3. the mainline positions of the events are the same and the events have the
25///    same origin_server_ts, but x’s event_id is less than y’s event_id.
26///
27/// ## Arguments
28///
29/// * `events` - The list of event IDs to sort.
30/// * `power_level` - The power level event in the current state.
31/// * `fetch_event` - Function to fetch an event in the room given its event ID.
32///
33/// ## Returns
34///
35/// Returns the sorted list of event IDs, or an `Err(_)` if one the event in the
36/// room has an unexpected format.
37#[tracing::instrument(
38	level = "debug",
39	skip_all,
40	fields(
41		power_levels = power_level_event_id
42			.as_deref()
43			.map(EventId::as_str)
44			.unwrap_or_default(),
45	)
46)]
47pub(super) async fn mainline_sort<'a, RemainingEvents, Fetch, Fut, Pdu>(
48	power_level_event_id: Option<OwnedEventId>,
49	events: RemainingEvents,
50	fetch: &Fetch,
51) -> Result<Vec<OwnedEventId>>
52where
53	RemainingEvents: Stream<Item = &'a EventId> + Send,
54	Fetch: Fn(OwnedEventId) -> Fut + Sync,
55	Fut: Future<Output = Result<Pdu>> + Send,
56	Pdu: Event,
57{
58	// Populate the mainline of the power level.
59	let mainline: Vec<_> = try_unfold(power_level_event_id, async |power_level_event_id| {
60		let Some(power_level_event_id) = power_level_event_id else {
61			return Ok::<_, Error>(None);
62		};
63
64		let power_level_event = fetch(power_level_event_id).await?;
65		let this_event_id = power_level_event.event_id().to_owned();
66		let next_event_id = get_power_levels_auth_event(&power_level_event, fetch)
67			.map_ok(|event| {
68				event
69					.as_ref()
70					.map(Event::event_id)
71					.map(ToOwned::to_owned)
72			})
73			.await?;
74
75		trace!(?this_event_id, ?next_event_id, "mainline descent",);
76
77		Ok(Some((this_event_id, next_event_id)))
78	})
79	.try_collect()
80	.await?;
81
82	let mainline = mainline.iter().rev().map(AsRef::as_ref);
83
84	events
85		.map(ToOwned::to_owned)
86		.broad_filter_map(async |event_id| {
87			let event = fetch(event_id.clone()).await.ok()?;
88			let origin_server_ts = event.origin_server_ts();
89			let position = mainline_position(Some(event), &mainline, fetch)
90				.await
91				.ok()?;
92
93			Some((event_id, (position, origin_server_ts)))
94		})
95		.inspect(|(event_id, (position, origin_server_ts))| {
96			trace!(position, ?origin_server_ts, ?event_id, "mainline position");
97		})
98		.collect()
99		.map(|mut vec: Vec<_>| {
100			vec.sort_by(|a, b| {
101				let (a_pos, a_ots) = &a.1;
102				let (b_pos, b_ots) = &b.1;
103				a_pos
104					.cmp(b_pos)
105					.then(a_ots.cmp(b_ots))
106					.then(a.cmp(b))
107			});
108
109			vec.into_iter().map(at!(0)).collect()
110		})
111		.map(Ok)
112		.await
113}
114
115/// Get the mainline position of the given event from the given mainline map.
116///
117/// ## Arguments
118///
119/// * `event` - The event to compute the mainline position of.
120/// * `mainline_map` - The mainline map of the m.room.power_levels event.
121/// * `fetch` - Function to fetch an event in the room given its event ID.
122///
123/// ## Returns
124///
125/// Returns the mainline position of the event, or an `Err(_)` if one of the
126/// events in the auth chain of the event was not found.
127#[tracing::instrument(
128	name = "position",
129	level = "trace",
130	ret(level = "trace"),
131	skip_all,
132	fields(
133		mainline = mainline.clone().count(),
134		event = ?current_event.as_ref().map(Event::event_id).map(ToOwned::to_owned),
135	)
136)]
137async fn mainline_position<'a, Mainline, Fetch, Fut, Pdu>(
138	mut current_event: Option<Pdu>,
139	mainline: &Mainline,
140	fetch: &Fetch,
141) -> Result<usize>
142where
143	Mainline: Iterator<Item = &'a EventId> + Clone + Send + Sync,
144	Fetch: Fn(OwnedEventId) -> Fut + Sync,
145	Fut: Future<Output = Result<Pdu>> + Send,
146	Pdu: Event,
147{
148	while let Some(event) = current_event {
149		trace!(
150			event_id = ?event.event_id(),
151			"mainline position search",
152		);
153
154		// Real positions are 1..N (i + 1) so that 0 is free to mark
155		// "no power-levels in the auth chain". Without that, no-PL events
156		// would tie with events rooted at the oldest mainline PL.
157		if let Some(position) = mainline
158			.clone()
159			.position(is_equal_to!(event.event_id()))
160		{
161			return Ok(position.saturating_add(1));
162		}
163
164		// Look for the power levels event in the auth events.
165		current_event = get_power_levels_auth_event(&event, fetch).await?;
166	}
167
168	// No power-levels ancestor in the auth chain; sort before all
169	// chain-rooted events.
170	Ok(0)
171}
172
173#[expect(clippy::redundant_closure)]
174#[tracing::instrument(level = "trace", skip_all)]
175async fn get_power_levels_auth_event<Fetch, Fut, Pdu>(
176	event: &Pdu,
177	fetch: &Fetch,
178) -> Result<Option<Pdu>>
179where
180	Fetch: Fn(OwnedEventId) -> Fut + Sync,
181	Fut: Future<Output = Result<Pdu>> + Send,
182	Pdu: Event,
183{
184	let power_level_event = event
185		.auth_events()
186		.try_stream()
187		.map_ok(ToOwned::to_owned)
188		.and_then(|auth_event_id| fetch(auth_event_id))
189		.ready_try_skip_while(|auth_event| {
190			Ok(!auth_event.is_type_and_state_key(&TimelineEventType::RoomPowerLevels, ""))
191		});
192
193	pin_mut!(power_level_event);
194	power_level_event.try_next().await
195}