Skip to main content

tuwunel_service/key_backups/
mod.rs

1use std::{collections::BTreeMap, sync::Arc};
2
3use futures::StreamExt;
4use ruma::{
5	OwnedRoomId, RoomId, UserId,
6	api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
7	serde::Raw,
8};
9use tuwunel_core::{
10	Err, Result, err, implement,
11	utils::stream::{ReadyExt, TryIgnore},
12};
13use tuwunel_database::{Deserialized, Ignore, Interfix, Json, Map};
14
15pub struct Service {
16	db: Data,
17	services: Arc<crate::services::OnceServices>,
18}
19
20struct Data {
21	backupid_algorithm: Arc<Map>,
22	backupid_etag: Arc<Map>,
23	backupkeyid_backup: Arc<Map>,
24}
25
26impl crate::Service for Service {
27	fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
28		Ok(Arc::new(Self {
29			db: Data {
30				backupid_algorithm: args.db["backupid_algorithm"].clone(),
31				backupid_etag: args.db["backupid_etag"].clone(),
32				backupkeyid_backup: args.db["backupkeyid_backup"].clone(),
33			},
34			services: args.services.clone(),
35		}))
36	}
37
38	fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
39}
40
41#[implement(Service)]
42pub fn create_backup(
43	&self,
44	user_id: &UserId,
45	backup_metadata: &Raw<BackupAlgorithm>,
46) -> Result<String> {
47	let version = self.services.globals.next_count();
48	let count = self.services.globals.next_count();
49
50	let version_string = version.to_string();
51	let key = (user_id, &version_string);
52	self.db
53		.backupid_algorithm
54		.put(key, Json(backup_metadata));
55
56	self.db.backupid_etag.put(key, *count);
57
58	Ok(version_string)
59}
60
61#[implement(Service)]
62pub async fn delete_backup(&self, user_id: &UserId, version: &str) {
63	let key = (user_id, version);
64	self.db.backupid_algorithm.del(key);
65	self.db.backupid_etag.del(key);
66
67	let key = (user_id, version, Interfix);
68	self.db
69		.backupkeyid_backup
70		.keys_prefix_raw(&key)
71		.ignore_err()
72		.ready_for_each(|outdated_key| {
73			self.db.backupkeyid_backup.remove(outdated_key);
74		})
75		.await;
76}
77
78#[implement(Service)]
79pub async fn update_backup<'a>(
80	&self,
81	user_id: &UserId,
82	version: &'a str,
83	backup_metadata: &Raw<BackupAlgorithm>,
84) -> Result<&'a str> {
85	let key = (user_id, version);
86	if self
87		.db
88		.backupid_algorithm
89		.qry(&key)
90		.await
91		.is_err()
92	{
93		return Err!(Request(NotFound("Tried to update nonexistent backup.")));
94	}
95
96	let count = self.services.globals.next_count();
97	self.db.backupid_etag.put(key, *count);
98	self.db
99		.backupid_algorithm
100		.put_raw(key, backup_metadata.json().get());
101
102	Ok(version)
103}
104
105#[implement(Service)]
106pub async fn get_latest_backup_version(&self, user_id: &UserId) -> Result<String> {
107	type Key<'a> = (&'a UserId, &'a str);
108
109	let key = (user_id, Interfix);
110	let mut versions: Vec<_> = self
111		.db
112		.backupid_algorithm
113		.keys_from(&key)
114		.ignore_err()
115		.ready_take_while(|(user_id_, _): &Key<'_>| *user_id_ == user_id)
116		.ready_filter_map(|(_, version): Key<'_>| version.parse::<u64>().ok())
117		.collect()
118		.await;
119
120	versions.sort_unstable();
121	let Some(latest) = versions.last() else {
122		return Err!(Request(NotFound("No backup versions found")));
123	};
124
125	Ok(latest.to_string())
126}
127
128#[implement(Service)]
129pub async fn get_latest_backup(
130	&self,
131	user_id: &UserId,
132) -> Result<(String, Raw<BackupAlgorithm>)> {
133	let version = self.get_latest_backup_version(user_id).await?;
134
135	let key = (user_id, version.as_str());
136	self.db
137		.backupid_algorithm
138		.qry(&key)
139		.await
140		.deserialized()
141		.map(|algorithm| (version, algorithm))
142		.map_err(|e| err!(Request(NotFound("No backup found: {e}"))))
143}
144
145#[implement(Service)]
146pub async fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Raw<BackupAlgorithm>> {
147	let key = (user_id, version);
148	self.db
149		.backupid_algorithm
150		.qry(&key)
151		.await
152		.deserialized()
153}
154
155#[implement(Service)]
156pub async fn add_key(
157	&self,
158	user_id: &UserId,
159	version: &str,
160	room_id: &RoomId,
161	session_id: &str,
162	key_data: &Raw<KeyBackupData>,
163) -> Result {
164	let key = (user_id, version);
165	if self
166		.db
167		.backupid_algorithm
168		.qry(&key)
169		.await
170		.is_err()
171	{
172		return Err!(Request(NotFound("Tried to update nonexistent backup.")));
173	}
174
175	let count = self.services.globals.next_count();
176	self.db.backupid_etag.put(key, *count);
177
178	let key = (user_id, version, room_id, session_id);
179	self.db
180		.backupkeyid_backup
181		.put_raw(key, key_data.json().get());
182
183	Ok(())
184}
185
186#[implement(Service)]
187pub async fn count_keys(&self, user_id: &UserId, version: &str) -> usize {
188	let prefix = (user_id, version);
189	self.db
190		.backupkeyid_backup
191		.keys_prefix_raw(&prefix)
192		.count()
193		.await
194}
195
196#[implement(Service)]
197pub async fn get_etag(&self, user_id: &UserId, version: &str) -> String {
198	let key = (user_id, version);
199	self.db
200		.backupid_etag
201		.qry(&key)
202		.await
203		.deserialized::<u64>()
204		.as_ref()
205		.map(ToString::to_string)
206		.expect("Backup has no etag.")
207}
208
209#[implement(Service)]
210pub async fn get_all(
211	&self,
212	user_id: &UserId,
213	version: &str,
214) -> BTreeMap<OwnedRoomId, RoomKeyBackup> {
215	type Key<'a> = (Ignore, Ignore, &'a RoomId, &'a str);
216	type KeyVal<'a> = (Key<'a>, Raw<KeyBackupData>);
217
218	let mut rooms = BTreeMap::<OwnedRoomId, RoomKeyBackup>::new();
219	let default = || RoomKeyBackup { sessions: BTreeMap::new() };
220
221	let prefix = (user_id, version, Interfix);
222	self.db
223		.backupkeyid_backup
224		.stream_prefix(&prefix)
225		.ignore_err()
226		.ready_for_each(|((_, _, room_id, session_id), key_backup_data): KeyVal<'_>| {
227			rooms
228				.entry(room_id.into())
229				.or_insert_with(default)
230				.sessions
231				.insert(session_id.into(), key_backup_data);
232		})
233		.await;
234
235	rooms
236}
237
238#[implement(Service)]
239pub async fn get_room(
240	&self,
241	user_id: &UserId,
242	version: &str,
243	room_id: &RoomId,
244) -> BTreeMap<String, Raw<KeyBackupData>> {
245	type KeyVal<'a> = ((Ignore, Ignore, Ignore, &'a str), Raw<KeyBackupData>);
246
247	let prefix = (user_id, version, room_id, Interfix);
248	self.db
249		.backupkeyid_backup
250		.stream_prefix(&prefix)
251		.ignore_err()
252		.map(|((.., session_id), key_backup_data): KeyVal<'_>| {
253			(session_id.to_owned(), key_backup_data)
254		})
255		.collect()
256		.await
257}
258
259#[implement(Service)]
260pub async fn get_session(
261	&self,
262	user_id: &UserId,
263	version: &str,
264	room_id: &RoomId,
265	session_id: &str,
266) -> Result<Raw<KeyBackupData>> {
267	let key = (user_id, version, room_id, session_id);
268
269	self.db
270		.backupkeyid_backup
271		.qry(&key)
272		.await
273		.deserialized()
274}
275
276#[implement(Service)]
277pub async fn delete_all_keys(&self, user_id: &UserId, version: &str) {
278	let key = (user_id, version, Interfix);
279	self.db
280		.backupkeyid_backup
281		.keys_prefix_raw(&key)
282		.ignore_err()
283		.ready_for_each(|outdated_key| self.db.backupkeyid_backup.remove(outdated_key))
284		.await;
285}
286
287#[implement(Service)]
288pub async fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) {
289	let key = (user_id, version, room_id, Interfix);
290	self.db
291		.backupkeyid_backup
292		.keys_prefix_raw(&key)
293		.ignore_err()
294		.ready_for_each(|outdated_key| {
295			self.db.backupkeyid_backup.remove(outdated_key);
296		})
297		.await;
298}
299
300#[implement(Service)]
301pub async fn delete_room_key(
302	&self,
303	user_id: &UserId,
304	version: &str,
305	room_id: &RoomId,
306	session_id: &str,
307) {
308	let key = (user_id, version, room_id, session_id);
309	self.db
310		.backupkeyid_backup
311		.keys_prefix_raw(&key)
312		.ignore_err()
313		.ready_for_each(|outdated_key| {
314			self.db.backupkeyid_backup.remove(outdated_key);
315		})
316		.await;
317}