tuwunel_service/oauth/server/
client.rs1use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as b64};
2use serde::{Deserialize, Serialize};
3use tuwunel_core::{
4 Err, Result, err, implement,
5 utils::{hash::sha256, time::now_secs},
6};
7use tuwunel_database::{Cbor, Deserialized};
8
9const MAX_REGISTRATION_BYTES: usize = 4096;
12
13#[derive(Debug, Deserialize, Serialize)]
14pub struct DcrRequest {
15 pub redirect_uris: Vec<String>,
16 pub client_name: Option<String>,
17 pub client_uri: Option<String>,
18 pub logo_uri: Option<String>,
19 #[serde(default)]
20 pub contacts: Vec<String>,
21 pub token_endpoint_auth_method: Option<String>,
22 pub grant_types: Option<Vec<String>>,
23 pub response_types: Option<Vec<String>>,
24 pub application_type: Option<String>,
25 pub policy_uri: Option<String>,
26 pub tos_uri: Option<String>,
27 pub software_id: Option<String>,
28 pub software_version: Option<String>,
29}
30
31#[derive(Clone, Debug, Deserialize, Serialize)]
32pub struct ClientRegistration {
33 pub client_id: String,
34 pub redirect_uris: Vec<String>,
35 pub client_name: Option<String>,
36 pub client_uri: Option<String>,
37 pub logo_uri: Option<String>,
38 pub contacts: Vec<String>,
39 pub token_endpoint_auth_method: String,
40 pub grant_types: Vec<String>,
41 pub response_types: Vec<String>,
42 pub application_type: Option<String>,
43 pub policy_uri: Option<String>,
44 pub tos_uri: Option<String>,
45 pub software_id: Option<String>,
46 pub software_version: Option<String>,
47 pub registered_at: u64,
48}
49
50#[implement(super::Server)]
51pub async fn register_client(&self, request: DcrRequest) -> Result<ClientRegistration> {
52 let request = normalize(request);
53 let serialized = serde_json::to_vec(&request).expect("DcrRequest is always serializable");
54
55 if serialized.len() > MAX_REGISTRATION_BYTES {
56 return Err!(Request(TooLarge(
57 "Client registration exceeds {MAX_REGISTRATION_BYTES} byte limit"
58 )));
59 }
60
61 let client_id = b64.encode(sha256::hash(&serialized));
62
63 if let Ok(existing) = self.get_client(&client_id).await {
64 return Ok(existing);
65 }
66
67 let auth_method = request
68 .token_endpoint_auth_method
69 .unwrap_or_else(|| "none".to_owned());
70
71 let response_types = request
72 .response_types
73 .unwrap_or_else(|| vec!["code".to_owned()]);
74
75 let grant_types = request
76 .grant_types
77 .unwrap_or_else(|| vec!["authorization_code".to_owned(), "refresh_token".to_owned()]);
78
79 let registration = ClientRegistration {
80 client_id,
81 redirect_uris: request.redirect_uris,
82 client_name: request.client_name,
83 client_uri: request.client_uri,
84 logo_uri: request.logo_uri,
85 contacts: request.contacts,
86 token_endpoint_auth_method: auth_method,
87 grant_types,
88 response_types,
89 application_type: request.application_type,
90 policy_uri: request.policy_uri,
91 tos_uri: request.tos_uri,
92 software_id: request.software_id,
93 software_version: request.software_version,
94 registered_at: now_secs(),
95 };
96
97 self.db
98 .oidcclientid_registration
99 .raw_put(&*registration.client_id, Cbor(®istration));
100
101 Ok(registration)
102}
103
104#[implement(super::Server)]
105pub async fn get_client(&self, client_id: &str) -> Result<ClientRegistration> {
106 self.db
107 .oidcclientid_registration
108 .get(client_id)
109 .await
110 .deserialized::<Cbor<_>>()
111 .map(|cbor: Cbor<ClientRegistration>| cbor.0)
112 .map_err(|_| err!(Request(NotFound("Unknown client_id"))))
113}
114
115fn normalize(mut request: DcrRequest) -> DcrRequest {
116 request.redirect_uris.sort();
117 request.contacts.sort();
118 request
119 .grant_types
120 .iter_mut()
121 .for_each(|v| v.sort());
122 request
123 .response_types
124 .iter_mut()
125 .for_each(|v| v.sort());
126
127 request
128}