Skip to main content

tuwunel_service/rooms/state_res/resolve/
split_conflicted.rs

1use std::{collections::HashMap, hash::Hash, iter::IntoIterator};
2
3use futures::{Stream, StreamExt};
4use tuwunel_core::validated;
5
6use super::{ConflictMap, StateMap};
7
8/// Split the unconflicted state map and the conflicted state set.
9///
10/// Definition in the specification:
11///
12/// If a given key _K_ is present in every _Si_ with the same value _V_ in each
13/// state map, then the pair (_K_, _V_) belongs to the unconflicted state map.
14/// Otherwise, _V_ belongs to the conflicted state set.
15///
16/// It means that, for a given (event type, state key) tuple, if all state maps
17/// have the same event ID, it lands in the unconflicted state map, otherwise
18/// the event IDs land in the conflicted state set.
19///
20/// ## Arguments
21///
22/// * `state_maps` - The incoming states to resolve. Each `StateMap` represents
23///   a possible fork in the state of a room.
24///
25/// ## Returns
26///
27/// Returns an `(unconflicted_state, conflicted_states)` tuple.
28#[tracing::instrument(name = "split", level = "debug", skip_all)]
29pub(super) async fn split_conflicted_state<'a, Maps, Id>(
30	state_maps: Maps,
31) -> (StateMap<Id>, ConflictMap<Id>)
32where
33	Maps: Stream<Item = StateMap<Id>>,
34	Id: Clone + Eq + Hash + Ord + Send + Sync + 'a,
35{
36	let state_maps: Vec<_> = state_maps.collect().await;
37
38	let state_ids_est = state_maps.iter().flatten().count();
39
40	let state_set_count = state_maps
41		.iter()
42		.fold(0_usize, |acc, _| validated!(acc + 1));
43
44	let mut occurrences = HashMap::<_, HashMap<_, usize>>::with_capacity(state_ids_est);
45
46	for (k, v) in state_maps
47		.into_iter()
48		.flat_map(IntoIterator::into_iter)
49	{
50		let acc = occurrences
51			.entry(k.clone())
52			.or_default()
53			.entry(v.clone())
54			.or_default();
55
56		*acc = acc.saturating_add(1);
57	}
58
59	let mut unconflicted_state_map = StateMap::new();
60	let mut conflicted_state_set = ConflictMap::new();
61
62	for (k, v) in occurrences {
63		for (id, occurrence_count) in v {
64			if occurrence_count == state_set_count {
65				unconflicted_state_map.insert((k.0.clone(), k.1.clone()), id.clone());
66			} else {
67				let conflicts = conflicted_state_set
68					.entry((k.0.clone(), k.1.clone()))
69					.or_default();
70
71				debug_assert!(
72					!conflicts.contains(&id),
73					"Unexpected duplicate conflicted state event"
74				);
75
76				conflicts.push(id.clone());
77			}
78		}
79	}
80
81	(unconflicted_state_map, conflicted_state_set)
82}