Skip to main content

tuwunel_service/rooms/event_handler/
state_at_incoming.rs

1use std::{collections::HashMap, iter::Iterator};
2
3use futures::{FutureExt, StreamExt, TryFutureExt, TryStreamExt, future::try_join};
4use ruma::{OwnedEventId, RoomId, RoomVersionId};
5use tuwunel_core::{
6	Result, apply, debug, debug_warn, err, implement,
7	matrix::Event,
8	ref_at, trace,
9	utils::{
10		option::OptionExt,
11		stream::{BroadbandExt, IterStream, ReadyExt, TryBroadbandExt, TryWidebandExt},
12	},
13};
14
15use crate::rooms::{
16	short::ShortStateHash,
17	state_res::{AuthSet, StateMap},
18};
19
20// TODO: if we know the prev_events of the incoming event we can avoid the
21#[implement(super::Service)]
22// request and build the state from a known point and resolve if > 1 prev_event
23#[tracing::instrument(name = "state1", level = "debug", skip_all)]
24pub(super) async fn state_at_incoming_degree_one<Pdu>(
25	&self,
26	incoming_pdu: &Pdu,
27) -> Result<Option<HashMap<u64, OwnedEventId>>>
28where
29	Pdu: Event,
30{
31	debug_assert!(
32		incoming_pdu.prev_events().count() == 1,
33		"Incoming PDU must have one prev_event to make this call"
34	);
35
36	let prev_event_id = incoming_pdu
37		.prev_events()
38		.next()
39		.expect("at least one prev_event");
40
41	let Ok(prev_event_sstatehash) = self
42		.services
43		.state
44		.pdu_shortstatehash(prev_event_id)
45		.inspect_err(|e| debug_warn!(?prev_event_id, "Missing state at prev_event: {e}"))
46		.await
47	else {
48		return Ok(None);
49	};
50
51	debug!(?prev_event_id, ?prev_event_sstatehash, "Resolving state at prev_event.");
52
53	let prev_event = self
54		.services
55		.timeline
56		.get_pdu(prev_event_id)
57		.map_err(|e| err!(Database("Could not find prev_event, but we know the state: {e:?}")));
58
59	let state = self
60		.services
61		.state_accessor
62		.state_full_ids(prev_event_sstatehash)
63		.collect::<HashMap<_, _>>()
64		.map(Ok);
65
66	let (prev_event, mut state) = try_join(prev_event, state).await?;
67
68	debug!(
69		?prev_event_id,
70		?prev_event_sstatehash,
71		state_ids = state.len(),
72		"Resolved state at prev_event.",
73	);
74
75	if let Some(state_key) = prev_event.state_key() {
76		let prev_event_type = prev_event.event_type().to_cow_str().into();
77
78		let shortstatekey = self
79			.services
80			.short
81			.get_or_create_shortstatekey(&prev_event_type, state_key)
82			.await;
83
84		state.insert(shortstatekey, prev_event.event_id().into());
85		// Now it's the state after the pdu
86		debug!(
87			?prev_event_id,
88			?prev_event_type,
89			?prev_event_sstatehash,
90			?shortstatekey,
91			state_ids = state.len(),
92			"Added prev_event to state.",
93		);
94	}
95
96	debug_assert!(!state.is_empty(), "should be returning None for empty HashMap result");
97
98	Ok(Some(state))
99}
100
101#[implement(super::Service)]
102#[tracing::instrument(name = "stateN", level = "debug", skip_all)]
103pub(super) async fn state_at_incoming_resolved<Pdu>(
104	&self,
105	incoming_pdu: &Pdu,
106	room_id: &RoomId,
107	room_version: &RoomVersionId,
108) -> Result<Option<HashMap<u64, OwnedEventId>>>
109where
110	Pdu: Event,
111{
112	debug_assert!(
113		incoming_pdu.prev_events().count() > 1,
114		"Incoming PDU should have more than one prev_event for this codepath"
115	);
116
117	trace!("Calculating extremity statehashes...");
118	let Ok(extremity_sstatehashes) = incoming_pdu
119		.prev_events()
120		.try_stream()
121		.broad_and_then(|prev_event_id| {
122			let sstatehash = self
123				.services
124				.state
125				.pdu_shortstatehash(prev_event_id);
126
127			let prev_event = self.services.timeline.get_pdu(prev_event_id);
128
129			try_join(sstatehash, prev_event).inspect_err(move |e| {
130				debug_warn!(?prev_event_id, "Missing state at prev_event: {e}");
131			})
132		})
133		.try_collect::<HashMap<_, _>>()
134		.await
135	else {
136		return Ok(None);
137	};
138
139	trace!("Calculating fork states...");
140	let (fork_states, auth_chain_sets) = extremity_sstatehashes
141		.into_iter()
142		.try_stream()
143		.wide_and_then(|(sstatehash, prev_event)| {
144			self.state_at_incoming_fork(room_id, room_version, sstatehash, prev_event)
145		})
146		.try_collect()
147		.map_ok(Vec::into_iter)
148		.map_ok(Iterator::unzip)
149		.map_ok(apply!(2, Vec::into_iter))
150		.map_ok(apply!(2, IterStream::stream))
151		.await?;
152
153	trace!("Resolving state");
154	let Ok(new_state) = self
155		.state_resolution(room_id, room_version, fork_states, auth_chain_sets)
156		.inspect_ok(|_| trace!("State resolution done."))
157		.await
158	else {
159		return Ok(None);
160	};
161
162	new_state
163		.into_iter()
164		.stream()
165		.broad_then(async |((event_type, state_key), event_id)| {
166			self.services
167				.short
168				.get_or_create_shortstatekey(&event_type, &state_key)
169				.map(move |shortstatekey| (shortstatekey, event_id))
170				.await
171		})
172		.collect::<HashMap<_, _>>()
173		.inspect(|state| trace!(state = state.len(), "Created shortstatekeys."))
174		.map(Some)
175		.map(Ok)
176		.await
177}
178
179#[implement(super::Service)]
180#[tracing::instrument(
181	name = "fork",
182	level = "debug",
183	skip_all,
184	fields(
185		?sstatehash,
186		prev_event = ?prev_event.event_id(),
187	)
188)]
189async fn state_at_incoming_fork<Pdu>(
190	&self,
191	room_id: &RoomId,
192	room_version: &RoomVersionId,
193	sstatehash: ShortStateHash,
194	prev_event: Pdu,
195) -> Result<(StateMap<OwnedEventId>, AuthSet<OwnedEventId>)>
196where
197	Pdu: Event,
198{
199	let leaf = prev_event
200		.state_key()
201		.map_stream(async |state_key| {
202			let event_id = prev_event.event_id();
203			let event_type = prev_event.kind().to_cow_str().into();
204			let shortstatekey = self
205				.services
206				.short
207				.get_or_create_shortstatekey(&event_type, state_key)
208				.await;
209
210			(shortstatekey, event_id.to_owned())
211		});
212
213	let leaf_state_after_event: Vec<_> = self
214		.services
215		.state_accessor
216		.state_full_ids(sstatehash)
217		.chain(leaf)
218		.collect()
219		.await;
220
221	trace!(
222		prev_event = ?prev_event.event_id(),
223		?sstatehash,
224		leaf_states = leaf_state_after_event.len(),
225		"leaf state after event"
226	);
227
228	let starting_events = leaf_state_after_event
229		.iter()
230		.map(ref_at!(1))
231		.map(AsRef::as_ref);
232
233	let auth_chain = self
234		.services
235		.auth_chain
236		.event_ids_iter(room_id, room_version, starting_events)
237		.try_collect();
238
239	let fork_state = leaf_state_after_event
240		.iter()
241		.stream()
242		.broad_then(|(k, id)| {
243			self.services
244				.short
245				.get_statekey_from_short(*k)
246				.map_ok(|(ty, sk)| ((ty, sk), id.clone()))
247		})
248		.ready_filter_map(Result::ok)
249		.collect()
250		.map(Ok);
251
252	try_join(fork_state, auth_chain).await
253}