Skip to main content

tuwunel_service/oauth/server/
client.rs

1use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as b64};
2use serde::{Deserialize, Serialize};
3use tuwunel_core::{
4	Err, Result, err, implement,
5	utils::{hash::sha256, time::now_secs},
6};
7use tuwunel_database::{Cbor, Deserialized};
8
9// Bounds the per-row footprint so an unauthenticated DCR endpoint cannot
10// evict every other client from the FIFO cache with one huge record.
11const MAX_REGISTRATION_BYTES: usize = 4096;
12
13#[derive(Debug, Deserialize, Serialize)]
14pub struct DcrRequest {
15	pub redirect_uris: Vec<String>,
16	pub client_name: Option<String>,
17	pub client_uri: Option<String>,
18	pub logo_uri: Option<String>,
19	#[serde(default)]
20	pub contacts: Vec<String>,
21	pub token_endpoint_auth_method: Option<String>,
22	pub grant_types: Option<Vec<String>>,
23	pub response_types: Option<Vec<String>>,
24	pub application_type: Option<String>,
25	pub policy_uri: Option<String>,
26	pub tos_uri: Option<String>,
27	pub software_id: Option<String>,
28	pub software_version: Option<String>,
29}
30
31#[derive(Clone, Debug, Deserialize, Serialize)]
32pub struct ClientRegistration {
33	pub client_id: String,
34	pub redirect_uris: Vec<String>,
35	pub client_name: Option<String>,
36	pub client_uri: Option<String>,
37	pub logo_uri: Option<String>,
38	pub contacts: Vec<String>,
39	pub token_endpoint_auth_method: String,
40	pub grant_types: Vec<String>,
41	pub response_types: Vec<String>,
42	pub application_type: Option<String>,
43	pub policy_uri: Option<String>,
44	pub tos_uri: Option<String>,
45	pub software_id: Option<String>,
46	pub software_version: Option<String>,
47	pub registered_at: u64,
48}
49
50#[implement(super::Server)]
51pub async fn register_client(&self, request: DcrRequest) -> Result<ClientRegistration> {
52	let request = normalize(request);
53	let serialized = serde_json::to_vec(&request).expect("DcrRequest is always serializable");
54
55	if serialized.len() > MAX_REGISTRATION_BYTES {
56		return Err!(Request(TooLarge(
57			"Client registration exceeds {MAX_REGISTRATION_BYTES} byte limit"
58		)));
59	}
60
61	let client_id = b64.encode(sha256::hash(&serialized));
62
63	if let Ok(existing) = self.get_client(&client_id).await {
64		return Ok(existing);
65	}
66
67	let auth_method = request
68		.token_endpoint_auth_method
69		.unwrap_or_else(|| "none".to_owned());
70
71	let response_types = request
72		.response_types
73		.unwrap_or_else(|| vec!["code".to_owned()]);
74
75	let grant_types = request
76		.grant_types
77		.unwrap_or_else(|| vec!["authorization_code".to_owned(), "refresh_token".to_owned()]);
78
79	let registration = ClientRegistration {
80		client_id,
81		redirect_uris: request.redirect_uris,
82		client_name: request.client_name,
83		client_uri: request.client_uri,
84		logo_uri: request.logo_uri,
85		contacts: request.contacts,
86		token_endpoint_auth_method: auth_method,
87		grant_types,
88		response_types,
89		application_type: request.application_type,
90		policy_uri: request.policy_uri,
91		tos_uri: request.tos_uri,
92		software_id: request.software_id,
93		software_version: request.software_version,
94		registered_at: now_secs(),
95	};
96
97	self.db
98		.oidcclientid_registration
99		.raw_put(&*registration.client_id, Cbor(&registration));
100
101	Ok(registration)
102}
103
104#[implement(super::Server)]
105pub async fn get_client(&self, client_id: &str) -> Result<ClientRegistration> {
106	self.db
107		.oidcclientid_registration
108		.get(client_id)
109		.await
110		.deserialized::<Cbor<_>>()
111		.map(|cbor: Cbor<ClientRegistration>| cbor.0)
112		.map_err(|_| err!(Request(NotFound("Unknown client_id"))))
113}
114
115fn normalize(mut request: DcrRequest) -> DcrRequest {
116	request.redirect_uris.sort();
117	request.contacts.sort();
118	request
119		.grant_types
120		.iter_mut()
121		.for_each(|v| v.sort());
122	request
123		.response_types
124		.iter_mut()
125		.for_each(|v| v.sort());
126
127	request
128}