Skip to main content

tuwunel_service/rooms/state_res/resolve/
power_sort.rs

1use std::{
2	borrow::Borrow,
3	collections::{HashMap, HashSet},
4	iter::once,
5};
6
7use futures::{StreamExt, TryFutureExt, TryStreamExt, stream::FuturesUnordered};
8use ruma::{
9	EventId, OwnedEventId,
10	events::{TimelineEventType, room::power_levels::UserPowerLevel},
11	room_version_rules::RoomVersionRules,
12};
13use tuwunel_core::{
14	Result, err,
15	matrix::Event,
16	utils::stream::{BroadbandExt, IterStream, TryBroadbandExt},
17};
18
19use super::super::{
20	events::{
21		RoomCreateEvent, RoomPowerLevelsEvent, RoomPowerLevelsIntField, is_power_event,
22		power_levels::RoomPowerLevelsEventOptionExt,
23	},
24	topological_sort,
25	topological_sort::ReferencedIds,
26};
27
28/// Enlarge the given list of conflicted power events by adding the events in
29/// their auth chain that are in the full conflicted set, and sort it using
30/// reverse topological power ordering.
31///
32/// ## Arguments
33///
34/// * `conflicted_power_events` - The list of power events in the full
35///   conflicted set.
36///
37/// * `full_conflicted_set` - The full conflicted set.
38///
39/// * `rules` - The authorization rules for the current room version.
40///
41/// * `fetch` - Function to fetch an event in the room given its event ID.
42///
43/// ## Returns
44///
45/// Returns the ordered list of event IDs from earliest to latest.
46#[tracing::instrument(
47	level = "debug",
48	skip_all,
49	fields(
50		conflicted = full_conflicted_set.len(),
51	)
52)]
53pub(super) async fn power_sort<Fetch, Fut, Pdu>(
54	rules: &RoomVersionRules,
55	full_conflicted_set: &HashSet<OwnedEventId>,
56	fetch: &Fetch,
57) -> Result<Vec<OwnedEventId>>
58where
59	Fetch: Fn(OwnedEventId) -> Fut + Sync,
60	Fut: Future<Output = Result<Pdu>> + Send,
61	Pdu: Event,
62{
63	// A representation of the DAG, a map of event ID to its list of auth events
64	// that are in the full conflicted set. Fill the graph.
65	let graph = full_conflicted_set
66		.iter()
67		.stream()
68		.broad_filter_map(async |id| {
69			is_power_event_id(id, fetch)
70				.await
71				.then(|| id.clone())
72		})
73		.enumerate()
74		.fold(HashMap::new(), |graph, (i, event_id)| {
75			add_event_auth_chain(full_conflicted_set, graph, event_id, fetch, i)
76		})
77		.await;
78
79	// The map of event ID to the power level of the sender of the event.
80	// Get the power level of the sender of each event in the graph.
81	let event_to_power_level: HashMap<_, _> = graph
82		.keys()
83		.try_stream()
84		.map_ok(AsRef::as_ref)
85		.broad_and_then(|event_id| {
86			power_level_for_sender(event_id, rules, fetch)
87				.map_ok(move |sender_power| (event_id, sender_power))
88				.map_err(|e| err!(Request(NotFound("Missing PL for sender: {e}"))))
89		})
90		.try_collect()
91		.await?;
92
93	let query = async |event_id: OwnedEventId| {
94		let power_level = *event_to_power_level
95			.get(&event_id.borrow())
96			.ok_or_else(|| err!(Request(NotFound("Missing PL event: {event_id}"))))?;
97
98		let event = fetch(event_id).await?;
99		Ok((power_level, event.origin_server_ts()))
100	};
101
102	topological_sort(&graph, &query).await
103}
104
105/// Add the event with the given event ID and all the events in its auth chain
106/// that are in the full conflicted set to the graph.
107#[tracing::instrument(
108	name = "auth_chain",
109	level = "trace",
110	skip_all,
111	fields(
112		graph = graph.len(),
113		?event_id,
114		%i,
115	)
116)]
117async fn add_event_auth_chain<Fetch, Fut, Pdu>(
118	full_conflicted_set: &HashSet<OwnedEventId>,
119	mut graph: HashMap<OwnedEventId, ReferencedIds>,
120	event_id: OwnedEventId,
121	fetch: &Fetch,
122	i: usize,
123) -> HashMap<OwnedEventId, ReferencedIds>
124where
125	Fetch: Fn(OwnedEventId) -> Fut + Sync,
126	Fut: Future<Output = Result<Pdu>> + Send,
127	Pdu: Event,
128{
129	let mut todo: FuturesUnordered<Fut> = once(fetch(event_id)).collect();
130
131	while let Some(event) = todo.next().await {
132		let Ok(event) = event else {
133			continue;
134		};
135
136		let event_id = event.event_id().to_owned();
137		graph.entry(event_id.clone()).or_default();
138
139		for auth_event_id in event
140			.auth_events_into()
141			.into_iter()
142			.filter(|auth_event_id| full_conflicted_set.contains(auth_event_id))
143		{
144			if !graph.contains_key(&auth_event_id) {
145				todo.push(fetch(auth_event_id.clone()));
146			}
147
148			let references = graph
149				.get_mut(&event_id)
150				.expect("event_id present in graph");
151
152			if !references.contains(&auth_event_id) {
153				references.push(auth_event_id);
154			}
155		}
156	}
157
158	graph
159}
160
161/// Find the power level for the sender of the event of the given event ID or
162/// return a default value of zero.
163///
164/// We find the most recent `m.room.power_levels` by walking backwards in the
165/// auth chain of the event.
166///
167/// Do NOT use this anywhere but topological sort.
168///
169/// ## Arguments
170///
171/// * `event_id` - The event ID of the event to get the power level of the
172///   sender of.
173///
174/// * `rules` - The authorization rules for the current room version.
175///
176/// * `fetch` - Function to fetch an event in the room given its event ID.
177///
178/// ## Returns
179///
180/// Returns the power level of the sender of the event or an `Err(_)` if one of
181/// the auth events if malformed.
182#[tracing::instrument(
183	name = "sender_power",
184	level = "trace",
185	skip_all,
186	fields(
187		?event_id,
188	)
189)]
190async fn power_level_for_sender<Fetch, Fut, Pdu>(
191	event_id: &EventId,
192	rules: &RoomVersionRules,
193	fetch: &Fetch,
194) -> Result<UserPowerLevel>
195where
196	Fetch: Fn(OwnedEventId) -> Fut + Sync,
197	Fut: Future<Output = Result<Pdu>> + Send,
198	Pdu: Event,
199{
200	let event = fetch(event_id.into()).await;
201	let hydra_room_id = rules
202		.authorization
203		.room_create_event_id_as_room_id;
204
205	let mut create_event = None;
206	let mut power_levels_event = None;
207	if hydra_room_id && let Ok(event) = event.as_ref() {
208		let create_id = event.room_id().as_event_id()?;
209		let fetched = fetch(create_id).await?;
210
211		_ = create_event.insert(RoomCreateEvent::new(fetched));
212	}
213
214	for auth_event_id in event
215		.as_ref()
216		.map(Event::auth_events)
217		.into_iter()
218		.flatten()
219	{
220		use TimelineEventType::{RoomCreate, RoomPowerLevels};
221
222		let Ok(auth_event) = fetch(auth_event_id.to_owned()).await else {
223			continue;
224		};
225
226		if !hydra_room_id && auth_event.is_type_and_state_key(&RoomCreate, "") {
227			_ = create_event.get_or_insert_with(|| RoomCreateEvent::new(auth_event));
228		} else if auth_event.is_type_and_state_key(&RoomPowerLevels, "") {
229			_ = power_levels_event.get_or_insert_with(|| RoomPowerLevelsEvent::new(auth_event));
230		}
231
232		if power_levels_event.is_some() && create_event.is_some() {
233			break;
234		}
235	}
236
237	let creators = create_event
238		.as_ref()
239		.and_then(|event| event.creators(&rules.authorization).ok());
240
241	if let Some((event, creators)) = event.ok().zip(creators) {
242		power_levels_event.user_power_level(event.sender(), creators, &rules.authorization)
243	} else {
244		power_levels_event
245			.get_as_int_or_default(RoomPowerLevelsIntField::UsersDefault, &rules.authorization)
246			.map(Into::into)
247	}
248}
249
250/// Whether the given event ID belongs to a power event.
251///
252/// See the docs of `is_power_event()` for the definition of a power event.
253#[tracing::instrument(
254	name = "is_power_event",
255	level = "trace",
256	skip_all,
257	fields(
258		?event_id,
259	)
260)]
261async fn is_power_event_id<Fetch, Fut, Pdu>(event_id: &EventId, fetch: &Fetch) -> bool
262where
263	Fetch: Fn(OwnedEventId) -> Fut + Sync,
264	Fut: Future<Output = Result<Pdu>> + Send,
265	Pdu: Event,
266{
267	match fetch(event_id.to_owned()).await {
268		| Ok(state) => is_power_event(&state),
269		| _ => false,
270	}
271}