Skip to main content

tuwunel_service/rooms/event_handler/
resolve_state.rs

1use std::{borrow::Borrow, collections::HashMap, sync::Arc};
2
3use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
4use ruma::{OwnedEventId, RoomId, RoomVersionId};
5use tuwunel_core::{
6	Result, err, implement,
7	matrix::room_version,
8	trace,
9	utils::stream::{IterStream, ReadyExt, TryWidebandExt, WidebandExt},
10};
11
12use crate::rooms::{
13	state_compressor::CompressedState,
14	state_res::{self, AuthSet, StateMap},
15};
16
17#[implement(super::Service)]
18#[tracing::instrument(
19	name = "state",
20	level = "debug",
21	skip_all,
22	fields(
23		incoming = ?incoming_state.len()
24	),
25)]
26pub async fn resolve_state(
27	&self,
28	room_id: &RoomId,
29	room_version: &RoomVersionId,
30	incoming_state: HashMap<u64, OwnedEventId>,
31) -> Result<Arc<CompressedState>> {
32	trace!("Loading current room state ids");
33	let current_sstatehash = self
34		.services
35		.state
36		.get_room_shortstatehash(room_id)
37		.map_err(|e| err!(Database(error!("No state for {room_id:?}: {e:?}"))))
38		.await?;
39
40	let current_state_ids: HashMap<_, _> = self
41		.services
42		.state_accessor
43		.state_full_ids(current_sstatehash)
44		.collect()
45		.await;
46
47	trace!("Loading fork states");
48	let fork_states = [current_state_ids, incoming_state];
49	let auth_chain_sets = fork_states
50		.iter()
51		.try_stream()
52		.wide_and_then(|state| {
53			self.services
54				.auth_chain
55				.event_ids_iter(room_id, room_version, state.values().map(Borrow::borrow))
56				.try_collect::<AuthSet<OwnedEventId>>()
57		})
58		.ready_filter_map(Result::ok);
59
60	let fork_states = fork_states
61		.iter()
62		.stream()
63		.wide_then(|fork_state| {
64			let shortstatekeys = fork_state.keys().copied().stream();
65			let event_ids = fork_state.values().cloned().stream();
66			self.services
67				.short
68				.multi_get_statekey_from_short(shortstatekeys)
69				.zip(event_ids)
70				.ready_filter_map(|(ty_sk, id)| Some((ty_sk.ok()?, id)))
71				.collect::<StateMap<OwnedEventId>>()
72		});
73
74	trace!("Resolving state");
75	let state = self
76		.state_resolution(room_id, room_version, fork_states, auth_chain_sets)
77		.await?;
78
79	trace!("State resolution done.");
80	let state_events: Vec<_> = state
81		.iter()
82		.stream()
83		.wide_then(|((event_type, state_key), event_id)| {
84			self.services
85				.short
86				.get_or_create_shortstatekey(event_type, state_key)
87				.map(move |shortstatekey| (shortstatekey, event_id))
88		})
89		.collect()
90		.await;
91
92	trace!("Compressing state...");
93	let new_room_state: CompressedState = self
94		.services
95		.state_compressor
96		.compress_state_events(
97			state_events
98				.iter()
99				.map(|(ssk, eid)| (ssk, (*eid).borrow())),
100		)
101		.collect()
102		.await;
103
104	Ok(Arc::new(new_room_state))
105}
106
107#[implement(super::Service)]
108#[tracing::instrument(name = "resolve", level = "debug", skip_all)]
109pub(super) async fn state_resolution<StateSets, AuthSets>(
110	&self,
111	_room_id: &RoomId,
112	room_version: &RoomVersionId,
113	state_sets: StateSets,
114	auth_chains: AuthSets,
115) -> Result<StateMap<OwnedEventId>>
116where
117	StateSets: Stream<Item = StateMap<OwnedEventId>> + Send,
118	AuthSets: Stream<Item = AuthSet<OwnedEventId>> + Send,
119{
120	state_res::resolve(
121		&room_version::rules(room_version)?,
122		state_sets,
123		auth_chains,
124		&async |event_id: OwnedEventId| self.event_fetch(&event_id).await,
125		&async |event_id: OwnedEventId| self.event_exists(&event_id).await,
126		self.services.server.config.hydra_backports,
127	)
128	.map_err(|e| err!(error!("State resolution failed: {e:?}")))
129	.await
130}