Skip to main content

tuwunel_service/rooms/state_res/resolve/
power_sort.rs

1use std::{
2	collections::{HashMap, HashSet},
3	iter::once,
4};
5
6use futures::{StreamExt, TryFutureExt, TryStreamExt, stream::FuturesUnordered};
7use ruma::{
8	EventId, OwnedEventId,
9	events::{TimelineEventType, room::power_levels::UserPowerLevel},
10	room_version_rules::RoomVersionRules,
11};
12use tuwunel_core::{
13	Result, err,
14	matrix::Event,
15	utils::stream::{BroadbandExt, IterStream, TryBroadbandExt},
16};
17
18use super::super::{
19	events::{
20		RoomCreateEvent, RoomPowerLevelsEvent, RoomPowerLevelsIntField, is_power_event,
21		power_levels::RoomPowerLevelsEventOptionExt,
22	},
23	topological_sort,
24	topological_sort::ReferencedIds,
25};
26
27/// Enlarge the given list of conflicted power events by adding the events in
28/// their auth chain that are in the full conflicted set, and sort it using
29/// reverse topological power ordering.
30///
31/// ## Arguments
32///
33/// * `conflicted_power_events` - The list of power events in the full
34///   conflicted set.
35///
36/// * `full_conflicted_set` - The full conflicted set.
37///
38/// * `rules` - The authorization rules for the current room version.
39///
40/// * `fetch` - Function to fetch an event in the room given its event ID.
41///
42/// ## Returns
43///
44/// Returns the ordered list of event IDs from earliest to latest.
45#[tracing::instrument(
46	level = "debug",
47	skip_all,
48	fields(
49		conflicted = full_conflicted_set.len(),
50	)
51)]
52pub(super) async fn power_sort<Fetch, Fut, Pdu>(
53	rules: &RoomVersionRules,
54	full_conflicted_set: &HashSet<OwnedEventId>,
55	fetch: &Fetch,
56) -> Result<Vec<OwnedEventId>>
57where
58	Fetch: Fn(OwnedEventId) -> Fut + Sync,
59	Fut: Future<Output = Result<Pdu>> + Send,
60	Pdu: Event,
61{
62	// A representation of the DAG, a map of event ID to its list of auth events
63	// that are in the full conflicted set. Fill the graph.
64	let graph = full_conflicted_set
65		.iter()
66		.stream()
67		.broad_filter_map(async |id| {
68			is_power_event_id(id, fetch)
69				.await
70				.then(|| id.clone())
71		})
72		.enumerate()
73		.fold(HashMap::new(), |graph, (i, event_id)| {
74			add_event_auth_chain(full_conflicted_set, graph, event_id, fetch, i)
75		})
76		.await;
77
78	// The map of event ID to the power level of the sender of the event.
79	// Get the power level of the sender of each event in the graph.
80	let event_to_power_level: HashMap<_, _> = graph
81		.keys()
82		.try_stream()
83		.map_ok(AsRef::as_ref)
84		.broad_and_then(|event_id| {
85			power_level_for_sender(event_id, rules, fetch)
86				.map_ok(move |sender_power| (event_id.to_owned(), sender_power))
87				.map_err(|e| err!(Request(NotFound("Missing PL for sender: {e}"))))
88		})
89		.try_collect()
90		.await?;
91
92	let query = async |event_id: OwnedEventId| {
93		let power_level = *event_to_power_level
94			.get(&event_id)
95			.ok_or_else(|| err!(Request(NotFound("Missing PL event: {event_id}"))))?;
96
97		let event = fetch(event_id).await?;
98		Ok((power_level, event.origin_server_ts()))
99	};
100
101	topological_sort(graph, &query).await
102}
103
104/// Add the event with the given event ID and all the events in its auth chain
105/// that are in the full conflicted set to the graph.
106#[tracing::instrument(
107	name = "auth_chain",
108	level = "trace",
109	skip_all,
110	fields(
111		graph = graph.len(),
112		?event_id,
113		%i,
114	)
115)]
116async fn add_event_auth_chain<Fetch, Fut, Pdu>(
117	full_conflicted_set: &HashSet<OwnedEventId>,
118	mut graph: HashMap<OwnedEventId, ReferencedIds>,
119	event_id: OwnedEventId,
120	fetch: &Fetch,
121	i: usize,
122) -> HashMap<OwnedEventId, ReferencedIds>
123where
124	Fetch: Fn(OwnedEventId) -> Fut + Sync,
125	Fut: Future<Output = Result<Pdu>> + Send,
126	Pdu: Event,
127{
128	let mut todo: FuturesUnordered<Fut> = once(fetch(event_id)).collect();
129
130	while let Some(event) = todo.next().await {
131		let Ok(event) = event else {
132			continue;
133		};
134
135		let event_id = event.event_id().to_owned();
136		graph.entry(event_id.clone()).or_default();
137
138		for auth_event_id in event
139			.auth_events_into()
140			.into_iter()
141			.filter(|auth_event_id| full_conflicted_set.contains(auth_event_id))
142		{
143			if !graph.contains_key(&auth_event_id) {
144				todo.push(fetch(auth_event_id.clone()));
145			}
146
147			let references = graph
148				.get_mut(&event_id)
149				.expect("event_id present in graph");
150
151			if !references.contains(&auth_event_id) {
152				references.push(auth_event_id);
153			}
154		}
155	}
156
157	graph
158}
159
160/// Find the power level for the sender of the event of the given event ID or
161/// return a default value of zero.
162///
163/// We find the most recent `m.room.power_levels` by walking backwards in the
164/// auth chain of the event.
165///
166/// Do NOT use this anywhere but topological sort.
167///
168/// ## Arguments
169///
170/// * `event_id` - The event ID of the event to get the power level of the
171///   sender of.
172///
173/// * `rules` - The authorization rules for the current room version.
174///
175/// * `fetch` - Function to fetch an event in the room given its event ID.
176///
177/// ## Returns
178///
179/// Returns the power level of the sender of the event or an `Err(_)` if one of
180/// the auth events if malformed.
181#[tracing::instrument(
182	name = "sender_power",
183	level = "trace",
184	skip_all,
185	fields(
186		?event_id,
187	)
188)]
189async fn power_level_for_sender<Fetch, Fut, Pdu>(
190	event_id: &EventId,
191	rules: &RoomVersionRules,
192	fetch: &Fetch,
193) -> Result<UserPowerLevel>
194where
195	Fetch: Fn(OwnedEventId) -> Fut + Sync,
196	Fut: Future<Output = Result<Pdu>> + Send,
197	Pdu: Event,
198{
199	let event = fetch(event_id.into()).await;
200	let hydra_room_id = rules
201		.authorization
202		.room_create_event_id_as_room_id;
203
204	let mut create_event = None;
205	let mut power_levels_event = None;
206	if hydra_room_id && let Ok(event) = event.as_ref() {
207		let create_id = event.room_id().as_event_id()?;
208		let fetched = fetch(create_id).await?;
209
210		_ = create_event.insert(RoomCreateEvent::new(fetched));
211	}
212
213	for auth_event_id in event
214		.as_ref()
215		.map(Event::auth_events)
216		.into_iter()
217		.flatten()
218	{
219		use TimelineEventType::{RoomCreate, RoomPowerLevels};
220
221		let Ok(auth_event) = fetch(auth_event_id.to_owned()).await else {
222			continue;
223		};
224
225		if !hydra_room_id && auth_event.is_type_and_state_key(&RoomCreate, "") {
226			_ = create_event.get_or_insert_with(|| RoomCreateEvent::new(auth_event));
227		} else if auth_event.is_type_and_state_key(&RoomPowerLevels, "") {
228			_ = power_levels_event.get_or_insert_with(|| RoomPowerLevelsEvent::new(auth_event));
229		}
230
231		if power_levels_event.is_some() && create_event.is_some() {
232			break;
233		}
234	}
235
236	let creators = create_event
237		.as_ref()
238		.and_then(|event| event.creators(&rules.authorization).ok());
239
240	if let Some((event, creators)) = event.ok().zip(creators) {
241		power_levels_event.user_power_level(event.sender(), creators, &rules.authorization)
242	} else {
243		power_levels_event
244			.get_as_int_or_default(RoomPowerLevelsIntField::UsersDefault, &rules.authorization)
245			.map(Into::into)
246	}
247}
248
249/// Whether the given event ID belongs to a power event.
250///
251/// See the docs of `is_power_event()` for the definition of a power event.
252#[tracing::instrument(
253	name = "is_power_event",
254	level = "trace",
255	skip_all,
256	fields(
257		?event_id,
258	)
259)]
260async fn is_power_event_id<Fetch, Fut, Pdu>(event_id: &EventId, fetch: &Fetch) -> bool
261where
262	Fetch: Fn(OwnedEventId) -> Fut + Sync,
263	Fut: Future<Output = Result<Pdu>> + Send,
264	Pdu: Event,
265{
266	match fetch(event_id.to_owned()).await {
267		| Ok(state) => is_power_event(&state),
268		| _ => false,
269	}
270}