Skip to main content

tuwunel_service/oauth/
providers.rs

1use std::collections::BTreeMap;
2
3use serde_json::{Map as JsonObject, Value as JsonValue};
4use tokio::sync::RwLock;
5pub use tuwunel_core::config::IdentityProvider as Provider;
6use tuwunel_core::{Err, Result, debug, debug::INFO_SPAN_LEVEL, err, implement};
7use url::Url;
8
9use crate::SelfServices;
10
11/// Discovered providers
12#[derive(Default)]
13pub struct Providers {
14	services: SelfServices,
15	providers: RwLock<BTreeMap<ProviderId, Provider>>,
16}
17
18/// Identity Provider ID
19pub type ProviderId = String;
20
21#[implement(Providers)]
22pub(super) fn build(args: &crate::Args<'_>) -> Self {
23	Self {
24		services: args.services.clone(),
25		..Default::default()
26	}
27}
28
29/// Get the Provider configuration after any discovery and adjustments
30/// made on top of the admin's configuration. This incurs network-based
31/// discovery on the first call but responds from cache on subsequent calls.
32#[implement(Providers)]
33#[tracing::instrument(level = "debug", skip(self))]
34pub async fn get(&self, id: &str) -> Result<Provider> {
35	if let Some(provider) = self.get_cached(id).await {
36		return Ok(provider);
37	}
38
39	let config = self.get_config(id)?;
40	let id = config.id().to_owned();
41	let mut map = self.providers.write().await;
42	let provider = self.configure(config).await?;
43
44	debug!(?id, ?provider);
45	_ = map.insert(id, provider.clone());
46
47	Ok(provider)
48}
49
50/// Get the admin-configured Provider which exists prior to any
51/// reconciliation with the well-known discovery (the server's config is
52/// immutable); though it is important to note the server config can be
53/// reloaded. This will Err NotFound for a non-existent idp.
54///
55/// When no provider is found with a matching client_id, providers are then
56/// searched by brand. Brand matching will be invalidated when more than one
57/// provider matches the brand.
58#[implement(Providers)]
59pub fn get_config(&self, id: &str) -> Result<Provider> {
60	let providers = &self.services.config.identity_provider;
61
62	if let Some(provider) = providers
63		.values()
64		.find(|config| config.id() == id)
65		.cloned()
66	{
67		return Ok(provider);
68	}
69
70	if let Some(provider) = providers
71		.values()
72		.find(|config| config.brand.eq_ignore_ascii_case(id))
73		.filter(|_| {
74			providers
75				.values()
76				.filter(|config| config.brand.eq_ignore_ascii_case(id))
77				.count()
78				.eq(&1)
79		})
80		.cloned()
81	{
82		return Ok(provider);
83	}
84
85	Err!(Request(NotFound("Unrecognized Identity Provider")))
86}
87
88/// Get the ID of the provider considered "default" as selected by the admin or
89/// by fallback.
90#[implement(Providers)]
91pub fn get_default_id(&self) -> Option<String> {
92	self.services
93		.config
94		.identity_provider
95		.values()
96		.find(|idp| idp.default)
97		.or_else(|| {
98			self.services
99				.config
100				.identity_provider
101				.values()
102				.next()
103		})
104		.map(Provider::id)
105		.map(ToOwned::to_owned)
106}
107
108/// Get the discovered provider from the runtime cache. ID may be client_id or
109/// brand if brand is unique among provider configurations.
110#[implement(Providers)]
111async fn get_cached(&self, id: &str) -> Option<Provider> {
112	let providers = self.providers.read().await;
113
114	if let Some(provider) = providers.get(id).cloned() {
115		return Some(provider);
116	}
117
118	providers
119		.values()
120		.find(|provider| provider.brand.eq_ignore_ascii_case(id))
121		.filter(|_| {
122			providers
123				.values()
124				.filter(|provider| provider.brand.eq_ignore_ascii_case(id))
125				.count()
126				.eq(&1)
127		})
128		.cloned()
129}
130
131/// Configure an identity provider; takes the admin-configured instance from the
132/// server's config, queries the provider for discovery, and then returns an
133/// updated config based on the proper reconciliation. This final config is then
134/// cached in memory to avoid repeating this process.
135#[implement(Providers)]
136#[tracing::instrument(
137	level = INFO_SPAN_LEVEL,
138	ret(level = "debug"),
139	skip(self),
140)]
141async fn configure(&self, mut provider: Provider) -> Result<Provider> {
142	_ = provider
143		.name
144		.get_or_insert_with(|| provider.brand.clone());
145
146	if provider.issuer_url.is_none() {
147		_ = provider
148			.issuer_url
149			.replace(match provider.brand.as_str() {
150				| "github" => "https://github.com".try_into()?,
151				| "gitlab" => "https://gitlab.com".try_into()?,
152				| "google" => "https://accounts.google.com".try_into()?,
153				| _ => return Err!(Config("issuer_url", "Required for this provider.")),
154			});
155	}
156
157	if provider.base_path.is_none() {
158		provider.base_path = match provider.brand.as_str() {
159			| "github" => Some("login/oauth/".to_owned()),
160			| _ => None,
161		};
162	}
163
164	let response = self
165		.discover(&provider)
166		.await
167		.and_then(|response| {
168			response.as_object().cloned().ok_or_else(|| {
169				err!(Request(NotJson("Expecting JSON object for discovery response")))
170			})
171		})
172		.and_then(|response| check_issuer(response, &provider))?;
173
174	if provider.authorization_url.is_none() {
175		response
176			.get("authorization_endpoint")
177			.and_then(JsonValue::as_str)
178			.map(Url::parse)
179			.transpose()?
180			.or_else(|| make_url(&provider, "authorize").ok())
181			.map(|url| provider.authorization_url.replace(url));
182	}
183
184	if provider.revocation_url.is_none() {
185		response
186			.get("revocation_endpoint")
187			.and_then(JsonValue::as_str)
188			.map(Url::parse)
189			.transpose()?
190			.or_else(|| make_url(&provider, "revocation").ok())
191			.map(|url| provider.revocation_url.replace(url));
192	}
193
194	if provider.introspection_url.is_none() {
195		response
196			.get("introspection_endpoint")
197			.and_then(JsonValue::as_str)
198			.map(Url::parse)
199			.transpose()?
200			.or_else(|| make_url(&provider, "introspection").ok())
201			.map(|url| provider.introspection_url.replace(url));
202	}
203
204	if provider.userinfo_url.is_none() {
205		response
206			.get("userinfo_endpoint")
207			.and_then(JsonValue::as_str)
208			.map(Url::parse)
209			.transpose()?
210			.or_else(|| match provider.brand.as_str() {
211				| "github" => "https://api.github.com/user".try_into().ok(),
212				| _ => make_url(&provider, "userinfo").ok(),
213			})
214			.map(|url| provider.userinfo_url.replace(url));
215	}
216
217	if provider.token_url.is_none() {
218		response
219			.get("token_endpoint")
220			.and_then(JsonValue::as_str)
221			.map(Url::parse)
222			.transpose()?
223			.or_else(|| {
224				let path = if provider.brand == "github" {
225					"access_token"
226				} else {
227					"token"
228				};
229
230				make_url(&provider, path).ok()
231			})
232			.map(|url| provider.token_url.replace(url));
233	}
234
235	if provider.callback_url.is_none()
236		&& let Some(server_url) = self.services.config.well_known.client.as_ref()
237	{
238		let callback_path =
239			format!("_matrix/client/unstable/login/sso/callback/{}", provider.client_id);
240
241		provider.callback_url = Some(server_url.join(&callback_path)?);
242	}
243
244	Ok(provider)
245}
246
247/// Send a network request to a provider at the computed location of the
248/// `.well-known/openid-configuration`, returning the configuration.
249#[implement(Providers)]
250#[tracing::instrument(level = "debug", ret(level = "trace"), skip(self))]
251pub async fn discover(&self, provider: &Provider) -> Result<JsonValue> {
252	self.services
253		.client
254		.oauth
255		.get(discovery_url(provider)?)
256		.send()
257		.await?
258		.error_for_status()?
259		.json()
260		.await
261		.map_err(Into::into)
262}
263
264/// Compute the location of the `/.well-known/openid-configuration` based on the
265/// local provider config.
266fn discovery_url(provider: &Provider) -> Result<Url> {
267	let default_url = provider
268		.discovery
269		.then(|| make_url(provider, ".well-known/openid-configuration"))
270		.transpose()?;
271
272	let Some(url) = provider
273		.discovery_url
274		.clone()
275		.filter(|_| provider.discovery)
276		.or(default_url)
277	else {
278		return Err!(Config(
279			"discovery_url",
280			"Failed to determine URL for discovery of provider {}",
281			provider.id()
282		));
283	};
284
285	Ok(url)
286}
287
288/// Validate that the locally configured `issuer_url` matches the issuer claimed
289/// in any response. todo: cryptographic validation is not yet implemented here.
290fn check_issuer(
291	response: JsonObject<String, JsonValue>,
292	provider: &Provider,
293) -> Result<JsonObject<String, JsonValue>> {
294	let expected = provider
295		.issuer_url
296		.as_ref()
297		.map(Url::as_str)
298		.map(|url| url.trim_end_matches('/'));
299
300	let responded = response
301		.get("issuer")
302		.and_then(JsonValue::as_str)
303		.map(|url| url.trim_end_matches('/'));
304
305	if expected != responded {
306		return Err!(Request(Unauthorized(
307			"Configured issuer_url {expected:?} does not match discovered {responded:?}",
308		)));
309	}
310
311	Ok(response)
312}
313
314/// Generate a full URL for a request to the idp based on the idp's derived
315/// configuration.
316fn make_url(provider: &Provider, path: &str) -> Result<Url> {
317	let mut suffix = provider.base_path.clone().unwrap_or_default();
318
319	suffix.push_str(path);
320	let issuer = provider.issuer_url.as_ref().ok_or_else(|| {
321		let id = &provider.client_id;
322		err!(Config("issuer_url", "Provider {id:?} required field"))
323	})?;
324	let issuer_path = issuer.path();
325
326	if issuer_path.ends_with('/') {
327		Ok(issuer.join(suffix.as_str())?)
328	} else {
329		let mut url = issuer.to_owned();
330		url.set_path((issuer_path.to_owned() + "/").as_str());
331		Ok(url.join(&suffix)?)
332	}
333}