Skip to main content

tuwunel_service/registration_tokens/
mod.rs

1mod data;
2
3use std::{collections::HashSet, sync::Arc};
4
5use data::Data;
6pub use data::{DatabaseTokenInfo, TokenExpires};
7use futures::{Stream, StreamExt, pin_mut};
8use tuwunel_core::{
9	Err, Result, error,
10	utils::{self, IterStream},
11};
12
13const RANDOM_TOKEN_LENGTH: usize = 16;
14
15pub struct Service {
16	db: Data,
17	services: Arc<crate::services::OnceServices>,
18}
19
20/// A validated registration token which may be used to create an account.
21#[derive(Debug)]
22pub struct ValidToken {
23	pub token: String,
24	pub source: ValidTokenSource,
25}
26
27impl std::fmt::Display for ValidToken {
28	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29		write!(f, "`{}` --- {}", self.token, self.source)
30	}
31}
32
33impl PartialEq<str> for ValidToken {
34	fn eq(&self, other: &str) -> bool { self.token == other }
35}
36
37/// The source of a valid database token.
38#[derive(Debug)]
39pub enum ValidTokenSource {
40	/// The static token set in the homeserver's config file, which is
41	/// always valid.
42	ConfigFile,
43	/// A database token which has been checked to be valid.
44	Database(DatabaseTokenInfo),
45}
46
47impl std::fmt::Display for ValidTokenSource {
48	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49		match self {
50			| Self::ConfigFile => write!(f, "Token defined in config."),
51			| Self::Database(info) => info.fmt(f),
52		}
53	}
54}
55
56impl crate::Service for Service {
57	fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
58		Ok(Arc::new(Self {
59			db: Data::new(args.db),
60			services: args.services.clone(),
61		}))
62	}
63
64	fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
65}
66
67impl Service {
68	/// Issue a new registration token and save it in the database.
69	pub async fn issue_token(
70		&self,
71		expires: TokenExpires,
72	) -> Result<(String, DatabaseTokenInfo)> {
73		let token = utils::random_string(RANDOM_TOKEN_LENGTH);
74
75		let info = self.db.save_token(&token, expires).await?;
76
77		Ok((token, info))
78	}
79
80	pub async fn is_enabled(&self) -> bool {
81		let stream = self.iterate_tokens();
82
83		pin_mut!(stream);
84
85		stream.next().await.is_some()
86	}
87
88	pub fn get_config_tokens(&self) -> HashSet<String> {
89		let mut tokens = HashSet::new();
90		if let Some(file) = &self
91			.services
92			.server
93			.config
94			.registration_token_file
95			.as_ref()
96		{
97			match std::fs::read_to_string(file) {
98				| Err(e) => error!("Failed to read the registration token file: {e}"),
99				| Ok(text) => {
100					text.split_ascii_whitespace().for_each(|token| {
101						tokens.insert(token.to_owned());
102					});
103				},
104			}
105		}
106
107		if let Some(token) = &self.services.server.config.registration_token {
108			tokens.insert(token.to_owned());
109		}
110
111		tokens
112	}
113
114	pub async fn is_token_valid(&self, token: &str) -> Result { self.check(token, false).await }
115
116	pub async fn try_consume(&self, token: &str) -> Result { self.check(token, true).await }
117
118	async fn check(&self, token: &str, consume: bool) -> Result {
119		if self.get_config_tokens().contains(token) || self.db.check_token(token, consume).await {
120			return Ok(());
121		}
122
123		Err!(Request(Forbidden("Registration token not valid")))
124	}
125
126	/// Try to revoke a valid token.
127	///
128	/// Note that tokens set in the config file cannot be revoked.
129	pub async fn revoke_token(&self, token: &str) -> Result {
130		if self.get_config_tokens().contains(token) {
131			return Err!(
132				"The token set in the config file cannot be revoked. Edit the config file to \
133				 change it."
134			);
135		}
136
137		self.db.revoke_token(token).await
138	}
139
140	/// Iterate over all valid registration tokens.
141	pub fn iterate_tokens(&self) -> impl Stream<Item = ValidToken> + Send + '_ {
142		let config_tokens = self
143			.get_config_tokens()
144			.into_iter()
145			.map(|token| ValidToken {
146				token,
147				source: ValidTokenSource::ConfigFile,
148			})
149			.stream();
150
151		let db_tokens = self
152			.db
153			.iterate_and_clean_tokens()
154			.map(|(token, info)| ValidToken {
155				token: token.to_owned(),
156				source: ValidTokenSource::Database(info),
157			});
158
159		config_tokens.chain(db_tokens)
160	}
161}