Skip to main content

tuwunel_service/oauth/
providers.rs

1use std::collections::BTreeMap;
2
3use serde_json::{Map as JsonObject, Value as JsonValue};
4use tokio::sync::RwLock;
5pub use tuwunel_core::config::IdentityProvider as Provider;
6use tuwunel_core::{Err, Result, debug, debug::INFO_SPAN_LEVEL, err, implement};
7use url::Url;
8
9use crate::{SelfServices, client::read_response_capped};
10
11/// Discovered providers
12#[derive(Default)]
13pub struct Providers {
14	services: SelfServices,
15	providers: RwLock<BTreeMap<ProviderId, Provider>>,
16}
17
18/// Identity Provider ID
19pub type ProviderId = String;
20
21#[implement(Providers)]
22pub(super) fn build(args: &crate::Args<'_>) -> Self {
23	Self {
24		services: args.services.clone(),
25		..Default::default()
26	}
27}
28
29/// Get the Provider configuration after any discovery and adjustments
30/// made on top of the admin's configuration. This incurs network-based
31/// discovery on the first call but responds from cache on subsequent calls.
32#[implement(Providers)]
33#[tracing::instrument(level = "debug", skip(self))]
34pub async fn get(&self, id: &str) -> Result<Provider> {
35	if let Some(provider) = self.get_cached(id).await {
36		return Ok(provider);
37	}
38
39	let config = self.get_config(id)?;
40	let id = config.id().to_owned();
41	let mut map = self.providers.write().await;
42	let provider = self.configure(config).await?;
43
44	debug!(?id, ?provider);
45	_ = map.insert(id, provider.clone());
46
47	Ok(provider)
48}
49
50/// Get the admin-configured Provider which exists prior to any
51/// reconciliation with the well-known discovery (the server's config is
52/// immutable); though it is important to note the server config can be
53/// reloaded. This will Err NotFound for a non-existent idp.
54///
55/// When no provider is found with a matching client_id, providers are then
56/// searched by brand. Brand matching will be invalidated when more than one
57/// provider matches the brand.
58#[implement(Providers)]
59pub fn get_config(&self, id: &str) -> Result<Provider> {
60	let providers = &self.services.config.identity_provider;
61
62	if let Some(provider) = providers
63		.values()
64		.find(|config| config.id() == id)
65		.cloned()
66	{
67		return Ok(provider);
68	}
69
70	if let Some(provider) = providers
71		.values()
72		.find(|config| config.brand.eq_ignore_ascii_case(id))
73		.filter(|_| {
74			providers
75				.values()
76				.filter(|config| config.brand.eq_ignore_ascii_case(id))
77				.count()
78				.eq(&1)
79		})
80		.cloned()
81	{
82		return Ok(provider);
83	}
84
85	Err!(Request(NotFound("Unrecognized Identity Provider")))
86}
87
88/// Get the ID of the provider considered "default" as selected by the admin or
89/// by fallback.
90#[implement(Providers)]
91pub fn get_default_id(&self) -> Option<String> {
92	self.services
93		.config
94		.identity_provider
95		.values()
96		.find(|idp| idp.default)
97		.or_else(|| {
98			self.services
99				.config
100				.identity_provider
101				.values()
102				.next()
103		})
104		.map(Provider::id)
105		.map(ToOwned::to_owned)
106}
107
108/// Get the discovered provider from the runtime cache. ID may be client_id or
109/// brand if brand is unique among provider configurations.
110#[implement(Providers)]
111async fn get_cached(&self, id: &str) -> Option<Provider> {
112	let providers = self.providers.read().await;
113
114	if let Some(provider) = providers.get(id).cloned() {
115		return Some(provider);
116	}
117
118	providers
119		.values()
120		.find(|provider| provider.brand.eq_ignore_ascii_case(id))
121		.filter(|_| {
122			providers
123				.values()
124				.filter(|provider| provider.brand.eq_ignore_ascii_case(id))
125				.count()
126				.eq(&1)
127		})
128		.cloned()
129}
130
131/// Configure an identity provider; takes the admin-configured instance from the
132/// server's config, queries the provider for discovery, and then returns an
133/// updated config based on the proper reconciliation. This final config is then
134/// cached in memory to avoid repeating this process.
135#[implement(Providers)]
136#[tracing::instrument(
137	level = INFO_SPAN_LEVEL,
138	ret(level = "debug"),
139	skip(self),
140)]
141async fn configure(&self, mut provider: Provider) -> Result<Provider> {
142	_ = provider
143		.name
144		.get_or_insert_with(|| provider.brand.clone());
145
146	if provider.issuer_url.is_none() {
147		_ = provider
148			.issuer_url
149			.replace(match provider.brand.as_str() {
150				| "github" => "https://github.com/login/oauth".try_into()?,
151				| "gitlab" => "https://gitlab.com".try_into()?,
152				| "google" => "https://accounts.google.com".try_into()?,
153				| _ => return Err!(Config("issuer_url", "Required for this provider.")),
154			});
155	}
156
157	let response = self
158		.discover(&provider)
159		.await
160		.and_then(|response| {
161			response.as_object().cloned().ok_or_else(|| {
162				err!(Request(NotJson("Expecting JSON object for discovery response")))
163			})
164		})
165		.and_then(|response| check_issuer(response, &provider))?;
166
167	if provider.authorization_url.is_none() {
168		response
169			.get("authorization_endpoint")
170			.and_then(JsonValue::as_str)
171			.map(Url::parse)
172			.transpose()?
173			.or_else(|| make_url(&provider, "authorize").ok())
174			.map(|url| provider.authorization_url.replace(url));
175	}
176
177	if provider.revocation_url.is_none() {
178		response
179			.get("revocation_endpoint")
180			.and_then(JsonValue::as_str)
181			.map(Url::parse)
182			.transpose()?
183			.or_else(|| make_url(&provider, "revocation").ok())
184			.map(|url| provider.revocation_url.replace(url));
185	}
186
187	if provider.introspection_url.is_none() {
188		response
189			.get("introspection_endpoint")
190			.and_then(JsonValue::as_str)
191			.map(Url::parse)
192			.transpose()?
193			.or_else(|| make_url(&provider, "introspection").ok())
194			.map(|url| provider.introspection_url.replace(url));
195	}
196
197	if provider.userinfo_url.is_none() {
198		response
199			.get("userinfo_endpoint")
200			.and_then(JsonValue::as_str)
201			.map(Url::parse)
202			.transpose()?
203			.or_else(|| match provider.brand.as_str() {
204				| "github" => "https://api.github.com/user".try_into().ok(),
205				| _ => make_url(&provider, "userinfo").ok(),
206			})
207			.map(|url| provider.userinfo_url.replace(url));
208	}
209
210	if provider.token_url.is_none() {
211		response
212			.get("token_endpoint")
213			.and_then(JsonValue::as_str)
214			.map(Url::parse)
215			.transpose()?
216			.or_else(|| {
217				let path = if provider.brand == "github" {
218					"access_token"
219				} else {
220					"token"
221				};
222
223				make_url(&provider, path).ok()
224			})
225			.map(|url| provider.token_url.replace(url));
226	}
227
228	if provider.callback_url.is_none()
229		&& let Some(server_url) = self.services.config.well_known.client.as_ref()
230	{
231		let callback_path =
232			format!("_matrix/client/unstable/login/sso/callback/{}", provider.client_id);
233
234		provider.callback_url = Some(server_url.join(&callback_path)?);
235	}
236
237	Ok(provider)
238}
239
240/// Send a network request to a provider at the computed location of the
241/// `.well-known/openid-configuration`, returning the configuration.
242#[implement(Providers)]
243#[tracing::instrument(level = "debug", ret(level = "trace"), skip(self))]
244pub async fn discover(&self, provider: &Provider) -> Result<JsonValue> {
245	let limit = self.services.config.max_response_size;
246	let response = self
247		.services
248		.client
249		.oauth
250		.get(discovery_url(provider)?)
251		.send()
252		.await?
253		.error_for_status()?;
254
255	let body = read_response_capped(response, limit).await?;
256
257	serde_json::from_slice(&body).map_err(Into::into)
258}
259
260/// Compute the location of the `/.well-known/openid-configuration` based on the
261/// local provider config.
262fn discovery_url(provider: &Provider) -> Result<Url> {
263	let default_url = provider
264		.discovery
265		.then(|| make_url(provider, ".well-known/openid-configuration"))
266		.transpose()?;
267
268	let Some(url) = provider
269		.discovery_url
270		.clone()
271		.filter(|_| provider.discovery)
272		.or(default_url)
273	else {
274		return Err!(Config(
275			"discovery_url",
276			"Failed to determine URL for discovery of provider {}",
277			provider.id()
278		));
279	};
280
281	Ok(url)
282}
283
284/// Validate that the locally configured `issuer_url` matches the issuer claimed
285/// in any response. todo: cryptographic validation is not yet implemented here.
286fn check_issuer(
287	response: JsonObject<String, JsonValue>,
288	provider: &Provider,
289) -> Result<JsonObject<String, JsonValue>> {
290	let expected = provider
291		.issuer_url
292		.as_ref()
293		.map(Url::as_str)
294		.map(|url| url.trim_end_matches('/'));
295
296	let responded = response
297		.get("issuer")
298		.and_then(JsonValue::as_str)
299		.map(|url| url.trim_end_matches('/'));
300
301	if expected != responded {
302		return Err!(Request(Unauthorized(
303			"Configured issuer_url {expected:?} does not match discovered {responded:?}",
304		)));
305	}
306
307	Ok(response)
308}
309
310/// Generate a full URL for a request to the idp based on the idp's derived
311/// configuration.
312fn make_url(provider: &Provider, path: &str) -> Result<Url> {
313	let mut suffix = provider.base_path.clone().unwrap_or_default();
314
315	suffix.push_str(path);
316	let issuer = provider.issuer_url.as_ref().ok_or_else(|| {
317		let id = &provider.client_id;
318		err!(Config("issuer_url", "Provider {id:?} required field"))
319	})?;
320	let issuer_path = issuer.path();
321
322	if issuer_path.ends_with('/') {
323		Ok(issuer.join(suffix.as_str())?)
324	} else {
325		let mut url = issuer.to_owned();
326		url.set_path((issuer_path.to_owned() + "/").as_str());
327		Ok(url.join(&suffix)?)
328	}
329}