Skip to main content

tuwunel_service/key_backups/
mod.rs

1use std::{cmp::Ordering, collections::BTreeMap, sync::Arc};
2
3use futures::StreamExt;
4use ruma::{
5	OwnedRoomId, RoomId, UInt, UserId,
6	api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
7	serde::Raw,
8};
9use tuwunel_core::{
10	Err, Result, err, implement,
11	utils::stream::{ReadyExt, TryIgnore},
12};
13use tuwunel_database::{Deserialized, Ignore, Interfix, Json, Map};
14
15pub struct Service {
16	db: Data,
17	services: Arc<crate::services::OnceServices>,
18}
19
20struct Data {
21	backupid_algorithm: Arc<Map>,
22	backupid_etag: Arc<Map>,
23	backupkeyid_backup: Arc<Map>,
24}
25
26impl crate::Service for Service {
27	fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
28		Ok(Arc::new(Self {
29			db: Data {
30				backupid_algorithm: args.db["backupid_algorithm"].clone(),
31				backupid_etag: args.db["backupid_etag"].clone(),
32				backupkeyid_backup: args.db["backupkeyid_backup"].clone(),
33			},
34			services: args.services.clone(),
35		}))
36	}
37
38	fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
39}
40
41#[implement(Service)]
42pub fn create_backup(
43	&self,
44	user_id: &UserId,
45	backup_metadata: &Raw<BackupAlgorithm>,
46) -> Result<String> {
47	let version = self.services.globals.next_count();
48	let count = self.services.globals.next_count();
49
50	let version_string = version.to_string();
51	let key = (user_id, &version_string);
52	self.db
53		.backupid_algorithm
54		.put(key, Json(backup_metadata));
55
56	self.db.backupid_etag.put(key, *count);
57
58	Ok(version_string)
59}
60
61#[implement(Service)]
62pub async fn delete_backup(&self, user_id: &UserId, version: &str) {
63	let key = (user_id, version);
64	self.db.backupid_algorithm.del(key);
65	self.db.backupid_etag.del(key);
66
67	let key = (user_id, version, Interfix);
68	self.db
69		.backupkeyid_backup
70		.keys_prefix_raw(&key)
71		.ignore_err()
72		.ready_for_each(|outdated_key| {
73			self.db.backupkeyid_backup.remove(outdated_key);
74		})
75		.await;
76}
77
78#[implement(Service)]
79pub async fn update_backup<'a>(
80	&self,
81	user_id: &UserId,
82	version: &'a str,
83	backup_metadata: &Raw<BackupAlgorithm>,
84) -> Result<&'a str> {
85	let key = (user_id, version);
86	if self
87		.db
88		.backupid_algorithm
89		.qry(&key)
90		.await
91		.is_err()
92	{
93		return Err!(Request(NotFound("Tried to update nonexistent backup.")));
94	}
95
96	let count = self.services.globals.next_count();
97	self.db.backupid_etag.put(key, *count);
98	self.db
99		.backupid_algorithm
100		.put_raw(key, backup_metadata.json().get());
101
102	Ok(version)
103}
104
105#[implement(Service)]
106pub async fn get_latest_backup_version(&self, user_id: &UserId) -> Result<String> {
107	type Key<'a> = (&'a UserId, &'a str);
108
109	let key = (user_id, Interfix);
110	let mut versions: Vec<_> = self
111		.db
112		.backupid_algorithm
113		.keys_from(&key)
114		.ignore_err()
115		.ready_take_while(|(user_id_, _): &Key<'_>| *user_id_ == user_id)
116		.ready_filter_map(|(_, version): Key<'_>| version.parse::<u64>().ok())
117		.collect()
118		.await;
119
120	versions.sort_unstable();
121	let Some(latest) = versions.last() else {
122		return Err!(Request(NotFound("No backup versions found")));
123	};
124
125	Ok(latest.to_string())
126}
127
128#[implement(Service)]
129pub async fn get_latest_backup(
130	&self,
131	user_id: &UserId,
132) -> Result<(String, Raw<BackupAlgorithm>)> {
133	let version = self.get_latest_backup_version(user_id).await?;
134
135	let key = (user_id, version.as_str());
136	self.db
137		.backupid_algorithm
138		.qry(&key)
139		.await
140		.deserialized()
141		.map(|algorithm| (version, algorithm))
142		.map_err(|e| err!(Request(NotFound("No backup found: {e}"))))
143}
144
145#[implement(Service)]
146pub async fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Raw<BackupAlgorithm>> {
147	let key = (user_id, version);
148	self.db
149		.backupid_algorithm
150		.qry(&key)
151		.await
152		.deserialized()
153}
154
155#[implement(Service)]
156pub async fn add_key(
157	&self,
158	user_id: &UserId,
159	version: &str,
160	room_id: &RoomId,
161	session_id: &str,
162	key_data: &Raw<KeyBackupData>,
163) -> Result {
164	let key = (user_id, version);
165	if self
166		.db
167		.backupid_algorithm
168		.qry(&key)
169		.await
170		.is_err()
171	{
172		return Err!(Request(NotFound("Tried to update nonexistent backup.")));
173	}
174
175	// Keep the existing key unless the incoming one is preferable per MSC1219.
176	let replace = match self
177		.get_session(user_id, version, room_id, session_id)
178		.await
179	{
180		| Ok(old_key) => is_better_key(&old_key, key_data)?,
181		| Err(_) => true,
182	};
183
184	if !replace {
185		return Ok(());
186	}
187
188	let count = self.services.globals.next_count();
189	self.db.backupid_etag.put(key, *count);
190
191	let key = (user_id, version, room_id, session_id);
192	self.db
193		.backupkeyid_backup
194		.put_raw(key, key_data.json().get());
195
196	Ok(())
197}
198
199/// Per MSC1219: prefer verified, then lower `first_message_index`, then lower
200/// `forwarded_count`; equal on all three keeps the existing key.
201fn is_better_key(old: &Raw<KeyBackupData>, new: &Raw<KeyBackupData>) -> Result<bool> {
202	let old_verified = old
203		.get_field::<bool>("is_verified")?
204		.unwrap_or_default();
205
206	let new_verified = new
207		.get_field::<bool>("is_verified")?
208		.ok_or_else(|| err!(Request(BadJson("`is_verified` field should exist"))))?;
209
210	if old_verified != new_verified {
211		return Ok(new_verified);
212	}
213
214	let old_first_message_index = old
215		.get_field::<UInt>("first_message_index")?
216		.unwrap_or(UInt::MAX);
217
218	let new_first_message_index = new
219		.get_field::<UInt>("first_message_index")?
220		.ok_or_else(|| err!(Request(BadJson("`first_message_index` field should exist"))))?;
221
222	match new_first_message_index.cmp(&old_first_message_index) {
223		| Ordering::Less => Ok(true),
224		| Ordering::Greater => Ok(false),
225		| Ordering::Equal => {
226			let old_forwarded_count = old
227				.get_field::<UInt>("forwarded_count")?
228				.unwrap_or(UInt::MAX);
229
230			let new_forwarded_count = new
231				.get_field::<UInt>("forwarded_count")?
232				.ok_or_else(|| err!(Request(BadJson("`forwarded_count` field should exist"))))?;
233
234			Ok(new_forwarded_count < old_forwarded_count)
235		},
236	}
237}
238
239#[implement(Service)]
240pub async fn count_keys(&self, user_id: &UserId, version: &str) -> usize {
241	let prefix = (user_id, version);
242	self.db
243		.backupkeyid_backup
244		.keys_prefix_raw(&prefix)
245		.count()
246		.await
247}
248
249#[implement(Service)]
250pub async fn get_etag(&self, user_id: &UserId, version: &str) -> String {
251	let key = (user_id, version);
252	self.db
253		.backupid_etag
254		.qry(&key)
255		.await
256		.deserialized::<u64>()
257		.as_ref()
258		.map(ToString::to_string)
259		.expect("Backup has no etag.")
260}
261
262#[implement(Service)]
263pub async fn get_all(
264	&self,
265	user_id: &UserId,
266	version: &str,
267) -> BTreeMap<OwnedRoomId, RoomKeyBackup> {
268	type Key<'a> = (Ignore, Ignore, &'a RoomId, &'a str);
269	type KeyVal<'a> = (Key<'a>, Raw<KeyBackupData>);
270
271	let mut rooms = BTreeMap::<OwnedRoomId, RoomKeyBackup>::new();
272	let default = || RoomKeyBackup { sessions: BTreeMap::new() };
273
274	let prefix = (user_id, version, Interfix);
275	self.db
276		.backupkeyid_backup
277		.stream_prefix(&prefix)
278		.ignore_err()
279		.ready_for_each(|((_, _, room_id, session_id), key_backup_data): KeyVal<'_>| {
280			rooms
281				.entry(room_id.into())
282				.or_insert_with(default)
283				.sessions
284				.insert(session_id.into(), key_backup_data);
285		})
286		.await;
287
288	rooms
289}
290
291#[implement(Service)]
292pub async fn get_room(
293	&self,
294	user_id: &UserId,
295	version: &str,
296	room_id: &RoomId,
297) -> BTreeMap<String, Raw<KeyBackupData>> {
298	type KeyVal<'a> = ((Ignore, Ignore, Ignore, &'a str), Raw<KeyBackupData>);
299
300	let prefix = (user_id, version, room_id, Interfix);
301	self.db
302		.backupkeyid_backup
303		.stream_prefix(&prefix)
304		.ignore_err()
305		.map(|((.., session_id), key_backup_data): KeyVal<'_>| {
306			(session_id.to_owned(), key_backup_data)
307		})
308		.collect()
309		.await
310}
311
312#[implement(Service)]
313pub async fn get_session(
314	&self,
315	user_id: &UserId,
316	version: &str,
317	room_id: &RoomId,
318	session_id: &str,
319) -> Result<Raw<KeyBackupData>> {
320	let key = (user_id, version, room_id, session_id);
321
322	self.db
323		.backupkeyid_backup
324		.qry(&key)
325		.await
326		.deserialized()
327}
328
329#[implement(Service)]
330pub async fn delete_all_keys(&self, user_id: &UserId, version: &str) {
331	let key = (user_id, version, Interfix);
332	self.db
333		.backupkeyid_backup
334		.keys_prefix_raw(&key)
335		.ignore_err()
336		.ready_for_each(|outdated_key| self.db.backupkeyid_backup.remove(outdated_key))
337		.await;
338}
339
340#[implement(Service)]
341pub async fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) {
342	let key = (user_id, version, room_id, Interfix);
343	self.db
344		.backupkeyid_backup
345		.keys_prefix_raw(&key)
346		.ignore_err()
347		.ready_for_each(|outdated_key| {
348			self.db.backupkeyid_backup.remove(outdated_key);
349		})
350		.await;
351}
352
353#[implement(Service)]
354pub async fn delete_room_key(
355	&self,
356	user_id: &UserId,
357	version: &str,
358	room_id: &RoomId,
359	session_id: &str,
360) {
361	let key = (user_id, version, room_id, session_id);
362	self.db
363		.backupkeyid_backup
364		.keys_prefix_raw(&key)
365		.ignore_err()
366		.ready_for_each(|outdated_key| {
367			self.db.backupkeyid_backup.remove(outdated_key);
368		})
369		.await;
370}