1use std::cmp;
2
3use futures::{FutureExt, StreamExt};
4use ruma::{MxcUri, OwnedUserId, RoomId, UserId, events::room::member::MembershipState};
5use tuwunel_core::{
6 Err, Result, debug, debug_info, debug_warn, err, info,
7 itertools::Itertools,
8 matrix::PduCount,
9 result::NotFound,
10 utils,
11 utils::{
12 BoolExt, IterStream, ReadyExt,
13 stream::{TryExpect, TryIgnore},
14 },
15 warn,
16};
17use tuwunel_database::{Deserialized, SEP};
18
19use crate::{Services, media};
20
21pub(crate) const DATABASE_VERSION: u64 = 17;
28
29const SERVER_NAME_KEY: &[u8] = b"server_name";
30
31pub(crate) async fn migrations(services: &Services) -> Result {
32 if !services.config.database_migrations {
33 warn!("Skipping database migrations due to configuration...");
34 return Ok(());
35 }
36
37 let users_count = services.users.count().await;
38 if users_count == 0 {
39 return fresh(services).await;
40 }
41
42 check_server_name(services).await?;
43 migrate(services).await
44}
45
46async fn check_server_name(services: &Services) -> Result {
51 let server_name = &services.server.name;
52
53 let existing = services.db["global"]
54 .get(SERVER_NAME_KEY)
55 .await
56 .deserialized::<String>();
57
58 match existing {
59 | Err(_) => backfill_server_name(services).await,
60 | Ok(existing) if existing.eq(server_name) => Ok(()),
61 | Ok(existing) => Err!(Database(
62 "Database belongs to {existing}; configured server name is {server_name}. Cannot \
63 reuse."
64 )),
65 }
66}
67
68async fn backfill_server_name(services: &Services) -> Result {
72 let server_name = &services.server.name;
73
74 services
75 .users
76 .stream()
77 .ready_any(|user_id| services.globals.user_is_local(user_id))
78 .await
79 .ok_or_else(|| {
80 err!(Database(
81 "Database has no users from {server_name}; refusing to reuse with this \
82 server_name."
83 ))
84 })?;
85
86 services.db["global"].insert(SERVER_NAME_KEY, server_name.as_str());
87 info!(%server_name, "Stamped server_name marker on upgraded database");
88
89 Ok(())
90}
91
92async fn fresh(services: &Services) -> Result {
93 let db = &services.db;
94
95 services
96 .globals
97 .db
98 .bump_database_version(DATABASE_VERSION);
99
100 db["global"].insert(SERVER_NAME_KEY, services.server.name.as_str());
101 db["global"].insert(b"feat_sha256_media", []);
102 db["global"].insert(b"fix_bad_double_separator_in_state_cache", []);
103 db["global"].insert(b"retroactively_fix_bad_data_from_roomuserid_joined", []);
104 db["global"].insert(b"fix_referencedevents_missing_sep", []);
105 db["global"].insert(b"fix_readreceiptid_readreceipt_duplicates", []);
106 db["global"].insert(b"fix_hashed_sentinel_passwords", []);
107 db["global"].insert(b"upgrade_legacy_mediaid_user", []);
108 db["global"].insert(b"remove_remote_media_userid", []);
109
110 if services.config.create_admin_room {
112 crate::admin::create_admin_room(services)
113 .boxed()
114 .await?;
115 }
116
117 warn!("Created new RocksDB database with version {DATABASE_VERSION}");
118
119 Ok(())
120}
121
122async fn migrate(services: &Services) -> Result {
124 let db = &services.db;
125 let config = &services.server.config;
126
127 let target_version = DATABASE_VERSION;
128 let discovered_version = async || services.globals.db.database_version().await;
129
130 if discovered_version().await < 13 {
131 return Err!(Database(
132 "Database schema version {} is no longer supported",
133 discovered_version().await,
134 ));
135 }
136
137 if db["global"]
138 .get(b"feat_sha256_media")
139 .await
140 .is_not_found()
141 {
142 media::migrations::migrate_sha256_media(services).await?;
143 } else if config.media_startup_check {
144 media::migrations::checkup_sha256_media(services).await?;
145 }
146
147 if db["global"]
148 .get(b"fix_bad_double_separator_in_state_cache")
149 .await
150 .is_not_found()
151 {
152 fix_bad_double_separator_in_state_cache(services).await?;
153 }
154
155 if db["global"]
156 .get(b"retroactively_fix_bad_data_from_roomuserid_joined")
157 .await
158 .is_not_found()
159 {
160 retroactively_fix_bad_data_from_roomuserid_joined(services).await?;
161 }
162
163 if db["global"]
164 .get(b"fix_referencedevents_missing_sep")
165 .await
166 .is_not_found()
167 {
168 fix_referencedevents_missing_sep(services).await?;
169 }
170
171 if db["global"]
172 .get(b"fix_readreceiptid_readreceipt_duplicates")
173 .await
174 .is_not_found()
175 {
176 fix_readreceiptid_readreceipt_duplicates(services).await?;
177 }
178
179 if db["global"]
180 .get(b"fix_hashed_sentinel_passwords")
181 .await
182 .is_not_found()
183 {
184 fix_hashed_sentinel_passwords(services).await?;
185 }
186
187 if db["global"]
188 .get(b"upgrade_legacy_mediaid_user")
189 .await
190 .is_not_found()
191 {
192 upgrade_legacy_mediaid_user(services).await?;
193 }
194
195 if db["global"]
196 .get(b"remove_remote_media_userid")
197 .await
198 .is_not_found()
199 {
200 remove_remote_media_userid(services).await?;
201 }
202
203 if discovered_version().await < target_version {
204 services
205 .globals
206 .db
207 .bump_database_version(target_version);
208
209 info!(
210 "Database: Migrated schema version from {} to {target_version}",
211 discovered_version().await
212 );
213 } else if discovered_version().await != target_version && config.force_migration {
214 services
215 .globals
216 .db
217 .bump_database_version(target_version);
218
219 warn!(
220 "Database: Forced migration from schema version {} to {target_version}",
221 discovered_version().await,
222 );
223 }
224
225 assert_eq!(
226 target_version,
227 discovered_version().await,
228 "Failed asserting local database version {} is equal to known latest tuwunel database \
229 version {target_version}",
230 discovered_version().await,
231 );
232
233 if !services.config.forbidden_usernames.is_empty() {
234 services
235 .users
236 .stream()
237 .filter(|user_id| services.users.is_active_local(user_id))
238 .ready_filter_map(|user_id| {
239 let patterns = &services.config.forbidden_usernames;
240 let matches = patterns.matches(user_id.localpart());
241 let matched = matches
242 .iter()
243 .map(|x| &patterns.patterns()[x])
244 .join(", ");
245
246 matches
247 .matched_any()
248 .then_some((user_id, matched))
249 })
250 .ready_for_each(|(user_id, matched)| {
251 warn!("User {user_id} matches forbidden username patterns: {matched:#?}");
252 })
253 .await;
254 }
255
256 if !services.config.forbidden_alias_names.is_empty() {
257 services
258 .metadata
259 .iter_ids()
260 .map(|room_id| {
261 services
262 .alias
263 .local_aliases_for_room(room_id)
264 .map(move |alias| (room_id, alias))
265 })
266 .flatten()
267 .ready_filter_map(|(room_id, room_alias)| {
268 let patterns = &services.config.forbidden_alias_names;
269 let matches = patterns.matches(room_alias.alias());
270 let matched = matches
271 .iter()
272 .map(|x| &patterns.patterns()[x])
273 .join(", ");
274
275 matches
276 .matched_any()
277 .then_some((room_id, room_alias, matched))
278 })
279 .ready_for_each(|(room_id, room_alias, matched)| {
280 warn!(
281 "Room {room_id} with alias {room_alias} matches the following forbidden \
282 room name patterns: {matched}"
283 );
284 })
285 .boxed()
286 .await;
287 }
288
289 info!("Loaded RocksDB database with schema version {DATABASE_VERSION}");
290
291 Ok(())
292}
293
294async fn fix_bad_double_separator_in_state_cache(services: &Services) -> Result {
295 warn!("Fixing bad double separator in state_cache roomuserid_joined");
296
297 let db = &services.db;
298 let roomuserid_joined = &db["roomuserid_joined"];
299 let _cork = db.cork_and_sync();
300
301 let mut iter_count: usize = 0;
302 roomuserid_joined
303 .raw_stream()
304 .ignore_err()
305 .ready_for_each(|(key, value)| {
306 let mut key = key.to_vec();
307 iter_count = iter_count.saturating_add(1);
308 debug_info!(%iter_count);
309 let first_sep_index = key
310 .iter()
311 .position(|&i| i == 0xFF)
312 .expect("found 0xFF delim");
313
314 if key
315 .iter()
316 .get(first_sep_index..=first_sep_index.saturating_add(1))
317 .copied()
318 .collect_vec()
319 == vec![0xFF, 0xFF]
320 {
321 debug_warn!("Found bad key: {key:?}");
322 roomuserid_joined.remove(&key);
323
324 key.remove(first_sep_index);
325 debug_warn!("Fixed key: {key:?}");
326 roomuserid_joined.insert(&key, value);
327 }
328 })
329 .await;
330
331 db.engine.sort()?;
332 db["global"].insert(b"fix_bad_double_separator_in_state_cache", []);
333
334 info!("Finished fixing");
335 Ok(())
336}
337
338async fn retroactively_fix_bad_data_from_roomuserid_joined(services: &Services) -> Result {
339 warn!("Retroactively fixing bad data from broken roomuserid_joined");
340
341 let db = &services.db;
342 let _cork = db.cork_and_sync();
343
344 let room_ids = services
345 .metadata
346 .iter_ids()
347 .map(ToOwned::to_owned)
348 .collect::<Vec<_>>()
349 .await;
350
351 for room_id in &room_ids {
352 debug_info!("Fixing room {room_id}");
353
354 let users_in_room: Vec<OwnedUserId> = services
355 .state_cache
356 .room_members(room_id)
357 .map(ToOwned::to_owned)
358 .collect()
359 .await;
360
361 let joined_members = users_in_room
362 .iter()
363 .stream()
364 .filter(|user_id| {
365 services
366 .state_accessor
367 .get_member(room_id, user_id)
368 .map(|member| {
369 member.is_ok_and(|member| member.membership == MembershipState::Join)
370 })
371 })
372 .collect::<Vec<_>>()
373 .await;
374
375 let non_joined_members = users_in_room
376 .iter()
377 .stream()
378 .filter(|user_id| {
379 services
380 .state_accessor
381 .get_member(room_id, user_id)
382 .map(|member| {
383 member.is_ok_and(|member| member.membership == MembershipState::Join)
384 })
385 })
386 .collect::<Vec<_>>()
387 .await;
388
389 for user_id in &joined_members {
390 debug_info!("User is joined, marking as joined");
391 let count = services.globals.next_count();
392 services
393 .state_cache
394 .mark_as_joined(user_id, room_id, PduCount::Normal(*count));
395 }
396
397 for user_id in &non_joined_members {
398 debug_info!("User is left or banned, marking as left");
399 let count = services.globals.next_count();
400 services
401 .state_cache
402 .mark_as_left(user_id, room_id, PduCount::Normal(*count));
403 }
404 }
405
406 for room_id in &room_ids {
407 debug_info!(
408 "Updating joined count for room {room_id} to fix servers in room after correcting \
409 membership states"
410 );
411
412 services
413 .state_cache
414 .update_joined_count(room_id)
415 .await;
416 }
417
418 db.engine.sort()?;
419 db["global"].insert(b"retroactively_fix_bad_data_from_roomuserid_joined", []);
420
421 info!("Finished fixing");
422 Ok(())
423}
424
425async fn fix_referencedevents_missing_sep(services: &Services) -> Result {
426 warn!("Fixing missing record separator between room_id and event_id in referencedevents");
427
428 let db = &services.db;
429 let cork = db.cork_and_sync();
430
431 let referencedevents = db["referencedevents"].clone();
432
433 let totals: (usize, usize) = (0, 0);
434 let (total, fixed) = referencedevents
435 .raw_stream()
436 .expect_ok()
437 .enumerate()
438 .ready_fold(totals, |mut a, (i, (key, val))| {
439 debug_assert!(val.is_empty(), "expected no value");
440
441 let has_sep = key.contains(&SEP);
442
443 if !has_sep {
444 let key_str = std::str::from_utf8(key).expect("key not utf-8");
445 let room_id_len = key_str.find('$').expect("missing '$' in key");
446 let (room_id, event_id) = key_str.split_at(room_id_len);
447 debug!(?a, "fixing {room_id}, {event_id}");
448
449 let new_key = (room_id, event_id);
450 referencedevents.put_raw(new_key, val);
451 referencedevents.remove(key);
452 }
453
454 a.0 = cmp::max(i, a.0);
455 a.1 = a.1.saturating_add((!has_sep).into());
456 a
457 })
458 .await;
459
460 drop(cork);
461 info!(?total, ?fixed, "Fixed missing record separators in 'referencedevents'.");
462
463 db["global"].insert(b"fix_referencedevents_missing_sep", []);
464 db.engine.sort()
465}
466
467async fn fix_readreceiptid_readreceipt_duplicates(services: &Services) -> Result {
468 use ruma::identifiers_validation::ID_MAX_BYTES;
469 use tuwunel_core::arrayvec::ArrayString;
470
471 type ArrayId = ArrayString<ID_MAX_BYTES>;
472 type Key<'a> = (&'a RoomId, u64, &'a UserId);
473
474 warn!("Fixing undeleted entries in readreceiptid_readreceipt...");
475
476 let db = &services.db;
477 let cork = db.cork_and_sync();
478 let readreceiptid_readreceipt = db["readreceiptid_readreceipt"].clone();
479
480 let mut cur_room: Option<ArrayId> = None;
481 let mut cur_user: Option<ArrayId> = None;
482 let (mut total, mut fixed): (usize, usize) = (0, 0);
483 readreceiptid_readreceipt
484 .keys()
485 .expect_ok()
486 .ready_for_each(|key: Key<'_>| {
487 let (room_id, _, user_id) = key;
488 let last_room = cur_room.replace(
489 room_id
490 .as_str()
491 .try_into()
492 .expect("invalid room_id in database"),
493 );
494
495 let last_user = cur_user.replace(
496 user_id
497 .as_str()
498 .try_into()
499 .expect("invalid user_id in database"),
500 );
501
502 let is_dup = cur_room == last_room && cur_user == last_user;
503 if is_dup {
504 readreceiptid_readreceipt.del(key);
505 }
506
507 fixed = fixed.saturating_add(is_dup.into());
508 total = total.saturating_add(1);
509 })
510 .await;
511
512 drop(cork);
513 info!(?total, ?fixed, "Fixed undeleted entries in readreceiptid_readreceipt.");
514
515 db["global"].insert(b"fix_readreceiptid_readreceipt_duplicates", []);
516 db.engine.sort()
517}
518
519async fn fix_hashed_sentinel_passwords(services: &Services) -> Result {
520 use tuwunel_core::utils::hash::verify_password;
521
522 const PASSWORD_SENTINEL: &str = "*";
523
524 if services.config.identity_provider.is_empty() {
525 debug!("Skipping sentinel password migration since no SSO IdP configured.");
526 return Ok(());
527 }
528
529 let db = &services.db;
530 let cork = db.cork_and_sync();
531 let userid_password = db["userid_password"].clone();
532 let hashed_sentinel = utils::hash::password(PASSWORD_SENTINEL).map_err(|e| {
533 err!("Could not apply migration: failed to hash sentinel password: {e:?}")
534 })?;
535
536 warn!(
537 "Fixing occurrences of password-hash {hashed_sentinel:?} generated from \
538 {PASSWORD_SENTINEL:?}"
539 );
540
541 let (checked, good, bad) = userid_password
542 .stream()
543 .expect_ok()
544 .ready_fold(
545 (0, 0, 0),
546 |(mut checked, mut good, mut bad): (usize, usize, usize),
547 (key, val): (&str, &str)| {
548 let good_sentinel = val == PASSWORD_SENTINEL;
549 let bad_sentinel = !val.is_empty()
550 && !good_sentinel
551 && verify_password(PASSWORD_SENTINEL, val).is_ok();
552
553 checked = checked.saturating_add(usize::from(true));
554 good = good.saturating_add(usize::from(good_sentinel));
555 bad = bad.saturating_add(usize::from(bad_sentinel));
556
557 if bad_sentinel {
558 userid_password.insert(key, PASSWORD_SENTINEL);
559 }
560
561 (checked, good, bad)
562 },
563 )
564 .await;
565
566 drop(cork);
567 info!(?checked, ?good, ?bad, "Fixed any occurrences of hashed sentinel passwords");
568
569 db["global"].insert(b"fix_hashed_sentinel_passwords", []);
570 db.engine.sort()
571}
572
573async fn upgrade_legacy_mediaid_user(services: &Services) -> Result {
574 let db = &services.db;
575 let cork = db.cork_and_sync();
576 let mediaid_user = db["mediaid_user"].clone();
577
578 warn!("Upgrading legacy mediaid_user keys to composite (mxc, user_id) layout");
579
580 let (checked, upgraded, removed_invalid) = mediaid_user
581 .raw_stream()
582 .ignore_err()
583 .ready_fold(
584 (0_usize, 0_usize, 0_usize),
585 |(mut checked, mut upgraded, mut removed_invalid), (raw_key, raw_val)| {
586 checked = checked.saturating_add(1);
587
588 let has_sep = raw_key.contains(&SEP);
589 let user_id = str::from_utf8(raw_val)
590 .ok()
591 .and_then(|s| <&UserId>::try_from(s).ok());
592
593 match (has_sep, user_id) {
594 | (true, _) => {},
595 | (false, None) => {
596 warn!(
597 ?raw_key,
598 ?raw_val,
599 "Legacy entry has unparsable user_id, removing"
600 );
601
602 mediaid_user.remove(raw_key);
603 removed_invalid = removed_invalid.saturating_add(1);
604 },
605 | (false, Some(user_id)) => {
606 let mut new_key = raw_key.to_vec();
607 new_key.push(SEP);
608 new_key.extend_from_slice(user_id.as_bytes());
609
610 mediaid_user.put_raw(new_key, user_id.as_str());
611 mediaid_user.remove(raw_key);
612
613 upgraded = upgraded.saturating_add(1);
614 },
615 }
616
617 (checked, upgraded, removed_invalid)
618 },
619 )
620 .await;
621
622 drop(cork);
623 info!(
624 %checked,
625 %upgraded,
626 %removed_invalid,
627 "Upgraded legacy mediaid_user keys"
628 );
629
630 db["global"].insert(b"upgrade_legacy_mediaid_user", []);
631 db.engine.sort()
632}
633
634async fn remove_remote_media_userid(services: &Services) -> Result {
635 let db = &services.db;
636 let cork = db.cork_and_sync();
637 let mediaid_user = db["mediaid_user"].clone();
638
639 warn!("Removing stored user id for remote media");
640
641 let (checked, removed_remote, removed_invalid) = mediaid_user
642 .keys()
643 .expect_ok()
644 .ready_fold(
645 (0, 0, 0),
646 |(mut checked, mut removed_remote, mut removed_invalid): (usize, usize, usize),
647 (mxc_uri, user_id): (&MxcUri, &UserId)| {
648 checked = checked.saturating_add(1);
649
650 let Ok(mxc) = mxc_uri.parts() else {
651 warn!(?mxc_uri, "Invalid MXC URL, removing it");
652
653 mediaid_user.del((mxc_uri, user_id));
654
655 removed_invalid = removed_invalid.saturating_add(1);
656
657 return (checked, removed_remote, removed_invalid);
658 };
659
660 if !services.globals.server_is_ours(mxc.server_name) {
661 mediaid_user.del((mxc_uri, user_id));
662
663 removed_remote = removed_remote.saturating_add(1);
664
665 return (checked, removed_remote, removed_invalid);
666 }
667
668 (checked, removed_remote, removed_invalid)
669 },
670 )
671 .await;
672
673 drop(cork);
674 info!(
675 %checked,
676 %removed_remote,
677 %removed_invalid,
678 "Removed stored user id for remote media"
679 );
680
681 db["global"].insert(b"remove_remote_media_userid", []);
682 db.engine.sort()
683}