tuwunel_api/client/sync/v5/extensions/
e2ee.rs1use std::collections::HashSet;
2
3use futures::{
4 FutureExt, StreamExt, TryFutureExt,
5 future::{join, join3},
6 stream::once,
7};
8use ruma::{
9 OwnedUserId, RoomId,
10 api::client::sync::sync_events::{DeviceLists, v5::response},
11 events::{
12 StateEventType, TimelineEventType,
13 room::member::{MembershipState, RoomMemberEventContent},
14 },
15};
16use tuwunel_core::{
17 Result, error,
18 matrix::{Event, pdu::PduCount},
19 pair_of,
20 utils::{
21 BoolExt, IterStream, ReadyExt, TryFutureExtExt, future::OptionStream,
22 stream::BroadbandExt,
23 },
24};
25use tuwunel_service::sync::Connection;
26
27use super::{SyncInfo, share_encrypted_room};
28
29#[tracing::instrument(name = "e2ee", level = "trace", skip_all)]
30pub(super) async fn collect(
31 sync_info: SyncInfo<'_>,
32 conn: &Connection,
33) -> Result<response::E2EE> {
34 let SyncInfo { services, sender_user, sender_device, .. } = sync_info;
35 let Some(sender_device) = sender_device else {
36 return Ok(response::E2EE::default());
37 };
38
39 let keys_changed = services
40 .users
41 .keys_changed(sender_user, conn.globalsince, Some(conn.next_batch))
42 .map(ToOwned::to_owned)
43 .collect::<HashSet<_>>()
44 .map(|changed| (changed, HashSet::new()));
45
46 let (changed, left) = (HashSet::new(), HashSet::new());
47 let (changed, left) = services
48 .state_cache
49 .rooms_joined(sender_user)
50 .map(ToOwned::to_owned)
51 .broad_filter_map(async |room_id| collect_room(sync_info, conn, &room_id).await.ok())
52 .chain(once(keys_changed))
53 .ready_fold((changed, left), |(mut changed, mut left), room| {
54 changed.extend(room.0);
55 left.extend(room.1);
56 (changed, left)
57 })
58 .await;
59
60 let left = left
61 .into_iter()
62 .stream()
63 .filter_map(async |user_id| {
64 share_encrypted_room(services, sender_user, &user_id, None)
65 .await
66 .is_false()
67 .then_some(user_id)
68 })
69 .collect();
70
71 let device_one_time_keys_count = services
72 .users
73 .last_one_time_keys_update(sender_user)
74 .then(|since| {
75 since.gt(&conn.globalsince).then_async(|| {
76 services
77 .users
78 .count_one_time_keys(sender_user, sender_device)
79 })
80 })
81 .map(Option::unwrap_or_default);
82
83 let device_unused_fallback_key_types = services
84 .users
85 .unused_fallback_key_algorithms(sender_user, sender_device)
86 .collect::<Vec<_>>()
87 .map(Some);
88
89 let (left, device_one_time_keys_count, device_unused_fallback_key_types) =
90 join3(left, device_one_time_keys_count, device_unused_fallback_key_types)
91 .boxed()
92 .await;
93
94 Ok(response::E2EE {
95 device_one_time_keys_count,
96 device_unused_fallback_key_types,
97 device_lists: DeviceLists {
98 changed: changed.into_iter().collect(),
99 left,
100 },
101 })
102}
103
104#[tracing::instrument(level = "trace", skip_all, fields(room_id), ret)]
105async fn collect_room(
106 SyncInfo { services, sender_user, .. }: SyncInfo<'_>,
107 conn: &Connection,
108 room_id: &RoomId,
109) -> Result<pair_of!(HashSet<OwnedUserId>)> {
110 let current_shortstatehash = services
111 .state
112 .get_room_shortstatehash(room_id)
113 .inspect_err(|e| error!("Room {room_id} has no state: {e}"));
114
115 let room_keys_changed = services
116 .users
117 .room_keys_changed(room_id, conn.globalsince, Some(conn.next_batch))
118 .map(|(user_id, _)| user_id)
119 .map(ToOwned::to_owned)
120 .collect::<HashSet<_>>();
121
122 let (current_shortstatehash, device_list_changed) =
123 join(current_shortstatehash, room_keys_changed)
124 .boxed()
125 .await;
126
127 let lists = (device_list_changed, HashSet::new());
128 let Ok(current_shortstatehash) = current_shortstatehash else {
129 return Ok(lists);
130 };
131
132 if current_shortstatehash <= conn.globalsince {
133 return Ok(lists);
134 }
135
136 let Ok(since_shortstatehash) = services
137 .timeline
138 .prev_shortstatehash(room_id, PduCount::Normal(conn.globalsince).saturating_add(1))
139 .await
140 else {
141 return Ok(lists);
142 };
143
144 if since_shortstatehash == current_shortstatehash {
145 return Ok(lists);
146 }
147
148 let encrypted_room = services
149 .state_accessor
150 .state_get(current_shortstatehash, &StateEventType::RoomEncryption, "")
151 .is_ok();
152
153 let since_encryption = services
154 .state_accessor
155 .state_get(since_shortstatehash, &StateEventType::RoomEncryption, "")
156 .is_ok();
157
158 let sender_joined_count = services
159 .state_cache
160 .get_joined_count(room_id, sender_user);
161
162 let (encrypted_room, since_encryption, sender_joined_count) =
163 join3(encrypted_room, since_encryption, sender_joined_count).await;
164
165 if !encrypted_room
166 && services
167 .config
168 .device_key_update_encrypted_rooms_only
169 {
170 return Ok(lists);
171 }
172
173 let encrypted_since_last_sync = !since_encryption;
174 let joined_since_last_sync = sender_joined_count.is_ok_and(|count| count > conn.globalsince);
175 let joined_members_burst =
176 (joined_since_last_sync || encrypted_since_last_sync).then_async(|| {
177 services
178 .state_cache
179 .room_members(room_id)
180 .ready_filter(|&user_id| user_id != sender_user)
181 .map(ToOwned::to_owned)
182 .map(|user_id| (MembershipState::Join, user_id))
183 .boxed()
184 .into_future()
185 });
186
187 services
188 .state_accessor
189 .state_added((since_shortstatehash, current_shortstatehash))
190 .broad_filter_map(async |(_shortstatekey, shorteventid)| {
191 services
192 .timeline
193 .get_pdu_from_shorteventid(shorteventid)
194 .ok()
195 .await
196 })
197 .ready_filter(|event| *event.kind() == TimelineEventType::RoomMember)
198 .ready_filter(|event| {
199 event
200 .state_key()
201 .is_some_and(|state_key| state_key != sender_user)
202 })
203 .ready_filter_map(|event| {
204 let content: RoomMemberEventContent = event.get_content().ok()?;
205 let user_id: OwnedUserId = event.state_key()?.parse().ok()?;
206
207 Some((content.membership, user_id))
208 })
209 .chain(joined_members_burst.stream())
210 .fold(lists, async |(mut changed, mut left), (membership, user_id)| {
211 use MembershipState::*;
212
213 let should_add = async |user_id| {
214 !share_encrypted_room(services, sender_user, user_id, Some(room_id)).await
215 };
216
217 match membership {
218 | Join if should_add(&user_id).await => changed.insert(user_id),
219 | Leave => left.insert(user_id),
220 | _ => false,
221 };
222
223 (changed, left)
224 })
225 .map(Ok)
226 .boxed()
227 .await
228}