Skip to main content

tuwunel_core/utils/
set.rs

1use std::{
2	cmp::{Eq, Ord},
3	convert::identity,
4	pin::Pin,
5	sync::Arc,
6};
7
8use futures::{
9	Stream, StreamExt,
10	stream::{Peekable, unfold},
11};
12use tokio::sync::Mutex;
13
14use crate::{is_equal_to, is_less_than, utils::stream::ReadyExt};
15
16/// Intersection of sets
17///
18/// Outputs the set of elements common to all input sets. Inputs do not have to
19/// be sorted. If inputs are sorted a more optimized function is available in
20/// this suite and should be used.
21pub fn intersection<Item, Iter, Iters>(mut input: Iters) -> impl Iterator<Item = Item> + Send
22where
23	Iters: Iterator<Item = Iter> + Clone + Send,
24	Iter: Iterator<Item = Item> + Send,
25	Item: Eq,
26{
27	input.next().into_iter().flat_map(move |first| {
28		let input = input.clone();
29		first.filter(move |targ| {
30			input
31				.clone()
32				.all(|mut other| other.any(is_equal_to!(*targ)))
33		})
34	})
35}
36
37/// Intersection of sets
38///
39/// Outputs the set of elements common to all input sets. Inputs must be sorted.
40pub fn intersection_sorted<Item, Iter, Iters>(
41	mut input: Iters,
42) -> impl Iterator<Item = Item> + Send
43where
44	Iters: Iterator<Item = Iter> + Clone + Send,
45	Iter: Iterator<Item = Item> + Send,
46	Item: Eq + Ord,
47{
48	input.next().into_iter().flat_map(move |first| {
49		let mut input = input.clone().collect::<Vec<_>>();
50		first.filter(move |targ| {
51			input.iter_mut().all(|it| {
52				it.by_ref()
53					.skip_while(is_less_than!(targ))
54					.peekable()
55					.peek()
56					.is_some_and(is_equal_to!(targ))
57			})
58		})
59	})
60}
61
62/// Intersection of sets
63///
64/// Outputs the set of elements common to both streams. Streams must be sorted.
65pub fn intersection_sorted_stream2<S, Item>(a: S, b: S) -> impl Stream<Item = Item> + Send
66where
67	S: Stream<Item = Item> + Send + Unpin,
68	Item: Eq + PartialOrd + Send + Sync,
69{
70	struct State<S: Stream> {
71		a: S,
72		b: Peekable<S>,
73	}
74
75	unfold(State { a, b: b.peekable() }, async |mut state| {
76		let ai = state.a.next().await?;
77		while let Some(bi) = Pin::new(&mut state.b)
78			.next_if(|bi| *bi <= ai)
79			.await
80			.as_ref()
81		{
82			if ai == *bi {
83				return Some((Some(ai), state));
84			}
85		}
86
87		Some((None, state))
88	})
89	.ready_filter_map(identity)
90}
91
92/// Difference of sets
93///
94/// Outputs the set of elements found in `a` which are not found in `b`. Streams
95/// must be sorted.
96pub fn difference_sorted_stream2<Item, A, B>(a: A, b: B) -> impl Stream<Item = Item> + Send
97where
98	A: Stream<Item = Item> + Send,
99	B: Stream<Item = Item> + Send + Unpin,
100	Item: Eq + PartialOrd + Send + Sync,
101{
102	let b = Arc::new(Mutex::new(b.peekable()));
103	a.map(move |ai| (ai, b.clone()))
104		.filter_map(async move |(ai, b)| {
105			let mut lock = b.lock().await;
106			let b = &mut Pin::new(&mut *lock);
107			while b.as_mut().next_if(|bi| *bi < ai).await.is_some() {
108				continue;
109			}
110
111			b.as_mut()
112				.next_if_eq(&ai)
113				.await
114				.is_none()
115				.then_some(ai)
116		})
117}