tuwunel_service/rooms/state_res/resolve/
power_sort.rs1use std::{
2 collections::{HashMap, HashSet},
3 iter::once,
4};
5
6use futures::{StreamExt, TryFutureExt, TryStreamExt, stream::FuturesUnordered};
7use ruma::{
8 EventId, OwnedEventId,
9 events::{TimelineEventType, room::power_levels::UserPowerLevel},
10 room_version_rules::RoomVersionRules,
11};
12use tuwunel_core::{
13 Result, err,
14 matrix::Event,
15 utils::stream::{BroadbandExt, IterStream, TryBroadbandExt},
16};
17
18use super::super::{
19 events::{
20 RoomCreateEvent, RoomPowerLevelsEvent, RoomPowerLevelsIntField, is_power_event,
21 power_levels::RoomPowerLevelsEventOptionExt,
22 },
23 topological_sort,
24 topological_sort::ReferencedIds,
25};
26
27#[tracing::instrument(
46 level = "debug",
47 skip_all,
48 fields(
49 conflicted = full_conflicted_set.len(),
50 )
51)]
52pub(super) async fn power_sort<Fetch, Fut, Pdu>(
53 rules: &RoomVersionRules,
54 full_conflicted_set: &HashSet<OwnedEventId>,
55 fetch: &Fetch,
56) -> Result<Vec<OwnedEventId>>
57where
58 Fetch: Fn(OwnedEventId) -> Fut + Sync,
59 Fut: Future<Output = Result<Pdu>> + Send,
60 Pdu: Event,
61{
62 let graph = full_conflicted_set
65 .iter()
66 .stream()
67 .broad_filter_map(async |id| {
68 is_power_event_id(id, fetch)
69 .await
70 .then(|| id.clone())
71 })
72 .enumerate()
73 .fold(HashMap::new(), |graph, (i, event_id)| {
74 add_event_auth_chain(full_conflicted_set, graph, event_id, fetch, i)
75 })
76 .await;
77
78 let event_to_power_level: HashMap<_, _> = graph
81 .keys()
82 .try_stream()
83 .map_ok(AsRef::as_ref)
84 .broad_and_then(|event_id| {
85 power_level_for_sender(event_id, rules, fetch)
86 .map_ok(move |sender_power| (event_id.to_owned(), sender_power))
87 .map_err(|e| err!(Request(NotFound("Missing PL for sender: {e}"))))
88 })
89 .try_collect()
90 .await?;
91
92 let query = async |event_id: OwnedEventId| {
93 let power_level = *event_to_power_level
94 .get(&event_id)
95 .ok_or_else(|| err!(Request(NotFound("Missing PL event: {event_id}"))))?;
96
97 let event = fetch(event_id).await?;
98 Ok((power_level, event.origin_server_ts()))
99 };
100
101 topological_sort(graph, &query).await
102}
103
104#[tracing::instrument(
107 name = "auth_chain",
108 level = "trace",
109 skip_all,
110 fields(
111 graph = graph.len(),
112 ?event_id,
113 %i,
114 )
115)]
116async fn add_event_auth_chain<Fetch, Fut, Pdu>(
117 full_conflicted_set: &HashSet<OwnedEventId>,
118 mut graph: HashMap<OwnedEventId, ReferencedIds>,
119 event_id: OwnedEventId,
120 fetch: &Fetch,
121 i: usize,
122) -> HashMap<OwnedEventId, ReferencedIds>
123where
124 Fetch: Fn(OwnedEventId) -> Fut + Sync,
125 Fut: Future<Output = Result<Pdu>> + Send,
126 Pdu: Event,
127{
128 let mut todo: FuturesUnordered<Fut> = once(fetch(event_id)).collect();
129
130 while let Some(event) = todo.next().await {
131 let Ok(event) = event else {
132 continue;
133 };
134
135 let event_id = event.event_id().to_owned();
136 graph.entry(event_id.clone()).or_default();
137
138 for auth_event_id in event
139 .auth_events_into()
140 .into_iter()
141 .filter(|auth_event_id| full_conflicted_set.contains(auth_event_id))
142 {
143 if !graph.contains_key(&auth_event_id) {
144 todo.push(fetch(auth_event_id.clone()));
145 }
146
147 let references = graph
148 .get_mut(&event_id)
149 .expect("event_id present in graph");
150
151 if !references.contains(&auth_event_id) {
152 references.push(auth_event_id);
153 }
154 }
155 }
156
157 graph
158}
159
160#[tracing::instrument(
182 name = "sender_power",
183 level = "trace",
184 skip_all,
185 fields(
186 ?event_id,
187 )
188)]
189async fn power_level_for_sender<Fetch, Fut, Pdu>(
190 event_id: &EventId,
191 rules: &RoomVersionRules,
192 fetch: &Fetch,
193) -> Result<UserPowerLevel>
194where
195 Fetch: Fn(OwnedEventId) -> Fut + Sync,
196 Fut: Future<Output = Result<Pdu>> + Send,
197 Pdu: Event,
198{
199 let event = fetch(event_id.into()).await;
200 let hydra_room_id = rules
201 .authorization
202 .room_create_event_id_as_room_id;
203
204 let mut create_event = None;
205 let mut power_levels_event = None;
206 if hydra_room_id && let Ok(event) = event.as_ref() {
207 let create_id = event.room_id().as_event_id()?;
208 let fetched = fetch(create_id).await?;
209
210 _ = create_event.insert(RoomCreateEvent::new(fetched));
211 }
212
213 for auth_event_id in event
214 .as_ref()
215 .map(Event::auth_events)
216 .into_iter()
217 .flatten()
218 {
219 use TimelineEventType::{RoomCreate, RoomPowerLevels};
220
221 let Ok(auth_event) = fetch(auth_event_id.to_owned()).await else {
222 continue;
223 };
224
225 if !hydra_room_id && auth_event.is_type_and_state_key(&RoomCreate, "") {
226 _ = create_event.get_or_insert_with(|| RoomCreateEvent::new(auth_event));
227 } else if auth_event.is_type_and_state_key(&RoomPowerLevels, "") {
228 _ = power_levels_event.get_or_insert_with(|| RoomPowerLevelsEvent::new(auth_event));
229 }
230
231 if power_levels_event.is_some() && create_event.is_some() {
232 break;
233 }
234 }
235
236 let creators = create_event
237 .as_ref()
238 .and_then(|event| event.creators(&rules.authorization).ok());
239
240 if let Some((event, creators)) = event.ok().zip(creators) {
241 power_levels_event.user_power_level(event.sender(), creators, &rules.authorization)
242 } else {
243 power_levels_event
244 .get_as_int_or_default(RoomPowerLevelsIntField::UsersDefault, &rules.authorization)
245 .map(Into::into)
246 }
247}
248
249#[tracing::instrument(
253 name = "is_power_event",
254 level = "trace",
255 skip_all,
256 fields(
257 ?event_id,
258 )
259)]
260async fn is_power_event_id<Fetch, Fut, Pdu>(event_id: &EventId, fetch: &Fetch) -> bool
261where
262 Fetch: Fn(OwnedEventId) -> Fut + Sync,
263 Fut: Future<Output = Result<Pdu>> + Send,
264 Pdu: Event,
265{
266 match fetch(event_id.to_owned()).await {
267 | Ok(state) => is_power_event(&state),
268 | _ => false,
269 }
270}