Skip to main content

tuwunel_service/rooms/auth_chain/
mod.rs

1use std::{
2	collections::{BTreeSet, HashSet},
3	iter::once,
4	sync::Arc,
5	time::Instant,
6};
7
8use async_trait::async_trait;
9use futures::{
10	FutureExt, Stream, StreamExt, TryFutureExt, pin_mut,
11	stream::{FuturesUnordered, unfold},
12};
13use ruma::{
14	EventId, OwnedEventId, OwnedRoomId, RoomId, RoomVersionId,
15	room_version_rules::RoomVersionRules,
16};
17use serde::Deserialize;
18use tuwunel_core::{
19	Err, Result, at, debug, debug_error, err, implement,
20	itertools::Itertools,
21	matrix::room_version,
22	pdu::AuthEvents,
23	smallvec::SmallVec,
24	trace,
25	utils::{
26		IterStream,
27		stream::{BroadbandExt, ReadyExt, TryExpect, automatic_width},
28	},
29	validated, warn,
30};
31use tuwunel_database::Map;
32
33use crate::rooms::short::ShortEventId;
34
35pub struct Service {
36	services: Arc<crate::services::OnceServices>,
37	db: Data,
38}
39
40struct Data {
41	authchainkey_authchain: Arc<Map>,
42	shorteventid_authchain: Arc<Map>,
43}
44
45type Bucket<'a> = BTreeSet<(ShortEventId, &'a EventId)>;
46type CacheKey = SmallVec<[ShortEventId; 1]>;
47
48#[async_trait]
49impl crate::Service for Service {
50	fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
51		Ok(Arc::new(Self {
52			services: args.services.clone(),
53			db: Data {
54				authchainkey_authchain: args.db["authchainkey_authchain"].clone(),
55				shorteventid_authchain: args.db["shorteventid_authchain"].clone(),
56			},
57		}))
58	}
59
60	async fn clear_cache(&self) {
61		self.db.authchainkey_authchain.clear().await;
62		self.db.shorteventid_authchain.clear().await;
63	}
64
65	fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
66}
67
68#[implement(Service)]
69pub fn event_ids_iter<'a, I>(
70	&'a self,
71	room_id: &'a RoomId,
72	room_version: &'a RoomVersionId,
73	starting_events: I,
74) -> impl Stream<Item = Result<OwnedEventId>> + Send + 'a
75where
76	I: Iterator<Item = &'a EventId> + Clone + ExactSizeIterator + Send + 'a,
77{
78	self.get_auth_chain(room_id, room_version, starting_events)
79		.map_ok(|chain| {
80			self.services
81				.short
82				.multi_get_eventid_from_short(chain.into_iter().stream())
83				.ready_filter(Result::is_ok)
84		})
85		.try_flatten_stream()
86}
87
88#[implement(Service)]
89#[tracing::instrument(
90	name = "auth_chain",
91	level = "debug",
92	skip_all,
93	fields(
94		%room_id,
95		starting_events = %starting_events.clone().count(),
96	)
97)]
98pub async fn get_auth_chain<'a, I>(
99	&'a self,
100	room_id: &RoomId,
101	room_version: &RoomVersionId,
102	starting_events: I,
103) -> Result<Vec<ShortEventId>>
104where
105	I: Iterator<Item = &'a EventId> + Clone + ExactSizeIterator + Send + 'a,
106{
107	const NUM_BUCKETS: usize = 50; //TODO: change possible w/o disrupting db?
108	const BUCKET: Bucket<'_> = BTreeSet::new();
109
110	let started = Instant::now();
111	let room_rules = room_version::rules(room_version)?;
112	let starting_events_count = starting_events.clone().count();
113	let starting_ids = self
114		.services
115		.short
116		.multi_get_or_create_shorteventid(starting_events.clone())
117		.zip(starting_events.stream());
118
119	pin_mut!(starting_ids);
120	let mut buckets = [BUCKET; NUM_BUCKETS];
121	while let Some((short, starting_event)) = starting_ids.next().await {
122		let bucket: usize = short.try_into()?;
123		let bucket: usize = validated!(bucket % NUM_BUCKETS);
124		buckets[bucket].insert((short, starting_event));
125	}
126
127	debug!(
128		starting_events = starting_events_count,
129		elapsed = ?started.elapsed(),
130		"start",
131	);
132
133	let full_auth_chain: Vec<ShortEventId> = buckets
134		.iter()
135		.stream()
136		.flat_map_unordered(automatic_width(), |starting_events| {
137			self.get_chunk_auth_chain(
138				room_id,
139				&started,
140				starting_events.iter().copied(),
141				&room_rules,
142			)
143			.boxed()
144		})
145		.collect::<Vec<_>>()
146		.map(IntoIterator::into_iter)
147		.map(Itertools::sorted_unstable)
148		.map(Itertools::dedup)
149		.map(Iterator::collect)
150		.boxed()
151		.await;
152
153	debug!(
154		chain_length = ?full_auth_chain.len(),
155		elapsed = ?started.elapsed(),
156		"done",
157	);
158
159	Ok(full_auth_chain)
160}
161
162#[implement(Service)]
163#[tracing::instrument(
164	name = "outer",
165	level = "trace",
166	skip_all,
167	fields(
168		starting_events = %starting_events.clone().count(),
169	)
170)]
171pub fn get_chunk_auth_chain<'a, I>(
172	&'a self,
173	room_id: &'a RoomId,
174	started: &'a Instant,
175	starting_events: I,
176	room_rules: &'a RoomVersionRules,
177) -> impl Stream<Item = ShortEventId> + Send + 'a
178where
179	I: Iterator<Item = (ShortEventId, &'a EventId)> + Clone + Send + Sync + 'a,
180{
181	let starting_shortids = starting_events.clone().map(at!(0));
182
183	let build_chain = async |(shortid, event_id): (ShortEventId, &'a EventId)| {
184		if let Ok(cached) = self.get_cached_auth_chain(once(shortid)).await {
185			return cached;
186		}
187
188		let auth_chain: Vec<_> = self
189			.get_event_auth_chain(room_id, event_id, room_rules)
190			.collect()
191			.await;
192
193		self.put_cached_auth_chain(once(shortid), auth_chain.as_slice());
194		debug!(
195			?event_id,
196			elapsed = ?started.elapsed(),
197			"Cache missed event"
198		);
199
200		auth_chain
201	};
202
203	let cache_chain = move |chunk_cache: &Vec<_>| {
204		self.put_cached_auth_chain(starting_shortids, chunk_cache.as_slice());
205		debug!(
206			chunk_cache_length = ?chunk_cache.len(),
207			elapsed = ?started.elapsed(),
208			"Cache missed chunk",
209		);
210	};
211
212	self.get_cached_auth_chain(starting_events.clone().map(at!(0)))
213		.map_ok(IntoIterator::into_iter)
214		.map_ok(IterStream::try_stream)
215		.or_else(move |_| async move {
216			starting_events
217				.clone()
218				.stream()
219				.broad_then(build_chain)
220				.collect::<Vec<_>>()
221				.map(IntoIterator::into_iter)
222				.map(Iterator::flatten)
223				.map(Itertools::sorted_unstable)
224				.map(Itertools::dedup)
225				.map(Iterator::collect)
226				.inspect(cache_chain)
227				.map(IntoIterator::into_iter)
228				.map(IterStream::try_stream)
229				.map(Ok)
230				.await
231		})
232		.try_flatten_stream()
233		.map_expect("either cache hit or cache miss yields a chain")
234}
235
236#[implement(Service)]
237#[tracing::instrument(name = "inner", level = "trace", skip_all)]
238pub fn get_event_auth_chain<'a>(
239	&'a self,
240	room_id: &'a RoomId,
241	event_id: &'a EventId,
242	room_rules: &'a RoomVersionRules,
243) -> impl Stream<Item = ShortEventId> + Send + 'a {
244	self.get_event_auth_chain_ids(room_id, event_id, room_rules)
245		.broad_then(async move |auth_event| {
246			self.services
247				.short
248				.get_or_create_shorteventid(&auth_event)
249				.await
250		})
251}
252
253#[implement(Service)]
254#[tracing::instrument(
255	name = "inner_ids",
256	level = "trace",
257	skip_all,
258	fields(%event_id)
259)]
260pub fn get_event_auth_chain_ids<'a>(
261	&'a self,
262	room_id: &'a RoomId,
263	event_id: &'a EventId,
264	room_rules: &'a RoomVersionRules,
265) -> impl Stream<Item = OwnedEventId> + Send + 'a {
266	struct State<Fut> {
267		todo: FuturesUnordered<Fut>,
268		seen: HashSet<OwnedEventId>,
269	}
270
271	let starting_events = self.get_event_auth_event_ids(room_id, event_id.to_owned());
272
273	let state = State {
274		todo: once(starting_events).collect(),
275		seen: room_rules
276			.authorization
277			.room_create_event_id_as_room_id
278			.then_some(room_id.as_event_id().ok())
279			.into_iter()
280			.flatten()
281			.collect(),
282	};
283
284	let eval = |auth_events: AuthEvents, mut state: State<_>| {
285		let push = |auth_event: &OwnedEventId| {
286			trace!(todo = state.todo.len(), ?auth_event, "push");
287			state
288				.todo
289				.push(self.get_event_auth_event_ids(room_id, auth_event.clone()));
290		};
291
292		let seen = |auth_event: OwnedEventId| {
293			state
294				.seen
295				.insert(auth_event.clone())
296				.then_some(auth_event)
297		};
298
299		let out = auth_events
300			.into_iter()
301			.filter_map(seen)
302			.inspect(push)
303			.collect::<AuthEvents>()
304			.into_iter()
305			.stream();
306
307		(out, state)
308	};
309
310	unfold(state, move |mut state| async move {
311		match state.todo.next().await {
312			| None => None,
313			| Some(Err(_)) => Some((AuthEvents::new().into_iter().stream(), state)),
314			| Some(Ok(auth_events)) => Some(eval(auth_events, state)),
315		}
316	})
317	.flatten()
318}
319
320#[implement(Service)]
321#[tracing::instrument(
322	name = "cache_put",
323	level = "debug",
324	skip_all,
325	fields(
326		key_len = key.clone().count(),
327		chain_len = auth_chain.len(),
328	)
329)]
330fn put_cached_auth_chain<I>(&self, key: I, auth_chain: &[ShortEventId])
331where
332	I: Iterator<Item = ShortEventId> + Clone + Send,
333{
334	let key = key.collect::<CacheKey>();
335
336	debug_assert!(!key.is_empty(), "auth_chain key must not be empty");
337
338	self.db
339		.authchainkey_authchain
340		.put(key.as_slice(), auth_chain);
341}
342
343#[implement(Service)]
344#[tracing::instrument(
345	name = "cache_get",
346	level = "trace",
347	err(level = "trace"),
348	skip_all,
349	fields(
350		key_len = %key.clone().count()
351	),
352)]
353async fn get_cached_auth_chain<I>(&self, key: I) -> Result<Vec<ShortEventId>>
354where
355	I: Iterator<Item = ShortEventId> + Clone + Send,
356{
357	let key = key.collect::<CacheKey>();
358
359	if key.is_empty() {
360		return Ok(Vec::new());
361	}
362
363	// On miss, fall back to the single-event legacy table for older entries.
364	let chain = self
365		.db
366		.authchainkey_authchain
367		.qry(key.as_slice())
368		.map_err(|_| err!(Request(NotFound("auth_chain not cached"))))
369		.or_else(async |e| {
370			if key.len() > 1 {
371				return Err(e);
372			}
373
374			self.db
375				.shorteventid_authchain
376				.qry(&key[0])
377				.map_err(|_| err!(Request(NotFound("auth_chain not found"))))
378				.await
379		})
380		.await?
381		.as_chunks::<{ size_of::<u64>() }>()
382		.0
383		.iter()
384		.copied()
385		.map(u64::from_be_bytes)
386		.collect();
387
388	Ok(chain)
389}
390
391#[implement(Service)]
392#[tracing::instrument(
393	name = "auth_events",
394	level = "trace",
395	ret(level = "trace"),
396	err(level = "trace"),
397	skip_all,
398	fields(%event_id)
399)]
400async fn get_event_auth_event_ids<'a>(
401	&'a self,
402	room_id: &'a RoomId,
403	event_id: OwnedEventId,
404) -> Result<AuthEvents> {
405	#[derive(Deserialize)]
406	struct Pdu {
407		auth_events: AuthEvents,
408		room_id: OwnedRoomId,
409	}
410
411	let pdu: Pdu = self
412		.services
413		.timeline
414		.get(&event_id)
415		.inspect_err(|e| {
416			debug_error!(?event_id, ?room_id, "auth chain event: {e}");
417		})
418		.await?;
419
420	if pdu.room_id != room_id {
421		return Err!(Request(Forbidden(error!(
422			?event_id,
423			?room_id,
424			wrong_room_id = ?pdu.room_id,
425			"auth event for incorrect room",
426		))));
427	}
428
429	Ok(pdu.auth_events)
430}