tuwunel_service/rooms/state_res/resolve/
power_sort.rs1use std::{
2 borrow::Borrow,
3 collections::{HashMap, HashSet},
4 iter::once,
5};
6
7use futures::{StreamExt, TryFutureExt, TryStreamExt, stream::FuturesUnordered};
8use ruma::{
9 EventId, OwnedEventId,
10 events::{TimelineEventType, room::power_levels::UserPowerLevel},
11 room_version_rules::RoomVersionRules,
12};
13use tuwunel_core::{
14 Result, err,
15 matrix::Event,
16 utils::stream::{BroadbandExt, IterStream, TryBroadbandExt},
17};
18
19use super::super::{
20 events::{
21 RoomCreateEvent, RoomPowerLevelsEvent, RoomPowerLevelsIntField, is_power_event,
22 power_levels::RoomPowerLevelsEventOptionExt,
23 },
24 topological_sort,
25 topological_sort::ReferencedIds,
26};
27
28#[tracing::instrument(
47 level = "debug",
48 skip_all,
49 fields(
50 conflicted = full_conflicted_set.len(),
51 )
52)]
53pub(super) async fn power_sort<Fetch, Fut, Pdu>(
54 rules: &RoomVersionRules,
55 full_conflicted_set: &HashSet<OwnedEventId>,
56 fetch: &Fetch,
57) -> Result<Vec<OwnedEventId>>
58where
59 Fetch: Fn(OwnedEventId) -> Fut + Sync,
60 Fut: Future<Output = Result<Pdu>> + Send,
61 Pdu: Event,
62{
63 let graph = full_conflicted_set
66 .iter()
67 .stream()
68 .broad_filter_map(async |id| {
69 is_power_event_id(id, fetch)
70 .await
71 .then(|| id.clone())
72 })
73 .enumerate()
74 .fold(HashMap::new(), |graph, (i, event_id)| {
75 add_event_auth_chain(full_conflicted_set, graph, event_id, fetch, i)
76 })
77 .await;
78
79 let event_to_power_level: HashMap<_, _> = graph
82 .keys()
83 .try_stream()
84 .map_ok(AsRef::as_ref)
85 .broad_and_then(|event_id| {
86 power_level_for_sender(event_id, rules, fetch)
87 .map_ok(move |sender_power| (event_id, sender_power))
88 .map_err(|e| err!(Request(NotFound("Missing PL for sender: {e}"))))
89 })
90 .try_collect()
91 .await?;
92
93 let query = async |event_id: OwnedEventId| {
94 let power_level = *event_to_power_level
95 .get(&event_id.borrow())
96 .ok_or_else(|| err!(Request(NotFound("Missing PL event: {event_id}"))))?;
97
98 let event = fetch(event_id).await?;
99 Ok((power_level, event.origin_server_ts()))
100 };
101
102 topological_sort(&graph, &query).await
103}
104
105#[tracing::instrument(
108 name = "auth_chain",
109 level = "trace",
110 skip_all,
111 fields(
112 graph = graph.len(),
113 ?event_id,
114 %i,
115 )
116)]
117async fn add_event_auth_chain<Fetch, Fut, Pdu>(
118 full_conflicted_set: &HashSet<OwnedEventId>,
119 mut graph: HashMap<OwnedEventId, ReferencedIds>,
120 event_id: OwnedEventId,
121 fetch: &Fetch,
122 i: usize,
123) -> HashMap<OwnedEventId, ReferencedIds>
124where
125 Fetch: Fn(OwnedEventId) -> Fut + Sync,
126 Fut: Future<Output = Result<Pdu>> + Send,
127 Pdu: Event,
128{
129 let mut todo: FuturesUnordered<Fut> = once(fetch(event_id)).collect();
130
131 while let Some(event) = todo.next().await {
132 let Ok(event) = event else {
133 continue;
134 };
135
136 let event_id = event.event_id().to_owned();
137 graph.entry(event_id.clone()).or_default();
138
139 for auth_event_id in event
140 .auth_events_into()
141 .into_iter()
142 .filter(|auth_event_id| full_conflicted_set.contains(auth_event_id))
143 {
144 if !graph.contains_key(&auth_event_id) {
145 todo.push(fetch(auth_event_id.clone()));
146 }
147
148 let references = graph
149 .get_mut(&event_id)
150 .expect("event_id present in graph");
151
152 if !references.contains(&auth_event_id) {
153 references.push(auth_event_id);
154 }
155 }
156 }
157
158 graph
159}
160
161#[tracing::instrument(
183 name = "sender_power",
184 level = "trace",
185 skip_all,
186 fields(
187 ?event_id,
188 )
189)]
190async fn power_level_for_sender<Fetch, Fut, Pdu>(
191 event_id: &EventId,
192 rules: &RoomVersionRules,
193 fetch: &Fetch,
194) -> Result<UserPowerLevel>
195where
196 Fetch: Fn(OwnedEventId) -> Fut + Sync,
197 Fut: Future<Output = Result<Pdu>> + Send,
198 Pdu: Event,
199{
200 let event = fetch(event_id.into()).await;
201 let hydra_room_id = rules
202 .authorization
203 .room_create_event_id_as_room_id;
204
205 let mut create_event = None;
206 let mut power_levels_event = None;
207 if hydra_room_id && let Ok(event) = event.as_ref() {
208 let create_id = event.room_id().as_event_id()?;
209 let fetched = fetch(create_id).await?;
210
211 _ = create_event.insert(RoomCreateEvent::new(fetched));
212 }
213
214 for auth_event_id in event
215 .as_ref()
216 .map(Event::auth_events)
217 .into_iter()
218 .flatten()
219 {
220 use TimelineEventType::{RoomCreate, RoomPowerLevels};
221
222 let Ok(auth_event) = fetch(auth_event_id.to_owned()).await else {
223 continue;
224 };
225
226 if !hydra_room_id && auth_event.is_type_and_state_key(&RoomCreate, "") {
227 _ = create_event.get_or_insert_with(|| RoomCreateEvent::new(auth_event));
228 } else if auth_event.is_type_and_state_key(&RoomPowerLevels, "") {
229 _ = power_levels_event.get_or_insert_with(|| RoomPowerLevelsEvent::new(auth_event));
230 }
231
232 if power_levels_event.is_some() && create_event.is_some() {
233 break;
234 }
235 }
236
237 let creators = create_event
238 .as_ref()
239 .and_then(|event| event.creators(&rules.authorization).ok());
240
241 if let Some((event, creators)) = event.ok().zip(creators) {
242 power_levels_event.user_power_level(event.sender(), creators, &rules.authorization)
243 } else {
244 power_levels_event
245 .get_as_int_or_default(RoomPowerLevelsIntField::UsersDefault, &rules.authorization)
246 .map(Into::into)
247 }
248}
249
250#[tracing::instrument(
254 name = "is_power_event",
255 level = "trace",
256 skip_all,
257 fields(
258 ?event_id,
259 )
260)]
261async fn is_power_event_id<Fetch, Fut, Pdu>(event_id: &EventId, fetch: &Fetch) -> bool
262where
263 Fetch: Fn(OwnedEventId) -> Fut + Sync,
264 Fut: Future<Output = Result<Pdu>> + Send,
265 Pdu: Event,
266{
267 match fetch(event_id.to_owned()).await {
268 | Ok(state) => is_power_event(&state),
269 | _ => false,
270 }
271}