Skip to main content

tuwunel_service/rooms/auth_chain/
mod.rs

1use std::{
2	collections::{BTreeSet, HashSet},
3	iter::once,
4	sync::Arc,
5	time::Instant,
6};
7
8use async_trait::async_trait;
9use futures::{
10	FutureExt, Stream, StreamExt, TryFutureExt, pin_mut,
11	stream::{FuturesUnordered, unfold},
12};
13use ruma::{
14	EventId, OwnedEventId, OwnedRoomId, RoomId, RoomVersionId,
15	room_version_rules::RoomVersionRules,
16};
17use serde::Deserialize;
18use tuwunel_core::{
19	Err, Result, at, debug, debug_error, err, implement,
20	itertools::Itertools,
21	matrix::room_version,
22	pdu::AuthEvents,
23	smallvec::SmallVec,
24	trace, utils,
25	utils::{
26		IterStream,
27		stream::{BroadbandExt, ReadyExt, TryExpect, automatic_width},
28	},
29	validated, warn,
30};
31use tuwunel_database::Map;
32
33use crate::rooms::short::ShortEventId;
34
35pub struct Service {
36	services: Arc<crate::services::OnceServices>,
37	db: Data,
38}
39
40struct Data {
41	authchainkey_authchain: Arc<Map>,
42	shorteventid_authchain: Arc<Map>,
43}
44
45type Bucket<'a> = BTreeSet<(ShortEventId, &'a EventId)>;
46type CacheKey = SmallVec<[ShortEventId; 1]>;
47
48#[async_trait]
49impl crate::Service for Service {
50	fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
51		Ok(Arc::new(Self {
52			services: args.services.clone(),
53			db: Data {
54				authchainkey_authchain: args.db["authchainkey_authchain"].clone(),
55				shorteventid_authchain: args.db["shorteventid_authchain"].clone(),
56			},
57		}))
58	}
59
60	async fn clear_cache(&self) { self.db.authchainkey_authchain.clear().await; }
61
62	fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
63}
64
65#[implement(Service)]
66pub fn event_ids_iter<'a, I>(
67	&'a self,
68	room_id: &'a RoomId,
69	room_version: &'a RoomVersionId,
70	starting_events: I,
71) -> impl Stream<Item = Result<OwnedEventId>> + Send + 'a
72where
73	I: Iterator<Item = &'a EventId> + Clone + ExactSizeIterator + Send + 'a,
74{
75	self.get_auth_chain(room_id, room_version, starting_events)
76		.map_ok(|chain| {
77			self.services
78				.short
79				.multi_get_eventid_from_short(chain.into_iter().stream())
80				.ready_filter(Result::is_ok)
81		})
82		.try_flatten_stream()
83}
84
85#[implement(Service)]
86#[tracing::instrument(
87	name = "auth_chain",
88	level = "debug",
89	skip_all,
90	fields(
91		%room_id,
92		starting_events = %starting_events.clone().count(),
93	)
94)]
95pub async fn get_auth_chain<'a, I>(
96	&'a self,
97	room_id: &RoomId,
98	room_version: &RoomVersionId,
99	starting_events: I,
100) -> Result<Vec<ShortEventId>>
101where
102	I: Iterator<Item = &'a EventId> + Clone + ExactSizeIterator + Send + 'a,
103{
104	const NUM_BUCKETS: usize = 50; //TODO: change possible w/o disrupting db?
105	const BUCKET: Bucket<'_> = BTreeSet::new();
106
107	let started = Instant::now();
108	let room_rules = room_version::rules(room_version)?;
109	let starting_events_count = starting_events.clone().count();
110	let starting_ids = self
111		.services
112		.short
113		.multi_get_or_create_shorteventid(starting_events.clone())
114		.zip(starting_events.stream());
115
116	pin_mut!(starting_ids);
117	let mut buckets = [BUCKET; NUM_BUCKETS];
118	while let Some((short, starting_event)) = starting_ids.next().await {
119		let bucket: usize = short.try_into()?;
120		let bucket: usize = validated!(bucket % NUM_BUCKETS);
121		buckets[bucket].insert((short, starting_event));
122	}
123
124	debug!(
125		starting_events = starting_events_count,
126		elapsed = ?started.elapsed(),
127		"start",
128	);
129
130	let full_auth_chain: Vec<ShortEventId> = buckets
131		.iter()
132		.stream()
133		.flat_map_unordered(automatic_width(), |starting_events| {
134			self.get_chunk_auth_chain(
135				room_id,
136				&started,
137				starting_events.iter().copied(),
138				&room_rules,
139			)
140			.boxed()
141		})
142		.collect::<Vec<_>>()
143		.map(IntoIterator::into_iter)
144		.map(Itertools::sorted_unstable)
145		.map(Itertools::dedup)
146		.map(Iterator::collect)
147		.boxed()
148		.await;
149
150	debug!(
151		chain_length = ?full_auth_chain.len(),
152		elapsed = ?started.elapsed(),
153		"done",
154	);
155
156	Ok(full_auth_chain)
157}
158
159#[implement(Service)]
160#[tracing::instrument(
161	name = "outer",
162	level = "trace",
163	skip_all,
164	fields(
165		starting_events = %starting_events.clone().count(),
166	)
167)]
168pub fn get_chunk_auth_chain<'a, I>(
169	&'a self,
170	room_id: &'a RoomId,
171	started: &'a Instant,
172	starting_events: I,
173	room_rules: &'a RoomVersionRules,
174) -> impl Stream<Item = ShortEventId> + Send + 'a
175where
176	I: Iterator<Item = (ShortEventId, &'a EventId)> + Clone + Send + Sync + 'a,
177{
178	let starting_shortids = starting_events.clone().map(at!(0));
179
180	let build_chain = async |(shortid, event_id): (ShortEventId, &'a EventId)| {
181		if let Ok(cached) = self.get_cached_auth_chain(once(shortid)).await {
182			return cached;
183		}
184
185		let auth_chain: Vec<_> = self
186			.get_event_auth_chain(room_id, event_id, room_rules)
187			.collect()
188			.await;
189
190		self.put_cached_auth_chain(once(shortid), auth_chain.as_slice());
191		debug!(
192			?event_id,
193			elapsed = ?started.elapsed(),
194			"Cache missed event"
195		);
196
197		auth_chain
198	};
199
200	let cache_chain = move |chunk_cache: &Vec<_>| {
201		self.put_cached_auth_chain(starting_shortids, chunk_cache.as_slice());
202		debug!(
203			chunk_cache_length = ?chunk_cache.len(),
204			elapsed = ?started.elapsed(),
205			"Cache missed chunk",
206		);
207	};
208
209	self.get_cached_auth_chain(starting_events.clone().map(at!(0)))
210		.map_ok(IntoIterator::into_iter)
211		.map_ok(IterStream::try_stream)
212		.or_else(move |_| async move {
213			starting_events
214				.clone()
215				.stream()
216				.broad_then(build_chain)
217				.collect::<Vec<_>>()
218				.map(IntoIterator::into_iter)
219				.map(Iterator::flatten)
220				.map(Itertools::sorted_unstable)
221				.map(Itertools::dedup)
222				.map(Iterator::collect)
223				.inspect(cache_chain)
224				.map(IntoIterator::into_iter)
225				.map(IterStream::try_stream)
226				.map(Ok)
227				.await
228		})
229		.try_flatten_stream()
230		.map_expect("either cache hit or cache miss yields a chain")
231}
232
233#[implement(Service)]
234#[tracing::instrument(name = "inner", level = "trace", skip_all)]
235pub fn get_event_auth_chain<'a>(
236	&'a self,
237	room_id: &'a RoomId,
238	event_id: &'a EventId,
239	room_rules: &'a RoomVersionRules,
240) -> impl Stream<Item = ShortEventId> + Send + 'a {
241	self.get_event_auth_chain_ids(room_id, event_id, room_rules)
242		.broad_then(async move |auth_event| {
243			self.services
244				.short
245				.get_or_create_shorteventid(&auth_event)
246				.await
247		})
248}
249
250#[implement(Service)]
251#[tracing::instrument(
252	name = "inner_ids",
253	level = "trace",
254	skip_all,
255	fields(%event_id)
256)]
257pub fn get_event_auth_chain_ids<'a>(
258	&'a self,
259	room_id: &'a RoomId,
260	event_id: &'a EventId,
261	room_rules: &'a RoomVersionRules,
262) -> impl Stream<Item = OwnedEventId> + Send + 'a {
263	struct State<Fut> {
264		todo: FuturesUnordered<Fut>,
265		seen: HashSet<OwnedEventId>,
266	}
267
268	let starting_events = self.get_event_auth_event_ids(room_id, event_id.to_owned());
269
270	let state = State {
271		todo: once(starting_events).collect(),
272		seen: room_rules
273			.authorization
274			.room_create_event_id_as_room_id
275			.then_some(room_id.as_event_id().ok())
276			.into_iter()
277			.flatten()
278			.collect(),
279	};
280
281	let eval = |auth_events: AuthEvents, mut state: State<_>| {
282		let push = |auth_event: &OwnedEventId| {
283			trace!(todo = state.todo.len(), ?auth_event, "push");
284			state
285				.todo
286				.push(self.get_event_auth_event_ids(room_id, auth_event.clone()));
287		};
288
289		let seen = |auth_event: OwnedEventId| {
290			state
291				.seen
292				.insert(auth_event.clone())
293				.then_some(auth_event)
294		};
295
296		let out = auth_events
297			.into_iter()
298			.filter_map(seen)
299			.inspect(push)
300			.collect::<AuthEvents>()
301			.into_iter()
302			.stream();
303
304		(out, state)
305	};
306
307	unfold(state, move |mut state| async move {
308		match state.todo.next().await {
309			| None => None,
310			| Some(Err(_)) => Some((AuthEvents::new().into_iter().stream(), state)),
311			| Some(Ok(auth_events)) => Some(eval(auth_events, state)),
312		}
313	})
314	.flatten()
315}
316
317#[implement(Service)]
318#[tracing::instrument(
319	name = "cache_put",
320	level = "debug",
321	skip_all,
322	fields(
323		key_len = key.clone().count(),
324		chain_len = auth_chain.len(),
325	)
326)]
327fn put_cached_auth_chain<I>(&self, key: I, auth_chain: &[ShortEventId])
328where
329	I: Iterator<Item = ShortEventId> + Clone + Send,
330{
331	let key = key.collect::<CacheKey>();
332
333	debug_assert!(!key.is_empty(), "auth_chain key must not be empty");
334
335	self.db
336		.authchainkey_authchain
337		.put(key.as_slice(), auth_chain);
338
339	if key.len() == 1 {
340		self.db
341			.shorteventid_authchain
342			.put(key, auth_chain);
343	}
344}
345
346#[implement(Service)]
347#[tracing::instrument(
348	name = "cache_get",
349	level = "trace",
350	err(level = "trace"),
351	skip_all,
352	fields(
353		key_len = %key.clone().count()
354	),
355)]
356async fn get_cached_auth_chain<I>(&self, key: I) -> Result<Vec<ShortEventId>>
357where
358	I: Iterator<Item = ShortEventId> + Clone + Send,
359{
360	let key = key.collect::<CacheKey>();
361
362	if key.is_empty() {
363		return Ok(Vec::new());
364	}
365
366	// Check cache. On miss, check first-order table for single-event keys.
367	let chain = self
368		.db
369		.authchainkey_authchain
370		.qry(key.as_slice())
371		.map_err(|_| err!(Request(NotFound("auth_chain not cached"))))
372		.or_else(async |e| {
373			if key.len() > 1 {
374				return Err(e);
375			}
376
377			self.db
378				.shorteventid_authchain
379				.qry(&key[0])
380				.map_err(|_| err!(Request(NotFound("auth_chain not found"))))
381				.await
382		})
383		.await?
384		.chunks_exact(size_of::<u64>())
385		.map(utils::u64_from_u8)
386		.collect();
387
388	Ok(chain)
389}
390
391#[implement(Service)]
392#[tracing::instrument(
393	name = "auth_events",
394	level = "trace",
395	ret(level = "trace"),
396	err(level = "trace"),
397	skip_all,
398	fields(%event_id)
399)]
400async fn get_event_auth_event_ids<'a>(
401	&'a self,
402	room_id: &'a RoomId,
403	event_id: OwnedEventId,
404) -> Result<AuthEvents> {
405	#[derive(Deserialize)]
406	struct Pdu {
407		auth_events: AuthEvents,
408		room_id: OwnedRoomId,
409	}
410
411	let pdu: Pdu = self
412		.services
413		.timeline
414		.get(&event_id)
415		.inspect_err(|e| {
416			debug_error!(?event_id, ?room_id, "auth chain event: {e}");
417		})
418		.await?;
419
420	if pdu.room_id != room_id {
421		return Err!(Request(Forbidden(error!(
422			?event_id,
423			?room_id,
424			wrong_room_id = ?pdu.room_id,
425			"auth event for incorrect room",
426		))));
427	}
428
429	Ok(pdu.auth_events)
430}