Skip to main content

tuwunel_service/rooms/state_res/
topological_sort.rs

1//! Sorts the given event graph using reverse topological power ordering.
2//!
3//! Definition in the specification:
4//!
5//! The reverse topological power ordering of a set of events is the
6//! lexicographically smallest topological ordering based on the DAG formed by
7//! referenced events (prev or auth, determined by caller). The reverse
8//! topological power ordering is ordered from earliest event to latest. For
9//! comparing two equal topological orderings to determine which is the
10//! lexicographically smallest, the following comparison relation on events is
11//! used: for events x and y, x < y if
12//!
13//! 1. x’s sender has greater power level than y’s sender, when looking at their
14//!    respective referenced events; or
15//! 2. the senders have the same power level, but x’s origin_server_ts is less
16//!    than y’s origin_server_ts; or
17//! 3. the senders have the same power level and the events have the same
18//!    origin_server_ts, but x’s event_id is less than y’s event_id.
19//!
20//! The reverse topological power ordering can be found by sorting the events
21//! using Kahn’s algorithm for topological sorting, and at each step selecting,
22//! among all the candidate vertices, the smallest vertex using the above
23//! comparison relation.
24
25use std::{
26	cmp::{Ordering, Reverse},
27	collections::{BinaryHeap, HashMap, HashSet},
28};
29
30use futures::{Stream, TryFutureExt, TryStreamExt, stream::try_unfold};
31use ruma::{
32	MilliSecondsSinceUnixEpoch, OwnedEventId, events::room::power_levels::UserPowerLevel,
33};
34use tuwunel_core::{
35	Error, Result, is_not_equal_to, smallvec::SmallVec, utils::stream::IterStream, validated,
36};
37
38pub type ReferencedIds = SmallVec<[OwnedEventId; 3]>;
39type PduInfo = (UserPowerLevel, MilliSecondsSinceUnixEpoch);
40
41#[derive(PartialEq, Eq)]
42struct TieBreaker {
43	event_id: OwnedEventId,
44	power_level: UserPowerLevel,
45	origin_server_ts: MilliSecondsSinceUnixEpoch,
46}
47
48// NOTE: the power level comparison is "backwards" intentionally.
49impl Ord for TieBreaker {
50	fn cmp(&self, other: &Self) -> Ordering {
51		other
52			.power_level
53			.cmp(&self.power_level)
54			.then(self.origin_server_ts.cmp(&other.origin_server_ts))
55			.then(self.event_id.cmp(&other.event_id))
56	}
57}
58
59impl PartialOrd for TieBreaker {
60	fn partial_cmp(&self, other: &Self) -> Option<Ordering> { Some(self.cmp(other)) }
61}
62
63/// Sorts the given event graph using reverse topological power ordering.
64///
65/// ## Arguments
66///
67/// * `graph` - The graph to sort. A map of event ID to its referenced events
68///   that are in the full conflicted set.
69///
70/// * `query` - Function to obtain a (power level, origin_server_ts) of an event
71///   for breaking ties.
72///
73/// ## Returns
74///
75/// Returns the ordered list of event IDs from earliest to latest. Every event
76/// in the graph appears exactly once; a reference to an event absent from the
77/// graph is treated as a non-edge rather than dropping the referencing event.
78///
79/// We consider that the DAG is directed from most recent events to oldest
80/// events, so an event is an incoming edge to its referenced events.
81/// zero_outdegs: Vec of events that have an outdegree of zero (no outgoing
82/// edges), i.e. the oldest events. incoming_edges_map: Map of event to the list
83/// of events that reference it in its referenced events.
84#[tracing::instrument(
85	level = "debug",
86	skip_all,
87	fields(
88		graph = graph.len(),
89	)
90)]
91#[expect(clippy::implicit_hasher)]
92pub async fn topological_sort<Query, Fut>(
93	graph: HashMap<OwnedEventId, ReferencedIds>,
94	query: &Query,
95) -> Result<Vec<OwnedEventId>>
96where
97	Query: Fn(OwnedEventId) -> Fut + Sync,
98	Fut: Future<Output = Result<PduInfo>> + Send,
99{
100	let query = async |event_id: OwnedEventId| {
101		let (power_level, origin_server_ts) = query(event_id.clone()).await?;
102		Ok::<_, Error>(TieBreaker { event_id, power_level, origin_server_ts })
103	};
104
105	let max_edges = graph
106		.values()
107		.map(ReferencedIds::len)
108		.fold(graph.len(), |a, c| validated!(a + c));
109
110	let incoming = graph
111		.iter()
112		.flat_map(|(event_id, out)| {
113			out.iter()
114				.map(move |reference| (event_id, reference))
115		})
116		.fold(HashMap::with_capacity(max_edges), |mut incoming, (event_id, reference)| {
117			let references: &mut ReferencedIds = incoming.entry(reference.clone()).or_default();
118
119			if !references.contains(event_id) {
120				references.push(event_id.clone());
121			}
122
123			incoming
124		});
125
126	// A reference absent from the graph is unresolvable and not an out-edge.
127	let horizon = graph
128		.iter()
129		.filter(|(_, references)| {
130			!references
131				.iter()
132				.any(|reference| graph.contains_key(reference))
133		})
134		.try_stream()
135		.and_then(async |(event_id, _)| Ok(Reverse(query(event_id.clone()).await?)))
136		.try_collect::<BinaryHeap<Reverse<TieBreaker>>>()
137		.await?;
138
139	kahn_sort(horizon, graph, &incoming, &query)
140		.try_collect()
141		.await
142}
143
144// Apply Kahn's algorithm.
145// https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
146// Use a BinaryHeap to keep the events with an outdegree of zero sorted.
147#[tracing::instrument(
148	level = "debug",
149	skip_all,
150	fields(
151		heap = %heap.len(),
152		graph = %graph.len(),
153	)
154)]
155fn kahn_sort<Query, Fut>(
156	heap: BinaryHeap<Reverse<TieBreaker>>,
157	graph: HashMap<OwnedEventId, ReferencedIds>,
158	incoming: &HashMap<OwnedEventId, ReferencedIds>,
159	query: &Query,
160) -> impl Stream<Item = Result<OwnedEventId>> + Send
161where
162	Query: Fn(OwnedEventId) -> Fut + Sync,
163	Fut: Future<Output = Result<TieBreaker>> + Send,
164{
165	try_unfold((heap, graph), move |(mut heap, graph)| async move {
166		let Some(Reverse(item)) = heap.pop() else {
167			return Ok(None);
168		};
169
170		let references = incoming.get(&item.event_id).cloned();
171		let state = (item.event_id, (heap, graph));
172		references
173			.into_iter()
174			.flatten()
175			.try_stream()
176			.try_fold(state, |(event_id, (mut heap, mut graph)), parent_id| async move {
177				graph
178					.get_mut(&parent_id)
179					.expect("contains all parent_ids")
180					.retain(is_not_equal_to!(&event_id));
181
182				// References to absent events never resolve; gate on present out-edges only.
183				if !graph[&parent_id]
184					.iter()
185					.any(|reference| graph.contains_key(reference))
186				{
187					heap.push(Reverse(query(parent_id.clone()).await?));
188				}
189
190				Ok::<_, Error>((event_id, (heap, graph)))
191			})
192			.map_ok(Some)
193			.await
194	})
195}
196
197/// Tests whether the events in `order` are in reverse topological power
198/// ordering with respect to `graph`: each event appears after every event it
199/// references that is present in `graph`.
200///
201/// A reference absent from `graph` is a non-edge, as in [`topological_sort`].
202/// The check covers relative order only: it certifies that the events present
203/// in `order` are mutually consistent, not that `order` is complete or
204/// duplicate-free, and (lacking the tie-breaker) not that it equals the exact
205/// sequence [`topological_sort`] would select.
206#[expect(clippy::implicit_hasher)]
207pub fn is_topologically_sorted<'a, Order>(
208	order: Order,
209	graph: &HashMap<OwnedEventId, ReferencedIds>,
210) -> bool
211where
212	Order: IntoIterator<Item = &'a OwnedEventId>,
213{
214	order
215		.into_iter()
216		.try_fold(HashSet::with_capacity(graph.len()), |mut seen, event_id| {
217			let satisfied = graph
218				.get(event_id)
219				.into_iter()
220				.flatten()
221				.filter(|reference| graph.contains_key(*reference))
222				.all(|reference| seen.contains(reference));
223
224			seen.insert(event_id);
225			satisfied.then_some(seen)
226		})
227		.is_some()
228}
229
230/// Whether `items` are already in reverse topological order, reading each
231/// item's references in place instead of from a prebuilt graph. An item must
232/// follow every item it references that is also present among `items`; a
233/// reference to an absent item is a non-edge. `id` reads an item's identifier,
234/// `references` its outgoing references.
235///
236/// This allocates nothing, scanning the remaining items for each reference
237/// rather than building a seen-set. The scan is quadratic, so it suits short
238/// sequences; [`is_topologically_sorted`] is the better pick for a large
239/// sequence or one whose graph is already built.
240pub fn is_topologically_sorted_in_place<'a, T, Id, Refs, Ref>(
241	items: &'a [T],
242	id: Id,
243	references: Refs,
244) -> bool
245where
246	Id: Fn(&'a T) -> &'a str,
247	Refs: Fn(&'a T) -> Ref,
248	Ref: Iterator<Item = &'a str>,
249{
250	if items.len() < 2 {
251		return true;
252	}
253
254	items.iter().enumerate().all(|(i, item)| {
255		references(item).all(|reference| {
256			items[i..]
257				.iter()
258				.all(|other| id(other) != reference)
259		})
260	})
261}
262
263#[cfg(test)]
264mod tests {
265	use super::is_topologically_sorted_in_place;
266
267	fn sorted(items: &[(&str, &[&str])]) -> bool {
268		is_topologically_sorted_in_place(
269			items,
270			|item: &(&str, &[&str])| item.0,
271			|item: &(&str, &[&str])| item.1.iter().copied(),
272		)
273	}
274
275	#[test]
276	fn empty_or_single_is_sorted() {
277		assert!(sorted(&[]));
278		assert!(sorted(&[("a", &[])]));
279	}
280
281	#[test]
282	fn parents_before_children() {
283		assert!(sorted(&[("a", &[]), ("b", &["a"]), ("c", &["a", "b"])]));
284	}
285
286	#[test]
287	fn child_before_parent_is_unsorted() {
288		assert!(!sorted(&[("b", &["a"]), ("a", &[])]));
289	}
290
291	#[test]
292	fn absent_reference_is_a_non_edge() {
293		assert!(sorted(&[("b", &["x"]), ("a", &["x"])]));
294	}
295
296	#[test]
297	fn self_reference_is_unsorted() {
298		assert!(!sorted(&[("a", &["a"]), ("b", &[])]));
299	}
300
301	#[test]
302	fn later_duplicate_of_a_parent_is_unsorted() {
303		assert!(!sorted(&[("a", &[]), ("b", &["a"]), ("a", &[])]));
304	}
305}