Skip to main content

tuwunel_api/client/sync/v5/extensions/
e2ee.rs

1use std::collections::HashSet;
2
3use futures::{
4	FutureExt, StreamExt, TryFutureExt,
5	future::{join, join3},
6	stream::once,
7};
8use ruma::{
9	OwnedUserId, RoomId,
10	api::client::sync::sync_events::{DeviceLists, v5::response},
11	events::{
12		StateEventType, TimelineEventType,
13		room::member::{MembershipState, RoomMemberEventContent},
14	},
15};
16use tuwunel_core::{
17	Result, error,
18	matrix::{Event, pdu::PduCount},
19	pair_of,
20	utils::{
21		BoolExt, IterStream, ReadyExt, TryFutureExtExt, future::OptionStream,
22		stream::BroadbandExt,
23	},
24};
25use tuwunel_service::sync::Connection;
26
27use super::{SyncInfo, share_encrypted_room};
28
29#[tracing::instrument(name = "e2ee", level = "trace", skip_all)]
30pub(super) async fn collect(
31	sync_info: SyncInfo<'_>,
32	conn: &Connection,
33) -> Result<response::E2EE> {
34	let SyncInfo { services, sender_user, sender_device, .. } = sync_info;
35	let Some(sender_device) = sender_device else {
36		return Ok(response::E2EE::default());
37	};
38
39	let keys_changed = services
40		.users
41		.keys_changed(sender_user, conn.globalsince, Some(conn.next_batch))
42		.map(ToOwned::to_owned)
43		.collect::<HashSet<_>>()
44		.map(|changed| (changed, HashSet::new()));
45
46	let (changed, left) = (HashSet::new(), HashSet::new());
47	let (changed, left) = services
48		.state_cache
49		.rooms_joined(sender_user)
50		.map(ToOwned::to_owned)
51		.broad_filter_map(async |room_id| collect_room(sync_info, conn, &room_id).await.ok())
52		.chain(once(keys_changed))
53		.ready_fold((changed, left), |(mut changed, mut left), room| {
54			changed.extend(room.0);
55			left.extend(room.1);
56			(changed, left)
57		})
58		.await;
59
60	let left = left
61		.into_iter()
62		.stream()
63		.filter_map(async |user_id| {
64			share_encrypted_room(services, sender_user, &user_id, None)
65				.await
66				.is_false()
67				.then_some(user_id)
68		})
69		.collect();
70
71	let device_one_time_keys_count = services
72		.users
73		.last_one_time_keys_update(sender_user)
74		.then(|since| {
75			since.gt(&conn.globalsince).then_async(|| {
76				services
77					.users
78					.count_one_time_keys(sender_user, sender_device)
79			})
80		})
81		.map(Option::unwrap_or_default);
82
83	let device_unused_fallback_key_types = services
84		.users
85		.unused_fallback_key_algorithms(sender_user, sender_device)
86		.collect::<Vec<_>>()
87		.map(Some);
88
89	let (left, device_one_time_keys_count, device_unused_fallback_key_types) =
90		join3(left, device_one_time_keys_count, device_unused_fallback_key_types)
91			.boxed()
92			.await;
93
94	Ok(response::E2EE {
95		device_one_time_keys_count,
96		device_unused_fallback_key_types,
97		device_lists: DeviceLists {
98			changed: changed.into_iter().collect(),
99			left,
100		},
101	})
102}
103
104#[tracing::instrument(level = "trace", skip_all, fields(room_id), ret)]
105async fn collect_room(
106	SyncInfo { services, sender_user, .. }: SyncInfo<'_>,
107	conn: &Connection,
108	room_id: &RoomId,
109) -> Result<pair_of!(HashSet<OwnedUserId>)> {
110	let current_shortstatehash = services
111		.state
112		.get_room_shortstatehash(room_id)
113		.inspect_err(|e| error!("Room {room_id} has no state: {e}"));
114
115	let room_keys_changed = services
116		.users
117		.room_keys_changed(room_id, conn.globalsince, Some(conn.next_batch))
118		.map(|(user_id, _)| user_id)
119		.map(ToOwned::to_owned)
120		.collect::<HashSet<_>>();
121
122	let (current_shortstatehash, device_list_changed) =
123		join(current_shortstatehash, room_keys_changed)
124			.boxed()
125			.await;
126
127	let lists = (device_list_changed, HashSet::new());
128	let Ok(current_shortstatehash) = current_shortstatehash else {
129		return Ok(lists);
130	};
131
132	if current_shortstatehash <= conn.globalsince {
133		return Ok(lists);
134	}
135
136	let Ok(since_shortstatehash) = services
137		.timeline
138		.prev_shortstatehash(room_id, PduCount::Normal(conn.globalsince).saturating_add(1))
139		.await
140	else {
141		return Ok(lists);
142	};
143
144	if since_shortstatehash == current_shortstatehash {
145		return Ok(lists);
146	}
147
148	let encrypted_room = services
149		.state_accessor
150		.state_get(current_shortstatehash, &StateEventType::RoomEncryption, "")
151		.is_ok();
152
153	let since_encryption = services
154		.state_accessor
155		.state_get(since_shortstatehash, &StateEventType::RoomEncryption, "")
156		.is_ok();
157
158	let sender_joined_count = services
159		.state_cache
160		.get_joined_count(room_id, sender_user);
161
162	let (encrypted_room, since_encryption, sender_joined_count) =
163		join3(encrypted_room, since_encryption, sender_joined_count).await;
164
165	if !encrypted_room
166		&& services
167			.config
168			.device_key_update_encrypted_rooms_only
169	{
170		return Ok(lists);
171	}
172
173	let encrypted_since_last_sync = !since_encryption;
174	let joined_since_last_sync = sender_joined_count.is_ok_and(|count| count > conn.globalsince);
175	let joined_members_burst =
176		(joined_since_last_sync || encrypted_since_last_sync).then_async(|| {
177			services
178				.state_cache
179				.room_members(room_id)
180				.ready_filter(|&user_id| user_id != sender_user)
181				.map(ToOwned::to_owned)
182				.map(|user_id| (MembershipState::Join, user_id))
183				.boxed()
184				.into_future()
185		});
186
187	services
188		.state_accessor
189		.state_added((since_shortstatehash, current_shortstatehash))
190		.broad_filter_map(async |(_shortstatekey, shorteventid)| {
191			services
192				.timeline
193				.get_pdu_from_shorteventid(shorteventid)
194				.ok()
195				.await
196		})
197		.ready_filter(|event| *event.kind() == TimelineEventType::RoomMember)
198		.ready_filter(|event| {
199			event
200				.state_key()
201				.is_some_and(|state_key| state_key != sender_user)
202		})
203		.ready_filter_map(|event| {
204			let content: RoomMemberEventContent = event.get_content().ok()?;
205			let user_id: OwnedUserId = event.state_key()?.parse().ok()?;
206
207			Some((content.membership, user_id))
208		})
209		.chain(joined_members_burst.stream())
210		.fold(lists, async |(mut changed, mut left), (membership, user_id)| {
211			use MembershipState::*;
212
213			let should_add = async |user_id| {
214				!share_encrypted_room(services, sender_user, user_id, Some(room_id)).await
215			};
216
217			match membership {
218				| Join if should_add(&user_id).await => changed.insert(user_id),
219				| Leave => left.insert(user_id),
220				| _ => false,
221			};
222
223			(changed, left)
224		})
225		.map(Ok)
226		.boxed()
227		.await
228}