tuwunel_service/rooms/state_compressor/
mod.rs1use std::{
2 collections::{BTreeSet, HashMap},
3 fmt::{Debug, Write},
4 mem::size_of,
5 sync::{Arc, Mutex},
6};
7
8use async_trait::async_trait;
9use futures::{Stream, StreamExt};
10use lru_cache::LruCache;
11use ruma::{EventId, RoomId};
12use tuwunel_core::{
13 Result,
14 arrayvec::ArrayVec,
15 at, checked, err, expected, implement, utils,
16 utils::{bytes, math::usize_from_f64, stream::IterStream},
17};
18use tuwunel_database::Map;
19
20use crate::rooms::short::{ShortEventId, ShortId, ShortStateHash, ShortStateKey};
21
22pub struct Service {
23 pub stateinfo_cache: Mutex<StateInfoLruCache>,
24 db: Data,
25 services: Arc<crate::services::OnceServices>,
26}
27
28struct Data {
29 shortstatehash_statediff: Arc<Map>,
30}
31
32#[derive(Clone)]
33struct StateDiff {
34 parent: Option<ShortStateHash>,
35 added: Arc<CompressedState>,
36 removed: Arc<CompressedState>,
37}
38
39#[derive(Clone, Default)]
40pub struct ShortStateInfo {
41 pub shortstatehash: ShortStateHash,
42 pub full_state: Arc<CompressedState>,
43 pub added: Arc<CompressedState>,
44 pub removed: Arc<CompressedState>,
45}
46
47#[derive(Clone, Default)]
48pub struct HashSetCompressStateEvent {
49 pub shortstatehash: ShortStateHash,
50 pub added: Arc<CompressedState>,
51 pub removed: Arc<CompressedState>,
52}
53
54type StateInfoLruCache = LruCache<ShortStateHash, ShortStateInfoVec>;
55type ShortStateInfoVec = Vec<ShortStateInfo>;
56type ParentStatesVec = Vec<ShortStateInfo>;
57
58pub type CompressedState = BTreeSet<CompressedStateEvent>;
59pub type CompressedStateEvent = [u8; 2 * size_of::<ShortId>()];
60
61#[async_trait]
62impl crate::Service for Service {
63 fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
64 let config = &args.server.config;
65 let cache_capacity =
66 f64::from(config.stateinfo_cache_capacity) * config.cache_capacity_modifier;
67 Ok(Arc::new(Self {
68 stateinfo_cache: LruCache::new(usize_from_f64(cache_capacity)?).into(),
69 db: Data {
70 shortstatehash_statediff: args.db["shortstatehash_statediff"].clone(),
71 },
72 services: args.services.clone(),
73 }))
74 }
75
76 async fn memory_usage(&self, out: &mut (dyn Write + Send)) -> Result {
77 let (cache_len, ents) = {
78 let cache = self.stateinfo_cache.lock().expect("locked");
79 let ents = cache
80 .iter()
81 .map(at!(1))
82 .flat_map(|vec| vec.iter())
83 .fold(HashMap::new(), |mut ents, ssi| {
84 for cs in &[&ssi.added, &ssi.removed, &ssi.full_state] {
85 ents.insert(Arc::as_ptr(cs), compressed_state_size(cs));
86 }
87
88 ents
89 });
90
91 (cache.len(), ents)
92 };
93
94 let ents_len = ents.len();
95 let bytes = ents
96 .values()
97 .copied()
98 .fold(0_usize, usize::saturating_add);
99
100 let bytes = bytes::pretty(bytes);
101 writeln!(out, "stateinfo_cache: {cache_len} {ents_len} ({bytes})")?;
102
103 Ok(())
104 }
105
106 async fn clear_cache(&self) {
107 self.stateinfo_cache
108 .lock()
109 .expect("locked")
110 .clear();
111 }
112
113 fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
114}
115
116#[implement(Service)]
119#[tracing::instrument(name = "load", level = "debug", skip(self))]
120pub async fn load_shortstatehash_info(
121 &self,
122 shortstatehash: ShortStateHash,
123) -> Result<ShortStateInfoVec> {
124 if let Some(r) = self
125 .stateinfo_cache
126 .lock()?
127 .get_mut(&shortstatehash)
128 {
129 return Ok(r.clone());
130 }
131
132 let stack = self
133 .new_shortstatehash_info(shortstatehash)
134 .await?;
135
136 self.cache_shortstatehash_info(shortstatehash, stack.clone())
137 .await?;
138
139 Ok(stack)
140}
141
142#[implement(Service)]
145#[tracing::instrument(
146 name = "cache",
147 level = "debug",
148 skip_all,
149 fields(
150 ?shortstatehash,
151 stack = stack.len(),
152 ),
153 )]
154async fn cache_shortstatehash_info(
155 &self,
156 shortstatehash: ShortStateHash,
157 stack: ShortStateInfoVec,
158) -> Result {
159 self.stateinfo_cache
160 .lock()?
161 .insert(shortstatehash, stack);
162
163 Ok(())
164}
165
166#[implement(Service)]
167async fn new_shortstatehash_info(
168 &self,
169 shortstatehash: ShortStateHash,
170) -> Result<ShortStateInfoVec> {
171 let StateDiff { parent, added, removed } = self.get_statediff(shortstatehash).await?;
172
173 let Some(parent) = parent else {
174 return Ok(vec![ShortStateInfo {
175 shortstatehash,
176 full_state: added.clone(),
177 added,
178 removed,
179 }]);
180 };
181
182 let mut stack = Box::pin(self.load_shortstatehash_info(parent)).await?;
183 let top = stack.last().expect("at least one frame");
184
185 let mut full_state = (*top.full_state).clone();
186 full_state.extend(added.iter().copied());
187
188 let removed = (*removed).clone();
189 for r in &removed {
190 full_state.remove(r);
191 }
192
193 stack.push(ShortStateInfo {
194 shortstatehash,
195 added,
196 removed: Arc::new(removed),
197 full_state: Arc::new(full_state),
198 });
199
200 Ok(stack)
201}
202
203#[implement(Service)]
204pub fn compress_state_events<'a, I>(
205 &'a self,
206 state: I,
207) -> impl Stream<Item = CompressedStateEvent> + Send + 'a
208where
209 I: Iterator<Item = (&'a ShortStateKey, &'a EventId)> + Clone + Debug + Send + 'a,
210{
211 let event_ids = state.clone().map(at!(1));
212
213 let short_event_ids = self
214 .services
215 .short
216 .multi_get_or_create_shorteventid(event_ids);
217
218 state
219 .stream()
220 .map(at!(0))
221 .zip(short_event_ids)
222 .map(|(shortstatekey, shorteventid)| compress_state_event(*shortstatekey, shorteventid))
223}
224
225#[implement(Service)]
226pub async fn compress_state_event(
227 &self,
228 shortstatekey: ShortStateKey,
229 event_id: &EventId,
230) -> CompressedStateEvent {
231 let shorteventid = self
232 .services
233 .short
234 .get_or_create_shorteventid(event_id)
235 .await;
236
237 compress_state_event(shortstatekey, shorteventid)
238}
239
240#[implement(Service)]
259pub fn save_state_from_diff(
260 &self,
261 shortstatehash: ShortStateHash,
262 statediffnew: Arc<CompressedState>,
263 statediffremoved: Arc<CompressedState>,
264 diff_to_sibling: usize,
265 mut parent_states: ParentStatesVec,
266) -> Result {
267 let statediffnew_len = statediffnew.len();
268 let statediffremoved_len = statediffremoved.len();
269 let diffsum = checked!(statediffnew_len + statediffremoved_len)?;
270
271 if parent_states.len() > 3 {
272 let parent = parent_states
275 .pop()
276 .expect("parent must have a state");
277
278 let mut parent_new = (*parent.added).clone();
279 let mut parent_removed = (*parent.removed).clone();
280
281 for removed in statediffremoved.iter() {
282 if !parent_new.remove(removed) {
283 parent_removed.insert(*removed);
285 }
286 }
289
290 for new in statediffnew.iter() {
291 if !parent_removed.remove(new) {
292 parent_new.insert(*new);
294 }
295 }
298
299 self.save_state_from_diff(
300 shortstatehash,
301 Arc::new(parent_new),
302 Arc::new(parent_removed),
303 diffsum,
304 parent_states,
305 )?;
306
307 return Ok(());
308 }
309
310 if parent_states.is_empty() {
311 self.save_statediff(shortstatehash, &StateDiff {
313 parent: None,
314 added: statediffnew,
315 removed: statediffremoved,
316 });
317
318 return Ok(());
319 }
320
321 let parent = parent_states
326 .pop()
327 .expect("parent must have a state");
328 let parent_added_len = parent.added.len();
329 let parent_removed_len = parent.removed.len();
330 let parent_diff = checked!(parent_added_len + parent_removed_len)?;
331
332 if checked!(diffsum * diffsum)? >= checked!(2 * diff_to_sibling * parent_diff)? {
333 let mut parent_new = (*parent.added).clone();
335 let mut parent_removed = (*parent.removed).clone();
336
337 for removed in statediffremoved.iter() {
338 if !parent_new.remove(removed) {
339 parent_removed.insert(*removed);
341 }
342 }
345
346 for new in statediffnew.iter() {
347 if !parent_removed.remove(new) {
348 parent_new.insert(*new);
350 }
351 }
354
355 self.save_state_from_diff(
356 shortstatehash,
357 Arc::new(parent_new),
358 Arc::new(parent_removed),
359 diffsum,
360 parent_states,
361 )?;
362 } else {
363 self.save_statediff(shortstatehash, &StateDiff {
365 parent: Some(parent.shortstatehash),
366 added: statediffnew,
367 removed: statediffremoved,
368 });
369 }
370
371 Ok(())
372}
373
374#[implement(Service)]
377#[tracing::instrument(skip(self, new_state_ids_compressed), level = "debug")]
378pub async fn save_state(
379 &self,
380 room_id: &RoomId,
381 new_state_ids_compressed: Arc<CompressedState>,
382) -> Result<HashSetCompressStateEvent> {
383 let previous_shortstatehash = self
384 .services
385 .state
386 .get_room_shortstatehash(room_id)
387 .await
388 .ok();
389
390 let state_hash = utils::calculate_hash(
391 new_state_ids_compressed
392 .iter()
393 .map(|bytes| &bytes[..]),
394 );
395
396 let (new_shortstatehash, already_existed) = self
397 .services
398 .short
399 .get_or_create_shortstatehash(&state_hash)
400 .await;
401
402 if Some(new_shortstatehash) == previous_shortstatehash {
403 return Ok(HashSetCompressStateEvent {
404 shortstatehash: new_shortstatehash,
405 ..Default::default()
406 });
407 }
408
409 let states_parents = if let Some(p) = previous_shortstatehash {
410 self.load_shortstatehash_info(p)
411 .await
412 .unwrap_or_default()
413 } else {
414 ShortStateInfoVec::new()
415 };
416
417 let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() {
418 let statediffnew: CompressedState = new_state_ids_compressed
419 .difference(&parent_stateinfo.full_state)
420 .copied()
421 .collect();
422
423 let statediffremoved: CompressedState = parent_stateinfo
424 .full_state
425 .difference(&new_state_ids_compressed)
426 .copied()
427 .collect();
428
429 (Arc::new(statediffnew), Arc::new(statediffremoved))
430 } else {
431 (new_state_ids_compressed, Arc::new(CompressedState::new()))
432 };
433
434 if !already_existed {
435 self.save_state_from_diff(
436 new_shortstatehash,
437 statediffnew.clone(),
438 statediffremoved.clone(),
439 2, states_parents,
441 )?;
442 }
443
444 Ok(HashSetCompressStateEvent {
445 shortstatehash: new_shortstatehash,
446 added: statediffnew,
447 removed: statediffremoved,
448 })
449}
450
451#[implement(Service)]
452#[tracing::instrument(skip(self), level = "debug", name = "get")]
453async fn get_statediff(&self, shortstatehash: ShortStateHash) -> Result<StateDiff> {
454 const BUFSIZE: usize = size_of::<ShortStateHash>();
455 const STRIDE: usize = size_of::<ShortStateHash>();
456
457 let value = self
458 .db
459 .shortstatehash_statediff
460 .aqry::<BUFSIZE, _>(&shortstatehash)
461 .await
462 .map_err(|e| {
463 err!(Database("Failed to find StateDiff from short {shortstatehash:?}: {e}"))
464 })?;
465
466 let parent = utils::u64_from_bytes(&value[0..size_of::<u64>()])
467 .ok()
468 .take_if(|parent| *parent != 0);
469
470 debug_assert!(value.len().is_multiple_of(STRIDE), "value not aligned to stride");
471 let _num_values = value.len() / STRIDE;
472
473 let mut add_mode = true;
474 let mut added = CompressedState::new();
475 let mut removed = CompressedState::new();
476
477 let mut i = STRIDE;
478 while let Some(v) = value.get(i..expected!(i + 2 * STRIDE)) {
479 if add_mode && v.starts_with(&0_u64.to_be_bytes()) {
480 add_mode = false;
481 i = expected!(i + STRIDE);
482 continue;
483 }
484 if add_mode {
485 added.insert(v.try_into()?);
486 } else {
487 removed.insert(v.try_into()?);
488 }
489 i = expected!(i + 2 * STRIDE);
490 }
491
492 Ok(StateDiff {
493 parent,
494 added: Arc::new(added),
495 removed: Arc::new(removed),
496 })
497}
498
499#[implement(Service)]
500fn save_statediff(&self, shortstatehash: ShortStateHash, diff: &StateDiff) {
501 let mut value = Vec::<u8>::with_capacity(
502 2_usize
503 .saturating_add(diff.added.len())
504 .saturating_add(diff.removed.len()),
505 );
506
507 let parent = diff.parent.unwrap_or(0_u64);
508 value.extend_from_slice(&parent.to_be_bytes());
509
510 for new in diff.added.iter() {
511 value.extend_from_slice(&new[..]);
512 }
513
514 if !diff.removed.is_empty() {
515 value.extend_from_slice(&0_u64.to_be_bytes());
516 for removed in diff.removed.iter() {
517 value.extend_from_slice(&removed[..]);
518 }
519 }
520
521 self.db
522 .shortstatehash_statediff
523 .insert(&shortstatehash.to_be_bytes(), &value);
524}
525
526#[inline]
527#[must_use]
528pub(crate) fn compress_state_event(
529 shortstatekey: ShortStateKey,
530 shorteventid: ShortEventId,
531) -> CompressedStateEvent {
532 const SIZE: usize = size_of::<CompressedStateEvent>();
533
534 let mut v = ArrayVec::<u8, SIZE>::new();
535 v.extend(shortstatekey.to_be_bytes());
536 v.extend(shorteventid.to_be_bytes());
537 v.as_ref()
538 .try_into()
539 .expect("failed to create CompressedStateEvent")
540}
541
542#[inline]
543#[must_use]
544pub(crate) fn parse_compressed_state_event(
545 compressed_event: CompressedStateEvent,
546) -> (ShortStateKey, ShortEventId) {
547 use utils::u64_from_u8;
548
549 let shortstatekey = u64_from_u8(&compressed_event[0..size_of::<ShortStateKey>()]);
550 let shorteventid = u64_from_u8(&compressed_event[size_of::<ShortStateKey>()..]);
551
552 (shortstatekey, shorteventid)
553}
554
555#[inline]
556fn compressed_state_size(compressed_state: &CompressedState) -> usize {
557 compressed_state
558 .len()
559 .checked_mul(size_of::<CompressedStateEvent>())
560 .expect("CompressedState size overflow")
561}