tuwunel_service/oauth/server/
device.rs1use std::time::{Duration, SystemTime};
2
3use ruma::OwnedUserId;
4use serde::{Deserialize, Serialize};
5use tuwunel_core::{Err, Result, err, implement, utils};
6use tuwunel_database::{Cbor, Deserialized};
7
8#[derive(Clone, Debug, Deserialize, Serialize)]
12pub struct DeviceGrant {
13 pub device_code: String,
14
15 pub user_code: String,
18
19 pub client_id: String,
20 pub scope: String,
21 pub status: DeviceGrantStatus,
22 pub attempts: u32,
23 pub created_at: SystemTime,
24 pub expires_at: SystemTime,
25}
26
27#[derive(Clone, Debug, Deserialize, Serialize)]
28pub enum DeviceGrantStatus {
29 Pending,
30 Approved {
31 user_id: OwnedUserId,
32 idp_id: Option<String>,
33 },
34 Denied,
35}
36
37pub struct ApprovedDeviceGrant {
39 pub client_id: String,
40 pub scope: String,
41 pub user_id: OwnedUserId,
42 pub idp_id: Option<String>,
43}
44
45pub enum DeviceGrantPoll {
48 Pending,
49 Approved(ApprovedDeviceGrant),
50 Denied,
51 Expired,
52}
53
54const DEVICE_CODE_LENGTH: usize = 64;
55
56const USER_CODE_LENGTH: usize = 10;
61
62const USER_CODE_CHARSET: &[u8] = b"BCDFGHJKLMNPQRSTVWXZ";
66
67const MAX_VERIFY_ATTEMPTS: u32 = 10;
71
72pub const DEVICE_GRANT_LIFETIME: Duration = Duration::from_mins(30);
73pub const DEVICE_GRANT_INTERVAL_SECS: u64 = 5;
74
75#[implement(super::Server)]
76#[must_use]
77pub fn create_device_grant(&self, client_id: &str, scope: &str) -> DeviceGrant {
78 let now = SystemTime::now();
79 let device_code = utils::random_string(DEVICE_CODE_LENGTH);
80 let user_code = utils::random_string_from(USER_CODE_CHARSET, USER_CODE_LENGTH);
81 let grant = DeviceGrant {
82 device_code: device_code.clone(),
83 user_code: user_code.clone(),
84 client_id: client_id.to_owned(),
85 scope: scope.to_owned(),
86 status: DeviceGrantStatus::Pending,
87 attempts: 0,
88 created_at: now,
89 expires_at: now
90 .checked_add(DEVICE_GRANT_LIFETIME)
91 .unwrap_or(now),
92 };
93
94 self.db
95 .oidcdevicecode_devicegrant
96 .raw_put(&*device_code, Cbor(&grant));
97
98 self.db
99 .oidcusercode_devicecode
100 .raw_put(&*user_code, Cbor(&device_code));
101
102 grant
103}
104
105#[implement(super::Server)]
108async fn resolve_device_code(&self, user_code: &str) -> Result<String> {
109 let user_code = normalize_user_code(user_code);
110
111 self.db
112 .oidcusercode_devicecode
113 .get(&user_code)
114 .await
115 .deserialized::<Cbor<_>>()
116 .map(|cbor: Cbor<String>| cbor.0)
117 .map_err(|_| err!(Request(NotFound("Unknown or expired user code"))))
118}
119
120#[implement(super::Server)]
123pub async fn verify_device_grant(&self, user_code: &str) -> Result<DeviceGrant> {
124 let device_code = self.resolve_device_code(user_code).await?;
125 let _lock = self.device_locks.lock(&device_code).await;
126
127 let mut grant = self.get_device_grant(&device_code).await?;
128
129 if SystemTime::now() > grant.expires_at {
130 self.remove_device_grant(&grant.device_code, &grant.user_code);
131
132 return Err!(Request(NotFound("The device authorization has expired")));
133 }
134
135 if !matches!(grant.status, DeviceGrantStatus::Pending) {
136 return Err!(Request(Forbidden("The device authorization was already resolved")));
137 }
138
139 grant.attempts = grant.attempts.saturating_add(1);
140 if grant.attempts > MAX_VERIFY_ATTEMPTS {
141 self.remove_device_grant(&grant.device_code, &grant.user_code);
142
143 return Err!(Request(Forbidden("Too many attempts; request a new code")));
144 }
145
146 self.db
147 .oidcdevicecode_devicegrant
148 .raw_put(&*grant.device_code, Cbor(&grant));
149
150 Ok(grant)
151}
152
153#[implement(super::Server)]
154pub async fn approve_device_grant(
155 &self,
156 user_code: &str,
157 user_id: OwnedUserId,
158 idp_id: Option<String>,
159) -> Result {
160 self.set_device_grant_status(user_code, DeviceGrantStatus::Approved { user_id, idp_id })
161 .await
162}
163
164#[implement(super::Server)]
165pub async fn deny_device_grant(&self, user_code: &str) -> Result {
166 self.set_device_grant_status(user_code, DeviceGrantStatus::Denied)
167 .await
168}
169
170#[implement(super::Server)]
173pub async fn poll_device_grant(
174 &self,
175 device_code: &str,
176 client_id: &str,
177) -> Result<DeviceGrantPoll> {
178 let _lock = self.device_locks.lock(device_code).await;
181
182 let grant = self.get_device_grant(device_code).await?;
183
184 if grant.client_id != client_id {
185 return Err!(Request(Forbidden("client_id mismatch")));
186 }
187
188 if SystemTime::now() > grant.expires_at {
189 self.remove_device_grant(&grant.device_code, &grant.user_code);
190
191 return Ok(DeviceGrantPoll::Expired);
192 }
193
194 match grant.status {
195 | DeviceGrantStatus::Pending => Ok(DeviceGrantPoll::Pending),
196 | DeviceGrantStatus::Denied => {
197 self.remove_device_grant(&grant.device_code, &grant.user_code);
198
199 Ok(DeviceGrantPoll::Denied)
200 },
201 | DeviceGrantStatus::Approved { user_id, idp_id } => {
202 self.remove_device_grant(&grant.device_code, &grant.user_code);
203
204 Ok(DeviceGrantPoll::Approved(ApprovedDeviceGrant {
205 client_id: grant.client_id,
206 scope: grant.scope,
207 user_id,
208 idp_id,
209 }))
210 },
211 }
212}
213
214#[implement(super::Server)]
215async fn get_device_grant(&self, device_code: &str) -> Result<DeviceGrant> {
216 self.db
217 .oidcdevicecode_devicegrant
218 .get(device_code)
219 .await
220 .deserialized::<Cbor<_>>()
221 .map(|cbor: Cbor<DeviceGrant>| cbor.0)
222 .map_err(|_| err!(Request(Forbidden("Invalid or expired device code"))))
223}
224
225#[implement(super::Server)]
226async fn set_device_grant_status(&self, user_code: &str, status: DeviceGrantStatus) -> Result {
227 let device_code = self.resolve_device_code(user_code).await?;
228 let _lock = self.device_locks.lock(&device_code).await;
229
230 let mut grant = self.get_device_grant(&device_code).await?;
231
232 if SystemTime::now() > grant.expires_at {
233 self.remove_device_grant(&grant.device_code, &grant.user_code);
234
235 return Err!(Request(NotFound("The device authorization has expired")));
236 }
237
238 if !matches!(grant.status, DeviceGrantStatus::Pending) {
239 return Err!(Request(Forbidden("The device authorization was already resolved")));
240 }
241
242 grant.status = status;
243 self.db
244 .oidcdevicecode_devicegrant
245 .raw_put(&*grant.device_code, Cbor(&grant));
246
247 Ok(())
248}
249
250#[implement(super::Server)]
251fn remove_device_grant(&self, device_code: &str, user_code: &str) {
252 self.db
253 .oidcdevicecode_devicegrant
254 .remove(device_code);
255 self.db.oidcusercode_devicecode.remove(user_code);
256}
257
258fn normalize_user_code(input: &str) -> String {
261 input
262 .bytes()
263 .map(|b| b.to_ascii_uppercase())
264 .filter(|b| USER_CODE_CHARSET.contains(b))
265 .map(char::from)
266 .collect()
267}
268
269#[must_use]
271pub fn format_user_code(code: &str) -> String {
272 code.split_at_checked(code.len() / 2)
273 .filter(|(head, tail)| !head.is_empty() && !tail.is_empty())
274 .map(|(head, tail)| format!("{head}-{tail}"))
275 .unwrap_or_else(|| code.to_owned())
276}
277
278#[cfg(test)]
279mod tests {
280 use super::{USER_CODE_CHARSET, USER_CODE_LENGTH, format_user_code, normalize_user_code};
281
282 #[test]
283 fn format_then_normalize_round_trips() {
284 let code = "BCDFGHJK";
285
286 assert_eq!(normalize_user_code(&format_user_code(code)), code);
287 }
288
289 #[test]
290 fn normalize_strips_separators_and_uppercases() {
291 assert_eq!(normalize_user_code("bcdf-ghjk"), "BCDFGHJK");
292 assert_eq!(normalize_user_code(" bc df ghjk "), "BCDFGHJK");
293 }
294
295 #[test]
296 fn normalize_drops_out_of_charset_characters() {
297 assert_eq!(normalize_user_code("B0C1DAEF"), "BCDF");
298 }
299
300 #[test]
301 fn format_inserts_a_single_separator() {
302 assert_eq!(format_user_code("BCDFGHJK"), "BCDF-GHJK");
303 }
304
305 #[test]
306 fn charset_is_base20_without_vowels_or_digits() {
307 assert_eq!(USER_CODE_CHARSET.len(), 20);
308 assert_eq!(USER_CODE_LENGTH, 10);
309
310 for excluded in b"AEIOUY0123456789" {
312 assert!(!USER_CODE_CHARSET.contains(excluded));
313 }
314 }
315}