1use std::any::TypeId;
2
3use ruma::{
4 CanonicalJsonValue,
5 api::{
6 auth_scheme::{
7 AccessToken, AccessTokenOptional, AppserviceToken, AppserviceTokenOptional,
8 AuthScheme, NoAccessToken, NoAuthentication,
9 },
10 error::{ErrorKind, UnknownTokenErrorData},
11 federation::authentication::ServerSignatures,
12 },
13};
14use tuwunel_core::{Err, Error, Result};
15use tuwunel_service::Services;
16
17use super::{Auth, Request, Token, appservice::auth_appservice, server::auth_server};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub(in crate::router) enum Scheme {
27 None,
28 AccessToken,
29 AccessTokenOptional,
30 AppserviceToken,
31 AppserviceTokenOptional,
32 ServerSignatures,
33}
34
35pub(in crate::router) trait AuthDispatch: AuthScheme {
41 const SCHEME: Scheme;
42
43 fn dispatch(
44 services: &Services,
45 request: &mut Request,
46 json_body: Option<&CanonicalJsonValue>,
47 token: Token,
48 route: TypeId,
49 ) -> impl Future<Output = Result<Auth>> + Send;
50}
51
52impl AuthDispatch for NoAccessToken {
53 const SCHEME: Scheme = Scheme::None;
54
55 async fn dispatch(
56 services: &Services,
57 request: &mut Request,
58 json_body: Option<&CanonicalJsonValue>,
59 token: Token,
60 route: TypeId,
61 ) -> Result<Auth> {
62 <NoAuthentication as AuthDispatch>::dispatch(services, request, json_body, token, route)
63 .await
64 }
65}
66
67impl AuthDispatch for NoAuthentication {
68 const SCHEME: Scheme = Scheme::None;
69
70 async fn dispatch(
71 _services: &Services,
72 _request: &mut Request,
73 _json_body: Option<&CanonicalJsonValue>,
74 token: Token,
75 _route: TypeId,
76 ) -> Result<Auth> {
77 match token {
78 | Token::Invalid | Token::Expired(_) | Token::None => Ok(Auth::default()),
81
82 | Token::User(user) => Ok(Auth {
83 sender_user: Some(user.0),
84 sender_device: Some(user.1),
85 _expires_at: user.2,
86 ..Auth::default()
87 }),
88
89 | Token::Appservice(info) => Ok(Auth {
90 appservice_info: Some(*info),
91 ..Auth::default()
92 }),
93 }
94 }
95}
96
97impl AuthDispatch for AccessToken {
98 const SCHEME: Scheme = Scheme::AccessToken;
99
100 async fn dispatch(
101 services: &Services,
102 request: &mut Request,
103 _json_body: Option<&CanonicalJsonValue>,
104 token: Token,
105 _route: TypeId,
106 ) -> Result<Auth> {
107 match token {
108 | Token::Invalid => unknown_token(),
109 | Token::Expired(access_token) => expired_token(services, &access_token).await,
110 | Token::Appservice(info) => Ok(auth_appservice(services, request, info).await?),
111 | Token::User(user) => Ok(Auth {
112 sender_user: Some(user.0),
113 sender_device: Some(user.1),
114 _expires_at: user.2,
115 ..Auth::default()
116 }),
117
118 | Token::None => Err!(Request(MissingToken("Missing access token."))),
119 }
120 }
121}
122
123impl AuthDispatch for AccessTokenOptional {
124 const SCHEME: Scheme = Scheme::AccessTokenOptional;
125
126 async fn dispatch(
127 services: &Services,
128 _request: &mut Request,
129 _json_body: Option<&CanonicalJsonValue>,
130 token: Token,
131 _route: TypeId,
132 ) -> Result<Auth> {
133 match token {
134 | Token::Invalid => unknown_token(),
135 | Token::Expired(access_token) => expired_token(services, &access_token).await,
136 | Token::User(user) => Ok(Auth {
137 sender_user: Some(user.0),
138 sender_device: Some(user.1),
139 _expires_at: user.2,
140 ..Auth::default()
141 }),
142 | Token::Appservice(info) => Ok(Auth {
143 appservice_info: Some(*info),
144 ..Auth::default()
145 }),
146 | Token::None => Ok(Auth::default()),
147 }
148 }
149}
150
151impl AuthDispatch for AppserviceToken {
152 const SCHEME: Scheme = Scheme::AppserviceToken;
153
154 async fn dispatch(
155 services: &Services,
156 _request: &mut Request,
157 _json_body: Option<&CanonicalJsonValue>,
158 token: Token,
159 _route: TypeId,
160 ) -> Result<Auth> {
161 match token {
162 | Token::Invalid => unknown_token(),
163 | Token::Expired(access_token) => expired_token(services, &access_token).await,
164 | Token::User(_) =>
165 Err!(Request(Unauthorized("Appservice tokens must be used on this endpoint."))),
166 | Token::Appservice(info) => Ok(Auth {
167 appservice_info: Some(*info),
168 ..Auth::default()
169 }),
170 | Token::None => Err!(Request(MissingToken("Missing access token."))),
171 }
172 }
173}
174
175impl AuthDispatch for AppserviceTokenOptional {
176 const SCHEME: Scheme = Scheme::AppserviceTokenOptional;
177
178 async fn dispatch(
179 services: &Services,
180 _request: &mut Request,
181 _json_body: Option<&CanonicalJsonValue>,
182 token: Token,
183 _route: TypeId,
184 ) -> Result<Auth> {
185 match token {
186 | Token::Invalid => unknown_token(),
187 | Token::Expired(access_token) => expired_token(services, &access_token).await,
188 | Token::User(user) => Ok(Auth {
189 sender_user: Some(user.0),
190 sender_device: Some(user.1),
191 _expires_at: user.2,
192 ..Auth::default()
193 }),
194 | Token::Appservice(info) => Ok(Auth {
195 appservice_info: Some(*info),
196 ..Auth::default()
197 }),
198 | Token::None => Ok(Auth::default()),
199 }
200 }
201}
202
203impl AuthDispatch for ServerSignatures {
204 const SCHEME: Scheme = Scheme::ServerSignatures;
205
206 async fn dispatch(
207 services: &Services,
208 request: &mut Request,
209 json_body: Option<&CanonicalJsonValue>,
210 token: Token,
211 _route: TypeId,
212 ) -> Result<Auth> {
213 match token {
214 | Token::Invalid => unknown_token(),
215 | Token::Expired(access_token) => expired_token(services, &access_token).await,
216 | Token::Appservice(_) | Token::User(_) =>
217 Err!(Request(Unauthorized("Server signatures must be used on this endpoint."))),
218 | Token::None => Ok(auth_server(services, request, json_body).await?),
219 }
220 }
221}
222
223fn unknown_token() -> Result<Auth> {
224 Err(Error::BadRequest(
225 ErrorKind::UnknownToken(UnknownTokenErrorData::new()),
226 "Unknown access token.",
227 ))
228}
229
230async fn expired_token(services: &Services, access_token: &str) -> Result<Auth> {
231 services
232 .users
233 .remove_access_token_value(access_token)
234 .await;
235
236 Err(Error::BadRequest(
237 ErrorKind::UnknownToken(UnknownTokenErrorData { soft_logout: true }),
238 "Expired access token.",
239 ))
240}