Skip to main content

tuwunel_api/router/
auth.rs

1mod appservice;
2mod dispatch;
3mod server;
4mod uiaa;
5
6use std::{any::TypeId, fmt::Debug, time::SystemTime};
7
8use axum::RequestPartsExt;
9use axum_extra::{
10	TypedHeader,
11	headers::{Authorization, authorization::Bearer},
12};
13use futures::{
14	TryFutureExt,
15	future::{
16		Either::{Left, Right},
17		select_ok, try_join,
18	},
19	pin_mut,
20};
21use ruma::{
22	CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId,
23	api::client::{
24		directory::get_public_rooms,
25		profile::{
26			delete_profile_field, get_avatar_url, get_display_name, get_profile,
27			get_profile_field, set_avatar_url, set_display_name, set_profile_field,
28		},
29		session::{logout, logout_all},
30	},
31};
32use tuwunel_core::{Err, Result, is_less_than, smallstr::SmallString};
33use tuwunel_service::{Services, appservice::RegistrationInfo};
34
35pub(super) use self::dispatch::AuthDispatch;
36use self::dispatch::Scheme;
37pub(crate) use self::uiaa::auth_uiaa;
38use super::request::Request;
39
40type AccessToken = SmallString<[u8; 32]>;
41
42pub(super) enum Token {
43	Appservice(Box<RegistrationInfo>),
44	User((OwnedUserId, OwnedDeviceId, Option<SystemTime>)),
45	Expired(AccessToken),
46	Invalid,
47	None,
48}
49
50#[derive(Debug, Default)]
51pub(super) struct Auth {
52	pub(super) origin: Option<OwnedServerName>,
53	pub(super) sender_user: Option<OwnedUserId>,
54	pub(super) sender_device: Option<OwnedDeviceId>,
55	pub(super) appservice_info: Option<RegistrationInfo>,
56	pub(super) _expires_at: Option<SystemTime>,
57}
58
59#[tracing::instrument(
60	level = "trace",
61	skip(services, request, json_body),
62	err(level = "debug"),
63	ret
64)]
65pub(super) async fn auth<A: AuthDispatch>(
66	services: &Services,
67	request: &mut Request,
68	json_body: Option<&CanonicalJsonValue>,
69	route: TypeId,
70) -> Result<Auth> {
71	let bearer: Option<TypedHeader<Authorization<Bearer>>> =
72		request.parts.extract().await.unwrap_or(None);
73
74	let access_token = match &bearer {
75		| Some(TypedHeader(Authorization(bearer))) => Some(bearer.token()),
76		| None => request.query.access_token.as_deref(),
77	};
78
79	let token = match find_token(services, access_token).await? {
80		| Token::User((_, _, expires_at))
81			if expires_at.is_some_and(is_less_than!(SystemTime::now())) =>
82			Token::Expired(access_token.unwrap_or_default().into()),
83
84		| token => token,
85	};
86
87	if A::SCHEME == Scheme::None {
88		check_auth_still_required(services, &token, route)?;
89	}
90
91	let auth = A::dispatch(services, request, json_body, token, route).await?;
92
93	try_join(
94		locked_account_check(services, &auth, route),
95		suspended_account_check(services, &auth, route),
96	)
97	.await?;
98
99	Ok(auth)
100}
101
102/// MSC3939: 401 `M_USER_LOCKED` for locked accounts; logout endpoints
103/// bypass. `soft_logout: true` is emitted by ruma for this errcode.
104#[inline(never)]
105async fn locked_account_check(services: &Services, auth: &Auth, route: TypeId) -> Result {
106	let Some(user_id) = auth.sender_user.as_deref() else {
107		return Ok(());
108	};
109
110	let is_logout = route == TypeId::of::<logout::v3::Request>()
111		|| route == TypeId::of::<logout_all::v3::Request>();
112
113	if is_logout || !services.users.is_locked(user_id).await {
114		return Ok(());
115	}
116
117	Err!(Request(UserLocked("This account has been locked.")))
118}
119
120/// MSC3823: 403 `M_USER_SUSPENDED` on `set_display_name` / `set_avatar_url`
121/// for suspended callers. Companion checks: per-field in the profile
122/// handlers, per-PDU in `timeline::build_and_append_pdu`.
123#[inline(never)]
124async fn suspended_account_check(services: &Services, auth: &Auth, route: TypeId) -> Result {
125	let Some(user_id) = auth.sender_user.as_deref() else {
126		return Ok(());
127	};
128
129	let blocked = route == TypeId::of::<set_display_name::v3::Request>()
130		|| route == TypeId::of::<set_avatar_url::v3::Request>()
131		|| route == TypeId::of::<set_profile_field::v3::Request>()
132		|| route == TypeId::of::<delete_profile_field::v3::Request>();
133
134	if !blocked || !services.users.is_suspended(user_id).await {
135		return Ok(());
136	}
137
138	Err!(Request(UserSuspended("Account is suspended.")))
139}
140
141#[inline(never)]
142fn check_auth_still_required(services: &Services, token: &Token, route: TypeId) -> Result {
143	let is_profile = route == TypeId::of::<get_profile::v3::Request>()
144		|| route == TypeId::of::<get_profile_field::v3::Request>()
145		|| route == TypeId::of::<get_display_name::v3::Request>()
146		|| route == TypeId::of::<get_avatar_url::v3::Request>();
147
148	let is_public_rooms = route == TypeId::of::<get_public_rooms::v3::Request>();
149
150	if (is_profile
151		&& services
152			.server
153			.config
154			.require_auth_for_profile_requests)
155		|| (is_public_rooms
156			&& !services
157				.server
158				.config
159				.allow_public_room_directory_without_auth)
160	{
161		match token {
162			| Token::Appservice(_) | Token::User(_) => Ok(()),
163			| Token::None | Token::Expired(_) | Token::Invalid =>
164				Err!(Request(MissingToken("Missing or invalid access token."))),
165		}
166	} else {
167		Ok(())
168	}
169}
170
171async fn find_token(services: &Services, token: Option<&str>) -> Result<Token> {
172	let Some(token) = token else {
173		return Ok(Token::None);
174	};
175
176	let user_token = services
177		.users
178		.find_from_token(token)
179		.map_ok(Token::User);
180
181	let appservice_token = services
182		.appservice
183		.find_from_access_token(token)
184		.map_ok(Box::new)
185		.map_ok(Token::Appservice);
186
187	pin_mut!(user_token, appservice_token);
188	match select_ok([Left(user_token), Right(appservice_token)]).await {
189		| Err(e) if !e.is_not_found() => Err(e),
190		| Ok((token, _)) => Ok(token),
191		| _ => Ok(Token::Invalid),
192	}
193}