tuwunel_service/fetcher/
select.rs1use std::sync::Arc;
8
9use async_trait::async_trait;
10use futures::{Stream, StreamExt, future::Either, stream::empty};
11use ruma::{EventId, OwnedServerName, RoomId, ServerName};
12use tuwunel_core::{
13 arrayvec::ArrayVec,
14 implement,
15 utils::{BoolExt, IterStream, ReadyExt, StreamTools, rand::index},
16};
17
18use super::{Op, Opts};
19use crate::{
20 federation::{Candidates, WhenAllBackedOff},
21 services::OnceServices,
22};
23
24const ROUTE_FANOUT: usize = 5;
27
28#[async_trait]
31pub(super) trait Select: Send + Sync {
32 async fn candidates(&self, opts: &Opts) -> Candidates;
33}
34
35pub(super) struct RoomCandidates {
36 pub(super) services: Arc<OnceServices>,
37}
38
39#[async_trait]
40impl Select for RoomCandidates {
41 #[tracing::instrument(
42 level = "trace",
43 skip_all,
44 fields(
45 room_id = ?opts.room_id,
46 ),
47 )]
48 async fn candidates(&self, opts: &Opts) -> Candidates {
49 if !opts.candidates.is_empty() {
50 return self.ranked_override(opts).await;
51 }
52
53 let authority = self.authority_server(opts).await;
54
55 let mxid_hosts = [
56 opts.event_id
57 .as_deref()
58 .and_then(EventId::server_name),
59 opts.room_id
60 .as_deref()
61 .and_then(RoomId::server_name),
62 ]
63 .into_iter()
64 .flatten()
65 .map(ToOwned::to_owned);
66
67 let popular = match opts.room_id.as_deref() {
68 | None => Either::Right(empty::<OwnedServerName>()),
69 | Some(room_id) => Either::Left(self.route_by_popularity(room_id).await),
70 };
71
72 let eligible = opts
73 .hint
74 .clone()
75 .into_iter()
76 .chain(authority)
77 .stream()
78 .chain(popular)
79 .chain(mxid_hosts.stream())
80 .ready_filter(|server| self.is_eligible(server));
81
82 self.rank_unique(eligible).await
83 }
84}
85
86#[implement(RoomCandidates)]
90#[tracing::instrument(level = "trace", skip_all)]
91async fn ranked_override(&self, opts: &Opts) -> Candidates {
92 let eligible = opts
93 .hint
94 .iter()
95 .chain(opts.candidates.iter())
96 .filter(|&server| self.is_eligible(server))
97 .cloned()
98 .stream();
99
100 self.rank_unique(eligible).await
101}
102
103#[implement(RoomCandidates)]
106async fn rank_unique<S>(&self, eligible: S) -> Candidates
107where
108 S: Stream<Item = OwnedServerName> + Send,
109{
110 let ordered: Candidates = eligible
111 .ready_fold(Candidates::new(), push_unique)
112 .await;
113
114 self.services
115 .federation
116 .rank_candidates(ordered, WhenAllBackedOff::Attempt)
117 .await
118}
119
120fn push_unique(mut ordered: Candidates, server: OwnedServerName) -> Candidates {
122 if !ordered.contains(&server) {
123 ordered.push(server);
124 }
125
126 ordered
127}
128
129#[implement(RoomCandidates)]
132#[tracing::instrument(level = "trace", skip_all)]
133async fn authority_server(&self, opts: &Opts) -> Option<OwnedServerName> {
134 let room_id = opts.room_id.as_deref()?;
135
136 matches!(opts.op, Op::AuthEvent | Op::AuthChain)
137 .then_async(|| {
138 self.services
139 .state_cache
140 .most_powerful_user_server(room_id)
141 })
142 .await
143 .flatten()
144}
145
146#[implement(RoomCandidates)]
152#[tracing::instrument(level = "trace", skip_all)]
153async fn route_by_popularity<'a>(
154 &'a self,
155 room_id: &'a RoomId,
156) -> impl Stream<Item = OwnedServerName> + Send + 'a {
157 let sampled: ArrayVec<OwnedServerName, ROUTE_FANOUT> = self
158 .services
159 .state_cache
160 .room_members(room_id)
161 .sample_by(|user| user.server_name().to_owned())
162 .await;
163
164 if sampled.is_empty() {
165 return Either::Right(
166 self.services
167 .state_cache
168 .room_servers(room_id)
169 .map(ToOwned::to_owned),
170 );
171 }
172
173 Either::Left(sampled.into_iter().stream())
174}
175
176#[implement(RoomCandidates)]
181#[allow(dead_code)]
182async fn route_uniformly<'a>(
183 &'a self,
184 room_id: &'a RoomId,
185) -> impl Stream<Item = OwnedServerName> + Send + 'a {
186 let count = self
187 .services
188 .state_cache
189 .room_servers(room_id)
190 .count()
191 .await;
192
193 let offset = index(count);
194
195 self.services
196 .state_cache
197 .room_servers(room_id)
198 .map(ToOwned::to_owned)
199 .skip(offset)
200 .take(ROUTE_FANOUT)
201}
202
203#[implement(RoomCandidates)]
204fn is_eligible(&self, server: &ServerName) -> bool {
205 !self.services.globals.server_is_ours(server)
206 && !self
207 .services
208 .server
209 .config
210 .is_forbidden_remote_server_name(server)
211}
212
213#[cfg(test)]
214mod tests {
215 use ruma::owned_server_name;
216
217 use super::{Candidates, push_unique};
218
219 #[test]
220 fn push_unique_keeps_first_occurrence() {
221 let pool = [
222 owned_server_name!("a.test"),
223 owned_server_name!("b.test"),
224 owned_server_name!("a.test"),
225 owned_server_name!("c.test"),
226 owned_server_name!("b.test"),
227 ];
228
229 let deduped: Candidates = pool
230 .into_iter()
231 .fold(Candidates::new(), push_unique);
232
233 let names: Vec<&str> = deduped.iter().map(AsRef::as_ref).collect();
234
235 assert_eq!(names, ["a.test", "b.test", "c.test"]);
236 }
237}