Skip to main content

tuwunel_service/
migrations.rs

1use std::cmp;
2
3use futures::{FutureExt, StreamExt};
4use ruma::{MxcUri, OwnedUserId, RoomId, UserId, events::room::member::MembershipState};
5use tuwunel_core::{
6	Err, Result, debug, debug_info, debug_warn, err, info,
7	itertools::Itertools,
8	matrix::PduCount,
9	result::NotFound,
10	utils,
11	utils::{
12		BoolExt, IterStream, ReadyExt,
13		stream::{TryExpect, TryIgnore},
14	},
15	warn,
16};
17use tuwunel_database::{Deserialized, SEP};
18
19use crate::{Services, media};
20
21/// The current schema version.
22/// - If database is opened at greater version we reject with error. The
23///   software must be updated for backward-incompatible changes.
24/// - If database is opened at lesser version we apply migrations up to this.
25///   Note that named-feature migrations may also be performed when opening at
26///   equal or lesser version. These are expected to be backward-compatible.
27pub(crate) const DATABASE_VERSION: u64 = 17;
28
29const SERVER_NAME_KEY: &[u8] = b"server_name";
30
31pub(crate) async fn migrations(services: &Services) -> Result {
32	if !services.config.database_migrations {
33		warn!("Skipping database migrations due to configuration...");
34		return Ok(());
35	}
36
37	let users_count = services.users.count().await;
38	if users_count == 0 {
39		return fresh(services).await;
40	}
41
42	check_server_name(services).await?;
43	migrate(services).await
44}
45
46/// Matrix resource ownership is based on the server name; changing it
47/// requires recreating the database from scratch. The marker is stamped
48/// once in fresh(); pre-marker databases are backfilled by probing for
49/// any user from the configured server.
50async fn check_server_name(services: &Services) -> Result {
51	let server_name = &services.server.name;
52
53	let existing = services.db["global"]
54		.get(SERVER_NAME_KEY)
55		.await
56		.deserialized::<String>();
57
58	match existing {
59		| Err(_) => backfill_server_name(services).await,
60		| Ok(existing) if existing.eq(server_name) => Ok(()),
61		| Ok(existing) => Err!(Database(
62			"Database belongs to {existing}; configured server name is {server_name}. Cannot \
63			 reuse."
64		)),
65	}
66}
67
68/// Stamp the marker on a database that pre-dates SERVER_NAME_KEY by probing
69/// for any user from the configured server. If none, the database belongs
70/// to a different server and reuse is refused.
71async fn backfill_server_name(services: &Services) -> Result {
72	let server_name = &services.server.name;
73
74	services
75		.users
76		.stream()
77		.ready_any(|user_id| services.globals.user_is_local(user_id))
78		.await
79		.ok_or_else(|| {
80			err!(Database(
81				"Database has no users from {server_name}; refusing to reuse with this \
82				 server_name."
83			))
84		})?;
85
86	services.db["global"].insert(SERVER_NAME_KEY, server_name.as_str());
87	info!(%server_name, "Stamped server_name marker on upgraded database");
88
89	Ok(())
90}
91
92async fn fresh(services: &Services) -> Result {
93	let db = &services.db;
94
95	services
96		.globals
97		.db
98		.bump_database_version(DATABASE_VERSION);
99
100	db["global"].insert(SERVER_NAME_KEY, services.server.name.as_str());
101	db["global"].insert(b"feat_sha256_media", []);
102	db["global"].insert(b"fix_bad_double_separator_in_state_cache", []);
103	db["global"].insert(b"retroactively_fix_bad_data_from_roomuserid_joined", []);
104	db["global"].insert(b"fix_referencedevents_missing_sep", []);
105	db["global"].insert(b"fix_readreceiptid_readreceipt_duplicates", []);
106	db["global"].insert(b"fix_hashed_sentinel_passwords", []);
107	db["global"].insert(b"upgrade_legacy_mediaid_user", []);
108	db["global"].insert(b"remove_remote_media_userid", []);
109
110	// Create the admin room and server user on first run
111	if services.config.create_admin_room {
112		crate::admin::create_admin_room(services)
113			.boxed()
114			.await?;
115	}
116
117	warn!("Created new RocksDB database with version {DATABASE_VERSION}");
118
119	Ok(())
120}
121
122/// Apply any migrations
123async fn migrate(services: &Services) -> Result {
124	let db = &services.db;
125	let config = &services.server.config;
126
127	let target_version = DATABASE_VERSION;
128	let discovered_version = async || services.globals.db.database_version().await;
129
130	if discovered_version().await < 13 {
131		return Err!(Database(
132			"Database schema version {} is no longer supported",
133			discovered_version().await,
134		));
135	}
136
137	if db["global"]
138		.get(b"feat_sha256_media")
139		.await
140		.is_not_found()
141	{
142		media::migrations::migrate_sha256_media(services).await?;
143	} else if config.media_startup_check {
144		media::migrations::checkup_sha256_media(services).await?;
145	}
146
147	if db["global"]
148		.get(b"fix_bad_double_separator_in_state_cache")
149		.await
150		.is_not_found()
151	{
152		fix_bad_double_separator_in_state_cache(services).await?;
153	}
154
155	if db["global"]
156		.get(b"retroactively_fix_bad_data_from_roomuserid_joined")
157		.await
158		.is_not_found()
159	{
160		retroactively_fix_bad_data_from_roomuserid_joined(services).await?;
161	}
162
163	if db["global"]
164		.get(b"fix_referencedevents_missing_sep")
165		.await
166		.is_not_found()
167	{
168		fix_referencedevents_missing_sep(services).await?;
169	}
170
171	if db["global"]
172		.get(b"fix_readreceiptid_readreceipt_duplicates")
173		.await
174		.is_not_found()
175	{
176		fix_readreceiptid_readreceipt_duplicates(services).await?;
177	}
178
179	if db["global"]
180		.get(b"fix_hashed_sentinel_passwords")
181		.await
182		.is_not_found()
183	{
184		fix_hashed_sentinel_passwords(services).await?;
185	}
186
187	if db["global"]
188		.get(b"upgrade_legacy_mediaid_user")
189		.await
190		.is_not_found()
191	{
192		upgrade_legacy_mediaid_user(services).await?;
193	}
194
195	if db["global"]
196		.get(b"remove_remote_media_userid")
197		.await
198		.is_not_found()
199	{
200		remove_remote_media_userid(services).await?;
201	}
202
203	if discovered_version().await < target_version {
204		services
205			.globals
206			.db
207			.bump_database_version(target_version);
208
209		info!(
210			"Database: Migrated schema version from {} to {target_version}",
211			discovered_version().await
212		);
213	} else if discovered_version().await != target_version && config.force_migration {
214		services
215			.globals
216			.db
217			.bump_database_version(target_version);
218
219		warn!(
220			"Database: Forced migration from schema version {} to {target_version}",
221			discovered_version().await,
222		);
223	}
224
225	assert_eq!(
226		target_version,
227		discovered_version().await,
228		"Failed asserting local database version {} is equal to known latest tuwunel database \
229		 version {target_version}",
230		discovered_version().await,
231	);
232
233	if !services.config.forbidden_usernames.is_empty() {
234		services
235			.users
236			.stream()
237			.filter(|user_id| services.users.is_active_local(user_id))
238			.ready_filter_map(|user_id| {
239				let patterns = &services.config.forbidden_usernames;
240				let matches = patterns.matches(user_id.localpart());
241				let matched = matches
242					.iter()
243					.map(|x| &patterns.patterns()[x])
244					.join(", ");
245
246				matches
247					.matched_any()
248					.then_some((user_id, matched))
249			})
250			.ready_for_each(|(user_id, matched)| {
251				warn!("User {user_id} matches forbidden username patterns: {matched:#?}");
252			})
253			.await;
254	}
255
256	if !services.config.forbidden_alias_names.is_empty() {
257		services
258			.metadata
259			.iter_ids()
260			.map(|room_id| {
261				services
262					.alias
263					.local_aliases_for_room(room_id)
264					.map(move |alias| (room_id, alias))
265			})
266			.flatten()
267			.ready_filter_map(|(room_id, room_alias)| {
268				let patterns = &services.config.forbidden_alias_names;
269				let matches = patterns.matches(room_alias.alias());
270				let matched = matches
271					.iter()
272					.map(|x| &patterns.patterns()[x])
273					.join(", ");
274
275				matches
276					.matched_any()
277					.then_some((room_id, room_alias, matched))
278			})
279			.ready_for_each(|(room_id, room_alias, matched)| {
280				warn!(
281					"Room {room_id} with alias {room_alias} matches the following forbidden \
282					 room name patterns: {matched}"
283				);
284			})
285			.boxed()
286			.await;
287	}
288
289	info!("Loaded RocksDB database with schema version {DATABASE_VERSION}");
290
291	Ok(())
292}
293
294async fn fix_bad_double_separator_in_state_cache(services: &Services) -> Result {
295	warn!("Fixing bad double separator in state_cache roomuserid_joined");
296
297	let db = &services.db;
298	let roomuserid_joined = &db["roomuserid_joined"];
299	let _cork = db.cork_and_sync();
300
301	let mut iter_count: usize = 0;
302	roomuserid_joined
303		.raw_stream()
304		.ignore_err()
305		.ready_for_each(|(key, value)| {
306			let mut key = key.to_vec();
307			iter_count = iter_count.saturating_add(1);
308			debug_info!(%iter_count);
309			let first_sep_index = key
310				.iter()
311				.position(|&i| i == 0xFF)
312				.expect("found 0xFF delim");
313
314			if key
315				.iter()
316				.get(first_sep_index..=first_sep_index.saturating_add(1))
317				.copied()
318				.collect_vec()
319				== vec![0xFF, 0xFF]
320			{
321				debug_warn!("Found bad key: {key:?}");
322				roomuserid_joined.remove(&key);
323
324				key.remove(first_sep_index);
325				debug_warn!("Fixed key: {key:?}");
326				roomuserid_joined.insert(&key, value);
327			}
328		})
329		.await;
330
331	db.engine.sort()?;
332	db["global"].insert(b"fix_bad_double_separator_in_state_cache", []);
333
334	info!("Finished fixing");
335	Ok(())
336}
337
338async fn retroactively_fix_bad_data_from_roomuserid_joined(services: &Services) -> Result {
339	warn!("Retroactively fixing bad data from broken roomuserid_joined");
340
341	let db = &services.db;
342	let _cork = db.cork_and_sync();
343
344	let room_ids = services
345		.metadata
346		.iter_ids()
347		.map(ToOwned::to_owned)
348		.collect::<Vec<_>>()
349		.await;
350
351	for room_id in &room_ids {
352		debug_info!("Fixing room {room_id}");
353
354		let users_in_room: Vec<OwnedUserId> = services
355			.state_cache
356			.room_members(room_id)
357			.map(ToOwned::to_owned)
358			.collect()
359			.await;
360
361		let joined_members = users_in_room
362			.iter()
363			.stream()
364			.filter(|user_id| {
365				services
366					.state_accessor
367					.get_member(room_id, user_id)
368					.map(|member| {
369						member.is_ok_and(|member| member.membership == MembershipState::Join)
370					})
371			})
372			.collect::<Vec<_>>()
373			.await;
374
375		let non_joined_members = users_in_room
376			.iter()
377			.stream()
378			.filter(|user_id| {
379				services
380					.state_accessor
381					.get_member(room_id, user_id)
382					.map(|member| {
383						member.is_ok_and(|member| member.membership == MembershipState::Join)
384					})
385			})
386			.collect::<Vec<_>>()
387			.await;
388
389		for user_id in &joined_members {
390			debug_info!("User is joined, marking as joined");
391			let count = services.globals.next_count();
392			services
393				.state_cache
394				.mark_as_joined(user_id, room_id, PduCount::Normal(*count));
395		}
396
397		for user_id in &non_joined_members {
398			debug_info!("User is left or banned, marking as left");
399			let count = services.globals.next_count();
400			services
401				.state_cache
402				.mark_as_left(user_id, room_id, PduCount::Normal(*count));
403		}
404	}
405
406	for room_id in &room_ids {
407		debug_info!(
408			"Updating joined count for room {room_id} to fix servers in room after correcting \
409			 membership states"
410		);
411
412		services
413			.state_cache
414			.update_joined_count(room_id)
415			.await;
416	}
417
418	db.engine.sort()?;
419	db["global"].insert(b"retroactively_fix_bad_data_from_roomuserid_joined", []);
420
421	info!("Finished fixing");
422	Ok(())
423}
424
425async fn fix_referencedevents_missing_sep(services: &Services) -> Result {
426	warn!("Fixing missing record separator between room_id and event_id in referencedevents");
427
428	let db = &services.db;
429	let cork = db.cork_and_sync();
430
431	let referencedevents = db["referencedevents"].clone();
432
433	let totals: (usize, usize) = (0, 0);
434	let (total, fixed) = referencedevents
435		.raw_stream()
436		.expect_ok()
437		.enumerate()
438		.ready_fold(totals, |mut a, (i, (key, val))| {
439			debug_assert!(val.is_empty(), "expected no value");
440
441			let has_sep = key.contains(&SEP);
442
443			if !has_sep {
444				let key_str = std::str::from_utf8(key).expect("key not utf-8");
445				let room_id_len = key_str.find('$').expect("missing '$' in key");
446				let (room_id, event_id) = key_str.split_at(room_id_len);
447				debug!(?a, "fixing {room_id}, {event_id}");
448
449				let new_key = (room_id, event_id);
450				referencedevents.put_raw(new_key, val);
451				referencedevents.remove(key);
452			}
453
454			a.0 = cmp::max(i, a.0);
455			a.1 = a.1.saturating_add((!has_sep).into());
456			a
457		})
458		.await;
459
460	drop(cork);
461	info!(?total, ?fixed, "Fixed missing record separators in 'referencedevents'.");
462
463	db["global"].insert(b"fix_referencedevents_missing_sep", []);
464	db.engine.sort()
465}
466
467async fn fix_readreceiptid_readreceipt_duplicates(services: &Services) -> Result {
468	use ruma::identifiers_validation::ID_MAX_BYTES;
469	use tuwunel_core::arrayvec::ArrayString;
470
471	type ArrayId = ArrayString<ID_MAX_BYTES>;
472	type Key<'a> = (&'a RoomId, u64, &'a UserId);
473
474	warn!("Fixing undeleted entries in readreceiptid_readreceipt...");
475
476	let db = &services.db;
477	let cork = db.cork_and_sync();
478	let readreceiptid_readreceipt = db["readreceiptid_readreceipt"].clone();
479
480	let mut cur_room: Option<ArrayId> = None;
481	let mut cur_user: Option<ArrayId> = None;
482	let (mut total, mut fixed): (usize, usize) = (0, 0);
483	readreceiptid_readreceipt
484		.keys()
485		.expect_ok()
486		.ready_for_each(|key: Key<'_>| {
487			let (room_id, _, user_id) = key;
488			let last_room = cur_room.replace(
489				room_id
490					.as_str()
491					.try_into()
492					.expect("invalid room_id in database"),
493			);
494
495			let last_user = cur_user.replace(
496				user_id
497					.as_str()
498					.try_into()
499					.expect("invalid user_id in database"),
500			);
501
502			let is_dup = cur_room == last_room && cur_user == last_user;
503			if is_dup {
504				readreceiptid_readreceipt.del(key);
505			}
506
507			fixed = fixed.saturating_add(is_dup.into());
508			total = total.saturating_add(1);
509		})
510		.await;
511
512	drop(cork);
513	info!(?total, ?fixed, "Fixed undeleted entries in readreceiptid_readreceipt.");
514
515	db["global"].insert(b"fix_readreceiptid_readreceipt_duplicates", []);
516	db.engine.sort()
517}
518
519async fn fix_hashed_sentinel_passwords(services: &Services) -> Result {
520	use tuwunel_core::utils::hash::verify_password;
521
522	const PASSWORD_SENTINEL: &str = "*";
523
524	if services.config.identity_provider.is_empty() {
525		debug!("Skipping sentinel password migration since no SSO IdP configured.");
526		return Ok(());
527	}
528
529	let db = &services.db;
530	let cork = db.cork_and_sync();
531	let userid_password = db["userid_password"].clone();
532	let hashed_sentinel = utils::hash::password(PASSWORD_SENTINEL).map_err(|e| {
533		err!("Could not apply migration: failed to hash sentinel password: {e:?}")
534	})?;
535
536	warn!(
537		"Fixing occurrences of password-hash {hashed_sentinel:?} generated from \
538		 {PASSWORD_SENTINEL:?}"
539	);
540
541	let (checked, good, bad) = userid_password
542		.stream()
543		.expect_ok()
544		.ready_fold(
545			(0, 0, 0),
546			|(mut checked, mut good, mut bad): (usize, usize, usize),
547			 (key, val): (&str, &str)| {
548				let good_sentinel = val == PASSWORD_SENTINEL;
549				let bad_sentinel = !val.is_empty()
550					&& !good_sentinel
551					&& verify_password(PASSWORD_SENTINEL, val).is_ok();
552
553				checked = checked.saturating_add(usize::from(true));
554				good = good.saturating_add(usize::from(good_sentinel));
555				bad = bad.saturating_add(usize::from(bad_sentinel));
556
557				if bad_sentinel {
558					userid_password.insert(key, PASSWORD_SENTINEL);
559				}
560
561				(checked, good, bad)
562			},
563		)
564		.await;
565
566	drop(cork);
567	info!(?checked, ?good, ?bad, "Fixed any occurrences of hashed sentinel passwords");
568
569	db["global"].insert(b"fix_hashed_sentinel_passwords", []);
570	db.engine.sort()
571}
572
573async fn upgrade_legacy_mediaid_user(services: &Services) -> Result {
574	let db = &services.db;
575	let cork = db.cork_and_sync();
576	let mediaid_user = db["mediaid_user"].clone();
577
578	warn!("Upgrading legacy mediaid_user keys to composite (mxc, user_id) layout");
579
580	let (checked, upgraded, removed_invalid) = mediaid_user
581		.raw_stream()
582		.ignore_err()
583		.ready_fold(
584			(0_usize, 0_usize, 0_usize),
585			|(mut checked, mut upgraded, mut removed_invalid), (raw_key, raw_val)| {
586				checked = checked.saturating_add(1);
587
588				let has_sep = raw_key.contains(&SEP);
589				let user_id = str::from_utf8(raw_val)
590					.ok()
591					.and_then(|s| <&UserId>::try_from(s).ok());
592
593				match (has_sep, user_id) {
594					| (true, _) => {},
595					| (false, None) => {
596						warn!(
597							?raw_key,
598							?raw_val,
599							"Legacy entry has unparsable user_id, removing"
600						);
601
602						mediaid_user.remove(raw_key);
603						removed_invalid = removed_invalid.saturating_add(1);
604					},
605					| (false, Some(user_id)) => {
606						let mut new_key = raw_key.to_vec();
607						new_key.push(SEP);
608						new_key.extend_from_slice(user_id.as_bytes());
609
610						mediaid_user.put_raw(new_key, user_id.as_str());
611						mediaid_user.remove(raw_key);
612
613						upgraded = upgraded.saturating_add(1);
614					},
615				}
616
617				(checked, upgraded, removed_invalid)
618			},
619		)
620		.await;
621
622	drop(cork);
623	info!(
624		%checked,
625		%upgraded,
626		%removed_invalid,
627		"Upgraded legacy mediaid_user keys"
628	);
629
630	db["global"].insert(b"upgrade_legacy_mediaid_user", []);
631	db.engine.sort()
632}
633
634async fn remove_remote_media_userid(services: &Services) -> Result {
635	let db = &services.db;
636	let cork = db.cork_and_sync();
637	let mediaid_user = db["mediaid_user"].clone();
638
639	warn!("Removing stored user id for remote media");
640
641	let (checked, removed_remote, removed_invalid) = mediaid_user
642		.keys()
643		.expect_ok()
644		.ready_fold(
645			(0, 0, 0),
646			|(mut checked, mut removed_remote, mut removed_invalid): (usize, usize, usize),
647			 (mxc_uri, user_id): (&MxcUri, &UserId)| {
648				checked = checked.saturating_add(1);
649
650				let Ok(mxc) = mxc_uri.parts() else {
651					warn!(?mxc_uri, "Invalid MXC URL, removing it");
652
653					mediaid_user.del((mxc_uri, user_id));
654
655					removed_invalid = removed_invalid.saturating_add(1);
656
657					return (checked, removed_remote, removed_invalid);
658				};
659
660				if !services.globals.server_is_ours(mxc.server_name) {
661					mediaid_user.del((mxc_uri, user_id));
662
663					removed_remote = removed_remote.saturating_add(1);
664
665					return (checked, removed_remote, removed_invalid);
666				}
667
668				(checked, removed_remote, removed_invalid)
669			},
670		)
671		.await;
672
673	drop(cork);
674	info!(
675		%checked,
676		%removed_remote,
677		%removed_invalid,
678		"Removed stored user id for remote media"
679	);
680
681	db["global"].insert(b"remove_remote_media_userid", []);
682	db.engine.sort()
683}