tuwunel_service/rooms/state_res/resolve/
mainline_sort.rs1use futures::{
2 FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, pin_mut, stream::try_unfold,
3};
4use ruma::{EventId, OwnedEventId, events::TimelineEventType};
5use tuwunel_core::{
6 Error, Result, at, is_equal_to,
7 matrix::Event,
8 trace,
9 utils::stream::{BroadbandExt, IterStream, TryReadyExt},
10};
11
12#[tracing::instrument(
38 level = "debug",
39 skip_all,
40 fields(
41 power_levels = power_level_event_id
42 .as_deref()
43 .map(EventId::as_str)
44 .unwrap_or_default(),
45 )
46)]
47pub(super) async fn mainline_sort<'a, RemainingEvents, Fetch, Fut, Pdu>(
48 power_level_event_id: Option<OwnedEventId>,
49 events: RemainingEvents,
50 fetch: &Fetch,
51) -> Result<Vec<OwnedEventId>>
52where
53 RemainingEvents: Stream<Item = &'a EventId> + Send,
54 Fetch: Fn(OwnedEventId) -> Fut + Sync,
55 Fut: Future<Output = Result<Pdu>> + Send,
56 Pdu: Event,
57{
58 let mainline: Vec<_> = try_unfold(power_level_event_id, async |power_level_event_id| {
60 let Some(power_level_event_id) = power_level_event_id else {
61 return Ok::<_, Error>(None);
62 };
63
64 let power_level_event = fetch(power_level_event_id).await?;
65 let this_event_id = power_level_event.event_id().to_owned();
66 let next_event_id = get_power_levels_auth_event(&power_level_event, fetch)
67 .map_ok(|event| {
68 event
69 .as_ref()
70 .map(Event::event_id)
71 .map(ToOwned::to_owned)
72 })
73 .await?;
74
75 trace!(?this_event_id, ?next_event_id, "mainline descent",);
76
77 Ok(Some((this_event_id, next_event_id)))
78 })
79 .try_collect()
80 .await?;
81
82 let mainline = mainline.iter().rev().map(AsRef::as_ref);
83
84 events
85 .map(ToOwned::to_owned)
86 .broad_filter_map(async |event_id| {
87 let event = fetch(event_id.clone()).await.ok()?;
88 let origin_server_ts = event.origin_server_ts();
89 let position = mainline_position(Some(event), &mainline, fetch)
90 .await
91 .ok()?;
92
93 Some((event_id, (position, origin_server_ts)))
94 })
95 .inspect(|(event_id, (position, origin_server_ts))| {
96 trace!(position, ?origin_server_ts, ?event_id, "mainline position");
97 })
98 .collect()
99 .map(|mut vec: Vec<_>| {
100 vec.sort_by(|a, b| {
101 let (a_pos, a_ots) = &a.1;
102 let (b_pos, b_ots) = &b.1;
103 a_pos
104 .cmp(b_pos)
105 .then(a_ots.cmp(b_ots))
106 .then(a.cmp(b))
107 });
108
109 vec.into_iter().map(at!(0)).collect()
110 })
111 .map(Ok)
112 .await
113}
114
115#[tracing::instrument(
128 name = "position",
129 level = "trace",
130 ret(level = "trace"),
131 skip_all,
132 fields(
133 mainline = mainline.clone().count(),
134 event = ?current_event.as_ref().map(Event::event_id).map(ToOwned::to_owned),
135 )
136)]
137async fn mainline_position<'a, Mainline, Fetch, Fut, Pdu>(
138 mut current_event: Option<Pdu>,
139 mainline: &Mainline,
140 fetch: &Fetch,
141) -> Result<usize>
142where
143 Mainline: Iterator<Item = &'a EventId> + Clone + Send + Sync,
144 Fetch: Fn(OwnedEventId) -> Fut + Sync,
145 Fut: Future<Output = Result<Pdu>> + Send,
146 Pdu: Event,
147{
148 while let Some(event) = current_event {
149 trace!(
150 event_id = ?event.event_id(),
151 "mainline position search",
152 );
153
154 if let Some(position) = mainline
158 .clone()
159 .position(is_equal_to!(event.event_id()))
160 {
161 return Ok(position.saturating_add(1));
162 }
163
164 current_event = get_power_levels_auth_event(&event, fetch).await?;
166 }
167
168 Ok(0)
171}
172
173#[expect(clippy::redundant_closure)]
174#[tracing::instrument(level = "trace", skip_all)]
175async fn get_power_levels_auth_event<Fetch, Fut, Pdu>(
176 event: &Pdu,
177 fetch: &Fetch,
178) -> Result<Option<Pdu>>
179where
180 Fetch: Fn(OwnedEventId) -> Fut + Sync,
181 Fut: Future<Output = Result<Pdu>> + Send,
182 Pdu: Event,
183{
184 let power_level_event = event
185 .auth_events()
186 .try_stream()
187 .map_ok(ToOwned::to_owned)
188 .and_then(|auth_event_id| fetch(auth_event_id))
189 .ready_try_skip_while(|auth_event| {
190 Ok(!auth_event.is_type_and_state_key(&TimelineEventType::RoomPowerLevels, ""))
191 });
192
193 pin_mut!(power_level_event);
194 power_level_event.try_next().await
195}