tuwunel_service/oauth/
providers.rs1use std::collections::BTreeMap;
2
3use serde_json::{Map as JsonObject, Value as JsonValue};
4use tokio::sync::RwLock;
5pub use tuwunel_core::config::IdentityProvider as Provider;
6use tuwunel_core::{Err, Result, debug, debug::INFO_SPAN_LEVEL, err, implement};
7use url::Url;
8
9use crate::SelfServices;
10
11#[derive(Default)]
13pub struct Providers {
14 services: SelfServices,
15 providers: RwLock<BTreeMap<ProviderId, Provider>>,
16}
17
18pub type ProviderId = String;
20
21#[implement(Providers)]
22pub(super) fn build(args: &crate::Args<'_>) -> Self {
23 Self {
24 services: args.services.clone(),
25 ..Default::default()
26 }
27}
28
29#[implement(Providers)]
33#[tracing::instrument(level = "debug", skip(self))]
34pub async fn get(&self, id: &str) -> Result<Provider> {
35 if let Some(provider) = self.get_cached(id).await {
36 return Ok(provider);
37 }
38
39 let config = self.get_config(id)?;
40 let id = config.id().to_owned();
41 let mut map = self.providers.write().await;
42 let provider = self.configure(config).await?;
43
44 debug!(?id, ?provider);
45 _ = map.insert(id, provider.clone());
46
47 Ok(provider)
48}
49
50#[implement(Providers)]
59pub fn get_config(&self, id: &str) -> Result<Provider> {
60 let providers = &self.services.config.identity_provider;
61
62 if let Some(provider) = providers
63 .values()
64 .find(|config| config.id() == id)
65 .cloned()
66 {
67 return Ok(provider);
68 }
69
70 if let Some(provider) = providers
71 .values()
72 .find(|config| config.brand.eq_ignore_ascii_case(id))
73 .filter(|_| {
74 providers
75 .values()
76 .filter(|config| config.brand.eq_ignore_ascii_case(id))
77 .count()
78 .eq(&1)
79 })
80 .cloned()
81 {
82 return Ok(provider);
83 }
84
85 Err!(Request(NotFound("Unrecognized Identity Provider")))
86}
87
88#[implement(Providers)]
91pub fn get_default_id(&self) -> Option<String> {
92 self.services
93 .config
94 .identity_provider
95 .values()
96 .find(|idp| idp.default)
97 .or_else(|| {
98 self.services
99 .config
100 .identity_provider
101 .values()
102 .next()
103 })
104 .map(Provider::id)
105 .map(ToOwned::to_owned)
106}
107
108#[implement(Providers)]
111async fn get_cached(&self, id: &str) -> Option<Provider> {
112 let providers = self.providers.read().await;
113
114 if let Some(provider) = providers.get(id).cloned() {
115 return Some(provider);
116 }
117
118 providers
119 .values()
120 .find(|provider| provider.brand.eq_ignore_ascii_case(id))
121 .filter(|_| {
122 providers
123 .values()
124 .filter(|provider| provider.brand.eq_ignore_ascii_case(id))
125 .count()
126 .eq(&1)
127 })
128 .cloned()
129}
130
131#[implement(Providers)]
136#[tracing::instrument(
137 level = INFO_SPAN_LEVEL,
138 ret(level = "debug"),
139 skip(self),
140)]
141async fn configure(&self, mut provider: Provider) -> Result<Provider> {
142 _ = provider
143 .name
144 .get_or_insert_with(|| provider.brand.clone());
145
146 if provider.issuer_url.is_none() {
147 _ = provider
148 .issuer_url
149 .replace(match provider.brand.as_str() {
150 | "github" => "https://github.com".try_into()?,
151 | "gitlab" => "https://gitlab.com".try_into()?,
152 | "google" => "https://accounts.google.com".try_into()?,
153 | _ => return Err!(Config("issuer_url", "Required for this provider.")),
154 });
155 }
156
157 if provider.base_path.is_none() {
158 provider.base_path = match provider.brand.as_str() {
159 | "github" => Some("login/oauth/".to_owned()),
160 | _ => None,
161 };
162 }
163
164 let response = self
165 .discover(&provider)
166 .await
167 .and_then(|response| {
168 response.as_object().cloned().ok_or_else(|| {
169 err!(Request(NotJson("Expecting JSON object for discovery response")))
170 })
171 })
172 .and_then(|response| check_issuer(response, &provider))?;
173
174 if provider.authorization_url.is_none() {
175 response
176 .get("authorization_endpoint")
177 .and_then(JsonValue::as_str)
178 .map(Url::parse)
179 .transpose()?
180 .or_else(|| make_url(&provider, "authorize").ok())
181 .map(|url| provider.authorization_url.replace(url));
182 }
183
184 if provider.revocation_url.is_none() {
185 response
186 .get("revocation_endpoint")
187 .and_then(JsonValue::as_str)
188 .map(Url::parse)
189 .transpose()?
190 .or_else(|| make_url(&provider, "revocation").ok())
191 .map(|url| provider.revocation_url.replace(url));
192 }
193
194 if provider.introspection_url.is_none() {
195 response
196 .get("introspection_endpoint")
197 .and_then(JsonValue::as_str)
198 .map(Url::parse)
199 .transpose()?
200 .or_else(|| make_url(&provider, "introspection").ok())
201 .map(|url| provider.introspection_url.replace(url));
202 }
203
204 if provider.userinfo_url.is_none() {
205 response
206 .get("userinfo_endpoint")
207 .and_then(JsonValue::as_str)
208 .map(Url::parse)
209 .transpose()?
210 .or_else(|| match provider.brand.as_str() {
211 | "github" => "https://api.github.com/user".try_into().ok(),
212 | _ => make_url(&provider, "userinfo").ok(),
213 })
214 .map(|url| provider.userinfo_url.replace(url));
215 }
216
217 if provider.token_url.is_none() {
218 response
219 .get("token_endpoint")
220 .and_then(JsonValue::as_str)
221 .map(Url::parse)
222 .transpose()?
223 .or_else(|| {
224 let path = if provider.brand == "github" {
225 "access_token"
226 } else {
227 "token"
228 };
229
230 make_url(&provider, path).ok()
231 })
232 .map(|url| provider.token_url.replace(url));
233 }
234
235 if provider.callback_url.is_none()
236 && let Some(server_url) = self.services.config.well_known.client.as_ref()
237 {
238 let callback_path =
239 format!("_matrix/client/unstable/login/sso/callback/{}", provider.client_id);
240
241 provider.callback_url = Some(server_url.join(&callback_path)?);
242 }
243
244 Ok(provider)
245}
246
247#[implement(Providers)]
250#[tracing::instrument(level = "debug", ret(level = "trace"), skip(self))]
251pub async fn discover(&self, provider: &Provider) -> Result<JsonValue> {
252 self.services
253 .client
254 .oauth
255 .get(discovery_url(provider)?)
256 .send()
257 .await?
258 .error_for_status()?
259 .json()
260 .await
261 .map_err(Into::into)
262}
263
264fn discovery_url(provider: &Provider) -> Result<Url> {
267 let default_url = provider
268 .discovery
269 .then(|| make_url(provider, ".well-known/openid-configuration"))
270 .transpose()?;
271
272 let Some(url) = provider
273 .discovery_url
274 .clone()
275 .filter(|_| provider.discovery)
276 .or(default_url)
277 else {
278 return Err!(Config(
279 "discovery_url",
280 "Failed to determine URL for discovery of provider {}",
281 provider.id()
282 ));
283 };
284
285 Ok(url)
286}
287
288fn check_issuer(
291 response: JsonObject<String, JsonValue>,
292 provider: &Provider,
293) -> Result<JsonObject<String, JsonValue>> {
294 let expected = provider
295 .issuer_url
296 .as_ref()
297 .map(Url::as_str)
298 .map(|url| url.trim_end_matches('/'));
299
300 let responded = response
301 .get("issuer")
302 .and_then(JsonValue::as_str)
303 .map(|url| url.trim_end_matches('/'));
304
305 if expected != responded {
306 return Err!(Request(Unauthorized(
307 "Configured issuer_url {expected:?} does not match discovered {responded:?}",
308 )));
309 }
310
311 Ok(response)
312}
313
314fn make_url(provider: &Provider, path: &str) -> Result<Url> {
317 let mut suffix = provider.base_path.clone().unwrap_or_default();
318
319 suffix.push_str(path);
320 let issuer = provider.issuer_url.as_ref().ok_or_else(|| {
321 let id = &provider.client_id;
322 err!(Config("issuer_url", "Provider {id:?} required field"))
323 })?;
324 let issuer_path = issuer.path();
325
326 if issuer_path.ends_with('/') {
327 Ok(issuer.join(suffix.as_str())?)
328 } else {
329 let mut url = issuer.to_owned();
330 url.set_path((issuer_path.to_owned() + "/").as_str());
331 Ok(url.join(&suffix)?)
332 }
333}