Skip to main content

tuwunel_service/registration_tokens/
data.rs

1use std::{sync::Arc, time::SystemTime};
2
3use futures::Stream;
4use serde::{Deserialize, Serialize};
5use tuwunel_core::{
6	Err, Result,
7	utils::{
8		self,
9		stream::{ReadyExt, TryIgnore},
10	},
11};
12use tuwunel_database::{Database, Deserialized, Json, Map};
13
14pub(super) struct Data {
15	registrationtoken_info: Arc<Map>,
16}
17
18/// Metadata of a registration token.
19#[derive(Debug, Serialize, Deserialize)]
20pub struct DatabaseTokenInfo {
21	/// The number of times this token has been used to create an account.
22	pub uses: u64,
23	/// When this token will expire, if it expires.
24	pub expires: TokenExpires,
25}
26
27impl DatabaseTokenInfo {
28	pub(super) fn new(expires: TokenExpires) -> Self { Self { uses: 0, expires } }
29
30	/// Determine whether this token info represents a valid token, i.e. one
31	/// that has not exhausted its `max_uses` or passed its `max_age`. When
32	/// both `expires.max_uses` and `expires.max_age` are `None`, this always
33	/// returns `true`.
34	#[must_use]
35	pub fn is_valid(&self) -> bool {
36		if let Some(max_uses) = self.expires.max_uses
37			&& self.uses >= max_uses
38		{
39			return false;
40		}
41
42		if let Some(max_age) = self.expires.max_age {
43			let now = SystemTime::now();
44
45			if now > max_age {
46				return false;
47			}
48		}
49
50		true
51	}
52}
53
54impl std::fmt::Display for DatabaseTokenInfo {
55	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56		write!(f, "Token used {} times. {}", self.uses, self.expires)?;
57
58		Ok(())
59	}
60}
61
62#[derive(Debug, Serialize, Deserialize)]
63pub struct TokenExpires {
64	pub max_uses: Option<u64>,
65	pub max_age: Option<SystemTime>,
66}
67
68impl std::fmt::Display for TokenExpires {
69	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70		let mut msgs = vec![];
71
72		if let Some(max_uses) = self.max_uses {
73			msgs.push(format!("after {max_uses} uses"));
74		}
75
76		if let Some(max_age) = self.max_age {
77			let now = SystemTime::now();
78			let expires_at = utils::time::format(max_age, "%F %T");
79
80			match max_age.duration_since(now) {
81				| Ok(duration) => {
82					let expires_in = utils::time::pretty(duration);
83					msgs.push(format!("in {expires_in} ({expires_at})"));
84				},
85				| Err(_) => {
86					write!(f, "Expired at {expires_at}")?;
87					return Ok(());
88				},
89			}
90		}
91
92		if !msgs.is_empty() {
93			write!(f, "Expires {}.", msgs.join(" or "))?;
94		} else {
95			write!(f, "Never expires.")?;
96		}
97
98		Ok(())
99	}
100}
101
102impl Data {
103	pub(super) fn new(db: &Arc<Database>) -> Self {
104		Self {
105			registrationtoken_info: db["registrationtoken_info"].clone(),
106		}
107	}
108
109	/// Associate a registration token with its metadata in the database.
110	pub(super) async fn save_token(
111		&self,
112		token: &str,
113		expires: TokenExpires,
114	) -> Result<DatabaseTokenInfo> {
115		if self
116			.registrationtoken_info
117			.exists(token)
118			.await
119			.is_err()
120		{
121			let info = DatabaseTokenInfo::new(expires);
122
123			self.registrationtoken_info
124				.raw_put(token, Json(&info));
125
126			Ok(info)
127		} else {
128			Err!(Request(InvalidParam("Registration token already exists")))
129		}
130	}
131
132	/// Delete a registration token.
133	pub(super) async fn revoke_token(&self, token: &str) -> Result {
134		if self
135			.registrationtoken_info
136			.exists(token)
137			.await
138			.is_ok()
139		{
140			self.registrationtoken_info.remove(token);
141
142			Ok(())
143		} else {
144			Err!(Request(NotFound("Registration token not found")))
145		}
146	}
147
148	/// Look up a registration token's metadata.
149	pub(super) async fn check_token(&self, token: &str, consume: bool) -> bool {
150		let info = self
151			.registrationtoken_info
152			.get(token)
153			.await
154			.deserialized::<DatabaseTokenInfo>()
155			.ok();
156
157		info.map(|mut info| {
158			if !info.is_valid() {
159				self.registrationtoken_info.remove(token);
160				return false;
161			}
162
163			if consume {
164				info.uses = info.uses.saturating_add(1);
165
166				if info.is_valid() {
167					self.registrationtoken_info
168						.raw_put(token, Json(info));
169				} else {
170					self.registrationtoken_info.remove(token);
171				}
172			}
173
174			true
175		})
176		.unwrap_or(false)
177	}
178
179	/// Iterate over all valid tokens and delete expired ones.
180	pub(super) fn iterate_and_clean_tokens(
181		&self,
182	) -> impl Stream<Item = (&str, DatabaseTokenInfo)> + Send + '_ {
183		self.registrationtoken_info
184			.stream()
185			.ignore_err()
186			.ready_filter_map(|(token, info): (&str, DatabaseTokenInfo)| {
187				if info.is_valid() {
188					Some((token, info))
189				} else {
190					self.registrationtoken_info.remove(token);
191					None
192				}
193			})
194	}
195}