tuwunel_service/oauth/sessions/
association.rs1use std::collections::BTreeMap;
2
3use ruma::{OwnedUserId, UserId};
4use serde_json::Value;
5use tuwunel_core::{debug, implement, trace};
6
7use super::{Sessions, UserInfo};
8
9pub(super) type Pending = BTreeMap<String, Claimants>;
10type Claimants = BTreeMap<OwnedUserId, Claims>;
11pub type Claims = BTreeMap<String, String>;
12
13#[implement(Sessions)]
14pub fn set_user_association_pending(
15 &self,
16 idp_id: &str,
17 user_id: &UserId,
18 claims: Claims,
19) -> Option<Claims> {
20 self.association_pending
21 .lock()
22 .expect("locked")
23 .entry(idp_id.into())
24 .or_default()
25 .insert(user_id.into(), claims)
26}
27
28#[implement(Sessions)]
29pub fn find_user_association_pending(
30 &self,
31 idp_id: &str,
32 userinfo: &UserInfo,
33) -> Option<OwnedUserId> {
34 let claiming = serde_json::to_value(userinfo)
35 .expect("Failed to transform user_info into serde_json::Value");
36
37 let claiming = claiming
38 .as_object()
39 .expect("Failed to interpret user_info as object");
40
41 assert!(
42 !claiming.is_empty(),
43 "Expecting at least one claim from user_info such as `sub`"
44 );
45
46 debug!(?idp_id, ?claiming, "finding pending association",);
47 self.association_pending
48 .lock()
49 .expect("locked")
50 .get(idp_id)
51 .into_iter()
52 .flat_map(Claimants::iter)
53 .find_map(|(user_id, claimant)| {
54 trace!(?user_id, ?claimant, "checking against pending association");
55
56 assert!(
57 !claimant.is_empty(),
58 "Must not match empty set of claims; should not exist in association_pending."
59 );
60
61 for (claim, value) in claimant {
62 if claiming.get(claim).and_then(Value::as_str) != Some(value) {
63 return None;
64 }
65 }
66
67 Some(user_id.clone())
68 })
69}
70
71#[implement(Sessions)]
72pub fn remove_provider_associations_pending(&self, idp_id: &str) {
73 self.association_pending
74 .lock()
75 .expect("locked")
76 .remove(idp_id);
77}
78
79#[implement(Sessions)]
80pub fn remove_user_association_pending(&self, user_id: &UserId, idp_id: Option<&str>) {
81 self.association_pending
82 .lock()
83 .expect("locked")
84 .iter_mut()
85 .filter(|(provider, _)| idp_id == Some(provider))
86 .for_each(|(_, claiming)| {
87 claiming.remove(user_id);
88 });
89}
90
91#[implement(Sessions)]
92pub fn is_user_association_pending(&self, user_id: &UserId) -> bool {
93 self.association_pending
94 .lock()
95 .expect("locked")
96 .values()
97 .any(|claiming| claiming.contains_key(user_id))
98}