1use std::{
2 collections::{BTreeSet, HashSet},
3 iter::once,
4 sync::Arc,
5 time::Instant,
6};
7
8use async_trait::async_trait;
9use futures::{
10 FutureExt, Stream, StreamExt, TryFutureExt, pin_mut,
11 stream::{FuturesUnordered, unfold},
12};
13use ruma::{
14 EventId, OwnedEventId, OwnedRoomId, RoomId, RoomVersionId,
15 room_version_rules::RoomVersionRules,
16};
17use serde::Deserialize;
18use tuwunel_core::{
19 Err, Result, at, debug, debug_error, err, implement,
20 itertools::Itertools,
21 matrix::room_version,
22 pdu::AuthEvents,
23 smallvec::SmallVec,
24 trace, utils,
25 utils::{
26 IterStream,
27 stream::{BroadbandExt, ReadyExt, TryExpect, automatic_width},
28 },
29 validated, warn,
30};
31use tuwunel_database::Map;
32
33use crate::rooms::short::ShortEventId;
34
35pub struct Service {
36 services: Arc<crate::services::OnceServices>,
37 db: Data,
38}
39
40struct Data {
41 authchainkey_authchain: Arc<Map>,
42 shorteventid_authchain: Arc<Map>,
43}
44
45type Bucket<'a> = BTreeSet<(ShortEventId, &'a EventId)>;
46type CacheKey = SmallVec<[ShortEventId; 1]>;
47
48#[async_trait]
49impl crate::Service for Service {
50 fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
51 Ok(Arc::new(Self {
52 services: args.services.clone(),
53 db: Data {
54 authchainkey_authchain: args.db["authchainkey_authchain"].clone(),
55 shorteventid_authchain: args.db["shorteventid_authchain"].clone(),
56 },
57 }))
58 }
59
60 async fn clear_cache(&self) { self.db.authchainkey_authchain.clear().await; }
61
62 fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
63}
64
65#[implement(Service)]
66pub fn event_ids_iter<'a, I>(
67 &'a self,
68 room_id: &'a RoomId,
69 room_version: &'a RoomVersionId,
70 starting_events: I,
71) -> impl Stream<Item = Result<OwnedEventId>> + Send + 'a
72where
73 I: Iterator<Item = &'a EventId> + Clone + ExactSizeIterator + Send + 'a,
74{
75 self.get_auth_chain(room_id, room_version, starting_events)
76 .map_ok(|chain| {
77 self.services
78 .short
79 .multi_get_eventid_from_short(chain.into_iter().stream())
80 .ready_filter(Result::is_ok)
81 })
82 .try_flatten_stream()
83}
84
85#[implement(Service)]
86#[tracing::instrument(
87 name = "auth_chain",
88 level = "debug",
89 skip_all,
90 fields(
91 %room_id,
92 starting_events = %starting_events.clone().count(),
93 )
94)]
95pub async fn get_auth_chain<'a, I>(
96 &'a self,
97 room_id: &RoomId,
98 room_version: &RoomVersionId,
99 starting_events: I,
100) -> Result<Vec<ShortEventId>>
101where
102 I: Iterator<Item = &'a EventId> + Clone + ExactSizeIterator + Send + 'a,
103{
104 const NUM_BUCKETS: usize = 50; const BUCKET: Bucket<'_> = BTreeSet::new();
106
107 let started = Instant::now();
108 let room_rules = room_version::rules(room_version)?;
109 let starting_events_count = starting_events.clone().count();
110 let starting_ids = self
111 .services
112 .short
113 .multi_get_or_create_shorteventid(starting_events.clone())
114 .zip(starting_events.stream());
115
116 pin_mut!(starting_ids);
117 let mut buckets = [BUCKET; NUM_BUCKETS];
118 while let Some((short, starting_event)) = starting_ids.next().await {
119 let bucket: usize = short.try_into()?;
120 let bucket: usize = validated!(bucket % NUM_BUCKETS);
121 buckets[bucket].insert((short, starting_event));
122 }
123
124 debug!(
125 starting_events = starting_events_count,
126 elapsed = ?started.elapsed(),
127 "start",
128 );
129
130 let full_auth_chain: Vec<ShortEventId> = buckets
131 .iter()
132 .stream()
133 .flat_map_unordered(automatic_width(), |starting_events| {
134 self.get_chunk_auth_chain(
135 room_id,
136 &started,
137 starting_events.iter().copied(),
138 &room_rules,
139 )
140 .boxed()
141 })
142 .collect::<Vec<_>>()
143 .map(IntoIterator::into_iter)
144 .map(Itertools::sorted_unstable)
145 .map(Itertools::dedup)
146 .map(Iterator::collect)
147 .boxed()
148 .await;
149
150 debug!(
151 chain_length = ?full_auth_chain.len(),
152 elapsed = ?started.elapsed(),
153 "done",
154 );
155
156 Ok(full_auth_chain)
157}
158
159#[implement(Service)]
160#[tracing::instrument(
161 name = "outer",
162 level = "trace",
163 skip_all,
164 fields(
165 starting_events = %starting_events.clone().count(),
166 )
167)]
168pub fn get_chunk_auth_chain<'a, I>(
169 &'a self,
170 room_id: &'a RoomId,
171 started: &'a Instant,
172 starting_events: I,
173 room_rules: &'a RoomVersionRules,
174) -> impl Stream<Item = ShortEventId> + Send + 'a
175where
176 I: Iterator<Item = (ShortEventId, &'a EventId)> + Clone + Send + Sync + 'a,
177{
178 let starting_shortids = starting_events.clone().map(at!(0));
179
180 let build_chain = async |(shortid, event_id): (ShortEventId, &'a EventId)| {
181 if let Ok(cached) = self.get_cached_auth_chain(once(shortid)).await {
182 return cached;
183 }
184
185 let auth_chain: Vec<_> = self
186 .get_event_auth_chain(room_id, event_id, room_rules)
187 .collect()
188 .await;
189
190 self.put_cached_auth_chain(once(shortid), auth_chain.as_slice());
191 debug!(
192 ?event_id,
193 elapsed = ?started.elapsed(),
194 "Cache missed event"
195 );
196
197 auth_chain
198 };
199
200 let cache_chain = move |chunk_cache: &Vec<_>| {
201 self.put_cached_auth_chain(starting_shortids, chunk_cache.as_slice());
202 debug!(
203 chunk_cache_length = ?chunk_cache.len(),
204 elapsed = ?started.elapsed(),
205 "Cache missed chunk",
206 );
207 };
208
209 self.get_cached_auth_chain(starting_events.clone().map(at!(0)))
210 .map_ok(IntoIterator::into_iter)
211 .map_ok(IterStream::try_stream)
212 .or_else(move |_| async move {
213 starting_events
214 .clone()
215 .stream()
216 .broad_then(build_chain)
217 .collect::<Vec<_>>()
218 .map(IntoIterator::into_iter)
219 .map(Iterator::flatten)
220 .map(Itertools::sorted_unstable)
221 .map(Itertools::dedup)
222 .map(Iterator::collect)
223 .inspect(cache_chain)
224 .map(IntoIterator::into_iter)
225 .map(IterStream::try_stream)
226 .map(Ok)
227 .await
228 })
229 .try_flatten_stream()
230 .map_expect("either cache hit or cache miss yields a chain")
231}
232
233#[implement(Service)]
234#[tracing::instrument(name = "inner", level = "trace", skip_all)]
235pub fn get_event_auth_chain<'a>(
236 &'a self,
237 room_id: &'a RoomId,
238 event_id: &'a EventId,
239 room_rules: &'a RoomVersionRules,
240) -> impl Stream<Item = ShortEventId> + Send + 'a {
241 self.get_event_auth_chain_ids(room_id, event_id, room_rules)
242 .broad_then(async move |auth_event| {
243 self.services
244 .short
245 .get_or_create_shorteventid(&auth_event)
246 .await
247 })
248}
249
250#[implement(Service)]
251#[tracing::instrument(
252 name = "inner_ids",
253 level = "trace",
254 skip_all,
255 fields(%event_id)
256)]
257pub fn get_event_auth_chain_ids<'a>(
258 &'a self,
259 room_id: &'a RoomId,
260 event_id: &'a EventId,
261 room_rules: &'a RoomVersionRules,
262) -> impl Stream<Item = OwnedEventId> + Send + 'a {
263 struct State<Fut> {
264 todo: FuturesUnordered<Fut>,
265 seen: HashSet<OwnedEventId>,
266 }
267
268 let starting_events = self.get_event_auth_event_ids(room_id, event_id.to_owned());
269
270 let state = State {
271 todo: once(starting_events).collect(),
272 seen: room_rules
273 .authorization
274 .room_create_event_id_as_room_id
275 .then_some(room_id.as_event_id().ok())
276 .into_iter()
277 .flatten()
278 .collect(),
279 };
280
281 let eval = |auth_events: AuthEvents, mut state: State<_>| {
282 let push = |auth_event: &OwnedEventId| {
283 trace!(todo = state.todo.len(), ?auth_event, "push");
284 state
285 .todo
286 .push(self.get_event_auth_event_ids(room_id, auth_event.clone()));
287 };
288
289 let seen = |auth_event: OwnedEventId| {
290 state
291 .seen
292 .insert(auth_event.clone())
293 .then_some(auth_event)
294 };
295
296 let out = auth_events
297 .into_iter()
298 .filter_map(seen)
299 .inspect(push)
300 .collect::<AuthEvents>()
301 .into_iter()
302 .stream();
303
304 (out, state)
305 };
306
307 unfold(state, move |mut state| async move {
308 match state.todo.next().await {
309 | None => None,
310 | Some(Err(_)) => Some((AuthEvents::new().into_iter().stream(), state)),
311 | Some(Ok(auth_events)) => Some(eval(auth_events, state)),
312 }
313 })
314 .flatten()
315}
316
317#[implement(Service)]
318#[tracing::instrument(
319 name = "cache_put",
320 level = "debug",
321 skip_all,
322 fields(
323 key_len = key.clone().count(),
324 chain_len = auth_chain.len(),
325 )
326)]
327fn put_cached_auth_chain<I>(&self, key: I, auth_chain: &[ShortEventId])
328where
329 I: Iterator<Item = ShortEventId> + Clone + Send,
330{
331 let key = key.collect::<CacheKey>();
332
333 debug_assert!(!key.is_empty(), "auth_chain key must not be empty");
334
335 self.db
336 .authchainkey_authchain
337 .put(key.as_slice(), auth_chain);
338
339 if key.len() == 1 {
340 self.db
341 .shorteventid_authchain
342 .put(key, auth_chain);
343 }
344}
345
346#[implement(Service)]
347#[tracing::instrument(
348 name = "cache_get",
349 level = "trace",
350 err(level = "trace"),
351 skip_all,
352 fields(
353 key_len = %key.clone().count()
354 ),
355)]
356async fn get_cached_auth_chain<I>(&self, key: I) -> Result<Vec<ShortEventId>>
357where
358 I: Iterator<Item = ShortEventId> + Clone + Send,
359{
360 let key = key.collect::<CacheKey>();
361
362 if key.is_empty() {
363 return Ok(Vec::new());
364 }
365
366 let chain = self
368 .db
369 .authchainkey_authchain
370 .qry(key.as_slice())
371 .map_err(|_| err!(Request(NotFound("auth_chain not cached"))))
372 .or_else(async |e| {
373 if key.len() > 1 {
374 return Err(e);
375 }
376
377 self.db
378 .shorteventid_authchain
379 .qry(&key[0])
380 .map_err(|_| err!(Request(NotFound("auth_chain not found"))))
381 .await
382 })
383 .await?
384 .chunks_exact(size_of::<u64>())
385 .map(utils::u64_from_u8)
386 .collect();
387
388 Ok(chain)
389}
390
391#[implement(Service)]
392#[tracing::instrument(
393 name = "auth_events",
394 level = "trace",
395 ret(level = "trace"),
396 err(level = "trace"),
397 skip_all,
398 fields(%event_id)
399)]
400async fn get_event_auth_event_ids<'a>(
401 &'a self,
402 room_id: &'a RoomId,
403 event_id: OwnedEventId,
404) -> Result<AuthEvents> {
405 #[derive(Deserialize)]
406 struct Pdu {
407 auth_events: AuthEvents,
408 room_id: OwnedRoomId,
409 }
410
411 let pdu: Pdu = self
412 .services
413 .timeline
414 .get(&event_id)
415 .inspect_err(|e| {
416 debug_error!(?event_id, ?room_id, "auth chain event: {e}");
417 })
418 .await?;
419
420 if pdu.room_id != room_id {
421 return Err!(Request(Forbidden(error!(
422 ?event_id,
423 ?room_id,
424 wrong_room_id = ?pdu.room_id,
425 "auth event for incorrect room",
426 ))));
427 }
428
429 Ok(pdu.auth_events)
430}