Skip to main content

tuwunel_service/rooms/state_res/
resolve.rs

1#[cfg(test)]
2mod tests;
3
4mod auth_difference;
5mod conflicted_subgraph;
6mod iterative_auth_check;
7mod mainline_sort;
8mod power_sort;
9mod split_conflicted;
10
11use std::{
12	collections::{BTreeMap, BTreeSet, HashSet},
13	ops::Deref,
14};
15
16use futures::{FutureExt, Stream, StreamExt, TryFutureExt};
17use ruma::{OwnedEventId, events::StateEventType, room_version_rules::RoomVersionRules};
18use tuwunel_core::{
19	Result, debug,
20	itertools::Itertools,
21	matrix::{Event, TypeStateKey},
22	smallvec::SmallVec,
23	trace,
24	utils::{
25		BoolExt,
26		stream::{BroadbandExt, IterStream},
27	},
28};
29
30use self::{
31	auth_difference::auth_difference, conflicted_subgraph::conflicted_subgraph_dfs,
32	iterative_auth_check::iterative_auth_check, mainline_sort::mainline_sort,
33	power_sort::power_sort, split_conflicted::split_conflicted_state,
34};
35#[cfg(test)]
36use super::test_utils;
37
38/// A mapping of event type and state_key to some value `T`, usually an
39/// `EventId`.
40pub type StateMap<Id> = BTreeMap<TypeStateKey, Id>;
41
42/// Full recursive set of `auth_events` for each event in a StateMap.
43pub type AuthSet<Id> = BTreeSet<Id>;
44
45/// ConflictMap of OwnedEventId specifically.
46pub type ConflictMap<Id> = StateMap<ConflictVec<Id>>;
47
48/// List of conflicting event_ids
49type ConflictVec<Id> = SmallVec<[Id; 2]>;
50
51/// Apply the [state resolution] algorithm introduced in room version 2 to
52/// resolve the state of a room.
53///
54/// ## Arguments
55///
56/// * `rules` - The rules to apply for the version of the current room.
57///
58/// * `state_maps` - The incoming states to resolve. Each `StateMap` represents
59///   a possible fork in the state of a room.
60///
61/// * `auth_chains` - The list of full recursive sets of `auth_events` for each
62///   event in the `state_maps`.
63///
64/// * `fetch_event` - Function to fetch an event in the room given its event ID.
65///
66/// ## Invariants
67///
68/// The caller of `resolve` must ensure that all the events are from the same
69/// room.
70///
71/// ## Returns
72///
73/// The resolved room state.
74///
75/// [state resolution]: https://spec.matrix.org/latest/rooms/v2/#state-resolution
76#[tracing::instrument(level = "debug", skip_all)]
77pub async fn resolve<States, AuthSets, FetchExists, ExistsFut, FetchEvent, EventFut, Pdu>(
78	rules: &RoomVersionRules,
79	state_maps: States,
80	auth_sets: AuthSets,
81	fetch: &FetchEvent,
82	exists: &FetchExists,
83	hydra_backports: bool,
84) -> Result<StateMap<OwnedEventId>>
85where
86	States: Stream<Item = StateMap<OwnedEventId>> + Send,
87	AuthSets: Stream<Item = AuthSet<OwnedEventId>> + Send,
88	FetchExists: Fn(OwnedEventId) -> ExistsFut + Sync,
89	ExistsFut: Future<Output = bool> + Send,
90	FetchEvent: Fn(OwnedEventId) -> EventFut + Sync,
91	EventFut: Future<Output = Result<Pdu>> + Send,
92	Pdu: Event + Clone,
93{
94	// Split the unconflicted state map and the conflicted state set.
95	let (unconflicted_state, conflicted_states) = split_conflicted_state(state_maps).await;
96
97	debug!(
98		unconflicted = unconflicted_state.len(),
99		conflicted_states = conflicted_states.len(),
100		conflicted_events = conflicted_states
101			.values()
102			.fold(0_usize, |a, s| a.saturating_add(s.len())),
103		"unresolved states"
104	);
105
106	trace!(
107		?unconflicted_state,
108		?conflicted_states,
109		unconflicted = unconflicted_state.len(),
110		conflicted_states = conflicted_states.len(),
111		"unresolved states"
112	);
113
114	if conflicted_states.is_empty() {
115		return Ok(unconflicted_state.into_iter().collect());
116	}
117
118	// 0. The full conflicted set is the union of the conflicted state set and the
119	//    auth difference. Don't honor events that don't exist.
120	let full_conflicted_set =
121		full_conflicted_set(rules, conflicted_states, auth_sets, fetch, exists, hydra_backports)
122			.await;
123
124	// 1. Select the set X of all power events that appear in the full conflicted
125	//    set. For each such power event P, enlarge X by adding the events in the
126	//    auth chain of P which also belong to the full conflicted set. Sort X into
127	//    a list using the reverse topological power ordering.
128	let sorted_power_set: Vec<_> = power_sort(rules, &full_conflicted_set, fetch)
129		.inspect_ok(|list| debug!(count = list.len(), "sorted power events"))
130		.inspect_ok(|list| trace!(?list, "sorted power events"))
131		.boxed()
132		.await?;
133
134	let power_set_event_ids: Vec<_> = sorted_power_set
135		.iter()
136		.sorted_unstable()
137		.collect();
138
139	let sorted_power_set = sorted_power_set
140		.iter()
141		.stream()
142		.map(AsRef::as_ref);
143
144	let begin_with_empty_state_map = rules
145		.state_res
146		.v2_rules()
147		.is_some_and(|r| r.begin_iterative_auth_checks_with_empty_state_map)
148		|| hydra_backports;
149
150	let initial_state = begin_with_empty_state_map
151		.is_false()
152		.then(|| unconflicted_state.clone())
153		.unwrap_or_default();
154
155	// 2. Apply the iterative auth checks algorithm, starting from the unconflicted
156	//    state map, to the list of events from the previous step to get a partially
157	//    resolved state.
158	let partially_resolved_state =
159		iterative_auth_check(rules, sorted_power_set, initial_state, fetch)
160			.inspect_ok(|map| debug!(count = map.len(), "partially resolved power state"))
161			.inspect_ok(|map| trace!(?map, "partially resolved power state"))
162			.boxed()
163			.await?;
164
165	// This "epochs" power level event
166	let power_ty_sk = (StateEventType::RoomPowerLevels, "".into());
167	let power_event = partially_resolved_state.get(&power_ty_sk);
168	debug!(event_id = ?power_event, "epoch power event");
169
170	let remaining_events: Vec<_> = full_conflicted_set
171		.into_iter()
172		.filter(|id| power_set_event_ids.binary_search(&id).is_err())
173		.collect();
174
175	debug!(count = remaining_events.len(), "remaining events");
176	trace!(list = ?remaining_events, "remaining events");
177
178	let have_remaining_events = !remaining_events.is_empty();
179	let remaining_events = remaining_events
180		.iter()
181		.stream()
182		.map(AsRef::as_ref);
183
184	// 3. Take all remaining events that weren’t picked in step 1 and order them by
185	//    the mainline ordering based on the power level in the partially resolved
186	//    state obtained in step 2.
187	let sorted_remaining_events = have_remaining_events
188		.then_async(move || mainline_sort(power_event.cloned(), remaining_events, fetch))
189		.boxed();
190
191	let sorted_remaining_events = sorted_remaining_events
192		.await
193		.unwrap_or(Ok(Vec::new()))?;
194
195	debug!(count = sorted_remaining_events.len(), "sorted remaining events");
196	trace!(list = ?sorted_remaining_events, "sorted remaining events");
197
198	let sorted_remaining_events = sorted_remaining_events
199		.iter()
200		.stream()
201		.map(AsRef::as_ref);
202
203	// 4. Apply the iterative auth checks algorithm on the partial resolved state
204	//    and the list of events from the previous step.
205	let mut resolved_state =
206		iterative_auth_check(rules, sorted_remaining_events, partially_resolved_state, fetch)
207			.boxed()
208			.await?;
209
210	// 5. Update the result by replacing any event with the event with the same key
211	//    from the unconflicted state map, if such an event exists, to get the final
212	//    resolved state.
213	resolved_state.extend(unconflicted_state);
214
215	debug!(resolved_state = resolved_state.len(), "resolved state");
216	trace!(?resolved_state, "resolved state");
217
218	Ok(resolved_state)
219}
220
221#[tracing::instrument(
222	name = "conflicted",
223	level = "debug",
224	skip_all,
225	fields(
226		states = conflicted_states.len(),
227		events = conflicted_states.values().flatten().count()
228	),
229)]
230async fn full_conflicted_set<AuthSets, FetchExists, ExistsFut, FetchEvent, EventFut, Pdu>(
231	rules: &RoomVersionRules,
232	conflicted_states: ConflictMap<OwnedEventId>,
233	auth_sets: AuthSets,
234	fetch: &FetchEvent,
235	exists: &FetchExists,
236	hydra_backports: bool,
237) -> HashSet<OwnedEventId>
238where
239	AuthSets: Stream<Item = AuthSet<OwnedEventId>> + Send,
240	FetchExists: Fn(OwnedEventId) -> ExistsFut + Sync,
241	ExistsFut: Future<Output = bool> + Send,
242	FetchEvent: Fn(OwnedEventId) -> EventFut + Sync,
243	EventFut: Future<Output = Result<Pdu>> + Send,
244	Pdu: Event,
245{
246	let consider_conflicted_subgraph = rules
247		.state_res
248		.v2_rules()
249		.is_some_and(|rules| rules.consider_conflicted_state_subgraph)
250		|| hydra_backports;
251
252	let conflicted_state_set: Vec<_> = conflicted_states
253		.values()
254		.flatten()
255		.sorted_unstable()
256		.dedup()
257		.collect();
258
259	// Since `org.matrix.hydra.11`, fetch the conflicted state subgraph.
260	let conflicted_subgraph = consider_conflicted_subgraph
261		.then_async(async || conflicted_subgraph_dfs(&conflicted_state_set, fetch))
262		.map(Option::into_iter)
263		.map(IterStream::stream)
264		.flatten_stream()
265		.flatten()
266		.boxed();
267
268	let conflicted_state_ids = conflicted_state_set
269		.iter()
270		.map(Deref::deref)
271		.cloned()
272		.stream();
273
274	auth_difference(auth_sets)
275		.chain(conflicted_state_ids)
276		.broad_filter_map(async |id| exists(id.clone()).await.then_some(id))
277		.chain(conflicted_subgraph)
278		.collect::<HashSet<_>>()
279		.inspect(|set| debug!(count = set.len(), "full conflicted set"))
280		.inspect(|set| trace!(?set, "full conflicted set"))
281		.await
282}