Skip to main content

tuwunel_service/rooms/state_compressor/
mod.rs

1use std::{
2	collections::{BTreeSet, HashMap},
3	fmt::{Debug, Write},
4	mem::size_of,
5	sync::{Arc, Mutex},
6};
7
8use async_trait::async_trait;
9use futures::{Stream, StreamExt};
10use lru_cache::LruCache;
11use ruma::{EventId, RoomId};
12use tuwunel_core::{
13	Result,
14	arrayvec::ArrayVec,
15	at, checked, err, expected, implement, utils,
16	utils::{bytes, math::usize_from_f64, stream::IterStream},
17};
18use tuwunel_database::Map;
19
20use crate::rooms::short::{ShortEventId, ShortId, ShortStateHash, ShortStateKey};
21
22pub struct Service {
23	pub stateinfo_cache: Mutex<StateInfoLruCache>,
24	db: Data,
25	services: Arc<crate::services::OnceServices>,
26}
27
28struct Data {
29	shortstatehash_statediff: Arc<Map>,
30}
31
32#[derive(Clone)]
33struct StateDiff {
34	parent: Option<ShortStateHash>,
35	added: Arc<CompressedState>,
36	removed: Arc<CompressedState>,
37}
38
39#[derive(Clone, Default)]
40pub struct ShortStateInfo {
41	pub shortstatehash: ShortStateHash,
42	pub full_state: Arc<CompressedState>,
43	pub added: Arc<CompressedState>,
44	pub removed: Arc<CompressedState>,
45}
46
47#[derive(Clone, Default)]
48pub struct HashSetCompressStateEvent {
49	pub shortstatehash: ShortStateHash,
50	pub added: Arc<CompressedState>,
51	pub removed: Arc<CompressedState>,
52}
53
54type StateInfoLruCache = LruCache<ShortStateHash, ShortStateInfoVec>;
55type ShortStateInfoVec = Vec<ShortStateInfo>;
56type ParentStatesVec = Vec<ShortStateInfo>;
57
58pub type CompressedState = BTreeSet<CompressedStateEvent>;
59pub type CompressedStateEvent = [u8; 2 * size_of::<ShortId>()];
60
61#[async_trait]
62impl crate::Service for Service {
63	fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
64		let config = &args.server.config;
65		let cache_capacity =
66			f64::from(config.stateinfo_cache_capacity) * config.cache_capacity_modifier;
67		Ok(Arc::new(Self {
68			stateinfo_cache: LruCache::new(usize_from_f64(cache_capacity)?).into(),
69			db: Data {
70				shortstatehash_statediff: args.db["shortstatehash_statediff"].clone(),
71			},
72			services: args.services.clone(),
73		}))
74	}
75
76	async fn memory_usage(&self, out: &mut (dyn Write + Send)) -> Result {
77		let (cache_len, ents) = {
78			let cache = self.stateinfo_cache.lock().expect("locked");
79			let ents = cache
80				.iter()
81				.map(at!(1))
82				.flat_map(|vec| vec.iter())
83				.fold(HashMap::new(), |mut ents, ssi| {
84					for cs in &[&ssi.added, &ssi.removed, &ssi.full_state] {
85						ents.insert(Arc::as_ptr(cs), compressed_state_size(cs));
86					}
87
88					ents
89				});
90
91			(cache.len(), ents)
92		};
93
94		let ents_len = ents.len();
95		let bytes = ents
96			.values()
97			.copied()
98			.fold(0_usize, usize::saturating_add);
99
100		let bytes = bytes::pretty(bytes);
101		writeln!(out, "stateinfo_cache: {cache_len} {ents_len} ({bytes})")?;
102
103		Ok(())
104	}
105
106	async fn clear_cache(&self) {
107		self.stateinfo_cache
108			.lock()
109			.expect("locked")
110			.clear();
111	}
112
113	fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
114}
115
116/// Returns a stack with info on shortstatehash, full state, added diff and
117/// removed diff for the selected shortstatehash and each parent layer.
118#[implement(Service)]
119#[tracing::instrument(name = "load", level = "debug", skip(self))]
120pub async fn load_shortstatehash_info(
121	&self,
122	shortstatehash: ShortStateHash,
123) -> Result<ShortStateInfoVec> {
124	if let Some(r) = self
125		.stateinfo_cache
126		.lock()?
127		.get_mut(&shortstatehash)
128	{
129		return Ok(r.clone());
130	}
131
132	let stack = self
133		.new_shortstatehash_info(shortstatehash)
134		.await?;
135
136	self.cache_shortstatehash_info(shortstatehash, stack.clone())
137		.await?;
138
139	Ok(stack)
140}
141
142/// Returns a stack with info on shortstatehash, full state, added diff and
143/// removed diff for the selected shortstatehash and each parent layer.
144#[implement(Service)]
145#[tracing::instrument(
146		name = "cache",
147		level = "debug",
148		skip_all,
149		fields(
150			?shortstatehash,
151			stack = stack.len(),
152		),
153	)]
154async fn cache_shortstatehash_info(
155	&self,
156	shortstatehash: ShortStateHash,
157	stack: ShortStateInfoVec,
158) -> Result {
159	self.stateinfo_cache
160		.lock()?
161		.insert(shortstatehash, stack);
162
163	Ok(())
164}
165
166#[implement(Service)]
167async fn new_shortstatehash_info(
168	&self,
169	shortstatehash: ShortStateHash,
170) -> Result<ShortStateInfoVec> {
171	let StateDiff { parent, added, removed } = self.get_statediff(shortstatehash).await?;
172
173	let Some(parent) = parent else {
174		return Ok(vec![ShortStateInfo {
175			shortstatehash,
176			full_state: added.clone(),
177			added,
178			removed,
179		}]);
180	};
181
182	let mut stack = Box::pin(self.load_shortstatehash_info(parent)).await?;
183	let top = stack.last().expect("at least one frame");
184
185	let mut full_state = (*top.full_state).clone();
186	full_state.extend(added.iter().copied());
187
188	let removed = (*removed).clone();
189	for r in &removed {
190		full_state.remove(r);
191	}
192
193	stack.push(ShortStateInfo {
194		shortstatehash,
195		added,
196		removed: Arc::new(removed),
197		full_state: Arc::new(full_state),
198	});
199
200	Ok(stack)
201}
202
203#[implement(Service)]
204pub fn compress_state_events<'a, I>(
205	&'a self,
206	state: I,
207) -> impl Stream<Item = CompressedStateEvent> + Send + 'a
208where
209	I: Iterator<Item = (&'a ShortStateKey, &'a EventId)> + Clone + Debug + Send + 'a,
210{
211	let event_ids = state.clone().map(at!(1));
212
213	let short_event_ids = self
214		.services
215		.short
216		.multi_get_or_create_shorteventid(event_ids);
217
218	state
219		.stream()
220		.map(at!(0))
221		.zip(short_event_ids)
222		.map(|(shortstatekey, shorteventid)| compress_state_event(*shortstatekey, shorteventid))
223}
224
225#[implement(Service)]
226pub async fn compress_state_event(
227	&self,
228	shortstatekey: ShortStateKey,
229	event_id: &EventId,
230) -> CompressedStateEvent {
231	let shorteventid = self
232		.services
233		.short
234		.get_or_create_shorteventid(event_id)
235		.await;
236
237	compress_state_event(shortstatekey, shorteventid)
238}
239
240/// Creates a new shortstatehash that often is just a diff to an already
241/// existing shortstatehash and therefore very efficient.
242///
243/// There are multiple layers of diffs. The bottom layer 0 always contains
244/// the full state. Layer 1 contains diffs to states of layer 0, layer 2
245/// diffs to layer 1 and so on. If layer n > 0 grows too big, it will be
246/// combined with layer n-1 to create a new diff on layer n-1 that's
247/// based on layer n-2. If that layer is also too big, it will recursively
248/// fix above layers too.
249///
250/// * `shortstatehash` - Shortstatehash of this state
251/// * `statediffnew` - Added to base. Each vec is shortstatekey+shorteventid
252/// * `statediffremoved` - Removed from base. Each vec is
253///   shortstatekey+shorteventid
254/// * `diff_to_sibling` - Approximately how much the diff grows each time for
255///   this layer
256/// * `parent_states` - A stack with info on shortstatehash, full state, added
257///   diff and removed diff for each parent layer
258#[implement(Service)]
259pub fn save_state_from_diff(
260	&self,
261	shortstatehash: ShortStateHash,
262	statediffnew: Arc<CompressedState>,
263	statediffremoved: Arc<CompressedState>,
264	diff_to_sibling: usize,
265	mut parent_states: ParentStatesVec,
266) -> Result {
267	let statediffnew_len = statediffnew.len();
268	let statediffremoved_len = statediffremoved.len();
269	let diffsum = checked!(statediffnew_len + statediffremoved_len)?;
270
271	if parent_states.len() > 3 {
272		// Number of layers
273		// To many layers, we have to go deeper
274		let parent = parent_states
275			.pop()
276			.expect("parent must have a state");
277
278		let mut parent_new = (*parent.added).clone();
279		let mut parent_removed = (*parent.removed).clone();
280
281		for removed in statediffremoved.iter() {
282			if !parent_new.remove(removed) {
283				// It was not added in the parent and we removed it
284				parent_removed.insert(*removed);
285			}
286			// Else it was added in the parent and we removed it again. We
287			// can forget this change
288		}
289
290		for new in statediffnew.iter() {
291			if !parent_removed.remove(new) {
292				// It was not touched in the parent and we added it
293				parent_new.insert(*new);
294			}
295			// Else it was removed in the parent and we added it again. We
296			// can forget this change
297		}
298
299		self.save_state_from_diff(
300			shortstatehash,
301			Arc::new(parent_new),
302			Arc::new(parent_removed),
303			diffsum,
304			parent_states,
305		)?;
306
307		return Ok(());
308	}
309
310	if parent_states.is_empty() {
311		// There is no parent layer, create a new state
312		self.save_statediff(shortstatehash, &StateDiff {
313			parent: None,
314			added: statediffnew,
315			removed: statediffremoved,
316		});
317
318		return Ok(());
319	}
320
321	// Else we have two options.
322	// 1. We add the current diff on top of the parent layer.
323	// 2. We replace a layer above
324
325	let parent = parent_states
326		.pop()
327		.expect("parent must have a state");
328	let parent_added_len = parent.added.len();
329	let parent_removed_len = parent.removed.len();
330	let parent_diff = checked!(parent_added_len + parent_removed_len)?;
331
332	if checked!(diffsum * diffsum)? >= checked!(2 * diff_to_sibling * parent_diff)? {
333		// Diff too big, we replace above layer(s)
334		let mut parent_new = (*parent.added).clone();
335		let mut parent_removed = (*parent.removed).clone();
336
337		for removed in statediffremoved.iter() {
338			if !parent_new.remove(removed) {
339				// It was not added in the parent and we removed it
340				parent_removed.insert(*removed);
341			}
342			// Else it was added in the parent and we removed it again. We
343			// can forget this change
344		}
345
346		for new in statediffnew.iter() {
347			if !parent_removed.remove(new) {
348				// It was not touched in the parent and we added it
349				parent_new.insert(*new);
350			}
351			// Else it was removed in the parent and we added it again. We
352			// can forget this change
353		}
354
355		self.save_state_from_diff(
356			shortstatehash,
357			Arc::new(parent_new),
358			Arc::new(parent_removed),
359			diffsum,
360			parent_states,
361		)?;
362	} else {
363		// Diff small enough, we add diff as layer on top of parent
364		self.save_statediff(shortstatehash, &StateDiff {
365			parent: Some(parent.shortstatehash),
366			added: statediffnew,
367			removed: statediffremoved,
368		});
369	}
370
371	Ok(())
372}
373
374/// Returns the new shortstatehash, and the state diff from the previous
375/// room state
376#[implement(Service)]
377#[tracing::instrument(skip(self, new_state_ids_compressed), level = "debug")]
378pub async fn save_state(
379	&self,
380	room_id: &RoomId,
381	new_state_ids_compressed: Arc<CompressedState>,
382) -> Result<HashSetCompressStateEvent> {
383	let previous_shortstatehash = self
384		.services
385		.state
386		.get_room_shortstatehash(room_id)
387		.await
388		.ok();
389
390	let state_hash = utils::calculate_hash(
391		new_state_ids_compressed
392			.iter()
393			.map(|bytes| &bytes[..]),
394	);
395
396	let (new_shortstatehash, already_existed) = self
397		.services
398		.short
399		.get_or_create_shortstatehash(&state_hash)
400		.await;
401
402	if Some(new_shortstatehash) == previous_shortstatehash {
403		return Ok(HashSetCompressStateEvent {
404			shortstatehash: new_shortstatehash,
405			..Default::default()
406		});
407	}
408
409	let states_parents = if let Some(p) = previous_shortstatehash {
410		self.load_shortstatehash_info(p)
411			.await
412			.unwrap_or_default()
413	} else {
414		ShortStateInfoVec::new()
415	};
416
417	let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() {
418		let statediffnew: CompressedState = new_state_ids_compressed
419			.difference(&parent_stateinfo.full_state)
420			.copied()
421			.collect();
422
423		let statediffremoved: CompressedState = parent_stateinfo
424			.full_state
425			.difference(&new_state_ids_compressed)
426			.copied()
427			.collect();
428
429		(Arc::new(statediffnew), Arc::new(statediffremoved))
430	} else {
431		(new_state_ids_compressed, Arc::new(CompressedState::new()))
432	};
433
434	if !already_existed {
435		self.save_state_from_diff(
436			new_shortstatehash,
437			statediffnew.clone(),
438			statediffremoved.clone(),
439			2, // every state change is 2 event changes on average
440			states_parents,
441		)?;
442	}
443
444	Ok(HashSetCompressStateEvent {
445		shortstatehash: new_shortstatehash,
446		added: statediffnew,
447		removed: statediffremoved,
448	})
449}
450
451#[implement(Service)]
452#[tracing::instrument(skip(self), level = "debug", name = "get")]
453async fn get_statediff(&self, shortstatehash: ShortStateHash) -> Result<StateDiff> {
454	const BUFSIZE: usize = size_of::<ShortStateHash>();
455	const STRIDE: usize = size_of::<ShortStateHash>();
456
457	let value = self
458		.db
459		.shortstatehash_statediff
460		.aqry::<BUFSIZE, _>(&shortstatehash)
461		.await
462		.map_err(|e| {
463			err!(Database("Failed to find StateDiff from short {shortstatehash:?}: {e}"))
464		})?;
465
466	let parent = utils::u64_from_bytes(&value[0..size_of::<u64>()])
467		.ok()
468		.take_if(|parent| *parent != 0);
469
470	debug_assert!(value.len().is_multiple_of(STRIDE), "value not aligned to stride");
471	let _num_values = value.len() / STRIDE;
472
473	let mut add_mode = true;
474	let mut added = CompressedState::new();
475	let mut removed = CompressedState::new();
476
477	let mut i = STRIDE;
478	while let Some(v) = value.get(i..expected!(i + 2 * STRIDE)) {
479		if add_mode && v.starts_with(&0_u64.to_be_bytes()) {
480			add_mode = false;
481			i = expected!(i + STRIDE);
482			continue;
483		}
484		if add_mode {
485			added.insert(v.try_into()?);
486		} else {
487			removed.insert(v.try_into()?);
488		}
489		i = expected!(i + 2 * STRIDE);
490	}
491
492	Ok(StateDiff {
493		parent,
494		added: Arc::new(added),
495		removed: Arc::new(removed),
496	})
497}
498
499#[implement(Service)]
500fn save_statediff(&self, shortstatehash: ShortStateHash, diff: &StateDiff) {
501	let mut value = Vec::<u8>::with_capacity(
502		2_usize
503			.saturating_add(diff.added.len())
504			.saturating_add(diff.removed.len()),
505	);
506
507	let parent = diff.parent.unwrap_or(0_u64);
508	value.extend_from_slice(&parent.to_be_bytes());
509
510	for new in diff.added.iter() {
511		value.extend_from_slice(&new[..]);
512	}
513
514	if !diff.removed.is_empty() {
515		value.extend_from_slice(&0_u64.to_be_bytes());
516		for removed in diff.removed.iter() {
517			value.extend_from_slice(&removed[..]);
518		}
519	}
520
521	self.db
522		.shortstatehash_statediff
523		.insert(&shortstatehash.to_be_bytes(), &value);
524}
525
526#[inline]
527#[must_use]
528pub(crate) fn compress_state_event(
529	shortstatekey: ShortStateKey,
530	shorteventid: ShortEventId,
531) -> CompressedStateEvent {
532	const SIZE: usize = size_of::<CompressedStateEvent>();
533
534	let mut v = ArrayVec::<u8, SIZE>::new();
535	v.extend(shortstatekey.to_be_bytes());
536	v.extend(shorteventid.to_be_bytes());
537	v.as_ref()
538		.try_into()
539		.expect("failed to create CompressedStateEvent")
540}
541
542#[inline]
543#[must_use]
544pub(crate) fn parse_compressed_state_event(
545	compressed_event: CompressedStateEvent,
546) -> (ShortStateKey, ShortEventId) {
547	use utils::u64_from_u8;
548
549	let shortstatekey = u64_from_u8(&compressed_event[0..size_of::<ShortStateKey>()]);
550	let shorteventid = u64_from_u8(&compressed_event[size_of::<ShortStateKey>()..]);
551
552	(shortstatekey, shorteventid)
553}
554
555#[inline]
556fn compressed_state_size(compressed_state: &CompressedState) -> usize {
557	compressed_state
558		.len()
559		.checked_mul(size_of::<CompressedStateEvent>())
560		.expect("CompressedState size overflow")
561}