1use std::{collections::HashMap, fmt::Write, iter::once, sync::Arc};
2
3use async_trait::async_trait;
4use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, future::join_all};
5use ruma::{
6 EventId, OwnedEventId, OwnedRoomId, RoomId, RoomVersionId, UserId,
7 events::{AnyStrippedStateEvent, StateEventType, TimelineEventType},
8 room_version_rules::AuthorizationRules,
9 serde::Raw,
10};
11use tuwunel_core::{
12 Event, PduEvent, Result, err,
13 error::inspect_debug_log,
14 implement,
15 matrix::{PduCount, RoomVersionRules, StateKey, TypeStateKey, room_version},
16 result::{AndThenRef, FlatOk},
17 trace,
18 utils::{
19 IterStream, MutexMap, MutexMapGuard, ReadyExt, calculate_hash,
20 mutex_map::Guard,
21 stream::{BroadbandExt, TryIgnore, WidebandExt},
22 },
23 warn,
24};
25use tuwunel_database::{Deserialized, Ignore, Interfix, Map};
26
27use crate::{
28 rooms::{
29 short::{ShortEventId, ShortStateHash, ShortStateKey},
30 state_compressor::{CompressedState, parse_compressed_state_event},
31 state_res::{StateMap, auth_types_for_event},
32 },
33 services::OnceServices,
34};
35
36pub struct Service {
37 pub mutex: RoomMutexMap,
38 services: Arc<OnceServices>,
39 db: Data,
40}
41
42struct Data {
43 shorteventid_shortstatehash: Arc<Map>,
44 roomid_shortstatehash: Arc<Map>,
45 roomid_pduleaves: Arc<Map>,
46}
47
48type RoomMutexMap = MutexMap<OwnedRoomId, ()>;
49pub type RoomMutexGuard = MutexMapGuard<OwnedRoomId, ()>;
50
51#[async_trait]
52impl crate::Service for Service {
53 fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
54 Ok(Arc::new(Self {
55 mutex: RoomMutexMap::new(),
56 services: args.services.clone(),
57 db: Data {
58 shorteventid_shortstatehash: args.db["shorteventid_shortstatehash"].clone(),
59 roomid_shortstatehash: args.db["roomid_shortstatehash"].clone(),
60 roomid_pduleaves: args.db["roomid_pduleaves"].clone(),
61 },
62 }))
63 }
64
65 async fn memory_usage(&self, out: &mut (dyn Write + Send)) -> Result {
66 let mutex = self.mutex.len();
67 writeln!(out, "state_mutex: {mutex}")?;
68
69 Ok(())
70 }
71
72 fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
73}
74
75#[implement(Service)]
77#[tracing::instrument(
78 name = "force",
79 level = "debug",
80 skip_all,
81 fields(
82 count = ?self.services.globals.pending_count(),
83 %shortstatehash,
84 )
85)]
86pub async fn force_state(
87 &self,
88 room_id: &RoomId,
89 shortstatehash: u64,
90 statediffnew: Arc<CompressedState>,
91 _statediffremoved: Arc<CompressedState>,
92 state_lock: &RoomMutexGuard,
93) -> Result {
94 statediffnew
95 .iter()
96 .stream()
97 .map(|&new| parse_compressed_state_event(new).1)
98 .wide_filter_map(async |shorteventid| {
99 let event_id: OwnedEventId = self
100 .services
101 .short
102 .get_eventid_from_short(shorteventid)
103 .inspect_err(inspect_debug_log)
104 .await
105 .ok()?;
106
107 self.services
108 .timeline
109 .get_pdu(&event_id)
110 .await
111 .ok()
112 })
113 .map(Ok)
114 .try_for_each(async |pdu| match pdu.kind {
115 | TimelineEventType::RoomMember => {
116 let Some(user_id) = pdu
117 .state_key
118 .as_ref()
119 .map(UserId::parse)
120 .flat_ok()
121 else {
122 return Ok(());
123 };
124
125 let Ok(membership_event) = pdu.get_content() else {
126 return Ok(());
127 };
128
129 let count = self.services.globals.next_count();
130 self.services
131 .state_cache
132 .update_membership(
133 room_id,
134 &user_id,
135 membership_event,
136 &pdu.sender,
137 None,
138 None,
139 false,
140 PduCount::Normal(*count),
141 )
142 .await
143 },
144 | TimelineEventType::SpaceChild => {
145 self.services.spaces.cache_evict(pdu.room_id());
146
147 Ok(())
148 },
149 | _ => Ok(()),
150 })
151 .boxed()
152 .await?;
153
154 self.services
155 .state_cache
156 .update_joined_count(room_id)
157 .await;
158
159 self.set_room_state(room_id, shortstatehash, state_lock);
160
161 Ok(())
162}
163
164#[implement(Service)]
169#[tracing::instrument(
170 name = "set",
171 level = "debug",
172 skip(self, state_ids_compressed),
173 fields(
174 count = ?self.services.globals.pending_count(),
175 )
176)]
177pub async fn set_event_state(
178 &self,
179 event_id: &EventId,
180 room_id: &RoomId,
181 state_ids_compressed: Arc<CompressedState>,
182) -> Result<ShortStateHash> {
183 const KEY_LEN: usize = size_of::<ShortEventId>();
184 const VAL_LEN: usize = size_of::<ShortStateHash>();
185
186 let shorteventid = self
187 .services
188 .short
189 .get_or_create_shorteventid(event_id)
190 .await;
191
192 let previous_shortstatehash = self.get_room_shortstatehash(room_id).await;
193
194 let state_hash = calculate_hash(state_ids_compressed.iter().map(|s| &s[..]));
195
196 let (shortstatehash, already_existed) = self
197 .services
198 .short
199 .get_or_create_shortstatehash(&state_hash)
200 .await;
201
202 if !already_existed {
203 let states_parents = match previous_shortstatehash {
204 | Ok(p) =>
205 self.services
206 .state_compressor
207 .load_shortstatehash_info(p)
208 .await?,
209 | _ => Vec::new(),
210 };
211
212 let (statediffnew, statediffremoved) =
213 if let Some(parent_stateinfo) = states_parents.last() {
214 let statediffnew: CompressedState = state_ids_compressed
215 .difference(&parent_stateinfo.full_state)
216 .copied()
217 .collect();
218
219 let statediffremoved: CompressedState = parent_stateinfo
220 .full_state
221 .difference(&state_ids_compressed)
222 .copied()
223 .collect();
224
225 (Arc::new(statediffnew), Arc::new(statediffremoved))
226 } else {
227 (state_ids_compressed, Arc::new(CompressedState::new()))
228 };
229
230 self.services
231 .state_compressor
232 .save_state_from_diff(
233 shortstatehash,
234 statediffnew,
235 statediffremoved,
236 1_000_000, states_parents,
238 )?;
239 }
240
241 self.db
242 .shorteventid_shortstatehash
243 .aput::<KEY_LEN, VAL_LEN, _, _>(shorteventid, shortstatehash);
244
245 Ok(shortstatehash)
246}
247
248#[implement(Service)]
253#[tracing::instrument(
254 name = "set",
255 level = "debug",
256 skip(self, new_pdu),
257 fields(
258 count = ?self.services.globals.pending_count(),
259 )
260)]
261pub async fn append_to_state(&self, new_pdu: &PduEvent) -> Result<u64> {
262 const KEY_LEN: usize = size_of::<ShortEventId>();
263 const VAL_LEN: usize = size_of::<ShortStateHash>();
264
265 let shorteventid = self
266 .services
267 .short
268 .get_or_create_shorteventid(&new_pdu.event_id)
269 .await;
270
271 let previous_shortstatehash = self
272 .get_room_shortstatehash(&new_pdu.room_id)
273 .await;
274
275 if let Ok(p) = previous_shortstatehash {
276 self.db
277 .shorteventid_shortstatehash
278 .aput::<KEY_LEN, VAL_LEN, _, _>(shorteventid, p);
279 }
280
281 match &new_pdu.state_key {
282 | Some(state_key) => {
283 let states_parents = match previous_shortstatehash {
284 | Ok(p) =>
285 self.services
286 .state_compressor
287 .load_shortstatehash_info(p)
288 .await?,
289 | _ => Vec::new(),
290 };
291
292 let shortstatekey = self
293 .services
294 .short
295 .get_or_create_shortstatekey(&new_pdu.kind.to_string().into(), state_key)
296 .await;
297
298 let new = self
299 .services
300 .state_compressor
301 .compress_state_event(shortstatekey, &new_pdu.event_id)
302 .await;
303
304 let replaces = states_parents
305 .last()
306 .map(|info| {
307 info.full_state
308 .iter()
309 .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes()))
310 })
311 .unwrap_or_default();
312
313 if Some(&new) == replaces {
314 return Ok(previous_shortstatehash.expect("must exist"));
315 }
316
317 let shortstatehash = self.services.globals.next_count();
319
320 let mut statediffnew = CompressedState::new();
321 statediffnew.insert(new);
322
323 let mut statediffremoved = CompressedState::new();
324 if let Some(replaces) = replaces {
325 statediffremoved.insert(*replaces);
326 }
327
328 self.services
329 .state_compressor
330 .save_state_from_diff(
331 *shortstatehash,
332 Arc::new(statediffnew),
333 Arc::new(statediffremoved),
334 2,
335 states_parents,
336 )?;
337
338 Ok(*shortstatehash)
339 },
340 | _ => Ok(previous_shortstatehash.expect("first event in room must be a state event")),
341 }
342}
343
344#[implement(Service)]
346#[tracing::instrument(skip(self, _mutex_lock), level = "debug")]
347pub fn set_room_state(
348 &self,
349 room_id: &RoomId,
350 shortstatehash: u64,
351 _mutex_lock: &RoomMutexGuard,
353) {
354 const BUFSIZE: usize = size_of::<u64>();
355
356 self.db
357 .roomid_shortstatehash
358 .raw_aput::<BUFSIZE, _, _>(room_id, shortstatehash);
359}
360
361#[implement(Service)]
363#[expect(clippy::too_many_arguments)]
364#[tracing::instrument(skip(self, content), level = "debug")]
365pub async fn get_auth_events(
366 &self,
367 room_id: &RoomId,
368 kind: &TimelineEventType,
369 sender: &UserId,
370 state_key: Option<&str>,
371 content: &serde_json::value::RawValue,
372 auth_rules: &AuthorizationRules,
373 include_create: bool,
374) -> Result<StateMap<PduEvent>>
375where
376 StateEventType: Send + Sync,
377 StateKey: Send + Sync,
378{
379 let Ok(shortstatehash) = self.get_room_shortstatehash(room_id).await else {
380 return Ok(StateMap::new());
381 };
382
383 let sauthevents: HashMap<ShortStateKey, TypeStateKey> =
384 auth_types_for_event(kind, sender, state_key, content, auth_rules, include_create)?
385 .into_iter()
386 .stream()
387 .broad_filter_map(async |(event_type, state_key): TypeStateKey| {
388 self.services
389 .short
390 .get_shortstatekey(&event_type, &state_key)
391 .await
392 .map(move |sstatekey| (sstatekey, (event_type, state_key)))
393 .ok()
394 })
395 .collect()
396 .await;
397
398 self.services
399 .state_accessor
400 .state_full_shortids(shortstatehash)
401 .ready_filter_map(Result::ok)
402 .ready_filter_map(|(shortstatekey, shorteventid)| {
403 sauthevents
404 .get(&shortstatekey)
405 .map(move |(ty, sk)| ((ty, sk), shorteventid))
406 })
407 .unzip()
408 .map(|(state_keys, event_ids): (Vec<_>, Vec<_>)| {
409 self.services
410 .short
411 .multi_get_eventid_from_short(event_ids.into_iter().stream())
412 .zip(state_keys.into_iter().stream())
413 })
414 .flatten_stream()
415 .ready_filter_map(|(event_id, (ty, sk))| Some(((ty, sk), event_id.ok()?)))
416 .broad_filter_map(async |((ty, sk), event_id): ((&_, &_), OwnedEventId)| {
417 let pdu = self.services.timeline.get_pdu(&event_id).await;
418
419 Some(((ty.clone(), sk.clone()), pdu.ok()?))
420 })
421 .collect()
422 .map(Ok)
423 .await
424}
425
426#[implement(Service)]
427#[tracing::instrument(skip_all, level = "debug")]
428pub async fn summary_stripped<Pdu: Event>(&self, event: &Pdu) -> Vec<Raw<AnyStrippedStateEvent>> {
429 let cells = [
430 (&StateEventType::RoomCreate, ""),
431 (&StateEventType::RoomJoinRules, ""),
432 (&StateEventType::RoomCanonicalAlias, ""),
433 (&StateEventType::RoomName, ""),
434 (&StateEventType::RoomAvatar, ""),
435 (&StateEventType::RoomMember, event.sender().as_str()), (&StateEventType::RoomEncryption, ""),
437 (&StateEventType::RoomTopic, ""),
438 ];
439
440 let fetches = cells.into_iter().map(|(event_type, state_key)| {
441 self.services
442 .state_accessor
443 .room_state_get(event.room_id(), event_type, state_key)
444 });
445
446 join_all(fetches)
447 .await
448 .into_iter()
449 .filter_map(Result::ok)
450 .map(Event::into_format)
451 .chain(once(event.to_format()))
452 .collect()
453}
454
455#[implement(Service)]
457#[inline]
458pub async fn get_room_version_rules(&self, room_id: &RoomId) -> Result<RoomVersionRules> {
459 self.get_room_version(room_id)
460 .await
461 .and_then_ref(room_version::rules)
462}
463
464#[implement(Service)]
466#[tracing::instrument(
467 level = "debug"
468 skip(self),
469 ret(level = "trace"),
470)]
471pub async fn get_room_version(&self, room_id: &RoomId) -> Result<RoomVersionId> {
472 self.services
473 .state_accessor
474 .room_state_get_content(room_id, &StateEventType::RoomCreate, "")
475 .await
476 .as_ref()
477 .map(room_version::from_create_content)
478 .cloned()
479 .map_err(|e| err!(Request(NotFound("No create event found: {e:?}"))))
480}
481
482#[implement(Service)]
483#[tracing::instrument(
484 level = "debug"
485 skip(self),
486 ret(level = "trace"),
487)]
488pub async fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<ShortStateHash> {
489 self.db
490 .roomid_shortstatehash
491 .get(room_id)
492 .await
493 .deserialized()
494}
495
496#[implement(Service)]
498pub async fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<ShortStateHash> {
499 self.services
500 .short
501 .get_shorteventid(event_id)
502 .and_then(|shorteventid| self.get_shortstatehash(shorteventid))
503 .await
504}
505
506#[implement(Service)]
508#[tracing::instrument(
509 level = "debug"
510 skip(self),
511 ret(level = "trace"),
512)]
513pub async fn get_shortstatehash(&self, shorteventid: ShortEventId) -> Result<ShortStateHash> {
514 const BUFSIZE: usize = size_of::<ShortEventId>();
515
516 self.db
517 .shorteventid_shortstatehash
518 .aqry::<BUFSIZE, _>(&shorteventid)
519 .await
520 .deserialized()
521}
522
523#[implement(Service)]
524pub(super) fn delete_room_shortstatehash(
525 &self,
526 room_id: &RoomId,
527 _mutex_lock: &Guard<OwnedRoomId, ()>,
528) -> Result {
529 self.db.roomid_shortstatehash.remove(room_id);
530
531 Ok(())
532}
533
534#[implement(Service)]
535#[tracing::instrument(
536 level = "trace"
537 skip(self),
538)]
539pub fn get_forward_extremities<'a>(
540 &'a self,
541 room_id: &'a RoomId,
542) -> impl Stream<Item = &EventId> + Send + '_ {
543 let prefix = (room_id, Interfix);
544
545 self.db
546 .roomid_pduleaves
547 .keys_prefix(&prefix)
548 .map_ok(|(_, event_id): (Ignore, &EventId)| event_id)
549 .ignore_err()
550}
551
552#[implement(Service)]
553#[tracing::instrument(
554 level = "debug"
555 skip_all,
556 fields(%room_id),
557)]
558pub async fn set_forward_extremities<'a, I>(
559 &'a self,
560 room_id: &'a RoomId,
561 event_ids: I,
562 _state_lock: &'a RoomMutexGuard,
563) where
564 I: Iterator<Item = &'a EventId> + Send + 'a,
565{
566 let prefix = (room_id, Interfix);
567 self.db
568 .roomid_pduleaves
569 .keys_prefix_raw(&prefix)
570 .ignore_err()
571 .ready_for_each(|key| self.db.roomid_pduleaves.remove(key))
572 .await;
573
574 for event_id in event_ids {
575 let key = (room_id, event_id);
576 self.db.roomid_pduleaves.put_raw(key, event_id);
577 }
578}
579
580#[implement(Service)]
581pub(super) async fn delete_all_rooms_forward_extremities(&self, room_id: &RoomId) -> Result {
582 let prefix = (room_id, Interfix);
583
584 self.db
585 .roomid_pduleaves
586 .keys_prefix_raw(&prefix)
587 .ignore_err()
588 .ready_for_each(|key| {
589 trace!("Removing key: {key:?}");
590 self.db.roomid_pduleaves.remove(key);
591 })
592 .await;
593
594 Ok(())
595}