tuwunel_service/key_backups/
mod.rs1use std::{cmp::Ordering, collections::BTreeMap, sync::Arc};
2
3use futures::StreamExt;
4use ruma::{
5 OwnedRoomId, RoomId, UInt, UserId,
6 api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
7 serde::Raw,
8};
9use tuwunel_core::{
10 Err, Result, err, implement,
11 utils::stream::{ReadyExt, TryIgnore},
12};
13use tuwunel_database::{Deserialized, Ignore, Interfix, Json, Map};
14
15pub struct Service {
16 db: Data,
17 services: Arc<crate::services::OnceServices>,
18}
19
20struct Data {
21 backupid_algorithm: Arc<Map>,
22 backupid_etag: Arc<Map>,
23 backupkeyid_backup: Arc<Map>,
24}
25
26impl crate::Service for Service {
27 fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
28 Ok(Arc::new(Self {
29 db: Data {
30 backupid_algorithm: args.db["backupid_algorithm"].clone(),
31 backupid_etag: args.db["backupid_etag"].clone(),
32 backupkeyid_backup: args.db["backupkeyid_backup"].clone(),
33 },
34 services: args.services.clone(),
35 }))
36 }
37
38 fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
39}
40
41#[implement(Service)]
42pub fn create_backup(
43 &self,
44 user_id: &UserId,
45 backup_metadata: &Raw<BackupAlgorithm>,
46) -> Result<String> {
47 let version = self.services.globals.next_count();
48 let count = self.services.globals.next_count();
49
50 let version_string = version.to_string();
51 let key = (user_id, &version_string);
52 self.db
53 .backupid_algorithm
54 .put(key, Json(backup_metadata));
55
56 self.db.backupid_etag.put(key, *count);
57
58 Ok(version_string)
59}
60
61#[implement(Service)]
62pub async fn delete_backup(&self, user_id: &UserId, version: &str) {
63 let key = (user_id, version);
64 self.db.backupid_algorithm.del(key);
65 self.db.backupid_etag.del(key);
66
67 let key = (user_id, version, Interfix);
68 self.db
69 .backupkeyid_backup
70 .keys_prefix_raw(&key)
71 .ignore_err()
72 .ready_for_each(|outdated_key| {
73 self.db.backupkeyid_backup.remove(outdated_key);
74 })
75 .await;
76}
77
78#[implement(Service)]
79pub async fn update_backup<'a>(
80 &self,
81 user_id: &UserId,
82 version: &'a str,
83 backup_metadata: &Raw<BackupAlgorithm>,
84) -> Result<&'a str> {
85 let key = (user_id, version);
86 if self
87 .db
88 .backupid_algorithm
89 .qry(&key)
90 .await
91 .is_err()
92 {
93 return Err!(Request(NotFound("Tried to update nonexistent backup.")));
94 }
95
96 let count = self.services.globals.next_count();
97 self.db.backupid_etag.put(key, *count);
98 self.db
99 .backupid_algorithm
100 .put_raw(key, backup_metadata.json().get());
101
102 Ok(version)
103}
104
105#[implement(Service)]
106pub async fn get_latest_backup_version(&self, user_id: &UserId) -> Result<String> {
107 type Key<'a> = (&'a UserId, &'a str);
108
109 let key = (user_id, Interfix);
110 let mut versions: Vec<_> = self
111 .db
112 .backupid_algorithm
113 .keys_from(&key)
114 .ignore_err()
115 .ready_take_while(|(user_id_, _): &Key<'_>| *user_id_ == user_id)
116 .ready_filter_map(|(_, version): Key<'_>| version.parse::<u64>().ok())
117 .collect()
118 .await;
119
120 versions.sort_unstable();
121 let Some(latest) = versions.last() else {
122 return Err!(Request(NotFound("No backup versions found")));
123 };
124
125 Ok(latest.to_string())
126}
127
128#[implement(Service)]
129pub async fn get_latest_backup(
130 &self,
131 user_id: &UserId,
132) -> Result<(String, Raw<BackupAlgorithm>)> {
133 let version = self.get_latest_backup_version(user_id).await?;
134
135 let key = (user_id, version.as_str());
136 self.db
137 .backupid_algorithm
138 .qry(&key)
139 .await
140 .deserialized()
141 .map(|algorithm| (version, algorithm))
142 .map_err(|e| err!(Request(NotFound("No backup found: {e}"))))
143}
144
145#[implement(Service)]
146pub async fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Raw<BackupAlgorithm>> {
147 let key = (user_id, version);
148 self.db
149 .backupid_algorithm
150 .qry(&key)
151 .await
152 .deserialized()
153}
154
155#[implement(Service)]
156pub async fn add_key(
157 &self,
158 user_id: &UserId,
159 version: &str,
160 room_id: &RoomId,
161 session_id: &str,
162 key_data: &Raw<KeyBackupData>,
163) -> Result {
164 let key = (user_id, version);
165 if self
166 .db
167 .backupid_algorithm
168 .qry(&key)
169 .await
170 .is_err()
171 {
172 return Err!(Request(NotFound("Tried to update nonexistent backup.")));
173 }
174
175 let replace = match self
177 .get_session(user_id, version, room_id, session_id)
178 .await
179 {
180 | Ok(old_key) => is_better_key(&old_key, key_data)?,
181 | Err(_) => true,
182 };
183
184 if !replace {
185 return Ok(());
186 }
187
188 let count = self.services.globals.next_count();
189 self.db.backupid_etag.put(key, *count);
190
191 let key = (user_id, version, room_id, session_id);
192 self.db
193 .backupkeyid_backup
194 .put_raw(key, key_data.json().get());
195
196 Ok(())
197}
198
199fn is_better_key(old: &Raw<KeyBackupData>, new: &Raw<KeyBackupData>) -> Result<bool> {
202 let old_verified = old
203 .get_field::<bool>("is_verified")?
204 .unwrap_or_default();
205
206 let new_verified = new
207 .get_field::<bool>("is_verified")?
208 .ok_or_else(|| err!(Request(BadJson("`is_verified` field should exist"))))?;
209
210 if old_verified != new_verified {
211 return Ok(new_verified);
212 }
213
214 let old_first_message_index = old
215 .get_field::<UInt>("first_message_index")?
216 .unwrap_or(UInt::MAX);
217
218 let new_first_message_index = new
219 .get_field::<UInt>("first_message_index")?
220 .ok_or_else(|| err!(Request(BadJson("`first_message_index` field should exist"))))?;
221
222 match new_first_message_index.cmp(&old_first_message_index) {
223 | Ordering::Less => Ok(true),
224 | Ordering::Greater => Ok(false),
225 | Ordering::Equal => {
226 let old_forwarded_count = old
227 .get_field::<UInt>("forwarded_count")?
228 .unwrap_or(UInt::MAX);
229
230 let new_forwarded_count = new
231 .get_field::<UInt>("forwarded_count")?
232 .ok_or_else(|| err!(Request(BadJson("`forwarded_count` field should exist"))))?;
233
234 Ok(new_forwarded_count < old_forwarded_count)
235 },
236 }
237}
238
239#[implement(Service)]
240pub async fn count_keys(&self, user_id: &UserId, version: &str) -> usize {
241 let prefix = (user_id, version);
242 self.db
243 .backupkeyid_backup
244 .keys_prefix_raw(&prefix)
245 .count()
246 .await
247}
248
249#[implement(Service)]
250pub async fn get_etag(&self, user_id: &UserId, version: &str) -> String {
251 let key = (user_id, version);
252 self.db
253 .backupid_etag
254 .qry(&key)
255 .await
256 .deserialized::<u64>()
257 .as_ref()
258 .map(ToString::to_string)
259 .expect("Backup has no etag.")
260}
261
262#[implement(Service)]
263pub async fn get_all(
264 &self,
265 user_id: &UserId,
266 version: &str,
267) -> BTreeMap<OwnedRoomId, RoomKeyBackup> {
268 type Key<'a> = (Ignore, Ignore, &'a RoomId, &'a str);
269 type KeyVal<'a> = (Key<'a>, Raw<KeyBackupData>);
270
271 let mut rooms = BTreeMap::<OwnedRoomId, RoomKeyBackup>::new();
272 let default = || RoomKeyBackup { sessions: BTreeMap::new() };
273
274 let prefix = (user_id, version, Interfix);
275 self.db
276 .backupkeyid_backup
277 .stream_prefix(&prefix)
278 .ignore_err()
279 .ready_for_each(|((_, _, room_id, session_id), key_backup_data): KeyVal<'_>| {
280 rooms
281 .entry(room_id.into())
282 .or_insert_with(default)
283 .sessions
284 .insert(session_id.into(), key_backup_data);
285 })
286 .await;
287
288 rooms
289}
290
291#[implement(Service)]
292pub async fn get_room(
293 &self,
294 user_id: &UserId,
295 version: &str,
296 room_id: &RoomId,
297) -> BTreeMap<String, Raw<KeyBackupData>> {
298 type KeyVal<'a> = ((Ignore, Ignore, Ignore, &'a str), Raw<KeyBackupData>);
299
300 let prefix = (user_id, version, room_id, Interfix);
301 self.db
302 .backupkeyid_backup
303 .stream_prefix(&prefix)
304 .ignore_err()
305 .map(|((.., session_id), key_backup_data): KeyVal<'_>| {
306 (session_id.to_owned(), key_backup_data)
307 })
308 .collect()
309 .await
310}
311
312#[implement(Service)]
313pub async fn get_session(
314 &self,
315 user_id: &UserId,
316 version: &str,
317 room_id: &RoomId,
318 session_id: &str,
319) -> Result<Raw<KeyBackupData>> {
320 let key = (user_id, version, room_id, session_id);
321
322 self.db
323 .backupkeyid_backup
324 .qry(&key)
325 .await
326 .deserialized()
327}
328
329#[implement(Service)]
330pub async fn delete_all_keys(&self, user_id: &UserId, version: &str) {
331 let key = (user_id, version, Interfix);
332 self.db
333 .backupkeyid_backup
334 .keys_prefix_raw(&key)
335 .ignore_err()
336 .ready_for_each(|outdated_key| self.db.backupkeyid_backup.remove(outdated_key))
337 .await;
338}
339
340#[implement(Service)]
341pub async fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) {
342 let key = (user_id, version, room_id, Interfix);
343 self.db
344 .backupkeyid_backup
345 .keys_prefix_raw(&key)
346 .ignore_err()
347 .ready_for_each(|outdated_key| {
348 self.db.backupkeyid_backup.remove(outdated_key);
349 })
350 .await;
351}
352
353#[implement(Service)]
354pub async fn delete_room_key(
355 &self,
356 user_id: &UserId,
357 version: &str,
358 room_id: &RoomId,
359 session_id: &str,
360) {
361 let key = (user_id, version, room_id, session_id);
362 self.db
363 .backupkeyid_backup
364 .keys_prefix_raw(&key)
365 .ignore_err()
366 .ready_for_each(|outdated_key| {
367 self.db.backupkeyid_backup.remove(outdated_key);
368 })
369 .await;
370}