1use std::{
8 collections::{HashMap, VecDeque},
9 num::NonZeroUsize,
10 sync::{Arc, Weak},
11};
12
13use bytes::Bytes;
14use futures::{FutureExt, StreamExt, future::BoxFuture, stream::FuturesUnordered};
15use ruma::OwnedServerName;
16use tokio::sync::watch::channel;
17use tuwunel_core::{debug_warn, implement, trace};
18
19use super::{
20 Failure, Msg, Opts, Outcome, Service,
21 error::Attempted,
22 inflight::{Inflight, Key, SharedResult},
23};
24
25type FetchFuture<'a> = BoxFuture<'a, (Key, SharedResult)>;
28type FetchFutures<'a> = FuturesUnordered<FetchFuture<'a>>;
29
30#[implement(Service)]
33pub(super) async fn run_worker(self: Arc<Self>) {
34 let mut inflight: HashMap<Key, Inflight> = HashMap::new();
35 let mut pending: VecDeque<Msg> = VecDeque::new();
36 let mut futures: FetchFutures<'_> = FuturesUnordered::new();
37
38 self.work_loop(&mut inflight, &mut pending, &mut futures)
39 .await;
40}
41
42#[implement(Service)]
43async fn work_loop<'a>(
44 &'a self,
45 inflight: &mut HashMap<Key, Inflight>,
46 pending: &mut VecDeque<Msg>,
47 futures: &mut FetchFutures<'a>,
48) {
49 let rx = self.channel.1.clone();
50 while !rx.is_closed() {
51 while let Ok(msg) = rx.try_recv() {
53 self.on_request(msg, inflight, pending, futures);
54 }
55
56 tokio::select! {
57 Some((key, result)) = futures.next() =>
58 self.on_complete(key, result, inflight, pending, futures),
59 msg = rx.recv_async() => match msg {
60 | Ok(msg) => self.on_request(msg, inflight, pending, futures),
61 | Err(_) => break,
62 },
63 }
64 }
65}
66
67#[implement(Service)]
68fn on_request<'a>(
69 &'a self,
70 msg: Msg,
71 inflight: &mut HashMap<Key, Inflight>,
72 pending: &mut VecDeque<Msg>,
73 futures: &FetchFutures<'a>,
74) {
75 let Some(entry) = inflight.get_mut(&msg.key) else {
76 if futures.len() >= self.capacity {
78 pending.push_back(msg);
79 } else {
80 self.dispatch(msg, inflight, futures);
81 }
82
83 return;
84 };
85
86 match entry.interest.upgrade() {
87 | Some(strong) => {
89 msg.reply
90 .send((entry.tx.subscribe(), strong))
91 .ok();
92 },
93 | None => {
96 let interest = Arc::new(());
97 entry.interest = Arc::downgrade(&interest);
98 msg.reply
99 .send((entry.tx.subscribe(), interest))
100 .ok();
101 },
102 }
103}
104
105#[implement(Service)]
106fn dispatch<'a>(
107 &'a self,
108 msg: Msg,
109 inflight: &mut HashMap<Key, Inflight>,
110 futures: &FetchFutures<'a>,
111) {
112 let Msg { key, opts, reply } = msg;
113 let interest = Arc::new(());
114 let (tx, rx) = channel(None);
115
116 if reply.send((rx, interest.clone())).is_err() {
118 return;
119 }
120
121 let opts = Arc::new(opts);
122 let weak = Arc::downgrade(&interest);
123 inflight.insert(key.clone(), Inflight {
124 tx,
125 interest: weak.clone(),
126 opts: opts.clone(),
127 });
128
129 self.push_attempt(futures, key, opts, weak);
130}
131
132#[implement(Service)]
136fn push_attempt<'a>(
137 &'a self,
138 futures: &FetchFutures<'a>,
139 key: Key,
140 opts: Arc<Opts>,
141 weak: Weak<()>,
142) {
143 futures.push(async move { (key, self.run_attempts(&opts, &weak).await) }.boxed());
144}
145
146#[implement(Service)]
147fn on_complete<'a>(
148 &'a self,
149 key: Key,
150 result: SharedResult,
151 inflight: &mut HashMap<Key, Inflight>,
152 pending: &mut VecDeque<Msg>,
153 futures: &FetchFutures<'a>,
154) {
155 let Some(entry) = inflight.get(&key) else {
156 return;
157 };
158
159 if matches!(&result, Err(Failure::Cancelled)) && entry.interest.upgrade().is_some() {
162 let opts = entry.opts.clone();
163 let weak = entry.interest.clone();
164 self.push_attempt(futures, key, opts, weak);
165 return;
166 }
167
168 entry.tx.send(Some(result)).ok();
169 inflight.remove(&key);
170
171 while futures.len() < self.capacity {
174 let Some(msg) = pending.pop_front() else {
175 break;
176 };
177
178 self.on_request(msg, inflight, pending, futures);
179 }
180}
181
182#[implement(Service)]
183#[tracing::instrument(
184 name = "attempts",
185 level = "debug",
186 skip_all,
187 fields(
188 op = ?opts.op,
189 room_id = ?opts.room_id,
190 event_id = ?opts.event_id,
191 ),
192)]
193async fn run_attempts(&self, opts: &Opts, interest: &Weak<()>) -> SharedResult {
194 let candidates = self.select.candidates(opts).await;
195 if candidates.is_empty() {
196 return Err(Failure::NoCandidates);
197 }
198
199 let count = candidates.len();
200 let limit = opts
201 .attempt_limit
202 .map_or(count, |n| n.get().min(count));
203
204 let (config_width, config_rounds) = self
205 .services
206 .try_get()
207 .map_or((0, 0), |services| {
208 let config = &services.server.config;
209
210 (config.fetch_fanout_max_width, config.fetch_fanout_rounds)
211 });
212
213 let max_width = effective_cap(opts.fanout_max_width, config_width);
214 let max_rounds = effective_cap(opts.fanout_rounds, config_rounds);
215
216 let mut attempted: Attempted = Attempted::new();
217 let mut remaining = candidates.into_iter();
218 let mut round: usize = 0;
219
220 while attempted.len() < limit {
221 if interest.strong_count() == 0 {
222 return Err(Failure::Cancelled);
223 }
224
225 if round >= max_rounds {
226 break;
227 }
228
229 let budget = limit.saturating_sub(attempted.len());
230 let width = opts
231 .fanout_growth
232 .round_width(round)
233 .min(max_width)
234 .min(budget);
235
236 let mut racing: FuturesUnordered<_> = remaining
239 .by_ref()
240 .take(width)
241 .map(|server| self.attempt(server, opts))
242 .collect();
243
244 if racing.is_empty() {
245 break;
246 }
247
248 while let Some((server, bytes)) = racing.next().await {
249 let Some(bytes) = bytes else {
250 attempted.push(server);
251
252 if interest.strong_count() == 0 {
253 return Err(Failure::Cancelled);
254 }
255
256 continue;
257 };
258
259 trace!(%server, "fetch satisfied");
260 return Ok(Arc::new(Outcome { bytes, origin: server }));
261 }
262
263 round = round.saturating_add(1);
264 }
265
266 Err(Failure::NotFound { attempted })
267}
268
269pub(super) fn effective_cap(opt: Option<NonZeroUsize>, config: usize) -> usize {
273 opt.map_or(usize::MAX, NonZeroUsize::get)
274 .min(NonZeroUsize::new(config).map_or(usize::MAX, NonZeroUsize::get))
275}
276
277#[implement(Service)]
281#[tracing::instrument(
282 name = "attempt",
283 level = "trace",
284 skip_all,
285 fields(%server),
286)]
287async fn attempt(
288 &self,
289 server: OwnedServerName,
290 opts: &Opts,
291) -> (OwnedServerName, Option<Bytes>) {
292 let Some(bytes) = self
293 .transport
294 .fetch_raw(opts.op, &server, opts)
295 .await
296 .inspect_err(|error| debug_warn!(%server, "fetch attempt failed: {error}"))
297 .ok()
298 else {
299 return (server, None);
300 };
301
302 let valid = self
303 .validate(opts, &bytes)
304 .await
305 .inspect_err(|error| debug_warn!(%server, "rejecting poisoned response: {error}"))
306 .is_ok();
307
308 (server, valid.then_some(bytes))
309}