tuwunel_service/rooms/state_res/
topological_sort.rs1use std::{
26 cmp::{Ordering, Reverse},
27 collections::{BinaryHeap, HashMap},
28};
29
30use futures::{Stream, TryFutureExt, TryStreamExt, stream::try_unfold};
31use ruma::{
32 MilliSecondsSinceUnixEpoch, OwnedEventId, events::room::power_levels::UserPowerLevel,
33};
34use tuwunel_core::{
35 Error, Result, is_not_equal_to, smallvec::SmallVec, utils::stream::IterStream, validated,
36};
37
38pub type ReferencedIds = SmallVec<[OwnedEventId; 3]>;
39type PduInfo = (UserPowerLevel, MilliSecondsSinceUnixEpoch);
40
41#[derive(PartialEq, Eq)]
42struct TieBreaker {
43 event_id: OwnedEventId,
44 power_level: UserPowerLevel,
45 origin_server_ts: MilliSecondsSinceUnixEpoch,
46}
47
48impl Ord for TieBreaker {
50 fn cmp(&self, other: &Self) -> Ordering {
51 other
52 .power_level
53 .cmp(&self.power_level)
54 .then(self.origin_server_ts.cmp(&other.origin_server_ts))
55 .then(self.event_id.cmp(&other.event_id))
56 }
57}
58
59impl PartialOrd for TieBreaker {
60 fn partial_cmp(&self, other: &Self) -> Option<Ordering> { Some(self.cmp(other)) }
61}
62
63#[tracing::instrument(
83 level = "debug",
84 skip_all,
85 fields(
86 graph = graph.len(),
87 )
88)]
89#[expect(clippy::implicit_hasher)]
90pub async fn topological_sort<Query, Fut>(
91 graph: &HashMap<OwnedEventId, ReferencedIds>,
92 query: &Query,
93) -> Result<Vec<OwnedEventId>>
94where
95 Query: Fn(OwnedEventId) -> Fut + Sync,
96 Fut: Future<Output = Result<PduInfo>> + Send,
97{
98 let query = async |event_id: OwnedEventId| {
99 let (power_level, origin_server_ts) = query(event_id.clone()).await?;
100 Ok::<_, Error>(TieBreaker { event_id, power_level, origin_server_ts })
101 };
102
103 let max_edges = graph
104 .values()
105 .map(ReferencedIds::len)
106 .fold(graph.len(), |a, c| validated!(a + c));
107
108 let incoming = graph
109 .iter()
110 .flat_map(|(event_id, out)| {
111 out.iter()
112 .map(move |reference| (event_id, reference))
113 })
114 .fold(HashMap::with_capacity(max_edges), |mut incoming, (event_id, reference)| {
115 let references: &mut ReferencedIds = incoming.entry(reference.clone()).or_default();
116
117 if !references.contains(event_id) {
118 references.push(event_id.clone());
119 }
120
121 incoming
122 });
123
124 let horizon = graph
125 .iter()
126 .filter(|(_, references)| references.is_empty())
127 .try_stream()
128 .and_then(async |(event_id, _)| Ok(Reverse(query(event_id.clone()).await?)))
129 .try_collect::<BinaryHeap<Reverse<TieBreaker>>>()
130 .await?;
131
132 kahn_sort(horizon, graph.clone(), &incoming, &query)
133 .try_collect()
134 .await
135}
136
137#[tracing::instrument(
141 level = "debug",
142 skip_all,
143 fields(
144 heap = %heap.len(),
145 graph = %graph.len(),
146 )
147)]
148fn kahn_sort<Query, Fut>(
149 heap: BinaryHeap<Reverse<TieBreaker>>,
150 graph: HashMap<OwnedEventId, ReferencedIds>,
151 incoming: &HashMap<OwnedEventId, ReferencedIds>,
152 query: &Query,
153) -> impl Stream<Item = Result<OwnedEventId>> + Send
154where
155 Query: Fn(OwnedEventId) -> Fut + Sync,
156 Fut: Future<Output = Result<TieBreaker>> + Send,
157{
158 try_unfold((heap, graph), move |(mut heap, graph)| async move {
159 let Some(Reverse(item)) = heap.pop() else {
160 return Ok(None);
161 };
162
163 let references = incoming.get(&item.event_id).cloned();
164 let state = (item.event_id, (heap, graph));
165 references
166 .into_iter()
167 .flatten()
168 .try_stream()
169 .try_fold(state, |(event_id, (mut heap, mut graph)), parent_id| async move {
170 let out = graph
171 .get_mut(&parent_id)
172 .expect("contains all parent_ids");
173
174 out.retain(is_not_equal_to!(&event_id));
175
176 if out.is_empty() {
178 heap.push(Reverse(query(parent_id.clone()).await?));
179 }
180
181 Ok::<_, Error>((event_id, (heap, graph)))
182 })
183 .map_ok(Some)
184 .await
185 })
186}