tuwunel_service/federation/
rank.rs1use futures::StreamExt;
8use ruma::OwnedServerName;
9use tuwunel_core::{debug_warn, implement, smallvec::SmallVec, utils::IterStream};
10
11use super::ShouldAttempt;
12
13pub type Candidates = SmallVec<[OwnedServerName; 3]>;
18
19type Verdicts = SmallVec<[(OwnedServerName, ShouldAttempt); 3]>;
22
23#[derive(Clone, Copy, Debug)]
26pub enum WhenAllBackedOff {
27 Attempt,
30
31 #[allow(unused)]
33 Fail,
34}
35
36#[implement(super::Service)]
39pub async fn rank_candidates(
40 &self,
41 eligible: Candidates,
42 when_all: WhenAllBackedOff,
43) -> Candidates {
44 let verdicts: Verdicts = eligible
45 .into_iter()
46 .stream()
47 .then(async |server| {
48 let verdict = self.should_attempt(&server).await;
49 (server, verdict)
50 })
51 .collect()
52 .await;
53
54 rank_from_verdicts(verdicts, when_all).collect()
55}
56
57fn rank_from_verdicts(
61 mut verdicts: Verdicts,
62 when_all: WhenAllBackedOff,
63) -> impl Iterator<Item = OwnedServerName> {
64 let all_backed_off = verdicts
65 .iter()
66 .all(|(_, verdict)| matches!(verdict, ShouldAttempt::No { .. }));
67
68 let keep_backed_off = all_backed_off && matches!(when_all, WhenAllBackedOff::Attempt);
69
70 if keep_backed_off && !verdicts.is_empty() {
71 debug_warn!(
72 n = verdicts.len(),
73 "All candidates backed off via peer_status; attempting anyway"
74 );
75 }
76
77 verdicts.sort_by_key(|(_, verdict)| verdict.rank());
78
79 verdicts
80 .into_iter()
81 .filter(move |(_, verdict)| {
82 keep_backed_off || !matches!(verdict, ShouldAttempt::No { .. })
83 })
84 .map(|(server, _)| server)
85}
86
87#[implement(ShouldAttempt)]
89#[inline]
90fn rank(self) -> u8 {
91 match self {
92 | ShouldAttempt::Yes => 0,
93 | ShouldAttempt::Deprioritize => 1,
94 | ShouldAttempt::No { .. } => 2,
95 }
96}
97
98#[cfg(test)]
99mod tests {
100 use std::time::SystemTime;
101
102 use ruma::{OwnedServerName, owned_server_name};
103 use tuwunel_core::smallvec::smallvec;
104
105 use super::{Verdicts, WhenAllBackedOff, rank_from_verdicts};
106 use crate::federation::ShouldAttempt;
107
108 fn no() -> ShouldAttempt { ShouldAttempt::No { earliest_retry: SystemTime::UNIX_EPOCH } }
109
110 fn names(servers: &[OwnedServerName]) -> Vec<&str> {
111 servers.iter().map(AsRef::as_ref).collect()
112 }
113
114 #[test]
115 fn all_yes_preserves_order() {
116 let verdicts: Verdicts = smallvec![
117 (owned_server_name!("a.test"), ShouldAttempt::Yes),
118 (owned_server_name!("b.test"), ShouldAttempt::Yes),
119 (owned_server_name!("c.test"), ShouldAttempt::Yes),
120 ];
121
122 let ranked: Vec<_> = rank_from_verdicts(verdicts, WhenAllBackedOff::Attempt).collect();
123
124 assert_eq!(names(&ranked), ["a.test", "b.test", "c.test"]);
125 }
126
127 #[test]
128 fn drops_backed_off_when_pool_has_alternatives() {
129 let verdicts: Verdicts = smallvec![
130 (owned_server_name!("a.test"), ShouldAttempt::Yes),
131 (owned_server_name!("b.test"), no()),
132 (owned_server_name!("c.test"), ShouldAttempt::Yes),
133 ];
134
135 let ranked: Vec<_> = rank_from_verdicts(verdicts, WhenAllBackedOff::Attempt).collect();
136
137 assert_eq!(names(&ranked), ["a.test", "c.test"]);
138 }
139
140 #[test]
141 fn all_backed_off_attempt_falls_through() {
142 let verdicts: Verdicts = smallvec![
143 (owned_server_name!("a.test"), no()),
144 (owned_server_name!("b.test"), no()),
145 ];
146
147 let ranked: Vec<_> = rank_from_verdicts(verdicts, WhenAllBackedOff::Attempt).collect();
148
149 assert_eq!(names(&ranked), ["a.test", "b.test"]);
150 }
151
152 #[test]
153 fn all_backed_off_fail_returns_empty() {
154 let verdicts: Verdicts = smallvec![
155 (owned_server_name!("a.test"), no()),
156 (owned_server_name!("b.test"), no()),
157 ];
158
159 assert!(
160 rank_from_verdicts(verdicts, WhenAllBackedOff::Fail)
161 .next()
162 .is_none()
163 );
164 }
165
166 #[test]
167 fn deprioritize_ranks_after_yes() {
168 let verdicts: Verdicts = smallvec![
169 (owned_server_name!("d.test"), ShouldAttempt::Deprioritize),
170 (owned_server_name!("y.test"), ShouldAttempt::Yes),
171 (owned_server_name!("n.test"), no()),
172 ];
173
174 let ranked: Vec<_> = rank_from_verdicts(verdicts, WhenAllBackedOff::Attempt).collect();
175
176 assert_eq!(names(&ranked), ["y.test", "d.test"]);
177 }
178}