Skip to main content

tuwunel_service/oauth/server/
client.rs

1use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as b64};
2use serde::{Deserialize, Serialize};
3use tuwunel_core::{
4	Err, Result, err, implement,
5	utils::{hash::sha256, time::now_secs},
6};
7use tuwunel_database::{Cbor, Deserialized};
8
9// Bounds the per-row footprint so an unauthenticated DCR endpoint cannot
10// evict every other client from the FIFO cache with one huge record.
11const MAX_REGISTRATION_BYTES: usize = 4096;
12
13// Grant and response types the server understands; MSC2966 requires dropping
14// any others from a registration before it is stored and echoed.
15const KNOWN_GRANT_TYPES: [&str; 2] = ["authorization_code", "refresh_token"];
16const KNOWN_RESPONSE_TYPES: [&str; 1] = ["code"];
17
18#[derive(Debug, Deserialize, Serialize)]
19pub struct DcrRequest {
20	pub redirect_uris: Vec<String>,
21	pub client_name: Option<String>,
22	pub client_uri: Option<String>,
23	pub logo_uri: Option<String>,
24	#[serde(default)]
25	pub contacts: Vec<String>,
26	pub token_endpoint_auth_method: Option<String>,
27	pub grant_types: Option<Vec<String>>,
28	pub response_types: Option<Vec<String>>,
29	pub application_type: Option<String>,
30	pub policy_uri: Option<String>,
31	pub tos_uri: Option<String>,
32	pub software_id: Option<String>,
33	pub software_version: Option<String>,
34}
35
36#[derive(Clone, Debug, Deserialize, Serialize)]
37pub struct ClientRegistration {
38	pub client_id: String,
39	pub redirect_uris: Vec<String>,
40	pub client_name: Option<String>,
41	pub client_uri: Option<String>,
42	pub logo_uri: Option<String>,
43	pub contacts: Vec<String>,
44	pub token_endpoint_auth_method: String,
45	pub grant_types: Vec<String>,
46	pub response_types: Vec<String>,
47	pub application_type: Option<String>,
48	pub policy_uri: Option<String>,
49	pub tos_uri: Option<String>,
50	pub software_id: Option<String>,
51	pub software_version: Option<String>,
52	pub registered_at: u64,
53}
54
55#[implement(super::Server)]
56pub async fn register_client(&self, request: DcrRequest) -> Result<ClientRegistration> {
57	let request = normalize(request);
58	let serialized = serde_json::to_vec(&request).expect("DcrRequest is always serializable");
59
60	if serialized.len() > MAX_REGISTRATION_BYTES {
61		return Err!(Request(TooLarge(
62			"Client registration exceeds {MAX_REGISTRATION_BYTES} byte limit"
63		)));
64	}
65
66	let client_id = b64.encode(sha256::hash(&serialized));
67
68	if let Ok(existing) = self.get_client(&client_id).await {
69		return Ok(existing);
70	}
71
72	let auth_method = request
73		.token_endpoint_auth_method
74		.unwrap_or_else(|| "none".to_owned());
75
76	let response_types = request
77		.response_types
78		.unwrap_or_else(|| vec!["code".to_owned()]);
79
80	let grant_types = request
81		.grant_types
82		.unwrap_or_else(|| vec!["authorization_code".to_owned(), "refresh_token".to_owned()]);
83
84	let registration = ClientRegistration {
85		client_id,
86		redirect_uris: request.redirect_uris,
87		client_name: request.client_name,
88		client_uri: request.client_uri,
89		logo_uri: request.logo_uri,
90		contacts: request.contacts,
91		token_endpoint_auth_method: auth_method,
92		grant_types,
93		response_types,
94		application_type: request.application_type,
95		policy_uri: request.policy_uri,
96		tos_uri: request.tos_uri,
97		software_id: request.software_id,
98		software_version: request.software_version,
99		registered_at: now_secs(),
100	};
101
102	self.db
103		.oidcclientid_registration
104		.raw_put(&*registration.client_id, Cbor(&registration));
105
106	Ok(registration)
107}
108
109#[implement(super::Server)]
110pub async fn get_client(&self, client_id: &str) -> Result<ClientRegistration> {
111	self.db
112		.oidcclientid_registration
113		.get(client_id)
114		.await
115		.deserialized::<Cbor<_>>()
116		.map(|cbor: Cbor<ClientRegistration>| cbor.0)
117		.map_err(|_| err!(Request(NotFound("Unknown client_id"))))
118}
119
120fn normalize(mut request: DcrRequest) -> DcrRequest {
121	request.redirect_uris.sort();
122	request.contacts.sort();
123	prune(&mut request.grant_types, &KNOWN_GRANT_TYPES);
124	prune(&mut request.response_types, &KNOWN_RESPONSE_TYPES);
125
126	request
127}
128
129fn prune(types: &mut Option<Vec<String>>, known: &[&str]) {
130	let Some(types) = types else {
131		return;
132	};
133
134	types.retain(|ty| known.contains(&ty.as_str()));
135	types.sort();
136}