Skip to main content

tuwunel_service/rooms/state_res/resolve/
conflicted_subgraph.rs

1use std::{
2	collections::{HashMap as Map, hash_map::Entry},
3	iter::once,
4	ops::Deref,
5};
6
7use futures::{
8	Future, Stream, StreamExt,
9	stream::{FuturesUnordered, unfold},
10};
11use ruma::OwnedEventId;
12use tuwunel_core::{
13	Result, implement, is_equal_to,
14	itertools::Itertools,
15	matrix::{Event, pdu::AuthEvents},
16	smallvec::SmallVec,
17	utils::{
18		BoolExt,
19		stream::{IterStream, automatic_width},
20	},
21};
22
23#[derive(Default, Debug)]
24struct Global<Fut: Future + Send> {
25	subgraph: Subgraph,
26	todo: Todo<Fut>,
27	iter: usize,
28}
29
30#[derive(Default, Debug)]
31struct Local {
32	id: usize,
33	path: Path,
34	stack: Stack,
35}
36
37#[derive(Default, Debug)]
38struct Substate {
39	subgraph: bool,
40	seen: bool,
41}
42
43type Todo<Fut> = FuturesUnordered<Fut>;
44type Subgraph = Map<OwnedEventId, Substate>;
45type Path = SmallVec<[OwnedEventId; PATH_INLINE]>;
46type Stack = SmallVec<[Frame; STACK_INLINE]>;
47type Frame = AuthEvents;
48
49const PATH_INLINE: usize = 32;
50const STACK_INLINE: usize = 32;
51const CAPACITY_MULTIPLIER: usize = 4;
52
53#[tracing::instrument(
54	name = "subgraph_dfs",
55	level = "debug",
56	skip_all,
57	fields(
58		starting_events = %conflicted_set.len(),
59	)
60)]
61pub(super) fn conflicted_subgraph_dfs<Fetch, Fut, Pdu>(
62	conflicted_set: &Vec<&OwnedEventId>,
63	fetch: &Fetch,
64) -> impl Stream<Item = OwnedEventId> + Send
65where
66	Fetch: Fn(OwnedEventId) -> Fut + Sync,
67	Fut: Future<Output = Result<Pdu>> + Send,
68	Pdu: Event,
69{
70	let initial_capacity = conflicted_set
71		.len()
72		.saturating_mul(CAPACITY_MULTIPLIER);
73
74	let state = Global {
75		subgraph: Map::with_capacity(initial_capacity),
76		todo: Todo::<_>::new(),
77		iter: 0,
78	};
79
80	let inputs = conflicted_set
81		.iter()
82		.map(Deref::deref)
83		.cloned()
84		.enumerate()
85		.map(Local::new)
86		.filter_map(Local::pop)
87		.map(|(local, event_id)| local.push(fetch, Some(event_id)));
88
89	unfold((inputs, state), async |(mut inputs, mut state)| {
90		debug_assert!(
91			state.todo.len() <= automatic_width(),
92			"Excessive items todo in FuturesUnordered"
93		);
94
95		while state.todo.len() < automatic_width()
96			&& let Some(input) = inputs.next()
97		{
98			state.todo.push(input);
99		}
100
101		let outputs = state
102			.todo
103			.next()
104			.await?
105			.pop()
106			.map(|(local, event_id)| local.eval(&mut state, conflicted_set, event_id))
107			.map(|(local, next_id, outputs)| {
108				if !local.stack.is_empty() {
109					state.todo.push(local.push(fetch, next_id));
110				}
111
112				outputs
113			})
114			.into_iter()
115			.flatten()
116			.stream();
117
118		state.iter = state.iter.saturating_add(1);
119		Some((outputs, (inputs, state)))
120	})
121	.flatten()
122}
123
124#[implement(Local)]
125#[tracing::instrument(
126	name = "descent",
127	level = "trace",
128	skip_all,
129	fields(
130		i = state.iter,
131		s = ?state
132			.subgraph
133			.values()
134			.fold((0_u64, 0_u64), |(a, b), v| {
135				(a.saturating_add(u64::from(v.subgraph)), b.saturating_add(u64::from(v.seen)))
136			}),
137
138		%event_id,
139		id = self.id,
140		path = self.path.len(),
141		stack = self.stack.iter().flatten().count(),
142	)
143)]
144fn eval<Fut: Future + Send>(
145	mut self,
146	state: &mut Global<Fut>,
147	conflicted_event_ids: &Vec<&OwnedEventId>,
148	event_id: OwnedEventId,
149) -> (Self, Option<OwnedEventId>, Path) {
150	let Global { subgraph, .. } = state;
151
152	let insert_path_filter = |subgraph: &mut Subgraph, event_id: &OwnedEventId| match subgraph
153		.entry(event_id.clone())
154	{
155		| Entry::Occupied(state) if state.get().subgraph => false,
156		| Entry::Occupied(mut state) => {
157			state.get_mut().subgraph = true;
158			state.get().subgraph
159		},
160		| Entry::Vacant(state) =>
161			state
162				.insert(Substate { subgraph: true, seen: false })
163				.subgraph,
164	};
165
166	let insert_path = |subgraph: &mut Subgraph, local: &Local| {
167		local
168			.path
169			.iter()
170			.filter(|&event_id| insert_path_filter(subgraph, event_id))
171			.cloned()
172			.collect()
173	};
174
175	let is_conflicted = |event_id: &OwnedEventId| {
176		conflicted_event_ids
177			.binary_search(&event_id)
178			.is_ok()
179	};
180
181	let mut entry = subgraph.entry(event_id.clone());
182
183	if let Entry::Occupied(state) = &entry
184		&& state.get().subgraph
185	{
186		let path = (self.path.len() > 1)
187			.then(|| insert_path(subgraph, &self))
188			.unwrap_or_default();
189
190		self.path.pop();
191		return (self, None, path);
192	}
193
194	if let Entry::Occupied(state) = &mut entry {
195		state.get_mut().seen = true;
196		return (self, None, Path::new());
197	}
198
199	if let Entry::Vacant(state) = entry {
200		state.insert(Substate { subgraph: false, seen: true });
201	}
202
203	let path = (self.path.len() > 1)
204		.and_if(|| is_conflicted(&event_id))
205		.then(|| insert_path(subgraph, &self))
206		.unwrap_or_default();
207
208	let next_id = self
209		.path
210		.iter()
211		.dropping_back(1)
212		.any(is_equal_to!(&event_id))
213		.is_false()
214		.then_some(event_id);
215
216	(self, next_id, path)
217}
218
219#[implement(Local)]
220async fn push<Fetch, Fut, Pdu>(mut self, fetch: &Fetch, event_id: Option<OwnedEventId>) -> Self
221where
222	Fetch: Fn(OwnedEventId) -> Fut + Sync,
223	Fut: Future<Output = Result<Pdu>> + Send,
224	Pdu: Event,
225{
226	if let Some(event_id) = event_id
227		&& let Ok(event) = fetch(event_id).await
228	{
229		self.stack
230			.push(event.auth_events_into().into_iter().collect());
231	}
232
233	self
234}
235
236#[implement(Local)]
237fn pop(mut self) -> Option<(Self, OwnedEventId)> {
238	while self.stack.last().is_some_and(Frame::is_empty) {
239		self.stack.pop();
240		self.path.pop();
241	}
242
243	self.stack
244		.last_mut()
245		.and_then(Frame::pop)
246		.inspect(|event_id| self.path.push(event_id.clone()))
247		.map(move |event_id| (self, event_id))
248}
249
250#[implement(Local)]
251#[allow(clippy::redundant_clone)] // buggy, nursery
252fn new((id, conflicted_event_id): (usize, OwnedEventId)) -> Self {
253	Self {
254		id,
255		path: once(conflicted_event_id.clone()).collect(),
256		stack: once(once(conflicted_event_id).collect()).collect(),
257		..Default::default()
258	}
259}