Skip to main content

tuwunel_core/utils/stream/
tools.rs

1//! StreamTools for futures::Stream
2
3use std::{collections::HashMap, hash::Hash};
4
5use arrayvec::ArrayVec;
6use futures::{Future, Stream, StreamExt};
7
8use super::ReadyExt;
9use crate::{expected, utils::rand::index};
10
11/// StreamTools
12///
13/// This interface is not necessarily complete; feel free to add as-needed.
14pub trait Tools<Item>
15where
16	Self: Stream<Item = Item> + Send + Sized,
17	<Self as Stream>::Item: Send,
18{
19	fn counts(self) -> impl Future<Output = HashMap<Item, usize>> + Send
20	where
21		<Self as Stream>::Item: Eq + Hash;
22
23	fn counts_by<K, F>(self, f: F) -> impl Future<Output = HashMap<K, usize>> + Send
24	where
25		F: Fn(Item) -> K + Send,
26		K: Eq + Hash + Send;
27
28	fn counts_by_with_cap<const CAP: usize, K, F>(
29		self,
30		f: F,
31	) -> impl Future<Output = HashMap<K, usize>> + Send
32	where
33		F: Fn(Item) -> K + Send,
34		K: Eq + Hash + Send;
35
36	fn counts_with_cap<const CAP: usize>(
37		self,
38	) -> impl Future<Output = HashMap<Item, usize>> + Send
39	where
40		<Self as Stream>::Item: Eq + Hash;
41
42	/// Reservoir-samples up to `N` items uniformly at random in a single
43	/// pass, applying `f` only to the items retained. Items are drawn
44	/// without replacement; the keys `f` derives may still repeat, so a key
45	/// produced by twice as many items is twice as likely to appear.
46	fn sample_by<const N: usize, K, F>(self, f: F) -> impl Future<Output = ArrayVec<K, N>> + Send
47	where
48		F: Fn(Item) -> K + Send,
49		K: Send;
50
51	fn fold_default<T, F, Fut>(self, f: F) -> impl Future<Output = T> + Send
52	where
53		F: Fn(T, Item) -> Fut + Send,
54		Fut: Future<Output = T> + Send,
55		T: Default + Send;
56}
57
58impl<Item, S> Tools<Item> for S
59where
60	S: Stream<Item = Item> + Send + Sized,
61	<Self as Stream>::Item: Send,
62{
63	#[inline]
64	fn counts(self) -> impl Future<Output = HashMap<Item, usize>> + Send
65	where
66		<Self as Stream>::Item: Eq + Hash,
67	{
68		self.counts_with_cap::<0>()
69	}
70
71	#[inline]
72	fn counts_by<K, F>(self, f: F) -> impl Future<Output = HashMap<K, usize>> + Send
73	where
74		F: Fn(Item) -> K + Send,
75		K: Eq + Hash + Send,
76	{
77		self.counts_by_with_cap::<0, K, F>(f)
78	}
79
80	#[inline]
81	fn counts_by_with_cap<const CAP: usize, K, F>(
82		self,
83		f: F,
84	) -> impl Future<Output = HashMap<K, usize>> + Send
85	where
86		F: Fn(Item) -> K + Send,
87		K: Eq + Hash + Send,
88	{
89		self.map(f).counts_with_cap::<CAP>()
90	}
91
92	#[inline]
93	fn counts_with_cap<const CAP: usize>(
94		self,
95	) -> impl Future<Output = HashMap<Item, usize>> + Send
96	where
97		<Self as Stream>::Item: Eq + Hash,
98	{
99		self.ready_fold(HashMap::with_capacity(CAP), |mut counts, item| {
100			let entry = counts.entry(item).or_default();
101			let value = *entry;
102			*entry = expected!(value + 1);
103			counts
104		})
105	}
106
107	#[inline]
108	fn sample_by<const N: usize, K, F>(self, f: F) -> impl Future<Output = ArrayVec<K, N>> + Send
109	where
110		F: Fn(Item) -> K + Send,
111		K: Send,
112	{
113		self.enumerate()
114			.ready_fold(ArrayVec::<K, N>::new(), move |mut reservoir, (i, item)| {
115				if reservoir.len() < N {
116					reservoir.push(f(item));
117				} else {
118					let slot = index(expected!(i + 1));
119					if slot < N {
120						reservoir[slot] = f(item);
121					}
122				}
123
124				reservoir
125			})
126	}
127
128	#[inline]
129	fn fold_default<T, F, Fut>(self, f: F) -> impl Future<Output = T> + Send
130	where
131		F: Fn(T, Item) -> Fut + Send,
132		Fut: Future<Output = T> + Send,
133		T: Default + Send,
134	{
135		self.fold(T::default(), f)
136	}
137}