tuwunel_service/oauth/server/
client.rs1use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as b64};
2use serde::{Deserialize, Serialize};
3use tuwunel_core::{
4 Err, Result, err, implement,
5 utils::{hash::sha256, time::now_secs},
6};
7use tuwunel_database::{Cbor, Deserialized};
8
9const MAX_REGISTRATION_BYTES: usize = 4096;
12
13const KNOWN_GRANT_TYPES: [&str; 2] = ["authorization_code", "refresh_token"];
16const KNOWN_RESPONSE_TYPES: [&str; 1] = ["code"];
17
18#[derive(Debug, Deserialize, Serialize)]
19pub struct DcrRequest {
20 pub redirect_uris: Vec<String>,
21 pub client_name: Option<String>,
22 pub client_uri: Option<String>,
23 pub logo_uri: Option<String>,
24 #[serde(default)]
25 pub contacts: Vec<String>,
26 pub token_endpoint_auth_method: Option<String>,
27 pub grant_types: Option<Vec<String>>,
28 pub response_types: Option<Vec<String>>,
29 pub application_type: Option<String>,
30 pub policy_uri: Option<String>,
31 pub tos_uri: Option<String>,
32 pub software_id: Option<String>,
33 pub software_version: Option<String>,
34}
35
36#[derive(Clone, Debug, Deserialize, Serialize)]
37pub struct ClientRegistration {
38 pub client_id: String,
39 pub redirect_uris: Vec<String>,
40 pub client_name: Option<String>,
41 pub client_uri: Option<String>,
42 pub logo_uri: Option<String>,
43 pub contacts: Vec<String>,
44 pub token_endpoint_auth_method: String,
45 pub grant_types: Vec<String>,
46 pub response_types: Vec<String>,
47 pub application_type: Option<String>,
48 pub policy_uri: Option<String>,
49 pub tos_uri: Option<String>,
50 pub software_id: Option<String>,
51 pub software_version: Option<String>,
52 pub registered_at: u64,
53}
54
55#[implement(super::Server)]
56pub async fn register_client(&self, request: DcrRequest) -> Result<ClientRegistration> {
57 let request = normalize(request);
58 let serialized = serde_json::to_vec(&request).expect("DcrRequest is always serializable");
59
60 if serialized.len() > MAX_REGISTRATION_BYTES {
61 return Err!(Request(TooLarge(
62 "Client registration exceeds {MAX_REGISTRATION_BYTES} byte limit"
63 )));
64 }
65
66 let client_id = b64.encode(sha256::hash(&serialized));
67
68 if let Ok(existing) = self.get_client(&client_id).await {
69 return Ok(existing);
70 }
71
72 let auth_method = request
73 .token_endpoint_auth_method
74 .unwrap_or_else(|| "none".to_owned());
75
76 let response_types = request
77 .response_types
78 .unwrap_or_else(|| vec!["code".to_owned()]);
79
80 let grant_types = request
81 .grant_types
82 .unwrap_or_else(|| vec!["authorization_code".to_owned(), "refresh_token".to_owned()]);
83
84 let registration = ClientRegistration {
85 client_id,
86 redirect_uris: request.redirect_uris,
87 client_name: request.client_name,
88 client_uri: request.client_uri,
89 logo_uri: request.logo_uri,
90 contacts: request.contacts,
91 token_endpoint_auth_method: auth_method,
92 grant_types,
93 response_types,
94 application_type: request.application_type,
95 policy_uri: request.policy_uri,
96 tos_uri: request.tos_uri,
97 software_id: request.software_id,
98 software_version: request.software_version,
99 registered_at: now_secs(),
100 };
101
102 self.db
103 .oidcclientid_registration
104 .raw_put(&*registration.client_id, Cbor(®istration));
105
106 Ok(registration)
107}
108
109#[implement(super::Server)]
110pub async fn get_client(&self, client_id: &str) -> Result<ClientRegistration> {
111 self.db
112 .oidcclientid_registration
113 .get(client_id)
114 .await
115 .deserialized::<Cbor<_>>()
116 .map(|cbor: Cbor<ClientRegistration>| cbor.0)
117 .map_err(|_| err!(Request(NotFound("Unknown client_id"))))
118}
119
120fn normalize(mut request: DcrRequest) -> DcrRequest {
121 request.redirect_uris.sort();
122 request.contacts.sort();
123 prune(&mut request.grant_types, &KNOWN_GRANT_TYPES);
124 prune(&mut request.response_types, &KNOWN_RESPONSE_TYPES);
125
126 request
127}
128
129fn prune(types: &mut Option<Vec<String>>, known: &[&str]) {
130 let Some(types) = types else {
131 return;
132 };
133
134 types.retain(|ty| known.contains(&ty.as_str()));
135 types.sort();
136}