tuwunel_service/rooms/state_res/
resolve.rs1#[cfg(test)]
2mod tests;
3
4mod auth_difference;
5mod conflicted_subgraph;
6mod iterative_auth_check;
7mod mainline_sort;
8mod power_sort;
9mod split_conflicted;
10
11use std::{
12 collections::{BTreeMap, BTreeSet, HashSet},
13 ops::Deref,
14};
15
16use futures::{FutureExt, Stream, StreamExt, TryFutureExt};
17use ruma::{OwnedEventId, events::StateEventType, room_version_rules::RoomVersionRules};
18use tuwunel_core::{
19 Result, debug,
20 itertools::Itertools,
21 matrix::{Event, TypeStateKey},
22 smallvec::SmallVec,
23 trace,
24 utils::{
25 BoolExt,
26 stream::{BroadbandExt, IterStream},
27 },
28};
29
30use self::{
31 auth_difference::auth_difference, conflicted_subgraph::conflicted_subgraph_dfs,
32 iterative_auth_check::iterative_auth_check, mainline_sort::mainline_sort,
33 power_sort::power_sort, split_conflicted::split_conflicted_state,
34};
35#[cfg(test)]
36use super::test_utils;
37
38pub type StateMap<Id> = BTreeMap<TypeStateKey, Id>;
41
42pub type AuthSet<Id> = BTreeSet<Id>;
44
45pub type ConflictMap<Id> = StateMap<ConflictVec<Id>>;
47
48type ConflictVec<Id> = SmallVec<[Id; 2]>;
50
51#[tracing::instrument(level = "debug", skip_all)]
77pub async fn resolve<States, AuthSets, FetchExists, ExistsFut, FetchEvent, EventFut, Pdu>(
78 rules: &RoomVersionRules,
79 state_maps: States,
80 auth_sets: AuthSets,
81 fetch: &FetchEvent,
82 exists: &FetchExists,
83 backport_css: bool,
84) -> Result<StateMap<OwnedEventId>>
85where
86 States: Stream<Item = StateMap<OwnedEventId>> + Send,
87 AuthSets: Stream<Item = AuthSet<OwnedEventId>> + Send,
88 FetchExists: Fn(OwnedEventId) -> ExistsFut + Sync,
89 ExistsFut: Future<Output = bool> + Send,
90 FetchEvent: Fn(OwnedEventId) -> EventFut + Sync,
91 EventFut: Future<Output = Result<Pdu>> + Send,
92 Pdu: Event + Clone,
93{
94 let (unconflicted_state, conflicted_states) = split_conflicted_state(state_maps).await;
96
97 debug!(
98 unconflicted = unconflicted_state.len(),
99 conflicted_states = conflicted_states.len(),
100 conflicted_events = conflicted_states
101 .values()
102 .fold(0_usize, |a, s| a.saturating_add(s.len())),
103 "unresolved states"
104 );
105
106 trace!(
107 ?unconflicted_state,
108 ?conflicted_states,
109 unconflicted = unconflicted_state.len(),
110 conflicted_states = conflicted_states.len(),
111 "unresolved states"
112 );
113
114 if conflicted_states.is_empty() {
115 return Ok(unconflicted_state.into_iter().collect());
116 }
117
118 let full_conflicted_set =
121 full_conflicted_set(rules, conflicted_states, auth_sets, fetch, exists, backport_css)
122 .await;
123
124 let sorted_power_set: Vec<_> = power_sort(rules, &full_conflicted_set, fetch)
129 .inspect_ok(|list| debug!(count = list.len(), "sorted power events"))
130 .inspect_ok(|list| trace!(?list, "sorted power events"))
131 .boxed()
132 .await?;
133
134 let power_set_event_ids: Vec<_> = sorted_power_set
135 .iter()
136 .sorted_unstable()
137 .collect();
138
139 let sorted_power_set = sorted_power_set
140 .iter()
141 .stream()
142 .map(AsRef::as_ref);
143
144 let start_with_incoming_state = rules
145 .state_res
146 .v2_rules()
147 .is_none_or(|r| !r.begin_iterative_auth_checks_with_empty_state_map);
148
149 let initial_state = start_with_incoming_state
150 .then(|| unconflicted_state.clone())
151 .unwrap_or_default();
152
153 let partially_resolved_state =
157 iterative_auth_check(rules, sorted_power_set, initial_state, fetch)
158 .inspect_ok(|map| debug!(count = map.len(), "partially resolved power state"))
159 .inspect_ok(|map| trace!(?map, "partially resolved power state"))
160 .boxed()
161 .await?;
162
163 let power_ty_sk = (StateEventType::RoomPowerLevels, "".into());
165 let power_event = partially_resolved_state.get(&power_ty_sk);
166 debug!(event_id = ?power_event, "epoch power event");
167
168 let remaining_events: Vec<_> = full_conflicted_set
169 .into_iter()
170 .filter(|id| power_set_event_ids.binary_search(&id).is_err())
171 .collect();
172
173 debug!(count = remaining_events.len(), "remaining events");
174 trace!(list = ?remaining_events, "remaining events");
175
176 let have_remaining_events = !remaining_events.is_empty();
177 let remaining_events = remaining_events
178 .iter()
179 .stream()
180 .map(AsRef::as_ref);
181
182 let sorted_remaining_events = have_remaining_events
186 .then_async(move || mainline_sort(power_event.cloned(), remaining_events, fetch))
187 .boxed();
188
189 let sorted_remaining_events = sorted_remaining_events
190 .await
191 .unwrap_or(Ok(Vec::new()))?;
192
193 debug!(count = sorted_remaining_events.len(), "sorted remaining events");
194 trace!(list = ?sorted_remaining_events, "sorted remaining events");
195
196 let sorted_remaining_events = sorted_remaining_events
197 .iter()
198 .stream()
199 .map(AsRef::as_ref);
200
201 let mut resolved_state =
204 iterative_auth_check(rules, sorted_remaining_events, partially_resolved_state, fetch)
205 .boxed()
206 .await?;
207
208 resolved_state.extend(unconflicted_state);
212
213 debug!(resolved_state = resolved_state.len(), "resolved state");
214 trace!(?resolved_state, "resolved state");
215
216 Ok(resolved_state)
217}
218
219#[tracing::instrument(
220 name = "conflicted",
221 level = "debug",
222 skip_all,
223 fields(
224 states = conflicted_states.len(),
225 events = conflicted_states.values().flatten().count()
226 ),
227)]
228async fn full_conflicted_set<AuthSets, FetchExists, ExistsFut, FetchEvent, EventFut, Pdu>(
229 rules: &RoomVersionRules,
230 conflicted_states: ConflictMap<OwnedEventId>,
231 auth_sets: AuthSets,
232 fetch: &FetchEvent,
233 exists: &FetchExists,
234 backport_css: bool,
235) -> HashSet<OwnedEventId>
236where
237 AuthSets: Stream<Item = AuthSet<OwnedEventId>> + Send,
238 FetchExists: Fn(OwnedEventId) -> ExistsFut + Sync,
239 ExistsFut: Future<Output = bool> + Send,
240 FetchEvent: Fn(OwnedEventId) -> EventFut + Sync,
241 EventFut: Future<Output = Result<Pdu>> + Send,
242 Pdu: Event,
243{
244 let consider_conflicted_subgraph = rules
245 .state_res
246 .v2_rules()
247 .is_some_and(|rules| rules.consider_conflicted_state_subgraph)
248 || backport_css;
249
250 let conflicted_state_set: Vec<_> = conflicted_states
251 .values()
252 .flatten()
253 .sorted_unstable()
254 .dedup()
255 .collect();
256
257 let conflicted_subgraph = consider_conflicted_subgraph
259 .then_async(async || conflicted_subgraph_dfs(&conflicted_state_set, fetch))
260 .map(Option::into_iter)
261 .map(IterStream::stream)
262 .flatten_stream()
263 .flatten()
264 .boxed();
265
266 let conflicted_state_ids = conflicted_state_set
267 .iter()
268 .map(Deref::deref)
269 .cloned()
270 .stream();
271
272 auth_difference(auth_sets)
273 .chain(conflicted_state_ids)
274 .broad_filter_map(async |id| exists(id.clone()).await.then_some(id))
275 .chain(conflicted_subgraph)
276 .collect::<HashSet<_>>()
277 .inspect(|set| debug!(count = set.len(), "full conflicted set"))
278 .inspect(|set| trace!(?set, "full conflicted set"))
279 .await
280}