tuwunel_service/threepid/
pending.rs1use std::time::{Duration, SystemTime};
2
3use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as b64encode};
4use ruma::thirdparty::Medium;
5use serde::{Deserialize, Serialize};
6use subtle::ConstantTimeEq;
7use tuwunel_core::{
8 Err, Result, implement,
9 utils::{self, hash::sha256},
10};
11use tuwunel_database::{Cbor, Deserialized};
12
13use super::Association;
14
15const TOKEN_LENGTH: usize = 48;
17
18const MAX_VERIFY_ATTEMPTS: u32 = 5;
22
23#[derive(Clone, Debug, Deserialize, Serialize)]
27struct Pending {
28 client_secret: String,
29 medium: Medium,
30 address: String,
31 token: String,
32 send_attempt: u64,
33 attempts: u32,
34 validated_at: Option<SystemTime>,
35 expires_at: Option<SystemTime>,
36}
37
38#[derive(Clone, Debug)]
42pub struct PendingOutcome {
43 pub sid: String,
44 pub freshly_minted_token: Option<String>,
45}
46
47#[implement(super::Service)]
53#[tracing::instrument(level = "debug", skip(self, client_secret))]
54pub async fn create_or_reuse_pending(
55 &self,
56 client_secret: &str,
57 medium: Medium,
58 address: &str,
59 send_attempt: u64,
60 ttl: Duration,
61) -> Result<PendingOutcome> {
62 let sid = derive_sid(&medium, address, client_secret);
63
64 if let Ok(existing) = self.get_pending(&sid).await
65 && existing.validated_at.is_none()
66 && send_attempt <= existing.send_attempt
67 {
68 return Ok(PendingOutcome { sid, freshly_minted_token: None });
69 }
70
71 let token = utils::random_string(TOKEN_LENGTH);
72 let expires_at = SystemTime::now().checked_add(ttl);
73 let pending = Pending {
74 client_secret: client_secret.to_owned(),
75 medium,
76 address: address.to_owned(),
77 token: token.clone(),
78 send_attempt,
79 attempts: 0,
80 validated_at: None,
81 expires_at,
82 };
83
84 self.persist_pending(&sid, &pending);
85
86 Ok(PendingOutcome { sid, freshly_minted_token: Some(token) })
87}
88
89#[implement(super::Service)]
94#[tracing::instrument(level = "debug", skip(self, client_secret, token))]
95pub async fn validate_pending_token(
96 &self,
97 sid: &str,
98 client_secret: &str,
99 token: &str,
100) -> Result<()> {
101 let pending = self.get_pending(sid).await?;
102
103 if expired(&pending) {
104 self.delete_pending(sid);
105
106 return Err!(Request(NotFound("The verification session has expired")));
107 }
108
109 let secret_ok = ct_eq(&pending.client_secret, client_secret);
110 let token_ok = ct_eq(&pending.token, token);
111
112 if !secret_ok || !token_ok {
113 let attempts = pending.attempts.saturating_add(1);
114 match attempts >= MAX_VERIFY_ATTEMPTS {
115 | true => self.delete_pending(sid),
116 | false => self.persist_pending(sid, &Pending { attempts, ..pending }),
117 }
118
119 return Err!(Request(ThreepidAuthFailed("Invalid verification token")));
120 }
121
122 let validated_at = Some(SystemTime::now());
123 self.persist_pending(sid, &Pending { validated_at, ..pending });
124
125 Ok(())
126}
127
128#[implement(super::Service)]
132#[tracing::instrument(level = "debug", skip(self, client_secret))]
133pub async fn consume_validated(&self, sid: &str, client_secret: &str) -> Result<Association> {
134 let pending = self.get_pending(sid).await?;
135
136 if expired(&pending) {
137 self.delete_pending(sid);
138
139 return Err!(Request(NotFound("The verification session has expired")));
140 }
141
142 if !ct_eq(&pending.client_secret, client_secret) {
143 return Err!(Request(ThreepidAuthFailed("Client secret does not match")));
144 }
145
146 if pending.validated_at.is_none() {
147 return Err!(Request(ThreepidAuthFailed("The address has not been validated")));
148 }
149
150 self.delete_pending(sid);
151
152 Ok(Association {
153 medium: pending.medium,
154 address: pending.address,
155 })
156}
157
158#[implement(super::Service)]
163#[tracing::instrument(level = "debug", skip(self, client_secret))]
164pub async fn session_validated(&self, sid: &str, client_secret: &str) -> bool {
165 let Ok(pending) = self.get_pending(sid).await else {
166 return false;
167 };
168
169 !expired(&pending)
170 && ct_eq(&pending.client_secret, client_secret)
171 && pending.validated_at.is_some()
172}
173
174#[implement(super::Service)]
175fn persist_pending(&self, sid: &str, pending: &Pending) {
176 self.db
177 .threepidsid_pending
178 .raw_put(sid, Cbor(pending));
179}
180
181#[implement(super::Service)]
183#[tracing::instrument(level = "debug", skip(self))]
184pub fn delete_pending(&self, sid: &str) { self.db.threepidsid_pending.remove(sid); }
185
186#[implement(super::Service)]
187async fn get_pending(&self, sid: &str) -> Result<Pending> {
188 self.db
189 .threepidsid_pending
190 .get(sid)
191 .await
192 .deserialized::<Cbor<_>>()
193 .map(|Cbor(pending)| pending)
194}
195
196fn derive_sid(medium: &Medium, address: &str, client_secret: &str) -> String {
198 let parts = [medium.as_str().as_bytes(), address.as_bytes(), client_secret.as_bytes()];
199 let digest = sha256::delimited(parts.into_iter());
200
201 b64encode.encode(digest)
202}
203
204fn expired(pending: &Pending) -> bool {
205 pending
206 .expires_at
207 .is_some_and(|expires_at| SystemTime::now() > expires_at)
208}
209
210fn ct_eq(a: &str, b: &str) -> bool { a.as_bytes().ct_eq(b.as_bytes()).into() }