tuwunel_service/oauth/
providers.rs1use std::collections::BTreeMap;
2
3use serde_json::{Map as JsonObject, Value as JsonValue};
4use tokio::sync::RwLock;
5pub use tuwunel_core::config::IdentityProvider as Provider;
6use tuwunel_core::{Err, Result, debug, debug::INFO_SPAN_LEVEL, err, implement};
7use url::Url;
8
9use crate::{SelfServices, client::read_response_capped};
10
11#[derive(Default)]
13pub struct Providers {
14 services: SelfServices,
15 providers: RwLock<BTreeMap<ProviderId, Provider>>,
16}
17
18pub type ProviderId = String;
20
21#[implement(Providers)]
22pub(super) fn build(args: &crate::Args<'_>) -> Self {
23 Self {
24 services: args.services.clone(),
25 ..Default::default()
26 }
27}
28
29#[implement(Providers)]
33#[tracing::instrument(level = "debug", skip(self))]
34pub async fn get(&self, id: &str) -> Result<Provider> {
35 if let Some(provider) = self.get_cached(id).await {
36 return Ok(provider);
37 }
38
39 let config = self.get_config(id)?;
40 let id = config.id().to_owned();
41 let mut map = self.providers.write().await;
42 let provider = self.configure(config).await?;
43
44 debug!(?id, ?provider);
45 _ = map.insert(id, provider.clone());
46
47 Ok(provider)
48}
49
50#[implement(Providers)]
59pub fn get_config(&self, id: &str) -> Result<Provider> {
60 let providers = &self.services.config.identity_provider;
61
62 if let Some(provider) = providers
63 .values()
64 .find(|config| config.id() == id)
65 .cloned()
66 {
67 return Ok(provider);
68 }
69
70 if let Some(provider) = providers
71 .values()
72 .find(|config| config.brand.eq_ignore_ascii_case(id))
73 .filter(|_| {
74 providers
75 .values()
76 .filter(|config| config.brand.eq_ignore_ascii_case(id))
77 .count()
78 .eq(&1)
79 })
80 .cloned()
81 {
82 return Ok(provider);
83 }
84
85 Err!(Request(NotFound("Unrecognized Identity Provider")))
86}
87
88#[implement(Providers)]
91pub fn get_default_id(&self) -> Option<String> {
92 self.services
93 .config
94 .identity_provider
95 .values()
96 .find(|idp| idp.default)
97 .or_else(|| {
98 self.services
99 .config
100 .identity_provider
101 .values()
102 .next()
103 })
104 .map(Provider::id)
105 .map(ToOwned::to_owned)
106}
107
108#[implement(Providers)]
111async fn get_cached(&self, id: &str) -> Option<Provider> {
112 let providers = self.providers.read().await;
113
114 if let Some(provider) = providers.get(id).cloned() {
115 return Some(provider);
116 }
117
118 providers
119 .values()
120 .find(|provider| provider.brand.eq_ignore_ascii_case(id))
121 .filter(|_| {
122 providers
123 .values()
124 .filter(|provider| provider.brand.eq_ignore_ascii_case(id))
125 .count()
126 .eq(&1)
127 })
128 .cloned()
129}
130
131#[implement(Providers)]
136#[tracing::instrument(
137 level = INFO_SPAN_LEVEL,
138 ret(level = "debug"),
139 skip(self),
140)]
141async fn configure(&self, mut provider: Provider) -> Result<Provider> {
142 _ = provider
143 .name
144 .get_or_insert_with(|| provider.brand.clone());
145
146 if provider.issuer_url.is_none() {
147 _ = provider
148 .issuer_url
149 .replace(match provider.brand.as_str() {
150 | "github" => "https://github.com/login/oauth".try_into()?,
151 | "gitlab" => "https://gitlab.com".try_into()?,
152 | "google" => "https://accounts.google.com".try_into()?,
153 | _ => return Err!(Config("issuer_url", "Required for this provider.")),
154 });
155 }
156
157 let response = self
158 .discover(&provider)
159 .await
160 .and_then(|response| {
161 response.as_object().cloned().ok_or_else(|| {
162 err!(Request(NotJson("Expecting JSON object for discovery response")))
163 })
164 })
165 .and_then(|response| check_issuer(response, &provider))?;
166
167 if provider.authorization_url.is_none() {
168 response
169 .get("authorization_endpoint")
170 .and_then(JsonValue::as_str)
171 .map(Url::parse)
172 .transpose()?
173 .or_else(|| make_url(&provider, "authorize").ok())
174 .map(|url| provider.authorization_url.replace(url));
175 }
176
177 if provider.revocation_url.is_none() {
178 response
179 .get("revocation_endpoint")
180 .and_then(JsonValue::as_str)
181 .map(Url::parse)
182 .transpose()?
183 .or_else(|| make_url(&provider, "revocation").ok())
184 .map(|url| provider.revocation_url.replace(url));
185 }
186
187 if provider.introspection_url.is_none() {
188 response
189 .get("introspection_endpoint")
190 .and_then(JsonValue::as_str)
191 .map(Url::parse)
192 .transpose()?
193 .or_else(|| make_url(&provider, "introspection").ok())
194 .map(|url| provider.introspection_url.replace(url));
195 }
196
197 if provider.userinfo_url.is_none() {
198 response
199 .get("userinfo_endpoint")
200 .and_then(JsonValue::as_str)
201 .map(Url::parse)
202 .transpose()?
203 .or_else(|| match provider.brand.as_str() {
204 | "github" => "https://api.github.com/user".try_into().ok(),
205 | _ => make_url(&provider, "userinfo").ok(),
206 })
207 .map(|url| provider.userinfo_url.replace(url));
208 }
209
210 if provider.token_url.is_none() {
211 response
212 .get("token_endpoint")
213 .and_then(JsonValue::as_str)
214 .map(Url::parse)
215 .transpose()?
216 .or_else(|| {
217 let path = if provider.brand == "github" {
218 "access_token"
219 } else {
220 "token"
221 };
222
223 make_url(&provider, path).ok()
224 })
225 .map(|url| provider.token_url.replace(url));
226 }
227
228 if provider.callback_url.is_none()
229 && let Some(server_url) = self.services.config.well_known.client.as_ref()
230 {
231 let callback_path =
232 format!("_matrix/client/unstable/login/sso/callback/{}", provider.client_id);
233
234 provider.callback_url = Some(server_url.join(&callback_path)?);
235 }
236
237 Ok(provider)
238}
239
240#[implement(Providers)]
243#[tracing::instrument(level = "debug", ret(level = "trace"), skip(self))]
244pub async fn discover(&self, provider: &Provider) -> Result<JsonValue> {
245 let limit = self.services.config.max_response_size;
246 let response = self
247 .services
248 .client
249 .oauth
250 .get(discovery_url(provider)?)
251 .send()
252 .await?
253 .error_for_status()?;
254
255 let body = read_response_capped(response, limit).await?;
256
257 serde_json::from_slice(&body).map_err(Into::into)
258}
259
260fn discovery_url(provider: &Provider) -> Result<Url> {
263 let default_url = provider
264 .discovery
265 .then(|| make_url(provider, ".well-known/openid-configuration"))
266 .transpose()?;
267
268 let Some(url) = provider
269 .discovery_url
270 .clone()
271 .filter(|_| provider.discovery)
272 .or(default_url)
273 else {
274 return Err!(Config(
275 "discovery_url",
276 "Failed to determine URL for discovery of provider {}",
277 provider.id()
278 ));
279 };
280
281 Ok(url)
282}
283
284fn check_issuer(
287 response: JsonObject<String, JsonValue>,
288 provider: &Provider,
289) -> Result<JsonObject<String, JsonValue>> {
290 let expected = provider
291 .issuer_url
292 .as_ref()
293 .map(Url::as_str)
294 .map(|url| url.trim_end_matches('/'));
295
296 let responded = response
297 .get("issuer")
298 .and_then(JsonValue::as_str)
299 .map(|url| url.trim_end_matches('/'));
300
301 if expected != responded {
302 return Err!(Request(Unauthorized(
303 "Configured issuer_url {expected:?} does not match discovered {responded:?}",
304 )));
305 }
306
307 Ok(response)
308}
309
310fn make_url(provider: &Provider, path: &str) -> Result<Url> {
313 let mut suffix = provider.base_path.clone().unwrap_or_default();
314
315 suffix.push_str(path);
316 let issuer = provider.issuer_url.as_ref().ok_or_else(|| {
317 let id = &provider.client_id;
318 err!(Config("issuer_url", "Provider {id:?} required field"))
319 })?;
320 let issuer_path = issuer.path();
321
322 if issuer_path.ends_with('/') {
323 Ok(issuer.join(suffix.as_str())?)
324 } else {
325 let mut url = issuer.to_owned();
326 url.set_path((issuer_path.to_owned() + "/").as_str());
327 Ok(url.join(&suffix)?)
328 }
329}