Skip to main content

tuwunel_service/rooms/state/
mod.rs

1use std::{collections::HashMap, fmt::Write, iter::once, sync::Arc};
2
3use async_trait::async_trait;
4use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, future::join_all};
5use ruma::{
6	EventId, OwnedEventId, OwnedRoomId, RoomId, RoomVersionId, UserId,
7	events::{AnyStrippedStateEvent, StateEventType, TimelineEventType},
8	room_version_rules::AuthorizationRules,
9	serde::Raw,
10};
11use tuwunel_core::{
12	Event, PduEvent, Result, err,
13	error::inspect_debug_log,
14	implement,
15	matrix::{PduCount, RoomVersionRules, StateKey, TypeStateKey, room_version},
16	result::{AndThenRef, FlatOk},
17	trace,
18	utils::{
19		IterStream, MutexMap, MutexMapGuard, ReadyExt, calculate_hash,
20		mutex_map::Guard,
21		stream::{BroadbandExt, TryIgnore, WidebandExt},
22	},
23	warn,
24};
25use tuwunel_database::{Deserialized, Ignore, Interfix, Map};
26
27use crate::{
28	rooms::{
29		short::{ShortEventId, ShortStateHash, ShortStateKey},
30		state_compressor::{CompressedState, parse_compressed_state_event},
31		state_res::{StateMap, auth_types_for_event},
32	},
33	services::OnceServices,
34};
35
36pub struct Service {
37	pub mutex: RoomMutexMap,
38	services: Arc<OnceServices>,
39	db: Data,
40}
41
42struct Data {
43	shorteventid_shortstatehash: Arc<Map>,
44	roomid_shortstatehash: Arc<Map>,
45	roomid_pduleaves: Arc<Map>,
46}
47
48type RoomMutexMap = MutexMap<OwnedRoomId, ()>;
49pub type RoomMutexGuard = MutexMapGuard<OwnedRoomId, ()>;
50
51#[async_trait]
52impl crate::Service for Service {
53	fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
54		Ok(Arc::new(Self {
55			mutex: RoomMutexMap::new(),
56			services: args.services.clone(),
57			db: Data {
58				shorteventid_shortstatehash: args.db["shorteventid_shortstatehash"].clone(),
59				roomid_shortstatehash: args.db["roomid_shortstatehash"].clone(),
60				roomid_pduleaves: args.db["roomid_pduleaves"].clone(),
61			},
62		}))
63	}
64
65	async fn memory_usage(&self, out: &mut (dyn Write + Send)) -> Result {
66		let mutex = self.mutex.len();
67		writeln!(out, "state_mutex: {mutex}")?;
68
69		Ok(())
70	}
71
72	fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
73}
74
75/// Set the room to the given statehash and update caches.
76#[implement(Service)]
77#[tracing::instrument(
78	name = "force",
79	level = "debug",
80	skip_all,
81	fields(
82		count = ?self.services.globals.pending_count(),
83		%shortstatehash,
84	)
85)]
86pub async fn force_state(
87	&self,
88	room_id: &RoomId,
89	shortstatehash: u64,
90	statediffnew: Arc<CompressedState>,
91	_statediffremoved: Arc<CompressedState>,
92	state_lock: &RoomMutexGuard,
93) -> Result {
94	statediffnew
95		.iter()
96		.stream()
97		.map(|&new| parse_compressed_state_event(new).1)
98		.wide_filter_map(async |shorteventid| {
99			let event_id: OwnedEventId = self
100				.services
101				.short
102				.get_eventid_from_short(shorteventid)
103				.inspect_err(inspect_debug_log)
104				.await
105				.ok()?;
106
107			self.services
108				.timeline
109				.get_pdu(&event_id)
110				.await
111				.ok()
112		})
113		.map(Ok)
114		.try_for_each(async |pdu| match pdu.kind {
115			| TimelineEventType::RoomMember => {
116				let Some(user_id) = pdu
117					.state_key
118					.as_ref()
119					.map(UserId::parse)
120					.flat_ok()
121				else {
122					return Ok(());
123				};
124
125				let Ok(membership_event) = pdu.get_content() else {
126					return Ok(());
127				};
128
129				let count = self.services.globals.next_count();
130				self.services
131					.state_cache
132					.update_membership(
133						room_id,
134						&user_id,
135						membership_event,
136						&pdu.sender,
137						None,
138						None,
139						false,
140						PduCount::Normal(*count),
141					)
142					.await
143			},
144			| TimelineEventType::SpaceChild => {
145				self.services.spaces.cache_evict(pdu.room_id());
146
147				Ok(())
148			},
149			| _ => Ok(()),
150		})
151		.boxed()
152		.await?;
153
154	self.services
155		.state_cache
156		.update_joined_count(room_id)
157		.await;
158
159	self.set_room_state(room_id, shortstatehash, state_lock);
160
161	Ok(())
162}
163
164/// Generates a new StateHash and associates it with the incoming event.
165///
166/// This adds all current state events (not including the incoming event)
167/// to `stateid_pduid` and adds the incoming event to `eventid_statehash`.
168#[implement(Service)]
169#[tracing::instrument(
170	name = "set",
171	level = "debug",
172	skip(self, state_ids_compressed),
173	fields(
174		count = ?self.services.globals.pending_count(),
175	)
176)]
177pub async fn set_event_state(
178	&self,
179	event_id: &EventId,
180	room_id: &RoomId,
181	state_ids_compressed: Arc<CompressedState>,
182) -> Result<ShortStateHash> {
183	const KEY_LEN: usize = size_of::<ShortEventId>();
184	const VAL_LEN: usize = size_of::<ShortStateHash>();
185
186	let shorteventid = self
187		.services
188		.short
189		.get_or_create_shorteventid(event_id)
190		.await;
191
192	let previous_shortstatehash = self.get_room_shortstatehash(room_id).await;
193
194	let state_hash = calculate_hash(state_ids_compressed.iter().map(|s| &s[..]));
195
196	let (shortstatehash, already_existed) = self
197		.services
198		.short
199		.get_or_create_shortstatehash(&state_hash)
200		.await;
201
202	if !already_existed {
203		let states_parents = match previous_shortstatehash {
204			| Ok(p) =>
205				self.services
206					.state_compressor
207					.load_shortstatehash_info(p)
208					.await?,
209			| _ => Vec::new(),
210		};
211
212		let (statediffnew, statediffremoved) =
213			if let Some(parent_stateinfo) = states_parents.last() {
214				let statediffnew: CompressedState = state_ids_compressed
215					.difference(&parent_stateinfo.full_state)
216					.copied()
217					.collect();
218
219				let statediffremoved: CompressedState = parent_stateinfo
220					.full_state
221					.difference(&state_ids_compressed)
222					.copied()
223					.collect();
224
225				(Arc::new(statediffnew), Arc::new(statediffremoved))
226			} else {
227				(state_ids_compressed, Arc::new(CompressedState::new()))
228			};
229
230		self.services
231			.state_compressor
232			.save_state_from_diff(
233				shortstatehash,
234				statediffnew,
235				statediffremoved,
236				1_000_000, // high number because no state will be based on this one
237				states_parents,
238			)?;
239	}
240
241	self.db
242		.shorteventid_shortstatehash
243		.aput::<KEY_LEN, VAL_LEN, _, _>(shorteventid, shortstatehash);
244
245	Ok(shortstatehash)
246}
247
248/// Generates a new StateHash and associates it with the incoming event.
249///
250/// This adds all current state events (not including the incoming event)
251/// to `stateid_pduid` and adds the incoming event to `eventid_statehash`.
252#[implement(Service)]
253#[tracing::instrument(
254	name = "set",
255	level = "debug",
256	skip(self, new_pdu),
257	fields(
258		count = ?self.services.globals.pending_count(),
259	)
260)]
261pub async fn append_to_state(&self, new_pdu: &PduEvent) -> Result<u64> {
262	const KEY_LEN: usize = size_of::<ShortEventId>();
263	const VAL_LEN: usize = size_of::<ShortStateHash>();
264
265	let shorteventid = self
266		.services
267		.short
268		.get_or_create_shorteventid(&new_pdu.event_id)
269		.await;
270
271	let previous_shortstatehash = self
272		.get_room_shortstatehash(&new_pdu.room_id)
273		.await;
274
275	if let Ok(p) = previous_shortstatehash {
276		self.db
277			.shorteventid_shortstatehash
278			.aput::<KEY_LEN, VAL_LEN, _, _>(shorteventid, p);
279	}
280
281	match &new_pdu.state_key {
282		| Some(state_key) => {
283			let states_parents = match previous_shortstatehash {
284				| Ok(p) =>
285					self.services
286						.state_compressor
287						.load_shortstatehash_info(p)
288						.await?,
289				| _ => Vec::new(),
290			};
291
292			let shortstatekey = self
293				.services
294				.short
295				.get_or_create_shortstatekey(&new_pdu.kind.to_string().into(), state_key)
296				.await;
297
298			let new = self
299				.services
300				.state_compressor
301				.compress_state_event(shortstatekey, &new_pdu.event_id)
302				.await;
303
304			let replaces = states_parents
305				.last()
306				.map(|info| {
307					info.full_state
308						.iter()
309						.find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes()))
310				})
311				.unwrap_or_default();
312
313			if Some(&new) == replaces {
314				return Ok(previous_shortstatehash.expect("must exist"));
315			}
316
317			// TODO: statehash with deterministic inputs
318			let shortstatehash = self.services.globals.next_count();
319
320			let mut statediffnew = CompressedState::new();
321			statediffnew.insert(new);
322
323			let mut statediffremoved = CompressedState::new();
324			if let Some(replaces) = replaces {
325				statediffremoved.insert(*replaces);
326			}
327
328			self.services
329				.state_compressor
330				.save_state_from_diff(
331					*shortstatehash,
332					Arc::new(statediffnew),
333					Arc::new(statediffremoved),
334					2,
335					states_parents,
336				)?;
337
338			Ok(*shortstatehash)
339		},
340		| _ => Ok(previous_shortstatehash.expect("first event in room must be a state event")),
341	}
342}
343
344/// Set the state hash to a new version, but does not update state_cache.
345#[implement(Service)]
346#[tracing::instrument(skip(self, _mutex_lock), level = "debug")]
347pub fn set_room_state(
348	&self,
349	room_id: &RoomId,
350	shortstatehash: u64,
351	// Take mutex guard to make sure users get the room state mutex
352	_mutex_lock: &RoomMutexGuard,
353) {
354	const BUFSIZE: usize = size_of::<u64>();
355
356	self.db
357		.roomid_shortstatehash
358		.raw_aput::<BUFSIZE, _, _>(room_id, shortstatehash);
359}
360
361/// This fetches auth events from the current state.
362#[implement(Service)]
363#[expect(clippy::too_many_arguments)]
364#[tracing::instrument(skip(self, content), level = "debug")]
365pub async fn get_auth_events(
366	&self,
367	room_id: &RoomId,
368	kind: &TimelineEventType,
369	sender: &UserId,
370	state_key: Option<&str>,
371	content: &serde_json::value::RawValue,
372	auth_rules: &AuthorizationRules,
373	include_create: bool,
374) -> Result<StateMap<PduEvent>>
375where
376	StateEventType: Send + Sync,
377	StateKey: Send + Sync,
378{
379	let Ok(shortstatehash) = self.get_room_shortstatehash(room_id).await else {
380		return Ok(StateMap::new());
381	};
382
383	let sauthevents: HashMap<ShortStateKey, TypeStateKey> =
384		auth_types_for_event(kind, sender, state_key, content, auth_rules, include_create)?
385			.into_iter()
386			.stream()
387			.broad_filter_map(async |(event_type, state_key): TypeStateKey| {
388				self.services
389					.short
390					.get_shortstatekey(&event_type, &state_key)
391					.await
392					.map(move |sstatekey| (sstatekey, (event_type, state_key)))
393					.ok()
394			})
395			.collect()
396			.await;
397
398	self.services
399		.state_accessor
400		.state_full_shortids(shortstatehash)
401		.ready_filter_map(Result::ok)
402		.ready_filter_map(|(shortstatekey, shorteventid)| {
403			sauthevents
404				.get(&shortstatekey)
405				.map(move |(ty, sk)| ((ty, sk), shorteventid))
406		})
407		.unzip()
408		.map(|(state_keys, event_ids): (Vec<_>, Vec<_>)| {
409			self.services
410				.short
411				.multi_get_eventid_from_short(event_ids.into_iter().stream())
412				.zip(state_keys.into_iter().stream())
413		})
414		.flatten_stream()
415		.ready_filter_map(|(event_id, (ty, sk))| Some(((ty, sk), event_id.ok()?)))
416		.broad_filter_map(async |((ty, sk), event_id): ((&_, &_), OwnedEventId)| {
417			let pdu = self.services.timeline.get_pdu(&event_id).await;
418
419			Some(((ty.clone(), sk.clone()), pdu.ok()?))
420		})
421		.collect()
422		.map(Ok)
423		.await
424}
425
426#[implement(Service)]
427#[tracing::instrument(skip_all, level = "debug")]
428pub async fn summary_stripped<Pdu: Event>(&self, event: &Pdu) -> Vec<Raw<AnyStrippedStateEvent>> {
429	let cells = [
430		(&StateEventType::RoomCreate, ""),
431		(&StateEventType::RoomJoinRules, ""),
432		(&StateEventType::RoomCanonicalAlias, ""),
433		(&StateEventType::RoomName, ""),
434		(&StateEventType::RoomAvatar, ""),
435		(&StateEventType::RoomMember, event.sender().as_str()), // Add recommended events
436		(&StateEventType::RoomEncryption, ""),
437		(&StateEventType::RoomTopic, ""),
438	];
439
440	let fetches = cells.into_iter().map(|(event_type, state_key)| {
441		self.services
442			.state_accessor
443			.room_state_get(event.room_id(), event_type, state_key)
444	});
445
446	join_all(fetches)
447		.await
448		.into_iter()
449		.filter_map(Result::ok)
450		.map(Event::into_format)
451		.chain(once(event.to_format()))
452		.collect()
453}
454
455/// Returns the room's version rules
456#[implement(Service)]
457#[inline]
458pub async fn get_room_version_rules(&self, room_id: &RoomId) -> Result<RoomVersionRules> {
459	self.get_room_version(room_id)
460		.await
461		.and_then_ref(room_version::rules)
462}
463
464/// Returns the room's version.
465#[implement(Service)]
466#[tracing::instrument(
467	level = "debug"
468	skip(self),
469	ret(level = "trace"),
470)]
471pub async fn get_room_version(&self, room_id: &RoomId) -> Result<RoomVersionId> {
472	self.services
473		.state_accessor
474		.room_state_get_content(room_id, &StateEventType::RoomCreate, "")
475		.await
476		.as_ref()
477		.map(room_version::from_create_content)
478		.cloned()
479		.map_err(|e| err!(Request(NotFound("No create event found: {e:?}"))))
480}
481
482#[implement(Service)]
483#[tracing::instrument(
484	level = "debug"
485	skip(self),
486	ret(level = "trace"),
487)]
488pub async fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<ShortStateHash> {
489	self.db
490		.roomid_shortstatehash
491		.get(room_id)
492		.await
493		.deserialized()
494}
495
496/// Returns the state hash at this event.
497#[implement(Service)]
498pub async fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<ShortStateHash> {
499	self.services
500		.short
501		.get_shorteventid(event_id)
502		.and_then(|shorteventid| self.get_shortstatehash(shorteventid))
503		.await
504}
505
506/// Returns the state hash at this event.
507#[implement(Service)]
508#[tracing::instrument(
509	level = "debug"
510	skip(self),
511	ret(level = "trace"),
512)]
513pub async fn get_shortstatehash(&self, shorteventid: ShortEventId) -> Result<ShortStateHash> {
514	const BUFSIZE: usize = size_of::<ShortEventId>();
515
516	self.db
517		.shorteventid_shortstatehash
518		.aqry::<BUFSIZE, _>(&shorteventid)
519		.await
520		.deserialized()
521}
522
523#[implement(Service)]
524pub(super) fn delete_room_shortstatehash(
525	&self,
526	room_id: &RoomId,
527	_mutex_lock: &Guard<OwnedRoomId, ()>,
528) -> Result {
529	self.db.roomid_shortstatehash.remove(room_id);
530
531	Ok(())
532}
533
534#[implement(Service)]
535#[tracing::instrument(
536	level = "trace"
537	skip(self),
538)]
539pub fn get_forward_extremities<'a>(
540	&'a self,
541	room_id: &'a RoomId,
542) -> impl Stream<Item = &EventId> + Send + '_ {
543	let prefix = (room_id, Interfix);
544
545	self.db
546		.roomid_pduleaves
547		.keys_prefix(&prefix)
548		.map_ok(|(_, event_id): (Ignore, &EventId)| event_id)
549		.ignore_err()
550}
551
552#[implement(Service)]
553#[tracing::instrument(
554	level = "debug"
555	skip_all,
556	fields(%room_id),
557)]
558pub async fn set_forward_extremities<'a, I>(
559	&'a self,
560	room_id: &'a RoomId,
561	event_ids: I,
562	_state_lock: &'a RoomMutexGuard,
563) where
564	I: Iterator<Item = &'a EventId> + Send + 'a,
565{
566	let prefix = (room_id, Interfix);
567	self.db
568		.roomid_pduleaves
569		.keys_prefix_raw(&prefix)
570		.ignore_err()
571		.ready_for_each(|key| self.db.roomid_pduleaves.remove(key))
572		.await;
573
574	for event_id in event_ids {
575		let key = (room_id, event_id);
576		self.db.roomid_pduleaves.put_raw(key, event_id);
577	}
578}
579
580#[implement(Service)]
581pub(super) async fn delete_all_rooms_forward_extremities(&self, room_id: &RoomId) -> Result {
582	let prefix = (room_id, Interfix);
583
584	self.db
585		.roomid_pduleaves
586		.keys_prefix_raw(&prefix)
587		.ignore_err()
588		.ready_for_each(|key| {
589			trace!("Removing key: {key:?}");
590			self.db.roomid_pduleaves.remove(key);
591		})
592		.await;
593
594	Ok(())
595}