tuwunel_admin/query/oauth/
associate.rs1use futures::StreamExt;
2use ruma::OwnedUserId;
3use tuwunel_core::{Err, Result, apply, err, itertools::Itertools, utils::stream::ReadyExt};
4
5use crate::admin_command;
6
7#[admin_command]
8pub(super) async fn oauth_associate(
9 &self,
10 provider: String,
11 user_id: OwnedUserId,
12 claim: Vec<String>,
13 force: bool,
14) -> Result {
15 if !self.services.globals.user_is_local(&user_id) {
16 return Err!(Request(NotFound("User {user_id:?} does not belong to this server.")));
17 }
18
19 if !self.services.users.exists(&user_id).await {
20 return Err!(Request(NotFound("User {user_id:?} is not registered")));
21 }
22
23 let provider = self
24 .services
25 .oauth
26 .providers
27 .get(&provider)
28 .await?;
29
30 let claim = claim
31 .iter()
32 .map(|kv| {
33 let (key, val) = kv
34 .split_once('=')
35 .ok_or_else(|| err!("Missing '=' in --claim {kv}=???"))?;
36
37 if !key.is_empty() && !val.is_empty() {
38 Ok((key, val))
39 } else {
40 Err!("Missing key or value in --claim=key=value argument")
41 }
42 })
43 .map_ok(apply!(2, ToOwned::to_owned))
44 .collect::<Result<_>>()?;
45
46 let committed = self
47 .services
48 .oauth
49 .user_sessions(&user_id)
50 .ready_filter_map(Result::ok)
51 .count()
52 .await;
53
54 if committed > 0 && !force {
55 return Err!(
56 "{user_id} already has {committed} committed OAuth session(s); the pending claim \
57 would be shadowed at login. Re-run with --force to replace existing sessions, or \
58 run `query oauth delete {user_id} --force` first."
59 );
60 }
61
62 if committed > 0 {
63 self.services
64 .oauth
65 .delete_user_sessions(&user_id)
66 .await;
67 }
68
69 let replaced = self
70 .services
71 .oauth
72 .sessions
73 .set_user_association_pending(provider.id(), &user_id, claim);
74
75 let lead = match committed {
76 | 0 => format!("Pending association {}", replaced.map_or("added", |_| "replaced")),
77 | n => format!(
78 "Replaced {n} committed session(s) across all providers and added pending \
79 association"
80 ),
81 };
82
83 writeln!(self, "{lead} for {user_id} on provider {}.", provider.id()).await
84}