tuwunel_service/rooms/event_handler/
state_at_incoming.rs1use std::{collections::HashMap, iter::Iterator};
2
3use futures::{FutureExt, StreamExt, TryFutureExt, TryStreamExt, future::try_join};
4use ruma::{OwnedEventId, RoomId, RoomVersionId};
5use tuwunel_core::{
6 Result, apply, debug, debug_warn, err, implement,
7 matrix::Event,
8 ref_at, trace,
9 utils::{
10 option::OptionExt,
11 stream::{BroadbandExt, IterStream, ReadyExt, TryBroadbandExt, TryWidebandExt},
12 },
13};
14
15use crate::rooms::{
16 short::ShortStateHash,
17 state_res::{AuthSet, StateMap},
18};
19
20#[implement(super::Service)]
22#[tracing::instrument(name = "state1", level = "debug", skip_all)]
24pub(super) async fn state_at_incoming_degree_one<Pdu>(
25 &self,
26 incoming_pdu: &Pdu,
27) -> Result<Option<HashMap<u64, OwnedEventId>>>
28where
29 Pdu: Event,
30{
31 debug_assert!(
32 incoming_pdu.prev_events().count() == 1,
33 "Incoming PDU must have one prev_event to make this call"
34 );
35
36 let prev_event_id = incoming_pdu
37 .prev_events()
38 .next()
39 .expect("at least one prev_event");
40
41 let Ok(prev_event_sstatehash) = self
42 .services
43 .state
44 .pdu_shortstatehash(prev_event_id)
45 .inspect_err(|e| debug_warn!(?prev_event_id, "Missing state at prev_event: {e}"))
46 .await
47 else {
48 return Ok(None);
49 };
50
51 debug!(?prev_event_id, ?prev_event_sstatehash, "Resolving state at prev_event.");
52
53 let prev_event = self
54 .services
55 .timeline
56 .get_pdu(prev_event_id)
57 .map_err(|e| err!(Database("Could not find prev_event, but we know the state: {e:?}")));
58
59 let state = self
60 .services
61 .state_accessor
62 .state_full_ids(prev_event_sstatehash)
63 .collect::<HashMap<_, _>>()
64 .map(Ok);
65
66 let (prev_event, mut state) = try_join(prev_event, state).await?;
67
68 debug!(
69 ?prev_event_id,
70 ?prev_event_sstatehash,
71 state_ids = state.len(),
72 "Resolved state at prev_event.",
73 );
74
75 if let Some(state_key) = prev_event.state_key() {
76 let prev_event_type = prev_event.event_type().to_cow_str().into();
77
78 let shortstatekey = self
79 .services
80 .short
81 .get_or_create_shortstatekey(&prev_event_type, state_key)
82 .await;
83
84 state.insert(shortstatekey, prev_event.event_id().into());
85 debug!(
87 ?prev_event_id,
88 ?prev_event_type,
89 ?prev_event_sstatehash,
90 ?shortstatekey,
91 state_ids = state.len(),
92 "Added prev_event to state.",
93 );
94 }
95
96 debug_assert!(!state.is_empty(), "should be returning None for empty HashMap result");
97
98 Ok(Some(state))
99}
100
101#[implement(super::Service)]
102#[tracing::instrument(name = "stateN", level = "debug", skip_all)]
103pub(super) async fn state_at_incoming_resolved<Pdu>(
104 &self,
105 incoming_pdu: &Pdu,
106 room_id: &RoomId,
107 room_version: &RoomVersionId,
108) -> Result<Option<HashMap<u64, OwnedEventId>>>
109where
110 Pdu: Event,
111{
112 debug_assert!(
113 incoming_pdu.prev_events().count() > 1,
114 "Incoming PDU should have more than one prev_event for this codepath"
115 );
116
117 trace!("Calculating extremity statehashes...");
118 let Ok(extremity_sstatehashes) = incoming_pdu
119 .prev_events()
120 .try_stream()
121 .broad_and_then(|prev_event_id| {
122 let sstatehash = self
123 .services
124 .state
125 .pdu_shortstatehash(prev_event_id);
126
127 let prev_event = self.services.timeline.get_pdu(prev_event_id);
128
129 try_join(sstatehash, prev_event).inspect_err(move |e| {
130 debug_warn!(?prev_event_id, "Missing state at prev_event: {e}");
131 })
132 })
133 .try_collect::<HashMap<_, _>>()
134 .await
135 else {
136 return Ok(None);
137 };
138
139 trace!("Calculating fork states...");
140 let (fork_states, auth_chain_sets) = extremity_sstatehashes
141 .into_iter()
142 .try_stream()
143 .wide_and_then(|(sstatehash, prev_event)| {
144 self.state_at_incoming_fork(room_id, room_version, sstatehash, prev_event)
145 })
146 .try_collect()
147 .map_ok(Vec::into_iter)
148 .map_ok(Iterator::unzip)
149 .map_ok(apply!(2, Vec::into_iter))
150 .map_ok(apply!(2, IterStream::stream))
151 .await?;
152
153 trace!("Resolving state");
154 let Ok(new_state) = self
155 .state_resolution(room_id, room_version, fork_states, auth_chain_sets)
156 .inspect_ok(|_| trace!("State resolution done."))
157 .await
158 else {
159 return Ok(None);
160 };
161
162 new_state
163 .into_iter()
164 .stream()
165 .broad_then(async |((event_type, state_key), event_id)| {
166 self.services
167 .short
168 .get_or_create_shortstatekey(&event_type, &state_key)
169 .map(move |shortstatekey| (shortstatekey, event_id))
170 .await
171 })
172 .collect::<HashMap<_, _>>()
173 .inspect(|state| trace!(state = state.len(), "Created shortstatekeys."))
174 .map(Some)
175 .map(Ok)
176 .await
177}
178
179#[implement(super::Service)]
180#[tracing::instrument(
181 name = "fork",
182 level = "debug",
183 skip_all,
184 fields(
185 ?sstatehash,
186 prev_event = ?prev_event.event_id(),
187 )
188)]
189async fn state_at_incoming_fork<Pdu>(
190 &self,
191 room_id: &RoomId,
192 room_version: &RoomVersionId,
193 sstatehash: ShortStateHash,
194 prev_event: Pdu,
195) -> Result<(StateMap<OwnedEventId>, AuthSet<OwnedEventId>)>
196where
197 Pdu: Event,
198{
199 let leaf = prev_event
200 .state_key()
201 .map_stream(async |state_key| {
202 let event_id = prev_event.event_id();
203 let event_type = prev_event.kind().to_cow_str().into();
204 let shortstatekey = self
205 .services
206 .short
207 .get_or_create_shortstatekey(&event_type, state_key)
208 .await;
209
210 (shortstatekey, event_id.to_owned())
211 });
212
213 let leaf_state_after_event: Vec<_> = self
214 .services
215 .state_accessor
216 .state_full_ids(sstatehash)
217 .chain(leaf)
218 .collect()
219 .await;
220
221 trace!(
222 prev_event = ?prev_event.event_id(),
223 ?sstatehash,
224 leaf_states = leaf_state_after_event.len(),
225 "leaf state after event"
226 );
227
228 let starting_events = leaf_state_after_event
229 .iter()
230 .map(ref_at!(1))
231 .map(AsRef::as_ref);
232
233 let auth_chain = self
234 .services
235 .auth_chain
236 .event_ids_iter(room_id, room_version, starting_events)
237 .try_collect();
238
239 let fork_state = leaf_state_after_event
240 .iter()
241 .stream()
242 .broad_then(|(k, id)| {
243 self.services
244 .short
245 .get_statekey_from_short(*k)
246 .map_ok(|(ty, sk)| ((ty, sk), id.clone()))
247 })
248 .ready_filter_map(Result::ok)
249 .collect()
250 .map(Ok);
251
252 try_join(fork_state, auth_chain).await
253}