1mod appservice;
2mod dispatch;
3mod server;
4mod uiaa;
5
6use std::{any::TypeId, fmt::Debug, time::SystemTime};
7
8use axum::RequestPartsExt;
9use axum_extra::{
10 TypedHeader,
11 headers::{Authorization, authorization::Bearer},
12};
13use futures::{
14 TryFutureExt,
15 future::{
16 Either::{Left, Right},
17 select_ok, try_join,
18 },
19 pin_mut,
20};
21use ruma::{
22 CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId,
23 api::client::{
24 directory::get_public_rooms,
25 profile::{
26 delete_profile_field, get_avatar_url, get_display_name, get_profile,
27 get_profile_field, set_avatar_url, set_display_name, set_profile_field,
28 },
29 session::{logout, logout_all},
30 },
31};
32use tuwunel_core::{Err, Result, is_less_than, smallstr::SmallString};
33use tuwunel_service::{Services, appservice::RegistrationInfo};
34
35pub(super) use self::dispatch::AuthDispatch;
36use self::dispatch::Scheme;
37pub(crate) use self::uiaa::auth_uiaa;
38use super::request::Request;
39
40type AccessToken = SmallString<[u8; 32]>;
41
42pub(super) enum Token {
43 Appservice(Box<RegistrationInfo>),
44 User((OwnedUserId, OwnedDeviceId, Option<SystemTime>)),
45 Expired(AccessToken),
46 Invalid,
47 None,
48}
49
50#[derive(Debug, Default)]
51pub(super) struct Auth {
52 pub(super) origin: Option<OwnedServerName>,
53 pub(super) sender_user: Option<OwnedUserId>,
54 pub(super) sender_device: Option<OwnedDeviceId>,
55 pub(super) appservice_info: Option<RegistrationInfo>,
56 pub(super) _expires_at: Option<SystemTime>,
57}
58
59#[tracing::instrument(
60 level = "trace",
61 skip(services, request, json_body),
62 err(level = "debug"),
63 ret
64)]
65pub(super) async fn auth<A: AuthDispatch>(
66 services: &Services,
67 request: &mut Request,
68 json_body: Option<&CanonicalJsonValue>,
69 route: TypeId,
70) -> Result<Auth> {
71 let bearer: Option<TypedHeader<Authorization<Bearer>>> =
72 request.parts.extract().await.unwrap_or(None);
73
74 let access_token = match &bearer {
75 | Some(TypedHeader(Authorization(bearer))) => Some(bearer.token()),
76 | None => request.query.access_token.as_deref(),
77 };
78
79 let token = match find_token(services, access_token).await? {
80 | Token::User((_, _, expires_at))
81 if expires_at.is_some_and(is_less_than!(SystemTime::now())) =>
82 Token::Expired(access_token.unwrap_or_default().into()),
83
84 | token => token,
85 };
86
87 if A::SCHEME == Scheme::None {
88 check_auth_still_required(services, &token, route)?;
89 }
90
91 let auth = A::dispatch(services, request, json_body, token, route).await?;
92
93 try_join(
94 locked_account_check(services, &auth, route),
95 suspended_account_check(services, &auth, route),
96 )
97 .await?;
98
99 Ok(auth)
100}
101
102#[inline(never)]
105async fn locked_account_check(services: &Services, auth: &Auth, route: TypeId) -> Result {
106 let Some(user_id) = auth.sender_user.as_deref() else {
107 return Ok(());
108 };
109
110 let is_logout = route == TypeId::of::<logout::v3::Request>()
111 || route == TypeId::of::<logout_all::v3::Request>();
112
113 if is_logout || !services.users.is_locked(user_id).await {
114 return Ok(());
115 }
116
117 Err!(Request(UserLocked("This account has been locked.")))
118}
119
120#[inline(never)]
124async fn suspended_account_check(services: &Services, auth: &Auth, route: TypeId) -> Result {
125 let Some(user_id) = auth.sender_user.as_deref() else {
126 return Ok(());
127 };
128
129 let blocked = route == TypeId::of::<set_display_name::v3::Request>()
130 || route == TypeId::of::<set_avatar_url::v3::Request>()
131 || route == TypeId::of::<set_profile_field::v3::Request>()
132 || route == TypeId::of::<delete_profile_field::v3::Request>();
133
134 if !blocked || !services.users.is_suspended(user_id).await {
135 return Ok(());
136 }
137
138 Err!(Request(UserSuspended("Account is suspended.")))
139}
140
141#[inline(never)]
142fn check_auth_still_required(services: &Services, token: &Token, route: TypeId) -> Result {
143 let is_profile = route == TypeId::of::<get_profile::v3::Request>()
144 || route == TypeId::of::<get_profile_field::v3::Request>()
145 || route == TypeId::of::<get_display_name::v3::Request>()
146 || route == TypeId::of::<get_avatar_url::v3::Request>();
147
148 let is_public_rooms = route == TypeId::of::<get_public_rooms::v3::Request>();
149
150 if (is_profile
151 && services
152 .server
153 .config
154 .require_auth_for_profile_requests)
155 || (is_public_rooms
156 && !services
157 .server
158 .config
159 .allow_public_room_directory_without_auth)
160 {
161 match token {
162 | Token::Appservice(_) | Token::User(_) => Ok(()),
163 | Token::None | Token::Expired(_) | Token::Invalid =>
164 Err!(Request(MissingToken("Missing or invalid access token."))),
165 }
166 } else {
167 Ok(())
168 }
169}
170
171async fn find_token(services: &Services, token: Option<&str>) -> Result<Token> {
172 let Some(token) = token else {
173 return Ok(Token::None);
174 };
175
176 let user_token = services
177 .users
178 .find_from_token(token)
179 .map_ok(Token::User);
180
181 let appservice_token = services
182 .appservice
183 .find_from_access_token(token)
184 .map_ok(Box::new)
185 .map_ok(Token::Appservice);
186
187 pin_mut!(user_token, appservice_token);
188 match select_ok([Left(user_token), Right(appservice_token)]).await {
189 | Err(e) if !e.is_not_found() => Err(e),
190 | Ok((token, _)) => Ok(token),
191 | _ => Ok(Token::Invalid),
192 }
193}