tuwunel_service/rooms/state_res/resolve/
conflicted_subgraph.rs1use std::{
2 collections::{HashMap as Map, hash_map::Entry},
3 iter::once,
4 ops::Deref,
5};
6
7use futures::{
8 Future, Stream, StreamExt,
9 stream::{FuturesUnordered, unfold},
10};
11use ruma::OwnedEventId;
12use tuwunel_core::{
13 Result, implement, is_equal_to,
14 itertools::Itertools,
15 matrix::{Event, pdu::AuthEvents},
16 smallvec::SmallVec,
17 utils::{
18 BoolExt,
19 stream::{IterStream, automatic_width},
20 },
21};
22
23#[derive(Default, Debug)]
24struct Global<Fut: Future + Send> {
25 subgraph: Subgraph,
26 todo: Todo<Fut>,
27 iter: usize,
28}
29
30#[derive(Default, Debug)]
31struct Local {
32 id: usize,
33 path: Path,
34 stack: Stack,
35}
36
37#[derive(Default, Debug)]
38struct Substate {
39 subgraph: bool,
40 seen: bool,
41}
42
43type Todo<Fut> = FuturesUnordered<Fut>;
44type Subgraph = Map<OwnedEventId, Substate>;
45type Path = SmallVec<[OwnedEventId; PATH_INLINE]>;
46type Stack = SmallVec<[Frame; STACK_INLINE]>;
47type Frame = AuthEvents;
48
49const PATH_INLINE: usize = 32;
50const STACK_INLINE: usize = 32;
51const CAPACITY_MULTIPLIER: usize = 4;
52
53#[tracing::instrument(
54 name = "subgraph_dfs",
55 level = "debug",
56 skip_all,
57 fields(
58 starting_events = %conflicted_set.len(),
59 )
60)]
61pub(super) fn conflicted_subgraph_dfs<Fetch, Fut, Pdu>(
62 conflicted_set: &Vec<&OwnedEventId>,
63 fetch: &Fetch,
64) -> impl Stream<Item = OwnedEventId> + Send
65where
66 Fetch: Fn(OwnedEventId) -> Fut + Sync,
67 Fut: Future<Output = Result<Pdu>> + Send,
68 Pdu: Event,
69{
70 let initial_capacity = conflicted_set
71 .len()
72 .saturating_mul(CAPACITY_MULTIPLIER);
73
74 let state = Global {
75 subgraph: Map::with_capacity(initial_capacity),
76 todo: Todo::<_>::new(),
77 iter: 0,
78 };
79
80 let inputs = conflicted_set
81 .iter()
82 .map(Deref::deref)
83 .cloned()
84 .enumerate()
85 .map(Local::new)
86 .filter_map(Local::pop)
87 .map(|(local, event_id)| local.push(fetch, Some(event_id)));
88
89 unfold((inputs, state), async |(mut inputs, mut state)| {
90 debug_assert!(
91 state.todo.len() <= automatic_width(),
92 "Excessive items todo in FuturesUnordered"
93 );
94
95 while state.todo.len() < automatic_width()
96 && let Some(input) = inputs.next()
97 {
98 state.todo.push(input);
99 }
100
101 let outputs = state
102 .todo
103 .next()
104 .await?
105 .pop()
106 .map(|(local, event_id)| local.eval(&mut state, conflicted_set, event_id))
107 .map(|(local, next_id, outputs)| {
108 if !local.stack.is_empty() {
109 state.todo.push(local.push(fetch, next_id));
110 }
111
112 outputs
113 })
114 .into_iter()
115 .flatten()
116 .stream();
117
118 state.iter = state.iter.saturating_add(1);
119 Some((outputs, (inputs, state)))
120 })
121 .flatten()
122}
123
124#[implement(Local)]
125#[tracing::instrument(
126 name = "descent",
127 level = "trace",
128 skip_all,
129 fields(
130 i = state.iter,
131 s = ?state
132 .subgraph
133 .values()
134 .fold((0_u64, 0_u64), |(a, b), v| {
135 (a.saturating_add(u64::from(v.subgraph)), b.saturating_add(u64::from(v.seen)))
136 }),
137
138 %event_id,
139 id = self.id,
140 path = self.path.len(),
141 stack = self.stack.iter().flatten().count(),
142 )
143)]
144fn eval<Fut: Future + Send>(
145 mut self,
146 state: &mut Global<Fut>,
147 conflicted_event_ids: &Vec<&OwnedEventId>,
148 event_id: OwnedEventId,
149) -> (Self, Option<OwnedEventId>, Path) {
150 let Global { subgraph, .. } = state;
151
152 let insert_path_filter = |subgraph: &mut Subgraph, event_id: &OwnedEventId| match subgraph
153 .entry(event_id.clone())
154 {
155 | Entry::Occupied(state) if state.get().subgraph => false,
156 | Entry::Occupied(mut state) => {
157 state.get_mut().subgraph = true;
158 state.get().subgraph
159 },
160 | Entry::Vacant(state) =>
161 state
162 .insert(Substate { subgraph: true, seen: false })
163 .subgraph,
164 };
165
166 let insert_path = |subgraph: &mut Subgraph, local: &Local| {
167 local
168 .path
169 .iter()
170 .filter(|&event_id| insert_path_filter(subgraph, event_id))
171 .cloned()
172 .collect()
173 };
174
175 let is_conflicted = |event_id: &OwnedEventId| {
176 conflicted_event_ids
177 .binary_search(&event_id)
178 .is_ok()
179 };
180
181 let mut entry = subgraph.entry(event_id.clone());
182
183 if let Entry::Occupied(state) = &entry
184 && state.get().subgraph
185 {
186 let path = (self.path.len() > 1)
187 .then(|| insert_path(subgraph, &self))
188 .unwrap_or_default();
189
190 self.path.pop();
191 return (self, None, path);
192 }
193
194 if let Entry::Occupied(state) = &mut entry {
195 state.get_mut().seen = true;
196 return (self, None, Path::new());
197 }
198
199 if let Entry::Vacant(state) = entry {
200 state.insert(Substate { subgraph: false, seen: true });
201 }
202
203 let path = (self.path.len() > 1)
204 .and_if(|| is_conflicted(&event_id))
205 .then(|| insert_path(subgraph, &self))
206 .unwrap_or_default();
207
208 let next_id = self
209 .path
210 .iter()
211 .dropping_back(1)
212 .any(is_equal_to!(&event_id))
213 .is_false()
214 .then_some(event_id);
215
216 (self, next_id, path)
217}
218
219#[implement(Local)]
220async fn push<Fetch, Fut, Pdu>(mut self, fetch: &Fetch, event_id: Option<OwnedEventId>) -> Self
221where
222 Fetch: Fn(OwnedEventId) -> Fut + Sync,
223 Fut: Future<Output = Result<Pdu>> + Send,
224 Pdu: Event,
225{
226 if let Some(event_id) = event_id
227 && let Ok(event) = fetch(event_id).await
228 {
229 self.stack
230 .push(event.auth_events_into().into_iter().collect());
231 }
232
233 self
234}
235
236#[implement(Local)]
237fn pop(mut self) -> Option<(Self, OwnedEventId)> {
238 while self.stack.last().is_some_and(Frame::is_empty) {
239 self.stack.pop();
240 self.path.pop();
241 }
242
243 self.stack
244 .last_mut()
245 .and_then(Frame::pop)
246 .inspect(|event_id| self.path.push(event_id.clone()))
247 .map(move |event_id| (self, event_id))
248}
249
250#[implement(Local)]
251#[allow(clippy::redundant_clone)] fn new((id, conflicted_event_id): (usize, OwnedEventId)) -> Self {
253 Self {
254 id,
255 path: once(conflicted_event_id.clone()).collect(),
256 stack: once(once(conflicted_event_id).collect()).collect(),
257 ..Default::default()
258 }
259}