Skip to main content

tuwunel_api/router/
auth.rs

1mod appservice;
2mod dispatch;
3mod server;
4mod uiaa;
5
6use std::{any::TypeId, fmt::Debug, time::SystemTime};
7
8use axum::RequestPartsExt;
9use axum_extra::{
10	TypedHeader,
11	headers::{Authorization, authorization::Bearer},
12};
13use futures::{
14	TryFutureExt,
15	future::{
16		Either::{Left, Right},
17		select_ok, try_join,
18	},
19	pin_mut,
20};
21use ruma::{
22	CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId,
23	api::client::{
24		directory::get_public_rooms,
25		profile::{
26			get_avatar_url, get_display_name, get_profile, get_profile_field, set_avatar_url,
27			set_display_name,
28		},
29		session::{logout, logout_all},
30	},
31};
32use tuwunel_core::{Err, Result, is_less_than};
33use tuwunel_service::{Services, appservice::RegistrationInfo};
34
35pub(super) use self::dispatch::AuthDispatch;
36use self::dispatch::Scheme;
37pub(crate) use self::uiaa::auth_uiaa;
38use super::request::Request;
39
40pub(super) enum Token {
41	Appservice(Box<RegistrationInfo>),
42	User((OwnedUserId, OwnedDeviceId, Option<SystemTime>)),
43	Expired((OwnedUserId, OwnedDeviceId)),
44	Invalid,
45	None,
46}
47
48#[derive(Debug, Default)]
49pub(super) struct Auth {
50	pub(super) origin: Option<OwnedServerName>,
51	pub(super) sender_user: Option<OwnedUserId>,
52	pub(super) sender_device: Option<OwnedDeviceId>,
53	pub(super) appservice_info: Option<RegistrationInfo>,
54	pub(super) _expires_at: Option<SystemTime>,
55}
56
57#[tracing::instrument(
58	level = "trace",
59	skip(services, request, json_body),
60	err(level = "debug"),
61	ret
62)]
63pub(super) async fn auth<A: AuthDispatch>(
64	services: &Services,
65	request: &mut Request,
66	json_body: Option<&CanonicalJsonValue>,
67	route: TypeId,
68) -> Result<Auth> {
69	let bearer: Option<TypedHeader<Authorization<Bearer>>> =
70		request.parts.extract().await.unwrap_or(None);
71
72	let token = match &bearer {
73		| Some(TypedHeader(Authorization(bearer))) => Some(bearer.token()),
74		| None => request.query.access_token.as_deref(),
75	};
76
77	let token = match find_token(services, token).await? {
78		| Token::User((user_id, device_id, expires_at))
79			if expires_at.is_some_and(is_less_than!(SystemTime::now())) =>
80			Token::Expired((user_id, device_id)),
81
82		| token => token,
83	};
84
85	if A::SCHEME == Scheme::None {
86		check_auth_still_required(services, &token, route)?;
87	}
88
89	let auth = A::dispatch(services, request, json_body, token, route).await?;
90
91	try_join(
92		locked_account_check(services, &auth, route),
93		suspended_account_check(services, &auth, route),
94	)
95	.await?;
96
97	Ok(auth)
98}
99
100/// MSC3939: 401 `M_USER_LOCKED` for locked accounts; logout endpoints
101/// bypass. `soft_logout: true` is emitted by ruma for this errcode.
102#[inline(never)]
103async fn locked_account_check(services: &Services, auth: &Auth, route: TypeId) -> Result {
104	let Some(user_id) = auth.sender_user.as_deref() else {
105		return Ok(());
106	};
107
108	let is_logout = route == TypeId::of::<logout::v3::Request>()
109		|| route == TypeId::of::<logout_all::v3::Request>();
110
111	if is_logout || !services.users.is_locked(user_id).await {
112		return Ok(());
113	}
114
115	Err!(Request(UserLocked("This account has been locked.")))
116}
117
118/// MSC3823: 403 `M_USER_SUSPENDED` on `set_display_name` / `set_avatar_url`
119/// for suspended callers. Companion checks: per-field in the profile
120/// handlers, per-PDU in `timeline::build_and_append_pdu`.
121#[inline(never)]
122async fn suspended_account_check(services: &Services, auth: &Auth, route: TypeId) -> Result {
123	let Some(user_id) = auth.sender_user.as_deref() else {
124		return Ok(());
125	};
126
127	let blocked = route == TypeId::of::<set_display_name::v3::Request>()
128		|| route == TypeId::of::<set_avatar_url::v3::Request>();
129
130	if !blocked || !services.users.is_suspended(user_id).await {
131		return Ok(());
132	}
133
134	Err!(Request(UserSuspended("Account is suspended.")))
135}
136
137#[inline(never)]
138fn check_auth_still_required(services: &Services, token: &Token, route: TypeId) -> Result {
139	let is_profile = route == TypeId::of::<get_profile::v3::Request>()
140		|| route == TypeId::of::<get_profile_field::v3::Request>()
141		|| route == TypeId::of::<get_display_name::v3::Request>()
142		|| route == TypeId::of::<get_avatar_url::v3::Request>();
143
144	let is_public_rooms = route == TypeId::of::<get_public_rooms::v3::Request>();
145
146	if (is_profile
147		&& services
148			.server
149			.config
150			.require_auth_for_profile_requests)
151		|| (is_public_rooms
152			&& !services
153				.server
154				.config
155				.allow_public_room_directory_without_auth)
156	{
157		match token {
158			| Token::Appservice(_) | Token::User(_) => Ok(()),
159			| Token::None | Token::Expired(_) | Token::Invalid =>
160				Err!(Request(MissingToken("Missing or invalid access token."))),
161		}
162	} else {
163		Ok(())
164	}
165}
166
167async fn find_token(services: &Services, token: Option<&str>) -> Result<Token> {
168	let Some(token) = token else {
169		return Ok(Token::None);
170	};
171
172	let user_token = services
173		.users
174		.find_from_token(token)
175		.map_ok(Token::User);
176
177	let appservice_token = services
178		.appservice
179		.find_from_access_token(token)
180		.map_ok(Box::new)
181		.map_ok(Token::Appservice);
182
183	pin_mut!(user_token, appservice_token);
184	match select_ok([Left(user_token), Right(appservice_token)]).await {
185		| Err(e) if !e.is_not_found() => Err(e),
186		| Ok((token, _)) => Ok(token),
187		| _ => Ok(Token::Invalid),
188	}
189}