Skip to main content

tuwunel_service/rooms/state/
mod.rs

1use std::{collections::HashMap, fmt::Write, iter::once, sync::Arc};
2
3use async_trait::async_trait;
4use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, future::join_all};
5use ruma::{
6	CanonicalJsonObject, EventId, OwnedEventId, OwnedRoomId, RoomId, RoomVersionId, UserId,
7	events::{AnyStrippedStateEvent, StateEventType, TimelineEventType},
8	room_version_rules::AuthorizationRules,
9	serde::Raw,
10};
11use serde_json::value::RawValue as RawJsonValue;
12use tuwunel_core::{
13	Event, PduEvent, Result, err,
14	error::inspect_debug_log,
15	implement,
16	matrix::{PduCount, RoomVersionRules, StateKey, TypeStateKey, room_version},
17	result::{AndThenRef, FlatOk},
18	trace,
19	utils::{
20		IterStream, MutexMap, MutexMapGuard, ReadyExt, calculate_hash,
21		mutex_map::Guard,
22		stream::{BroadbandExt, TryIgnore, WidebandExt},
23	},
24	warn,
25};
26use tuwunel_database::{Deserialized, Ignore, Interfix, Map};
27
28use crate::{
29	rooms::{
30		short::{ShortEventId, ShortStateHash, ShortStateKey},
31		state_compressor::{CompressedState, parse_compressed_state_event},
32		state_res::{StateMap, auth_types_for_event},
33	},
34	services::OnceServices,
35};
36
37pub struct Service {
38	pub mutex: RoomMutexMap,
39	services: Arc<OnceServices>,
40	db: Data,
41}
42
43struct Data {
44	shorteventid_shortstatehash: Arc<Map>,
45	roomid_shortstatehash: Arc<Map>,
46	roomid_pduleaves: Arc<Map>,
47}
48
49type RoomMutexMap = MutexMap<OwnedRoomId, ()>;
50pub type RoomMutexGuard = MutexMapGuard<OwnedRoomId, ()>;
51
52#[async_trait]
53impl crate::Service for Service {
54	fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
55		Ok(Arc::new(Self {
56			mutex: RoomMutexMap::new(),
57			services: args.services.clone(),
58			db: Data {
59				shorteventid_shortstatehash: args.db["shorteventid_shortstatehash"].clone(),
60				roomid_shortstatehash: args.db["roomid_shortstatehash"].clone(),
61				roomid_pduleaves: args.db["roomid_pduleaves"].clone(),
62			},
63		}))
64	}
65
66	async fn memory_usage(&self, out: &mut (dyn Write + Send)) -> Result {
67		let mutex = self.mutex.len();
68		writeln!(out, "- state_mutex: {mutex}")?;
69
70		Ok(())
71	}
72
73	fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
74}
75
76/// Set the room to the given statehash and update caches.
77#[implement(Service)]
78#[tracing::instrument(
79	name = "force",
80	level = "debug",
81	skip_all,
82	fields(
83		count = ?self.services.globals.pending_count(),
84		%shortstatehash,
85	)
86)]
87pub async fn force_state(
88	&self,
89	room_id: &RoomId,
90	shortstatehash: u64,
91	statediffnew: Arc<CompressedState>,
92	_statediffremoved: Arc<CompressedState>,
93	state_lock: &RoomMutexGuard,
94) -> Result {
95	statediffnew
96		.iter()
97		.stream()
98		.map(|&new| parse_compressed_state_event(new).1)
99		.wide_filter_map(async |shorteventid| {
100			let event_id: OwnedEventId = self
101				.services
102				.short
103				.get_eventid_from_short(shorteventid)
104				.inspect_err(inspect_debug_log)
105				.await
106				.ok()?;
107
108			self.services
109				.timeline
110				.get_pdu(&event_id)
111				.await
112				.ok()
113		})
114		.map(Ok)
115		.try_for_each(async |pdu| match pdu.kind {
116			| TimelineEventType::RoomMember => {
117				let Some(user_id) = pdu
118					.state_key
119					.as_ref()
120					.map(UserId::parse)
121					.flat_ok()
122				else {
123					return Ok(());
124				};
125
126				let Ok(membership_event) = pdu.get_content() else {
127					return Ok(());
128				};
129
130				let count = self.services.globals.next_count();
131				self.services
132					.state_cache
133					.update_membership(
134						room_id,
135						&user_id,
136						membership_event,
137						&pdu.sender,
138						None,
139						None,
140						false,
141						PduCount::Normal(*count),
142					)
143					.await
144			},
145			| TimelineEventType::SpaceChild => {
146				self.services.spaces.cache_evict(pdu.room_id());
147
148				Ok(())
149			},
150			| _ => Ok(()),
151		})
152		.boxed()
153		.await?;
154
155	self.services
156		.state_cache
157		.update_joined_count(room_id)
158		.await;
159
160	self.set_room_state(room_id, shortstatehash, state_lock);
161
162	Ok(())
163}
164
165/// Generates a new StateHash and associates it with the incoming event.
166///
167/// This adds all current state events (not including the incoming event)
168/// to `stateid_pduid` and adds the incoming event to `eventid_statehash`.
169#[implement(Service)]
170#[tracing::instrument(
171	name = "set",
172	level = "debug",
173	skip(self, state_ids_compressed),
174	fields(
175		count = ?self.services.globals.pending_count(),
176	)
177)]
178pub async fn set_event_state(
179	&self,
180	event_id: &EventId,
181	room_id: &RoomId,
182	state_ids_compressed: Arc<CompressedState>,
183) -> Result<ShortStateHash> {
184	const KEY_LEN: usize = size_of::<ShortEventId>();
185	const VAL_LEN: usize = size_of::<ShortStateHash>();
186
187	let shorteventid = self
188		.services
189		.short
190		.get_or_create_shorteventid(event_id)
191		.await;
192
193	let previous_shortstatehash = self.get_room_shortstatehash(room_id).await;
194
195	let state_hash = calculate_hash(state_ids_compressed.iter().map(|s| &s[..]));
196
197	let (shortstatehash, already_existed) = self
198		.services
199		.short
200		.get_or_create_shortstatehash(&state_hash)
201		.await;
202
203	if !already_existed {
204		let states_parents = match previous_shortstatehash {
205			| Ok(p) =>
206				self.services
207					.state_compressor
208					.load_shortstatehash_info(p)
209					.await?,
210			| _ => Vec::new(),
211		};
212
213		let (statediffnew, statediffremoved) =
214			if let Some(parent_stateinfo) = states_parents.last() {
215				let statediffnew: CompressedState = state_ids_compressed
216					.difference(&parent_stateinfo.full_state)
217					.copied()
218					.collect();
219
220				let statediffremoved: CompressedState = parent_stateinfo
221					.full_state
222					.difference(&state_ids_compressed)
223					.copied()
224					.collect();
225
226				(Arc::new(statediffnew), Arc::new(statediffremoved))
227			} else {
228				(state_ids_compressed, Arc::new(CompressedState::new()))
229			};
230
231		self.services
232			.state_compressor
233			.save_state_from_diff(
234				shortstatehash,
235				statediffnew,
236				statediffremoved,
237				1_000_000, // high number because no state will be based on this one
238				states_parents,
239			)?;
240	}
241
242	self.db
243		.shorteventid_shortstatehash
244		.aput::<KEY_LEN, VAL_LEN, _, _>(shorteventid, shortstatehash);
245
246	Ok(shortstatehash)
247}
248
249/// Generates a new StateHash and associates it with the incoming event.
250///
251/// This adds all current state events (not including the incoming event)
252/// to `stateid_pduid` and adds the incoming event to `eventid_statehash`.
253#[implement(Service)]
254#[tracing::instrument(
255	name = "set",
256	level = "debug",
257	skip(self, new_pdu),
258	fields(
259		count = ?self.services.globals.pending_count(),
260	)
261)]
262pub async fn append_to_state(&self, new_pdu: &PduEvent) -> Result<u64> {
263	const KEY_LEN: usize = size_of::<ShortEventId>();
264	const VAL_LEN: usize = size_of::<ShortStateHash>();
265
266	let shorteventid = self
267		.services
268		.short
269		.get_or_create_shorteventid(&new_pdu.event_id)
270		.await;
271
272	let previous_shortstatehash = self
273		.get_room_shortstatehash(&new_pdu.room_id)
274		.await;
275
276	if let Ok(p) = previous_shortstatehash {
277		self.db
278			.shorteventid_shortstatehash
279			.aput::<KEY_LEN, VAL_LEN, _, _>(shorteventid, p);
280	}
281
282	match &new_pdu.state_key {
283		| Some(state_key) => {
284			let states_parents = match previous_shortstatehash {
285				| Ok(p) =>
286					self.services
287						.state_compressor
288						.load_shortstatehash_info(p)
289						.await?,
290				| _ => Vec::new(),
291			};
292
293			let shortstatekey = self
294				.services
295				.short
296				.get_or_create_shortstatekey(&new_pdu.kind.to_string().into(), state_key)
297				.await;
298
299			let new = self
300				.services
301				.state_compressor
302				.compress_state_event(shortstatekey, &new_pdu.event_id)
303				.await;
304
305			let replaces = states_parents
306				.last()
307				.map(|info| {
308					info.full_state
309						.iter()
310						.find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes()))
311				})
312				.unwrap_or_default();
313
314			if Some(&new) == replaces {
315				return Ok(previous_shortstatehash.expect("must exist"));
316			}
317
318			// TODO: statehash with deterministic inputs
319			let shortstatehash = self.services.globals.next_count();
320
321			let mut statediffnew = CompressedState::new();
322			statediffnew.insert(new);
323
324			let mut statediffremoved = CompressedState::new();
325			if let Some(replaces) = replaces {
326				statediffremoved.insert(*replaces);
327			}
328
329			self.services
330				.state_compressor
331				.save_state_from_diff(
332					*shortstatehash,
333					Arc::new(statediffnew),
334					Arc::new(statediffremoved),
335					2,
336					states_parents,
337				)?;
338
339			Ok(*shortstatehash)
340		},
341		| _ => Ok(previous_shortstatehash.expect("first event in room must be a state event")),
342	}
343}
344
345/// Set the state hash to a new version, but does not update state_cache.
346#[implement(Service)]
347#[tracing::instrument(skip(self, _mutex_lock), level = "debug")]
348pub fn set_room_state(
349	&self,
350	room_id: &RoomId,
351	shortstatehash: u64,
352	// Take mutex guard to make sure users get the room state mutex
353	_mutex_lock: &RoomMutexGuard,
354) {
355	const BUFSIZE: usize = size_of::<u64>();
356
357	self.db
358		.roomid_shortstatehash
359		.raw_aput::<BUFSIZE, _, _>(room_id, shortstatehash);
360}
361
362/// This fetches auth events from the current state.
363#[implement(Service)]
364#[expect(clippy::too_many_arguments)]
365#[tracing::instrument(skip(self, content), level = "debug")]
366pub async fn get_auth_events(
367	&self,
368	room_id: &RoomId,
369	kind: &TimelineEventType,
370	sender: &UserId,
371	state_key: Option<&str>,
372	content: &serde_json::value::RawValue,
373	auth_rules: &AuthorizationRules,
374	include_create: bool,
375) -> Result<StateMap<PduEvent>>
376where
377	StateEventType: Send + Sync,
378	StateKey: Send + Sync,
379{
380	let Ok(shortstatehash) = self.get_room_shortstatehash(room_id).await else {
381		return Ok(StateMap::new());
382	};
383
384	let sauthevents: HashMap<ShortStateKey, TypeStateKey> =
385		auth_types_for_event(kind, sender, state_key, content, auth_rules, include_create)?
386			.into_iter()
387			.stream()
388			.broad_filter_map(async |(event_type, state_key): TypeStateKey| {
389				self.services
390					.short
391					.get_shortstatekey(&event_type, &state_key)
392					.await
393					.map(move |sstatekey| (sstatekey, (event_type, state_key)))
394					.ok()
395			})
396			.collect()
397			.await;
398
399	let (state_keys, event_ids): (Vec<_>, Vec<_>) = self
400		.services
401		.state_accessor
402		.state_full_shortids(shortstatehash)
403		.ready_filter_map(Result::ok)
404		.ready_filter_map(|(shortstatekey, shorteventid)| {
405			sauthevents
406				.get(&shortstatekey)
407				.map(move |(ty, sk)| ((ty, sk), shorteventid))
408		})
409		.unzip()
410		.await;
411
412	self.services
413		.short
414		.multi_get_eventid_from_short(event_ids.into_iter().stream())
415		.zip(state_keys.into_iter().stream())
416		.ready_filter_map(|(event_id, (ty, sk))| Some(((ty, sk), event_id.ok()?)))
417		.broad_filter_map(async |((ty, sk), event_id): ((&_, &_), OwnedEventId)| {
418			let pdu = self.services.timeline.get_pdu(&event_id).await;
419
420			Some(((ty.clone(), sk.clone()), pdu.ok()?))
421		})
422		.collect()
423		.map(Ok)
424		.await
425}
426
427#[implement(Service)]
428#[tracing::instrument(skip_all, level = "debug")]
429pub async fn summary_stripped<Pdu: Event>(&self, event: &Pdu) -> Vec<Raw<AnyStrippedStateEvent>> {
430	let cells = [
431		(&StateEventType::RoomCreate, ""),
432		(&StateEventType::RoomJoinRules, ""),
433		(&StateEventType::RoomCanonicalAlias, ""),
434		(&StateEventType::RoomName, ""),
435		(&StateEventType::RoomAvatar, ""),
436		(&StateEventType::RoomMember, event.sender().as_str()), // Add recommended events
437		(&StateEventType::RoomEncryption, ""),
438		(&StateEventType::RoomTopic, ""),
439	];
440
441	let fetches = cells.into_iter().map(|(event_type, state_key)| {
442		self.services
443			.state_accessor
444			.room_state_get(event.room_id(), event_type, state_key)
445	});
446
447	join_all(fetches)
448		.await
449		.into_iter()
450		.filter_map(Result::ok)
451		.map(Event::into_format)
452		.chain(once(event.to_format()))
453		.collect()
454}
455
456/// Like `summary_stripped`, but formats each event as a full federation PDU
457/// per the room version's event format (MSC4311). The membership `event` is
458/// formatted from its `event_json`; the recommended state cells are fetched
459/// from stored room state.
460#[implement(Service)]
461#[tracing::instrument(skip_all, level = "debug")]
462pub async fn summary_pdus<Pdu: Event>(
463	&self,
464	event: &Pdu,
465	event_json: &CanonicalJsonObject,
466	room_version: &RoomVersionId,
467) -> Vec<Box<RawJsonValue>> {
468	let cells = [
469		(&StateEventType::RoomCreate, ""),
470		(&StateEventType::RoomJoinRules, ""),
471		(&StateEventType::RoomCanonicalAlias, ""),
472		(&StateEventType::RoomName, ""),
473		(&StateEventType::RoomAvatar, ""),
474		(&StateEventType::RoomMember, event.sender().as_str()),
475		(&StateEventType::RoomEncryption, ""),
476		(&StateEventType::RoomTopic, ""),
477	];
478
479	let membership = self
480		.services
481		.federation
482		.format_pdu_into(event_json.clone(), Some(room_version))
483		.await;
484
485	cells
486		.into_iter()
487		.stream()
488		.wide_filter_map(async |(event_type, state_key)| {
489			let pdu = self
490				.services
491				.state_accessor
492				.room_state_get(event.room_id(), event_type, state_key)
493				.await
494				.ok()?;
495
496			let pdu_json = self
497				.services
498				.timeline
499				.get_pdu_json(pdu.event_id())
500				.await
501				.ok()?;
502
503			Some(
504				self.services
505					.federation
506					.format_pdu_into(pdu_json, Some(room_version))
507					.await,
508			)
509		})
510		.chain(once(membership).stream())
511		.collect()
512		.await
513}
514
515/// Returns the room's version rules
516#[implement(Service)]
517#[inline]
518pub async fn get_room_version_rules(&self, room_id: &RoomId) -> Result<RoomVersionRules> {
519	self.get_room_version(room_id)
520		.await
521		.and_then_ref(room_version::rules)
522}
523
524/// Returns the room's version.
525#[implement(Service)]
526#[tracing::instrument(
527	level = "debug"
528	skip(self),
529	ret(level = "trace"),
530)]
531pub async fn get_room_version(&self, room_id: &RoomId) -> Result<RoomVersionId> {
532	self.services
533		.state_accessor
534		.room_state_get_content(room_id, &StateEventType::RoomCreate, "")
535		.await
536		.as_ref()
537		.map(room_version::from_create_content)
538		.cloned()
539		.map_err(|e| err!(Request(NotFound("No create event found: {e:?}"))))
540}
541
542#[implement(Service)]
543#[tracing::instrument(
544	level = "debug"
545	skip(self),
546	ret(level = "trace"),
547)]
548pub async fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<ShortStateHash> {
549	self.db
550		.roomid_shortstatehash
551		.get(room_id)
552		.await
553		.deserialized()
554}
555
556/// Returns the state hash at this event.
557#[implement(Service)]
558pub async fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<ShortStateHash> {
559	self.services
560		.short
561		.get_shorteventid(event_id)
562		.and_then(|shorteventid| self.get_shortstatehash(shorteventid))
563		.await
564}
565
566/// Returns the state hash at this event.
567#[implement(Service)]
568#[tracing::instrument(
569	level = "debug"
570	skip(self),
571	ret(level = "trace"),
572)]
573pub async fn get_shortstatehash(&self, shorteventid: ShortEventId) -> Result<ShortStateHash> {
574	const BUFSIZE: usize = size_of::<ShortEventId>();
575
576	self.db
577		.shorteventid_shortstatehash
578		.aqry::<BUFSIZE, _>(&shorteventid)
579		.await
580		.deserialized()
581}
582
583#[implement(Service)]
584pub(super) fn delete_room_shortstatehash(
585	&self,
586	room_id: &RoomId,
587	_mutex_lock: &Guard<OwnedRoomId, ()>,
588) -> Result {
589	self.db.roomid_shortstatehash.remove(room_id);
590
591	Ok(())
592}
593
594#[implement(Service)]
595#[tracing::instrument(
596	level = "trace"
597	skip(self),
598)]
599pub fn get_forward_extremities<'a>(
600	&'a self,
601	room_id: &'a RoomId,
602) -> impl Stream<Item = &EventId> + Send + '_ {
603	let prefix = (room_id, Interfix);
604
605	self.db
606		.roomid_pduleaves
607		.keys_prefix(&prefix)
608		.map_ok(|(_, event_id): (Ignore, &EventId)| event_id)
609		.ignore_err()
610}
611
612#[implement(Service)]
613#[tracing::instrument(
614	level = "debug"
615	skip_all,
616	fields(%room_id),
617)]
618pub async fn set_forward_extremities<'a, I>(
619	&'a self,
620	room_id: &'a RoomId,
621	event_ids: I,
622	_state_lock: &'a RoomMutexGuard,
623) where
624	I: Iterator<Item = &'a EventId> + Send + 'a,
625{
626	let prefix = (room_id, Interfix);
627	self.db
628		.roomid_pduleaves
629		.keys_prefix_raw(&prefix)
630		.ignore_err()
631		.ready_for_each(|key| self.db.roomid_pduleaves.remove(key))
632		.await;
633
634	for event_id in event_ids {
635		let key = (room_id, event_id);
636		self.db.roomid_pduleaves.put_raw(key, event_id);
637	}
638}
639
640#[implement(Service)]
641pub(super) async fn delete_all_rooms_forward_extremities(&self, room_id: &RoomId) -> Result {
642	let prefix = (room_id, Interfix);
643
644	self.db
645		.roomid_pduleaves
646		.keys_prefix_raw(&prefix)
647		.ignore_err()
648		.ready_for_each(|key| {
649			trace!("Removing key: {key:?}");
650			self.db.roomid_pduleaves.remove(key);
651		})
652		.await;
653
654	Ok(())
655}