tuwunel_service/rooms/state_res/
topological_sort.rs1use std::{
26 cmp::{Ordering, Reverse},
27 collections::{BinaryHeap, HashMap, HashSet},
28};
29
30use futures::{Stream, TryFutureExt, TryStreamExt, stream::try_unfold};
31use ruma::{
32 MilliSecondsSinceUnixEpoch, OwnedEventId, events::room::power_levels::UserPowerLevel,
33};
34use tuwunel_core::{
35 Error, Result, is_not_equal_to, smallvec::SmallVec, utils::stream::IterStream, validated,
36};
37
38pub type ReferencedIds = SmallVec<[OwnedEventId; 3]>;
39type PduInfo = (UserPowerLevel, MilliSecondsSinceUnixEpoch);
40
41#[derive(PartialEq, Eq)]
42struct TieBreaker {
43 event_id: OwnedEventId,
44 power_level: UserPowerLevel,
45 origin_server_ts: MilliSecondsSinceUnixEpoch,
46}
47
48impl Ord for TieBreaker {
50 fn cmp(&self, other: &Self) -> Ordering {
51 other
52 .power_level
53 .cmp(&self.power_level)
54 .then(self.origin_server_ts.cmp(&other.origin_server_ts))
55 .then(self.event_id.cmp(&other.event_id))
56 }
57}
58
59impl PartialOrd for TieBreaker {
60 fn partial_cmp(&self, other: &Self) -> Option<Ordering> { Some(self.cmp(other)) }
61}
62
63#[tracing::instrument(
85 level = "debug",
86 skip_all,
87 fields(
88 graph = graph.len(),
89 )
90)]
91#[expect(clippy::implicit_hasher)]
92pub async fn topological_sort<Query, Fut>(
93 graph: HashMap<OwnedEventId, ReferencedIds>,
94 query: &Query,
95) -> Result<Vec<OwnedEventId>>
96where
97 Query: Fn(OwnedEventId) -> Fut + Sync,
98 Fut: Future<Output = Result<PduInfo>> + Send,
99{
100 let query = async |event_id: OwnedEventId| {
101 let (power_level, origin_server_ts) = query(event_id.clone()).await?;
102 Ok::<_, Error>(TieBreaker { event_id, power_level, origin_server_ts })
103 };
104
105 let max_edges = graph
106 .values()
107 .map(ReferencedIds::len)
108 .fold(graph.len(), |a, c| validated!(a + c));
109
110 let incoming = graph
111 .iter()
112 .flat_map(|(event_id, out)| {
113 out.iter()
114 .map(move |reference| (event_id, reference))
115 })
116 .fold(HashMap::with_capacity(max_edges), |mut incoming, (event_id, reference)| {
117 let references: &mut ReferencedIds = incoming.entry(reference.clone()).or_default();
118
119 if !references.contains(event_id) {
120 references.push(event_id.clone());
121 }
122
123 incoming
124 });
125
126 let horizon = graph
128 .iter()
129 .filter(|(_, references)| {
130 !references
131 .iter()
132 .any(|reference| graph.contains_key(reference))
133 })
134 .try_stream()
135 .and_then(async |(event_id, _)| Ok(Reverse(query(event_id.clone()).await?)))
136 .try_collect::<BinaryHeap<Reverse<TieBreaker>>>()
137 .await?;
138
139 kahn_sort(horizon, graph, &incoming, &query)
140 .try_collect()
141 .await
142}
143
144#[tracing::instrument(
148 level = "debug",
149 skip_all,
150 fields(
151 heap = %heap.len(),
152 graph = %graph.len(),
153 )
154)]
155fn kahn_sort<Query, Fut>(
156 heap: BinaryHeap<Reverse<TieBreaker>>,
157 graph: HashMap<OwnedEventId, ReferencedIds>,
158 incoming: &HashMap<OwnedEventId, ReferencedIds>,
159 query: &Query,
160) -> impl Stream<Item = Result<OwnedEventId>> + Send
161where
162 Query: Fn(OwnedEventId) -> Fut + Sync,
163 Fut: Future<Output = Result<TieBreaker>> + Send,
164{
165 try_unfold((heap, graph), move |(mut heap, graph)| async move {
166 let Some(Reverse(item)) = heap.pop() else {
167 return Ok(None);
168 };
169
170 let references = incoming.get(&item.event_id).cloned();
171 let state = (item.event_id, (heap, graph));
172 references
173 .into_iter()
174 .flatten()
175 .try_stream()
176 .try_fold(state, |(event_id, (mut heap, mut graph)), parent_id| async move {
177 graph
178 .get_mut(&parent_id)
179 .expect("contains all parent_ids")
180 .retain(is_not_equal_to!(&event_id));
181
182 if !graph[&parent_id]
184 .iter()
185 .any(|reference| graph.contains_key(reference))
186 {
187 heap.push(Reverse(query(parent_id.clone()).await?));
188 }
189
190 Ok::<_, Error>((event_id, (heap, graph)))
191 })
192 .map_ok(Some)
193 .await
194 })
195}
196
197#[expect(clippy::implicit_hasher)]
207pub fn is_topologically_sorted<'a, Order>(
208 order: Order,
209 graph: &HashMap<OwnedEventId, ReferencedIds>,
210) -> bool
211where
212 Order: IntoIterator<Item = &'a OwnedEventId>,
213{
214 order
215 .into_iter()
216 .try_fold(HashSet::with_capacity(graph.len()), |mut seen, event_id| {
217 let satisfied = graph
218 .get(event_id)
219 .into_iter()
220 .flatten()
221 .filter(|reference| graph.contains_key(*reference))
222 .all(|reference| seen.contains(reference));
223
224 seen.insert(event_id);
225 satisfied.then_some(seen)
226 })
227 .is_some()
228}
229
230pub fn is_topologically_sorted_in_place<'a, T, Id, Refs, Ref>(
241 items: &'a [T],
242 id: Id,
243 references: Refs,
244) -> bool
245where
246 Id: Fn(&'a T) -> &'a str,
247 Refs: Fn(&'a T) -> Ref,
248 Ref: Iterator<Item = &'a str>,
249{
250 if items.len() < 2 {
251 return true;
252 }
253
254 items.iter().enumerate().all(|(i, item)| {
255 references(item).all(|reference| {
256 items[i..]
257 .iter()
258 .all(|other| id(other) != reference)
259 })
260 })
261}
262
263#[cfg(test)]
264mod tests {
265 use super::is_topologically_sorted_in_place;
266
267 fn sorted(items: &[(&str, &[&str])]) -> bool {
268 is_topologically_sorted_in_place(
269 items,
270 |item: &(&str, &[&str])| item.0,
271 |item: &(&str, &[&str])| item.1.iter().copied(),
272 )
273 }
274
275 #[test]
276 fn empty_or_single_is_sorted() {
277 assert!(sorted(&[]));
278 assert!(sorted(&[("a", &[])]));
279 }
280
281 #[test]
282 fn parents_before_children() {
283 assert!(sorted(&[("a", &[]), ("b", &["a"]), ("c", &["a", "b"])]));
284 }
285
286 #[test]
287 fn child_before_parent_is_unsorted() {
288 assert!(!sorted(&[("b", &["a"]), ("a", &[])]));
289 }
290
291 #[test]
292 fn absent_reference_is_a_non_edge() {
293 assert!(sorted(&[("b", &["x"]), ("a", &["x"])]));
294 }
295
296 #[test]
297 fn self_reference_is_unsorted() {
298 assert!(!sorted(&[("a", &["a"]), ("b", &[])]));
299 }
300
301 #[test]
302 fn later_duplicate_of_a_parent_is_unsorted() {
303 assert!(!sorted(&[("a", &[]), ("b", &["a"]), ("a", &[])]));
304 }
305}