Skip to main content

tuwunel_service/rooms/state_res/
resolve.rs

1#[cfg(test)]
2mod tests;
3
4mod auth_difference;
5mod conflicted_subgraph;
6mod iterative_auth_check;
7mod mainline_sort;
8mod power_sort;
9mod split_conflicted;
10
11use std::{
12	collections::{BTreeMap, BTreeSet, HashSet},
13	ops::Deref,
14};
15
16use futures::{FutureExt, Stream, StreamExt, TryFutureExt};
17use ruma::{OwnedEventId, events::StateEventType, room_version_rules::RoomVersionRules};
18use tuwunel_core::{
19	Result, debug,
20	itertools::Itertools,
21	matrix::{Event, TypeStateKey},
22	smallvec::SmallVec,
23	trace,
24	utils::{
25		BoolExt,
26		stream::{BroadbandExt, IterStream},
27	},
28};
29
30use self::{
31	auth_difference::auth_difference, conflicted_subgraph::conflicted_subgraph_dfs,
32	iterative_auth_check::iterative_auth_check, mainline_sort::mainline_sort,
33	power_sort::power_sort, split_conflicted::split_conflicted_state,
34};
35#[cfg(test)]
36use super::test_utils;
37
38/// A mapping of event type and state_key to some value `T`, usually an
39/// `EventId`.
40pub type StateMap<Id> = BTreeMap<TypeStateKey, Id>;
41
42/// Full recursive set of `auth_events` for each event in a StateMap.
43pub type AuthSet<Id> = BTreeSet<Id>;
44
45/// ConflictMap of OwnedEventId specifically.
46pub type ConflictMap<Id> = StateMap<ConflictVec<Id>>;
47
48/// List of conflicting event_ids
49type ConflictVec<Id> = SmallVec<[Id; 2]>;
50
51/// Apply the [state resolution] algorithm introduced in room version 2 to
52/// resolve the state of a room.
53///
54/// ## Arguments
55///
56/// * `rules` - The rules to apply for the version of the current room.
57///
58/// * `state_maps` - The incoming states to resolve. Each `StateMap` represents
59///   a possible fork in the state of a room.
60///
61/// * `auth_chains` - The list of full recursive sets of `auth_events` for each
62///   event in the `state_maps`.
63///
64/// * `fetch_event` - Function to fetch an event in the room given its event ID.
65///
66/// ## Invariants
67///
68/// The caller of `resolve` must ensure that all the events are from the same
69/// room.
70///
71/// ## Returns
72///
73/// The resolved room state.
74///
75/// [state resolution]: https://spec.matrix.org/latest/rooms/v2/#state-resolution
76#[tracing::instrument(level = "debug", skip_all)]
77pub async fn resolve<States, AuthSets, FetchExists, ExistsFut, FetchEvent, EventFut, Pdu>(
78	rules: &RoomVersionRules,
79	state_maps: States,
80	auth_sets: AuthSets,
81	fetch: &FetchEvent,
82	exists: &FetchExists,
83	backport_css: bool,
84) -> Result<StateMap<OwnedEventId>>
85where
86	States: Stream<Item = StateMap<OwnedEventId>> + Send,
87	AuthSets: Stream<Item = AuthSet<OwnedEventId>> + Send,
88	FetchExists: Fn(OwnedEventId) -> ExistsFut + Sync,
89	ExistsFut: Future<Output = bool> + Send,
90	FetchEvent: Fn(OwnedEventId) -> EventFut + Sync,
91	EventFut: Future<Output = Result<Pdu>> + Send,
92	Pdu: Event + Clone,
93{
94	// Split the unconflicted state map and the conflicted state set.
95	let (unconflicted_state, conflicted_states) = split_conflicted_state(state_maps).await;
96
97	debug!(
98		unconflicted = unconflicted_state.len(),
99		conflicted_states = conflicted_states.len(),
100		conflicted_events = conflicted_states
101			.values()
102			.fold(0_usize, |a, s| a.saturating_add(s.len())),
103		"unresolved states"
104	);
105
106	trace!(
107		?unconflicted_state,
108		?conflicted_states,
109		unconflicted = unconflicted_state.len(),
110		conflicted_states = conflicted_states.len(),
111		"unresolved states"
112	);
113
114	if conflicted_states.is_empty() {
115		return Ok(unconflicted_state.into_iter().collect());
116	}
117
118	// 0. The full conflicted set is the union of the conflicted state set and the
119	//    auth difference. Don't honor events that don't exist.
120	let full_conflicted_set =
121		full_conflicted_set(rules, conflicted_states, auth_sets, fetch, exists, backport_css)
122			.await;
123
124	// 1. Select the set X of all power events that appear in the full conflicted
125	//    set. For each such power event P, enlarge X by adding the events in the
126	//    auth chain of P which also belong to the full conflicted set. Sort X into
127	//    a list using the reverse topological power ordering.
128	let sorted_power_set: Vec<_> = power_sort(rules, &full_conflicted_set, fetch)
129		.inspect_ok(|list| debug!(count = list.len(), "sorted power events"))
130		.inspect_ok(|list| trace!(?list, "sorted power events"))
131		.boxed()
132		.await?;
133
134	let power_set_event_ids: Vec<_> = sorted_power_set
135		.iter()
136		.sorted_unstable()
137		.collect();
138
139	let sorted_power_set = sorted_power_set
140		.iter()
141		.stream()
142		.map(AsRef::as_ref);
143
144	let start_with_incoming_state = rules
145		.state_res
146		.v2_rules()
147		.is_none_or(|r| !r.begin_iterative_auth_checks_with_empty_state_map);
148
149	let initial_state = start_with_incoming_state
150		.then(|| unconflicted_state.clone())
151		.unwrap_or_default();
152
153	// 2. Apply the iterative auth checks algorithm, starting from the unconflicted
154	//    state map, to the list of events from the previous step to get a partially
155	//    resolved state.
156	let partially_resolved_state =
157		iterative_auth_check(rules, sorted_power_set, initial_state, fetch)
158			.inspect_ok(|map| debug!(count = map.len(), "partially resolved power state"))
159			.inspect_ok(|map| trace!(?map, "partially resolved power state"))
160			.boxed()
161			.await?;
162
163	// This "epochs" power level event
164	let power_ty_sk = (StateEventType::RoomPowerLevels, "".into());
165	let power_event = partially_resolved_state.get(&power_ty_sk);
166	debug!(event_id = ?power_event, "epoch power event");
167
168	let remaining_events: Vec<_> = full_conflicted_set
169		.into_iter()
170		.filter(|id| power_set_event_ids.binary_search(&id).is_err())
171		.collect();
172
173	debug!(count = remaining_events.len(), "remaining events");
174	trace!(list = ?remaining_events, "remaining events");
175
176	let have_remaining_events = !remaining_events.is_empty();
177	let remaining_events = remaining_events
178		.iter()
179		.stream()
180		.map(AsRef::as_ref);
181
182	// 3. Take all remaining events that weren’t picked in step 1 and order them by
183	//    the mainline ordering based on the power level in the partially resolved
184	//    state obtained in step 2.
185	let sorted_remaining_events = have_remaining_events
186		.then_async(move || mainline_sort(power_event.cloned(), remaining_events, fetch))
187		.boxed();
188
189	let sorted_remaining_events = sorted_remaining_events
190		.await
191		.unwrap_or(Ok(Vec::new()))?;
192
193	debug!(count = sorted_remaining_events.len(), "sorted remaining events");
194	trace!(list = ?sorted_remaining_events, "sorted remaining events");
195
196	let sorted_remaining_events = sorted_remaining_events
197		.iter()
198		.stream()
199		.map(AsRef::as_ref);
200
201	// 4. Apply the iterative auth checks algorithm on the partial resolved state
202	//    and the list of events from the previous step.
203	let mut resolved_state =
204		iterative_auth_check(rules, sorted_remaining_events, partially_resolved_state, fetch)
205			.boxed()
206			.await?;
207
208	// 5. Update the result by replacing any event with the event with the same key
209	//    from the unconflicted state map, if such an event exists, to get the final
210	//    resolved state.
211	resolved_state.extend(unconflicted_state);
212
213	debug!(resolved_state = resolved_state.len(), "resolved state");
214	trace!(?resolved_state, "resolved state");
215
216	Ok(resolved_state)
217}
218
219#[tracing::instrument(
220	name = "conflicted",
221	level = "debug",
222	skip_all,
223	fields(
224		states = conflicted_states.len(),
225		events = conflicted_states.values().flatten().count()
226	),
227)]
228async fn full_conflicted_set<AuthSets, FetchExists, ExistsFut, FetchEvent, EventFut, Pdu>(
229	rules: &RoomVersionRules,
230	conflicted_states: ConflictMap<OwnedEventId>,
231	auth_sets: AuthSets,
232	fetch: &FetchEvent,
233	exists: &FetchExists,
234	backport_css: bool,
235) -> HashSet<OwnedEventId>
236where
237	AuthSets: Stream<Item = AuthSet<OwnedEventId>> + Send,
238	FetchExists: Fn(OwnedEventId) -> ExistsFut + Sync,
239	ExistsFut: Future<Output = bool> + Send,
240	FetchEvent: Fn(OwnedEventId) -> EventFut + Sync,
241	EventFut: Future<Output = Result<Pdu>> + Send,
242	Pdu: Event,
243{
244	let consider_conflicted_subgraph = rules
245		.state_res
246		.v2_rules()
247		.is_some_and(|rules| rules.consider_conflicted_state_subgraph)
248		|| backport_css;
249
250	let conflicted_state_set: Vec<_> = conflicted_states
251		.values()
252		.flatten()
253		.sorted_unstable()
254		.dedup()
255		.collect();
256
257	// Since `org.matrix.hydra.11`, fetch the conflicted state subgraph.
258	let conflicted_subgraph = consider_conflicted_subgraph
259		.then_async(async || conflicted_subgraph_dfs(&conflicted_state_set, fetch))
260		.map(Option::into_iter)
261		.map(IterStream::stream)
262		.flatten_stream()
263		.flatten()
264		.boxed();
265
266	let conflicted_state_ids = conflicted_state_set
267		.iter()
268		.map(Deref::deref)
269		.cloned()
270		.stream();
271
272	auth_difference(auth_sets)
273		.chain(conflicted_state_ids)
274		.broad_filter_map(async |id| exists(id.clone()).await.then_some(id))
275		.chain(conflicted_subgraph)
276		.collect::<HashSet<_>>()
277		.inspect(|set| debug!(count = set.len(), "full conflicted set"))
278		.inspect(|set| trace!(?set, "full conflicted set"))
279		.await
280}