1use std::{collections::HashMap, fmt::Write, iter::once, sync::Arc};
2
3use async_trait::async_trait;
4use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, future::join_all};
5use ruma::{
6 CanonicalJsonObject, EventId, OwnedEventId, OwnedRoomId, RoomId, RoomVersionId, UserId,
7 events::{AnyStrippedStateEvent, StateEventType, TimelineEventType},
8 room_version_rules::AuthorizationRules,
9 serde::Raw,
10};
11use serde_json::value::RawValue as RawJsonValue;
12use tuwunel_core::{
13 Event, PduEvent, Result, err,
14 error::inspect_debug_log,
15 implement,
16 matrix::{PduCount, RoomVersionRules, StateKey, TypeStateKey, room_version},
17 result::{AndThenRef, FlatOk},
18 trace,
19 utils::{
20 IterStream, MutexMap, MutexMapGuard, ReadyExt, calculate_hash,
21 mutex_map::Guard,
22 stream::{BroadbandExt, TryIgnore, WidebandExt},
23 },
24 warn,
25};
26use tuwunel_database::{Deserialized, Ignore, Interfix, Map};
27
28use crate::{
29 rooms::{
30 short::{ShortEventId, ShortStateHash, ShortStateKey},
31 state_compressor::{CompressedState, parse_compressed_state_event},
32 state_res::{StateMap, auth_types_for_event},
33 },
34 services::OnceServices,
35};
36
37pub struct Service {
38 pub mutex: RoomMutexMap,
39 services: Arc<OnceServices>,
40 db: Data,
41}
42
43struct Data {
44 shorteventid_shortstatehash: Arc<Map>,
45 roomid_shortstatehash: Arc<Map>,
46 roomid_pduleaves: Arc<Map>,
47}
48
49type RoomMutexMap = MutexMap<OwnedRoomId, ()>;
50pub type RoomMutexGuard = MutexMapGuard<OwnedRoomId, ()>;
51
52#[async_trait]
53impl crate::Service for Service {
54 fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
55 Ok(Arc::new(Self {
56 mutex: RoomMutexMap::new(),
57 services: args.services.clone(),
58 db: Data {
59 shorteventid_shortstatehash: args.db["shorteventid_shortstatehash"].clone(),
60 roomid_shortstatehash: args.db["roomid_shortstatehash"].clone(),
61 roomid_pduleaves: args.db["roomid_pduleaves"].clone(),
62 },
63 }))
64 }
65
66 async fn memory_usage(&self, out: &mut (dyn Write + Send)) -> Result {
67 let mutex = self.mutex.len();
68 writeln!(out, "- state_mutex: {mutex}")?;
69
70 Ok(())
71 }
72
73 fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
74}
75
76#[implement(Service)]
78#[tracing::instrument(
79 name = "force",
80 level = "debug",
81 skip_all,
82 fields(
83 count = ?self.services.globals.pending_count(),
84 %shortstatehash,
85 )
86)]
87pub async fn force_state(
88 &self,
89 room_id: &RoomId,
90 shortstatehash: u64,
91 statediffnew: Arc<CompressedState>,
92 _statediffremoved: Arc<CompressedState>,
93 state_lock: &RoomMutexGuard,
94) -> Result {
95 statediffnew
96 .iter()
97 .stream()
98 .map(|&new| parse_compressed_state_event(new).1)
99 .wide_filter_map(async |shorteventid| {
100 let event_id: OwnedEventId = self
101 .services
102 .short
103 .get_eventid_from_short(shorteventid)
104 .inspect_err(inspect_debug_log)
105 .await
106 .ok()?;
107
108 self.services
109 .timeline
110 .get_pdu(&event_id)
111 .await
112 .ok()
113 })
114 .map(Ok)
115 .try_for_each(async |pdu| match pdu.kind {
116 | TimelineEventType::RoomMember => {
117 let Some(user_id) = pdu
118 .state_key
119 .as_ref()
120 .map(UserId::parse)
121 .flat_ok()
122 else {
123 return Ok(());
124 };
125
126 let Ok(membership_event) = pdu.get_content() else {
127 return Ok(());
128 };
129
130 let count = self.services.globals.next_count();
131 self.services
132 .state_cache
133 .update_membership(
134 room_id,
135 &user_id,
136 membership_event,
137 &pdu.sender,
138 None,
139 None,
140 false,
141 PduCount::Normal(*count),
142 )
143 .await
144 },
145 | TimelineEventType::SpaceChild => {
146 self.services.spaces.cache_evict(pdu.room_id());
147
148 Ok(())
149 },
150 | _ => Ok(()),
151 })
152 .boxed()
153 .await?;
154
155 self.services
156 .state_cache
157 .update_joined_count(room_id)
158 .await;
159
160 self.set_room_state(room_id, shortstatehash, state_lock);
161
162 Ok(())
163}
164
165#[implement(Service)]
170#[tracing::instrument(
171 name = "set",
172 level = "debug",
173 skip(self, state_ids_compressed),
174 fields(
175 count = ?self.services.globals.pending_count(),
176 )
177)]
178pub async fn set_event_state(
179 &self,
180 event_id: &EventId,
181 room_id: &RoomId,
182 state_ids_compressed: Arc<CompressedState>,
183) -> Result<ShortStateHash> {
184 const KEY_LEN: usize = size_of::<ShortEventId>();
185 const VAL_LEN: usize = size_of::<ShortStateHash>();
186
187 let shorteventid = self
188 .services
189 .short
190 .get_or_create_shorteventid(event_id)
191 .await;
192
193 let previous_shortstatehash = self.get_room_shortstatehash(room_id).await;
194
195 let state_hash = calculate_hash(state_ids_compressed.iter().map(|s| &s[..]));
196
197 let (shortstatehash, already_existed) = self
198 .services
199 .short
200 .get_or_create_shortstatehash(&state_hash)
201 .await;
202
203 if !already_existed {
204 let states_parents = match previous_shortstatehash {
205 | Ok(p) =>
206 self.services
207 .state_compressor
208 .load_shortstatehash_info(p)
209 .await?,
210 | _ => Vec::new(),
211 };
212
213 let (statediffnew, statediffremoved) =
214 if let Some(parent_stateinfo) = states_parents.last() {
215 let statediffnew: CompressedState = state_ids_compressed
216 .difference(&parent_stateinfo.full_state)
217 .copied()
218 .collect();
219
220 let statediffremoved: CompressedState = parent_stateinfo
221 .full_state
222 .difference(&state_ids_compressed)
223 .copied()
224 .collect();
225
226 (Arc::new(statediffnew), Arc::new(statediffremoved))
227 } else {
228 (state_ids_compressed, Arc::new(CompressedState::new()))
229 };
230
231 self.services
232 .state_compressor
233 .save_state_from_diff(
234 shortstatehash,
235 statediffnew,
236 statediffremoved,
237 1_000_000, states_parents,
239 )?;
240 }
241
242 self.db
243 .shorteventid_shortstatehash
244 .aput::<KEY_LEN, VAL_LEN, _, _>(shorteventid, shortstatehash);
245
246 Ok(shortstatehash)
247}
248
249#[implement(Service)]
254#[tracing::instrument(
255 name = "set",
256 level = "debug",
257 skip(self, new_pdu),
258 fields(
259 count = ?self.services.globals.pending_count(),
260 )
261)]
262pub async fn append_to_state(&self, new_pdu: &PduEvent) -> Result<u64> {
263 const KEY_LEN: usize = size_of::<ShortEventId>();
264 const VAL_LEN: usize = size_of::<ShortStateHash>();
265
266 let shorteventid = self
267 .services
268 .short
269 .get_or_create_shorteventid(&new_pdu.event_id)
270 .await;
271
272 let previous_shortstatehash = self
273 .get_room_shortstatehash(&new_pdu.room_id)
274 .await;
275
276 if let Ok(p) = previous_shortstatehash {
277 self.db
278 .shorteventid_shortstatehash
279 .aput::<KEY_LEN, VAL_LEN, _, _>(shorteventid, p);
280 }
281
282 match &new_pdu.state_key {
283 | Some(state_key) => {
284 let states_parents = match previous_shortstatehash {
285 | Ok(p) =>
286 self.services
287 .state_compressor
288 .load_shortstatehash_info(p)
289 .await?,
290 | _ => Vec::new(),
291 };
292
293 let shortstatekey = self
294 .services
295 .short
296 .get_or_create_shortstatekey(&new_pdu.kind.to_string().into(), state_key)
297 .await;
298
299 let new = self
300 .services
301 .state_compressor
302 .compress_state_event(shortstatekey, &new_pdu.event_id)
303 .await;
304
305 let replaces = states_parents
306 .last()
307 .map(|info| {
308 info.full_state
309 .iter()
310 .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes()))
311 })
312 .unwrap_or_default();
313
314 if Some(&new) == replaces {
315 return Ok(previous_shortstatehash.expect("must exist"));
316 }
317
318 let shortstatehash = self.services.globals.next_count();
320
321 let mut statediffnew = CompressedState::new();
322 statediffnew.insert(new);
323
324 let mut statediffremoved = CompressedState::new();
325 if let Some(replaces) = replaces {
326 statediffremoved.insert(*replaces);
327 }
328
329 self.services
330 .state_compressor
331 .save_state_from_diff(
332 *shortstatehash,
333 Arc::new(statediffnew),
334 Arc::new(statediffremoved),
335 2,
336 states_parents,
337 )?;
338
339 Ok(*shortstatehash)
340 },
341 | _ => Ok(previous_shortstatehash.expect("first event in room must be a state event")),
342 }
343}
344
345#[implement(Service)]
347#[tracing::instrument(skip(self, _mutex_lock), level = "debug")]
348pub fn set_room_state(
349 &self,
350 room_id: &RoomId,
351 shortstatehash: u64,
352 _mutex_lock: &RoomMutexGuard,
354) {
355 const BUFSIZE: usize = size_of::<u64>();
356
357 self.db
358 .roomid_shortstatehash
359 .raw_aput::<BUFSIZE, _, _>(room_id, shortstatehash);
360}
361
362#[implement(Service)]
364#[expect(clippy::too_many_arguments)]
365#[tracing::instrument(skip(self, content), level = "debug")]
366pub async fn get_auth_events(
367 &self,
368 room_id: &RoomId,
369 kind: &TimelineEventType,
370 sender: &UserId,
371 state_key: Option<&str>,
372 content: &serde_json::value::RawValue,
373 auth_rules: &AuthorizationRules,
374 include_create: bool,
375) -> Result<StateMap<PduEvent>>
376where
377 StateEventType: Send + Sync,
378 StateKey: Send + Sync,
379{
380 let Ok(shortstatehash) = self.get_room_shortstatehash(room_id).await else {
381 return Ok(StateMap::new());
382 };
383
384 let sauthevents: HashMap<ShortStateKey, TypeStateKey> =
385 auth_types_for_event(kind, sender, state_key, content, auth_rules, include_create)?
386 .into_iter()
387 .stream()
388 .broad_filter_map(async |(event_type, state_key): TypeStateKey| {
389 self.services
390 .short
391 .get_shortstatekey(&event_type, &state_key)
392 .await
393 .map(move |sstatekey| (sstatekey, (event_type, state_key)))
394 .ok()
395 })
396 .collect()
397 .await;
398
399 let (state_keys, event_ids): (Vec<_>, Vec<_>) = self
400 .services
401 .state_accessor
402 .state_full_shortids(shortstatehash)
403 .ready_filter_map(Result::ok)
404 .ready_filter_map(|(shortstatekey, shorteventid)| {
405 sauthevents
406 .get(&shortstatekey)
407 .map(move |(ty, sk)| ((ty, sk), shorteventid))
408 })
409 .unzip()
410 .await;
411
412 self.services
413 .short
414 .multi_get_eventid_from_short(event_ids.into_iter().stream())
415 .zip(state_keys.into_iter().stream())
416 .ready_filter_map(|(event_id, (ty, sk))| Some(((ty, sk), event_id.ok()?)))
417 .broad_filter_map(async |((ty, sk), event_id): ((&_, &_), OwnedEventId)| {
418 let pdu = self.services.timeline.get_pdu(&event_id).await;
419
420 Some(((ty.clone(), sk.clone()), pdu.ok()?))
421 })
422 .collect()
423 .map(Ok)
424 .await
425}
426
427#[implement(Service)]
428#[tracing::instrument(skip_all, level = "debug")]
429pub async fn summary_stripped<Pdu: Event>(&self, event: &Pdu) -> Vec<Raw<AnyStrippedStateEvent>> {
430 let cells = [
431 (&StateEventType::RoomCreate, ""),
432 (&StateEventType::RoomJoinRules, ""),
433 (&StateEventType::RoomCanonicalAlias, ""),
434 (&StateEventType::RoomName, ""),
435 (&StateEventType::RoomAvatar, ""),
436 (&StateEventType::RoomMember, event.sender().as_str()), (&StateEventType::RoomEncryption, ""),
438 (&StateEventType::RoomTopic, ""),
439 ];
440
441 let fetches = cells.into_iter().map(|(event_type, state_key)| {
442 self.services
443 .state_accessor
444 .room_state_get(event.room_id(), event_type, state_key)
445 });
446
447 join_all(fetches)
448 .await
449 .into_iter()
450 .filter_map(Result::ok)
451 .map(Event::into_format)
452 .chain(once(event.to_format()))
453 .collect()
454}
455
456#[implement(Service)]
461#[tracing::instrument(skip_all, level = "debug")]
462pub async fn summary_pdus<Pdu: Event>(
463 &self,
464 event: &Pdu,
465 event_json: &CanonicalJsonObject,
466 room_version: &RoomVersionId,
467) -> Vec<Box<RawJsonValue>> {
468 let cells = [
469 (&StateEventType::RoomCreate, ""),
470 (&StateEventType::RoomJoinRules, ""),
471 (&StateEventType::RoomCanonicalAlias, ""),
472 (&StateEventType::RoomName, ""),
473 (&StateEventType::RoomAvatar, ""),
474 (&StateEventType::RoomMember, event.sender().as_str()),
475 (&StateEventType::RoomEncryption, ""),
476 (&StateEventType::RoomTopic, ""),
477 ];
478
479 let membership = self
480 .services
481 .federation
482 .format_pdu_into(event_json.clone(), Some(room_version))
483 .await;
484
485 cells
486 .into_iter()
487 .stream()
488 .wide_filter_map(async |(event_type, state_key)| {
489 let pdu = self
490 .services
491 .state_accessor
492 .room_state_get(event.room_id(), event_type, state_key)
493 .await
494 .ok()?;
495
496 let pdu_json = self
497 .services
498 .timeline
499 .get_pdu_json(pdu.event_id())
500 .await
501 .ok()?;
502
503 Some(
504 self.services
505 .federation
506 .format_pdu_into(pdu_json, Some(room_version))
507 .await,
508 )
509 })
510 .chain(once(membership).stream())
511 .collect()
512 .await
513}
514
515#[implement(Service)]
517#[inline]
518pub async fn get_room_version_rules(&self, room_id: &RoomId) -> Result<RoomVersionRules> {
519 self.get_room_version(room_id)
520 .await
521 .and_then_ref(room_version::rules)
522}
523
524#[implement(Service)]
526#[tracing::instrument(
527 level = "debug"
528 skip(self),
529 ret(level = "trace"),
530)]
531pub async fn get_room_version(&self, room_id: &RoomId) -> Result<RoomVersionId> {
532 self.services
533 .state_accessor
534 .room_state_get_content(room_id, &StateEventType::RoomCreate, "")
535 .await
536 .as_ref()
537 .map(room_version::from_create_content)
538 .cloned()
539 .map_err(|e| err!(Request(NotFound("No create event found: {e:?}"))))
540}
541
542#[implement(Service)]
543#[tracing::instrument(
544 level = "debug"
545 skip(self),
546 ret(level = "trace"),
547)]
548pub async fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<ShortStateHash> {
549 self.db
550 .roomid_shortstatehash
551 .get(room_id)
552 .await
553 .deserialized()
554}
555
556#[implement(Service)]
558pub async fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<ShortStateHash> {
559 self.services
560 .short
561 .get_shorteventid(event_id)
562 .and_then(|shorteventid| self.get_shortstatehash(shorteventid))
563 .await
564}
565
566#[implement(Service)]
568#[tracing::instrument(
569 level = "debug"
570 skip(self),
571 ret(level = "trace"),
572)]
573pub async fn get_shortstatehash(&self, shorteventid: ShortEventId) -> Result<ShortStateHash> {
574 const BUFSIZE: usize = size_of::<ShortEventId>();
575
576 self.db
577 .shorteventid_shortstatehash
578 .aqry::<BUFSIZE, _>(&shorteventid)
579 .await
580 .deserialized()
581}
582
583#[implement(Service)]
584pub(super) fn delete_room_shortstatehash(
585 &self,
586 room_id: &RoomId,
587 _mutex_lock: &Guard<OwnedRoomId, ()>,
588) -> Result {
589 self.db.roomid_shortstatehash.remove(room_id);
590
591 Ok(())
592}
593
594#[implement(Service)]
595#[tracing::instrument(
596 level = "trace"
597 skip(self),
598)]
599pub fn get_forward_extremities<'a>(
600 &'a self,
601 room_id: &'a RoomId,
602) -> impl Stream<Item = &EventId> + Send + '_ {
603 let prefix = (room_id, Interfix);
604
605 self.db
606 .roomid_pduleaves
607 .keys_prefix(&prefix)
608 .map_ok(|(_, event_id): (Ignore, &EventId)| event_id)
609 .ignore_err()
610}
611
612#[implement(Service)]
613#[tracing::instrument(
614 level = "debug"
615 skip_all,
616 fields(%room_id),
617)]
618pub async fn set_forward_extremities<'a, I>(
619 &'a self,
620 room_id: &'a RoomId,
621 event_ids: I,
622 _state_lock: &'a RoomMutexGuard,
623) where
624 I: Iterator<Item = &'a EventId> + Send + 'a,
625{
626 let prefix = (room_id, Interfix);
627 self.db
628 .roomid_pduleaves
629 .keys_prefix_raw(&prefix)
630 .ignore_err()
631 .ready_for_each(|key| self.db.roomid_pduleaves.remove(key))
632 .await;
633
634 for event_id in event_ids {
635 let key = (room_id, event_id);
636 self.db.roomid_pduleaves.put_raw(key, event_id);
637 }
638}
639
640#[implement(Service)]
641pub(super) async fn delete_all_rooms_forward_extremities(&self, room_id: &RoomId) -> Result {
642 let prefix = (room_id, Interfix);
643
644 self.db
645 .roomid_pduleaves
646 .keys_prefix_raw(&prefix)
647 .ignore_err()
648 .ready_for_each(|key| {
649 trace!("Removing key: {key:?}");
650 self.db.roomid_pduleaves.remove(key);
651 })
652 .await;
653
654 Ok(())
655}