Skip to main content

tuwunel_service/rooms/state_res/
topological_sort.rs

1//! Sorts the given event graph using reverse topological power ordering.
2//!
3//! Definition in the specification:
4//!
5//! The reverse topological power ordering of a set of events is the
6//! lexicographically smallest topological ordering based on the DAG formed by
7//! referenced events (prev or auth, determined by caller). The reverse
8//! topological power ordering is ordered from earliest event to latest. For
9//! comparing two equal topological orderings to determine which is the
10//! lexicographically smallest, the following comparison relation on events is
11//! used: for events x and y, x < y if
12//!
13//! 1. x’s sender has greater power level than y’s sender, when looking at their
14//!    respective referenced events; or
15//! 2. the senders have the same power level, but x’s origin_server_ts is less
16//!    than y’s origin_server_ts; or
17//! 3. the senders have the same power level and the events have the same
18//!    origin_server_ts, but x’s event_id is less than y’s event_id.
19//!
20//! The reverse topological power ordering can be found by sorting the events
21//! using Kahn’s algorithm for topological sorting, and at each step selecting,
22//! among all the candidate vertices, the smallest vertex using the above
23//! comparison relation.
24
25use std::{
26	cmp::{Ordering, Reverse},
27	collections::{BinaryHeap, HashMap},
28};
29
30use futures::{Stream, TryFutureExt, TryStreamExt, stream::try_unfold};
31use ruma::{
32	MilliSecondsSinceUnixEpoch, OwnedEventId, events::room::power_levels::UserPowerLevel,
33};
34use tuwunel_core::{
35	Error, Result, is_not_equal_to, smallvec::SmallVec, utils::stream::IterStream, validated,
36};
37
38pub type ReferencedIds = SmallVec<[OwnedEventId; 3]>;
39type PduInfo = (UserPowerLevel, MilliSecondsSinceUnixEpoch);
40
41#[derive(PartialEq, Eq)]
42struct TieBreaker {
43	event_id: OwnedEventId,
44	power_level: UserPowerLevel,
45	origin_server_ts: MilliSecondsSinceUnixEpoch,
46}
47
48// NOTE: the power level comparison is "backwards" intentionally.
49impl Ord for TieBreaker {
50	fn cmp(&self, other: &Self) -> Ordering {
51		other
52			.power_level
53			.cmp(&self.power_level)
54			.then(self.origin_server_ts.cmp(&other.origin_server_ts))
55			.then(self.event_id.cmp(&other.event_id))
56	}
57}
58
59impl PartialOrd for TieBreaker {
60	fn partial_cmp(&self, other: &Self) -> Option<Ordering> { Some(self.cmp(other)) }
61}
62
63/// Sorts the given event graph using reverse topological power ordering.
64///
65/// ## Arguments
66///
67/// * `graph` - The graph to sort. A map of event ID to its referenced events
68///   that are in the full conflicted set.
69///
70/// * `query` - Function to obtain a (power level, origin_server_ts) of an event
71///   for breaking ties.
72///
73/// ## Returns
74///
75/// Returns the ordered list of event IDs from earliest to latest.
76///
77/// We consider that the DAG is directed from most recent events to oldest
78/// events, so an event is an incoming edge to its referenced events.
79/// zero_outdegs: Vec of events that have an outdegree of zero (no outgoing
80/// edges), i.e. the oldest events. incoming_edges_map: Map of event to the list
81/// of events that reference it in its referenced events.
82#[tracing::instrument(
83	level = "debug",
84	skip_all,
85	fields(
86		graph = graph.len(),
87	)
88)]
89#[expect(clippy::implicit_hasher)]
90pub async fn topological_sort<Query, Fut>(
91	graph: &HashMap<OwnedEventId, ReferencedIds>,
92	query: &Query,
93) -> Result<Vec<OwnedEventId>>
94where
95	Query: Fn(OwnedEventId) -> Fut + Sync,
96	Fut: Future<Output = Result<PduInfo>> + Send,
97{
98	let query = async |event_id: OwnedEventId| {
99		let (power_level, origin_server_ts) = query(event_id.clone()).await?;
100		Ok::<_, Error>(TieBreaker { event_id, power_level, origin_server_ts })
101	};
102
103	let max_edges = graph
104		.values()
105		.map(ReferencedIds::len)
106		.fold(graph.len(), |a, c| validated!(a + c));
107
108	let incoming = graph
109		.iter()
110		.flat_map(|(event_id, out)| {
111			out.iter()
112				.map(move |reference| (event_id, reference))
113		})
114		.fold(HashMap::with_capacity(max_edges), |mut incoming, (event_id, reference)| {
115			let references: &mut ReferencedIds = incoming.entry(reference.clone()).or_default();
116
117			if !references.contains(event_id) {
118				references.push(event_id.clone());
119			}
120
121			incoming
122		});
123
124	let horizon = graph
125		.iter()
126		.filter(|(_, references)| references.is_empty())
127		.try_stream()
128		.and_then(async |(event_id, _)| Ok(Reverse(query(event_id.clone()).await?)))
129		.try_collect::<BinaryHeap<Reverse<TieBreaker>>>()
130		.await?;
131
132	kahn_sort(horizon, graph.clone(), &incoming, &query)
133		.try_collect()
134		.await
135}
136
137// Apply Kahn's algorithm.
138// https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
139// Use a BinaryHeap to keep the events with an outdegree of zero sorted.
140#[tracing::instrument(
141	level = "debug",
142	skip_all,
143	fields(
144		heap = %heap.len(),
145		graph = %graph.len(),
146	)
147)]
148fn kahn_sort<Query, Fut>(
149	heap: BinaryHeap<Reverse<TieBreaker>>,
150	graph: HashMap<OwnedEventId, ReferencedIds>,
151	incoming: &HashMap<OwnedEventId, ReferencedIds>,
152	query: &Query,
153) -> impl Stream<Item = Result<OwnedEventId>> + Send
154where
155	Query: Fn(OwnedEventId) -> Fut + Sync,
156	Fut: Future<Output = Result<TieBreaker>> + Send,
157{
158	try_unfold((heap, graph), move |(mut heap, graph)| async move {
159		let Some(Reverse(item)) = heap.pop() else {
160			return Ok(None);
161		};
162
163		let references = incoming.get(&item.event_id).cloned();
164		let state = (item.event_id, (heap, graph));
165		references
166			.into_iter()
167			.flatten()
168			.try_stream()
169			.try_fold(state, |(event_id, (mut heap, mut graph)), parent_id| async move {
170				let out = graph
171					.get_mut(&parent_id)
172					.expect("contains all parent_ids");
173
174				out.retain(is_not_equal_to!(&event_id));
175
176				// Push on the heap once all the outgoing edges have been removed.
177				if out.is_empty() {
178					heap.push(Reverse(query(parent_id.clone()).await?));
179				}
180
181				Ok::<_, Error>((event_id, (heap, graph)))
182			})
183			.map_ok(Some)
184			.await
185	})
186}