Skip to main content

tuwunel_admin/query/oauth/
associate.rs

1use futures::StreamExt;
2use ruma::OwnedUserId;
3use tuwunel_core::{Err, Result, apply, err, itertools::Itertools, utils::stream::ReadyExt};
4
5use crate::admin_command;
6
7#[admin_command]
8pub(super) async fn oauth_associate(
9	&self,
10	provider: String,
11	user_id: OwnedUserId,
12	claim: Vec<String>,
13	force: bool,
14) -> Result {
15	if !self.services.globals.user_is_local(&user_id) {
16		return Err!(Request(NotFound("User {user_id:?} does not belong to this server.")));
17	}
18
19	if !self.services.users.exists(&user_id).await {
20		return Err!(Request(NotFound("User {user_id:?} is not registered")));
21	}
22
23	let provider = self
24		.services
25		.oauth
26		.providers
27		.get(&provider)
28		.await?;
29
30	let claim = claim
31		.iter()
32		.map(|kv| {
33			let (key, val) = kv
34				.split_once('=')
35				.ok_or_else(|| err!("Missing '=' in --claim {kv}=???"))?;
36
37			if !key.is_empty() && !val.is_empty() {
38				Ok((key, val))
39			} else {
40				Err!("Missing key or value in --claim=key=value argument")
41			}
42		})
43		.map_ok(apply!(2, ToOwned::to_owned))
44		.collect::<Result<_>>()?;
45
46	let committed = self
47		.services
48		.oauth
49		.user_sessions(&user_id)
50		.ready_filter_map(Result::ok)
51		.count()
52		.await;
53
54	if committed > 0 && !force {
55		return Err!(
56			"{user_id} already has {committed} committed OAuth session(s); the pending claim \
57			 would be shadowed at login. Re-run with --force to replace existing sessions, or \
58			 run `query oauth delete {user_id} --force` first."
59		);
60	}
61
62	if committed > 0 {
63		self.services
64			.oauth
65			.delete_user_sessions(&user_id)
66			.await;
67	}
68
69	let replaced = self
70		.services
71		.oauth
72		.sessions
73		.set_user_association_pending(provider.id(), &user_id, claim);
74
75	let lead = match committed {
76		| 0 => format!("Pending association {}", replaced.map_or("added", |_| "replaced")),
77		| n => format!(
78			"Replaced {n} committed session(s) across all providers and added pending \
79			 association"
80		),
81	};
82
83	writeln!(self, "{lead} for {user_id} on provider {}.", provider.id()).await
84}