1use std::{
2 collections::{BTreeSet, HashSet},
3 iter::once,
4 sync::Arc,
5 time::Instant,
6};
7
8use async_trait::async_trait;
9use futures::{
10 FutureExt, Stream, StreamExt, TryFutureExt, pin_mut,
11 stream::{FuturesUnordered, unfold},
12};
13use ruma::{
14 EventId, OwnedEventId, OwnedRoomId, RoomId, RoomVersionId,
15 room_version_rules::RoomVersionRules,
16};
17use serde::Deserialize;
18use tuwunel_core::{
19 Err, Result, at, debug, debug_error, err, implement,
20 itertools::Itertools,
21 matrix::room_version,
22 pdu::AuthEvents,
23 smallvec::SmallVec,
24 trace,
25 utils::{
26 IterStream,
27 stream::{BroadbandExt, ReadyExt, TryExpect, automatic_width},
28 },
29 validated, warn,
30};
31use tuwunel_database::Map;
32
33use crate::rooms::short::ShortEventId;
34
35pub struct Service {
36 services: Arc<crate::services::OnceServices>,
37 db: Data,
38}
39
40struct Data {
41 authchainkey_authchain: Arc<Map>,
42 shorteventid_authchain: Arc<Map>,
43}
44
45type Bucket<'a> = BTreeSet<(ShortEventId, &'a EventId)>;
46type CacheKey = SmallVec<[ShortEventId; 1]>;
47
48#[async_trait]
49impl crate::Service for Service {
50 fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
51 Ok(Arc::new(Self {
52 services: args.services.clone(),
53 db: Data {
54 authchainkey_authchain: args.db["authchainkey_authchain"].clone(),
55 shorteventid_authchain: args.db["shorteventid_authchain"].clone(),
56 },
57 }))
58 }
59
60 async fn clear_cache(&self) {
61 self.db.authchainkey_authchain.clear().await;
62 self.db.shorteventid_authchain.clear().await;
63 }
64
65 fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
66}
67
68#[implement(Service)]
69pub fn event_ids_iter<'a, I>(
70 &'a self,
71 room_id: &'a RoomId,
72 room_version: &'a RoomVersionId,
73 starting_events: I,
74) -> impl Stream<Item = Result<OwnedEventId>> + Send + 'a
75where
76 I: Iterator<Item = &'a EventId> + Clone + ExactSizeIterator + Send + 'a,
77{
78 self.get_auth_chain(room_id, room_version, starting_events)
79 .map_ok(|chain| {
80 self.services
81 .short
82 .multi_get_eventid_from_short(chain.into_iter().stream())
83 .ready_filter(Result::is_ok)
84 })
85 .try_flatten_stream()
86}
87
88#[implement(Service)]
89#[tracing::instrument(
90 name = "auth_chain",
91 level = "debug",
92 skip_all,
93 fields(
94 %room_id,
95 starting_events = %starting_events.clone().count(),
96 )
97)]
98pub async fn get_auth_chain<'a, I>(
99 &'a self,
100 room_id: &RoomId,
101 room_version: &RoomVersionId,
102 starting_events: I,
103) -> Result<Vec<ShortEventId>>
104where
105 I: Iterator<Item = &'a EventId> + Clone + ExactSizeIterator + Send + 'a,
106{
107 const NUM_BUCKETS: usize = 50; const BUCKET: Bucket<'_> = BTreeSet::new();
109
110 let started = Instant::now();
111 let room_rules = room_version::rules(room_version)?;
112 let starting_events_count = starting_events.clone().count();
113 let starting_ids = self
114 .services
115 .short
116 .multi_get_or_create_shorteventid(starting_events.clone())
117 .zip(starting_events.stream());
118
119 pin_mut!(starting_ids);
120 let mut buckets = [BUCKET; NUM_BUCKETS];
121 while let Some((short, starting_event)) = starting_ids.next().await {
122 let bucket: usize = short.try_into()?;
123 let bucket: usize = validated!(bucket % NUM_BUCKETS);
124 buckets[bucket].insert((short, starting_event));
125 }
126
127 debug!(
128 starting_events = starting_events_count,
129 elapsed = ?started.elapsed(),
130 "start",
131 );
132
133 let full_auth_chain: Vec<ShortEventId> = buckets
134 .iter()
135 .stream()
136 .flat_map_unordered(automatic_width(), |starting_events| {
137 self.get_chunk_auth_chain(
138 room_id,
139 &started,
140 starting_events.iter().copied(),
141 &room_rules,
142 )
143 .boxed()
144 })
145 .collect::<Vec<_>>()
146 .map(IntoIterator::into_iter)
147 .map(Itertools::sorted_unstable)
148 .map(Itertools::dedup)
149 .map(Iterator::collect)
150 .boxed()
151 .await;
152
153 debug!(
154 chain_length = ?full_auth_chain.len(),
155 elapsed = ?started.elapsed(),
156 "done",
157 );
158
159 Ok(full_auth_chain)
160}
161
162#[implement(Service)]
163#[tracing::instrument(
164 name = "outer",
165 level = "trace",
166 skip_all,
167 fields(
168 starting_events = %starting_events.clone().count(),
169 )
170)]
171pub fn get_chunk_auth_chain<'a, I>(
172 &'a self,
173 room_id: &'a RoomId,
174 started: &'a Instant,
175 starting_events: I,
176 room_rules: &'a RoomVersionRules,
177) -> impl Stream<Item = ShortEventId> + Send + 'a
178where
179 I: Iterator<Item = (ShortEventId, &'a EventId)> + Clone + Send + Sync + 'a,
180{
181 let starting_shortids = starting_events.clone().map(at!(0));
182
183 let build_chain = async |(shortid, event_id): (ShortEventId, &'a EventId)| {
184 if let Ok(cached) = self.get_cached_auth_chain(once(shortid)).await {
185 return cached;
186 }
187
188 let auth_chain: Vec<_> = self
189 .get_event_auth_chain(room_id, event_id, room_rules)
190 .collect()
191 .await;
192
193 self.put_cached_auth_chain(once(shortid), auth_chain.as_slice());
194 debug!(
195 ?event_id,
196 elapsed = ?started.elapsed(),
197 "Cache missed event"
198 );
199
200 auth_chain
201 };
202
203 let cache_chain = move |chunk_cache: &Vec<_>| {
204 self.put_cached_auth_chain(starting_shortids, chunk_cache.as_slice());
205 debug!(
206 chunk_cache_length = ?chunk_cache.len(),
207 elapsed = ?started.elapsed(),
208 "Cache missed chunk",
209 );
210 };
211
212 self.get_cached_auth_chain(starting_events.clone().map(at!(0)))
213 .map_ok(IntoIterator::into_iter)
214 .map_ok(IterStream::try_stream)
215 .or_else(move |_| async move {
216 starting_events
217 .clone()
218 .stream()
219 .broad_then(build_chain)
220 .collect::<Vec<_>>()
221 .map(IntoIterator::into_iter)
222 .map(Iterator::flatten)
223 .map(Itertools::sorted_unstable)
224 .map(Itertools::dedup)
225 .map(Iterator::collect)
226 .inspect(cache_chain)
227 .map(IntoIterator::into_iter)
228 .map(IterStream::try_stream)
229 .map(Ok)
230 .await
231 })
232 .try_flatten_stream()
233 .map_expect("either cache hit or cache miss yields a chain")
234}
235
236#[implement(Service)]
237#[tracing::instrument(name = "inner", level = "trace", skip_all)]
238pub fn get_event_auth_chain<'a>(
239 &'a self,
240 room_id: &'a RoomId,
241 event_id: &'a EventId,
242 room_rules: &'a RoomVersionRules,
243) -> impl Stream<Item = ShortEventId> + Send + 'a {
244 self.get_event_auth_chain_ids(room_id, event_id, room_rules)
245 .broad_then(async move |auth_event| {
246 self.services
247 .short
248 .get_or_create_shorteventid(&auth_event)
249 .await
250 })
251}
252
253#[implement(Service)]
254#[tracing::instrument(
255 name = "inner_ids",
256 level = "trace",
257 skip_all,
258 fields(%event_id)
259)]
260pub fn get_event_auth_chain_ids<'a>(
261 &'a self,
262 room_id: &'a RoomId,
263 event_id: &'a EventId,
264 room_rules: &'a RoomVersionRules,
265) -> impl Stream<Item = OwnedEventId> + Send + 'a {
266 struct State<Fut> {
267 todo: FuturesUnordered<Fut>,
268 seen: HashSet<OwnedEventId>,
269 }
270
271 let starting_events = self.get_event_auth_event_ids(room_id, event_id.to_owned());
272
273 let state = State {
274 todo: once(starting_events).collect(),
275 seen: room_rules
276 .authorization
277 .room_create_event_id_as_room_id
278 .then_some(room_id.as_event_id().ok())
279 .into_iter()
280 .flatten()
281 .collect(),
282 };
283
284 let eval = |auth_events: AuthEvents, mut state: State<_>| {
285 let push = |auth_event: &OwnedEventId| {
286 trace!(todo = state.todo.len(), ?auth_event, "push");
287 state
288 .todo
289 .push(self.get_event_auth_event_ids(room_id, auth_event.clone()));
290 };
291
292 let seen = |auth_event: OwnedEventId| {
293 state
294 .seen
295 .insert(auth_event.clone())
296 .then_some(auth_event)
297 };
298
299 let out = auth_events
300 .into_iter()
301 .filter_map(seen)
302 .inspect(push)
303 .collect::<AuthEvents>()
304 .into_iter()
305 .stream();
306
307 (out, state)
308 };
309
310 unfold(state, move |mut state| async move {
311 match state.todo.next().await {
312 | None => None,
313 | Some(Err(_)) => Some((AuthEvents::new().into_iter().stream(), state)),
314 | Some(Ok(auth_events)) => Some(eval(auth_events, state)),
315 }
316 })
317 .flatten()
318}
319
320#[implement(Service)]
321#[tracing::instrument(
322 name = "cache_put",
323 level = "debug",
324 skip_all,
325 fields(
326 key_len = key.clone().count(),
327 chain_len = auth_chain.len(),
328 )
329)]
330fn put_cached_auth_chain<I>(&self, key: I, auth_chain: &[ShortEventId])
331where
332 I: Iterator<Item = ShortEventId> + Clone + Send,
333{
334 let key = key.collect::<CacheKey>();
335
336 debug_assert!(!key.is_empty(), "auth_chain key must not be empty");
337
338 self.db
339 .authchainkey_authchain
340 .put(key.as_slice(), auth_chain);
341}
342
343#[implement(Service)]
344#[tracing::instrument(
345 name = "cache_get",
346 level = "trace",
347 err(level = "trace"),
348 skip_all,
349 fields(
350 key_len = %key.clone().count()
351 ),
352)]
353async fn get_cached_auth_chain<I>(&self, key: I) -> Result<Vec<ShortEventId>>
354where
355 I: Iterator<Item = ShortEventId> + Clone + Send,
356{
357 let key = key.collect::<CacheKey>();
358
359 if key.is_empty() {
360 return Ok(Vec::new());
361 }
362
363 let chain = self
365 .db
366 .authchainkey_authchain
367 .qry(key.as_slice())
368 .map_err(|_| err!(Request(NotFound("auth_chain not cached"))))
369 .or_else(async |e| {
370 if key.len() > 1 {
371 return Err(e);
372 }
373
374 self.db
375 .shorteventid_authchain
376 .qry(&key[0])
377 .map_err(|_| err!(Request(NotFound("auth_chain not found"))))
378 .await
379 })
380 .await?
381 .as_chunks::<{ size_of::<u64>() }>()
382 .0
383 .iter()
384 .copied()
385 .map(u64::from_be_bytes)
386 .collect();
387
388 Ok(chain)
389}
390
391#[implement(Service)]
392#[tracing::instrument(
393 name = "auth_events",
394 level = "trace",
395 ret(level = "trace"),
396 err(level = "trace"),
397 skip_all,
398 fields(%event_id)
399)]
400async fn get_event_auth_event_ids<'a>(
401 &'a self,
402 room_id: &'a RoomId,
403 event_id: OwnedEventId,
404) -> Result<AuthEvents> {
405 #[derive(Deserialize)]
406 struct Pdu {
407 auth_events: AuthEvents,
408 room_id: OwnedRoomId,
409 }
410
411 let pdu: Pdu = self
412 .services
413 .timeline
414 .get(&event_id)
415 .inspect_err(|e| {
416 debug_error!(?event_id, ?room_id, "auth chain event: {e}");
417 })
418 .await?;
419
420 if pdu.room_id != room_id {
421 return Err!(Request(Forbidden(error!(
422 ?event_id,
423 ?room_id,
424 wrong_room_id = ?pdu.room_id,
425 "auth event for incorrect room",
426 ))));
427 }
428
429 Ok(pdu.auth_events)
430}