tuwunel_service/rooms/state_res/
resolve.rs1#[cfg(test)]
2mod tests;
3
4mod auth_difference;
5mod conflicted_subgraph;
6mod iterative_auth_check;
7mod mainline_sort;
8mod power_sort;
9mod split_conflicted;
10
11use std::{
12 collections::{BTreeMap, BTreeSet, HashSet},
13 ops::Deref,
14};
15
16use futures::{FutureExt, Stream, StreamExt, TryFutureExt};
17use ruma::{OwnedEventId, events::StateEventType, room_version_rules::RoomVersionRules};
18use tuwunel_core::{
19 Result, debug,
20 itertools::Itertools,
21 matrix::{Event, TypeStateKey},
22 smallvec::SmallVec,
23 trace,
24 utils::{
25 BoolExt,
26 stream::{BroadbandExt, IterStream},
27 },
28};
29
30use self::{
31 auth_difference::auth_difference, conflicted_subgraph::conflicted_subgraph_dfs,
32 iterative_auth_check::iterative_auth_check, mainline_sort::mainline_sort,
33 power_sort::power_sort, split_conflicted::split_conflicted_state,
34};
35#[cfg(test)]
36use super::test_utils;
37
38pub type StateMap<Id> = BTreeMap<TypeStateKey, Id>;
41
42pub type AuthSet<Id> = BTreeSet<Id>;
44
45pub type ConflictMap<Id> = StateMap<ConflictVec<Id>>;
47
48type ConflictVec<Id> = SmallVec<[Id; 2]>;
50
51#[tracing::instrument(level = "debug", skip_all)]
77pub async fn resolve<States, AuthSets, FetchExists, ExistsFut, FetchEvent, EventFut, Pdu>(
78 rules: &RoomVersionRules,
79 state_maps: States,
80 auth_sets: AuthSets,
81 fetch: &FetchEvent,
82 exists: &FetchExists,
83 hydra_backports: bool,
84) -> Result<StateMap<OwnedEventId>>
85where
86 States: Stream<Item = StateMap<OwnedEventId>> + Send,
87 AuthSets: Stream<Item = AuthSet<OwnedEventId>> + Send,
88 FetchExists: Fn(OwnedEventId) -> ExistsFut + Sync,
89 ExistsFut: Future<Output = bool> + Send,
90 FetchEvent: Fn(OwnedEventId) -> EventFut + Sync,
91 EventFut: Future<Output = Result<Pdu>> + Send,
92 Pdu: Event + Clone,
93{
94 let (unconflicted_state, conflicted_states) = split_conflicted_state(state_maps).await;
96
97 debug!(
98 unconflicted = unconflicted_state.len(),
99 conflicted_states = conflicted_states.len(),
100 conflicted_events = conflicted_states
101 .values()
102 .fold(0_usize, |a, s| a.saturating_add(s.len())),
103 "unresolved states"
104 );
105
106 trace!(
107 ?unconflicted_state,
108 ?conflicted_states,
109 unconflicted = unconflicted_state.len(),
110 conflicted_states = conflicted_states.len(),
111 "unresolved states"
112 );
113
114 if conflicted_states.is_empty() {
115 return Ok(unconflicted_state.into_iter().collect());
116 }
117
118 let full_conflicted_set =
121 full_conflicted_set(rules, conflicted_states, auth_sets, fetch, exists, hydra_backports)
122 .await;
123
124 let sorted_power_set: Vec<_> = power_sort(rules, &full_conflicted_set, fetch)
129 .inspect_ok(|list| debug!(count = list.len(), "sorted power events"))
130 .inspect_ok(|list| trace!(?list, "sorted power events"))
131 .boxed()
132 .await?;
133
134 let power_set_event_ids: Vec<_> = sorted_power_set
135 .iter()
136 .sorted_unstable()
137 .collect();
138
139 let sorted_power_set = sorted_power_set
140 .iter()
141 .stream()
142 .map(AsRef::as_ref);
143
144 let begin_with_empty_state_map = rules
145 .state_res
146 .v2_rules()
147 .is_some_and(|r| r.begin_iterative_auth_checks_with_empty_state_map)
148 || hydra_backports;
149
150 let initial_state = begin_with_empty_state_map
151 .is_false()
152 .then(|| unconflicted_state.clone())
153 .unwrap_or_default();
154
155 let partially_resolved_state =
159 iterative_auth_check(rules, sorted_power_set, initial_state, fetch)
160 .inspect_ok(|map| debug!(count = map.len(), "partially resolved power state"))
161 .inspect_ok(|map| trace!(?map, "partially resolved power state"))
162 .boxed()
163 .await?;
164
165 let power_ty_sk = (StateEventType::RoomPowerLevels, "".into());
167 let power_event = partially_resolved_state.get(&power_ty_sk);
168 debug!(event_id = ?power_event, "epoch power event");
169
170 let remaining_events: Vec<_> = full_conflicted_set
171 .into_iter()
172 .filter(|id| power_set_event_ids.binary_search(&id).is_err())
173 .collect();
174
175 debug!(count = remaining_events.len(), "remaining events");
176 trace!(list = ?remaining_events, "remaining events");
177
178 let have_remaining_events = !remaining_events.is_empty();
179 let remaining_events = remaining_events
180 .iter()
181 .stream()
182 .map(AsRef::as_ref);
183
184 let sorted_remaining_events = have_remaining_events
188 .then_async(move || mainline_sort(power_event.cloned(), remaining_events, fetch))
189 .boxed();
190
191 let sorted_remaining_events = sorted_remaining_events
192 .await
193 .unwrap_or(Ok(Vec::new()))?;
194
195 debug!(count = sorted_remaining_events.len(), "sorted remaining events");
196 trace!(list = ?sorted_remaining_events, "sorted remaining events");
197
198 let sorted_remaining_events = sorted_remaining_events
199 .iter()
200 .stream()
201 .map(AsRef::as_ref);
202
203 let mut resolved_state =
206 iterative_auth_check(rules, sorted_remaining_events, partially_resolved_state, fetch)
207 .boxed()
208 .await?;
209
210 resolved_state.extend(unconflicted_state);
214
215 debug!(resolved_state = resolved_state.len(), "resolved state");
216 trace!(?resolved_state, "resolved state");
217
218 Ok(resolved_state)
219}
220
221#[tracing::instrument(
222 name = "conflicted",
223 level = "debug",
224 skip_all,
225 fields(
226 states = conflicted_states.len(),
227 events = conflicted_states.values().flatten().count()
228 ),
229)]
230async fn full_conflicted_set<AuthSets, FetchExists, ExistsFut, FetchEvent, EventFut, Pdu>(
231 rules: &RoomVersionRules,
232 conflicted_states: ConflictMap<OwnedEventId>,
233 auth_sets: AuthSets,
234 fetch: &FetchEvent,
235 exists: &FetchExists,
236 hydra_backports: bool,
237) -> HashSet<OwnedEventId>
238where
239 AuthSets: Stream<Item = AuthSet<OwnedEventId>> + Send,
240 FetchExists: Fn(OwnedEventId) -> ExistsFut + Sync,
241 ExistsFut: Future<Output = bool> + Send,
242 FetchEvent: Fn(OwnedEventId) -> EventFut + Sync,
243 EventFut: Future<Output = Result<Pdu>> + Send,
244 Pdu: Event,
245{
246 let consider_conflicted_subgraph = rules
247 .state_res
248 .v2_rules()
249 .is_some_and(|rules| rules.consider_conflicted_state_subgraph)
250 || hydra_backports;
251
252 let conflicted_state_set: Vec<_> = conflicted_states
253 .values()
254 .flatten()
255 .sorted_unstable()
256 .dedup()
257 .collect();
258
259 let conflicted_subgraph = consider_conflicted_subgraph
261 .then_async(async || conflicted_subgraph_dfs(&conflicted_state_set, fetch))
262 .map(Option::into_iter)
263 .map(IterStream::stream)
264 .flatten_stream()
265 .flatten()
266 .boxed();
267
268 let conflicted_state_ids = conflicted_state_set
269 .iter()
270 .map(Deref::deref)
271 .cloned()
272 .stream();
273
274 auth_difference(auth_sets)
275 .chain(conflicted_state_ids)
276 .broad_filter_map(async |id| exists(id.clone()).await.then_some(id))
277 .chain(conflicted_subgraph)
278 .collect::<HashSet<_>>()
279 .inspect(|set| debug!(count = set.len(), "full conflicted set"))
280 .inspect(|set| trace!(?set, "full conflicted set"))
281 .await
282}