Skip to main content

tuwunel_service/rooms/state_res/resolve/
auth_difference.rs

1use std::{borrow::Borrow, collections::BTreeMap};
2
3use futures::{FutureExt, Stream};
4use ruma::EventId;
5use tuwunel_core::utils::stream::{IterStream, ReadyExt};
6
7use super::AuthSet;
8
9struct Counts<Id: Ord> {
10	by_id: BTreeMap<Id, usize>,
11	total: usize,
12}
13
14impl<Id: Ord> Default for Counts<Id> {
15	fn default() -> Self { Self { by_id: BTreeMap::new(), total: 0 } }
16}
17
18impl<Id: Ord> Counts<Id> {
19	fn merge(mut self, set: AuthSet<Id>) -> Self {
20		self.total = self.total.saturating_add(1);
21		for id in set {
22			let count = self.by_id.entry(id).or_default();
23
24			*count = count.saturating_add(1);
25		}
26
27		self
28	}
29}
30
31/// Get the auth difference for the given auth chains.
32///
33/// Definition in the specification:
34///
35/// The auth difference is calculated by first calculating the full auth chain
36/// for each state _Si_, that is the union of the auth chains for each event in
37/// _Si_, and then taking every event that doesn’t appear in every auth chain.
38/// If _Ci_ is the full auth chain of _Si_, then the auth difference is ∪_Ci_ −
39/// ∩_Ci_.
40///
41/// ## Arguments
42///
43/// * `auth_chains` - The list of full recursive sets of `auth_events`. Inputs
44///   must be sorted.
45///
46/// ## Returns
47///
48/// Outputs the event IDs that are not present in all the auth chains.
49#[tracing::instrument(level = "debug", skip_all)]
50pub(super) fn auth_difference<'a, AuthSets, Id>(auth_sets: AuthSets) -> impl Stream<Item = Id>
51where
52	AuthSets: Stream<Item = AuthSet<Id>>,
53	Id: Borrow<EventId> + Clone + Eq + Ord + Send + 'a,
54{
55	auth_sets
56		.ready_fold_default(Counts::<Id>::merge)
57		.map(|Counts { by_id, total }: Counts<Id>| {
58			by_id
59				.into_iter()
60				.filter_map(move |(id, count)| (count < total).then_some(id))
61				.stream()
62		})
63		.flatten_stream()
64}