Skip to main content

tuwunel_service/globals/
mod.rs

1mod data;
2
3use std::{ops::Range, sync::Arc};
4
5use data::Data;
6use ruma::{OwnedUserId, RoomAliasId, ServerName, UserId};
7use tuwunel_core::{Result, Server, err, error};
8
9use crate::service;
10
11pub struct Service {
12	pub db: Data,
13	server: Arc<Server>,
14
15	pub server_user: OwnedUserId,
16	pub turn_secret: Option<String>,
17}
18
19impl crate::Service for Service {
20	fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
21		let db = Data::new(args);
22		let config = &args.server.config;
23
24		let turn_secret = config
25			.turn_secret_file
26			.as_ref()
27			.and_then(|path| {
28				std::fs::read_to_string(path)
29					.inspect_err(|e| {
30						error!("Failed to read the TURN secret file: {e}");
31					})
32					.ok()
33			})
34			.or_else(|| config.turn_secret.clone());
35
36		Ok(Arc::new(Self {
37			db,
38			server: args.server.clone(),
39			server_user: UserId::parse_with_server_name(
40				String::from("conduit"),
41				&args.server.name,
42			)
43			.expect("@conduit:server_name is valid"),
44			turn_secret,
45		}))
46	}
47
48	fn name(&self) -> &str { service::make_name(std::module_path!()) }
49}
50
51impl Service {
52	#[tracing::instrument(
53		level = "trace",
54		skip_all,
55		ret,
56		fields(pending = ?self.pending_count()),
57	)]
58	pub async fn wait_pending(&self) -> Result<u64> { self.db.wait_pending().await }
59
60	#[tracing::instrument(
61		level = "trace",
62		skip_all,
63		ret,
64		fields(pending = ?self.pending_count()),
65	)]
66	pub async fn wait_count(&self, count: &u64) -> Result<u64> { self.db.wait_count(count).await }
67
68	#[tracing::instrument(
69		level = "debug",
70		skip_all,
71		fields(pending = ?self.pending_count()),
72	)]
73	#[must_use]
74	pub fn next_count(&self) -> data::Permit { self.db.next_count() }
75
76	#[must_use]
77	pub fn current_count(&self) -> u64 { self.db.current_count() }
78
79	#[must_use]
80	pub fn pending_count(&self) -> Range<u64> { self.db.pending_count() }
81
82	#[inline]
83	#[must_use]
84	pub fn server_name(&self) -> &ServerName { self.server.name.as_ref() }
85
86	/// checks if `user_id` is local to us via server_name comparison
87	#[inline]
88	#[must_use]
89	pub fn user_is_local(&self, user_id: &UserId) -> bool {
90		self.server_is_ours(user_id.server_name())
91	}
92
93	#[inline]
94	#[must_use]
95	pub fn alias_is_local(&self, alias: &RoomAliasId) -> bool {
96		self.server_is_ours(alias.server_name())
97	}
98
99	#[inline]
100	#[must_use]
101	pub fn server_is_ours(&self, server_name: &ServerName) -> bool {
102		server_name == self.server_name()
103	}
104
105	#[inline]
106	#[must_use]
107	pub fn is_read_only(&self) -> bool { self.db.db.is_read_only() }
108
109	pub fn init_rustls_provider(&self) -> Result {
110		if rustls::crypto::CryptoProvider::get_default().is_none() {
111			rustls::crypto::aws_lc_rs::default_provider()
112				.install_default()
113				.map_err(|_provider| {
114					err!(error!("Error initialising aws_lc_rs rustls crypto backend"))
115				})
116		} else {
117			Ok(())
118		}
119	}
120}