tuwunel_api/router/
args.rs1use std::{any::TypeId, fmt::Debug, mem, ops::Deref};
2
3use axum::{body::Body, extract::FromRequest};
4use axum_extra::extract::cookie::CookieJar;
5use bytes::{BufMut, Bytes, BytesMut};
6use http::Method;
7use ruma::{
8 CanonicalJsonObject, CanonicalJsonValue, DeviceId, OwnedDeviceId, OwnedServerName,
9 OwnedUserId, ServerName, UserId, api::IncomingRequest,
10};
11use tuwunel_core::{Err, Error, Result, err, trace, utils::string::EMPTY};
12use tuwunel_service::{Services, appservice::RegistrationInfo};
13
14use super::{
15 auth,
16 auth::{Auth, AuthDispatch},
17 request,
18 request::Request,
19};
20use crate::State;
21
22#[derive(Debug)]
24pub(crate) struct Args<T> {
25 pub(crate) body: T,
27
28 pub(crate) cookie: CookieJar,
30
31 pub(crate) origin: Option<OwnedServerName>,
34
35 pub(crate) sender_user: Option<OwnedUserId>,
38
39 pub(crate) sender_device: Option<OwnedDeviceId>,
42
43 pub(crate) appservice_info: Option<RegistrationInfo>,
46
47 pub(crate) json_body: Option<CanonicalJsonValue>,
50}
51
52impl<T> Args<T> {
53 #[inline]
54 pub(crate) fn sender_user(&self) -> &UserId {
55 self.sender_user
56 .as_deref()
57 .expect("user must be authenticated for this handler")
58 }
59
60 #[inline]
61 pub(crate) fn origin(&self) -> &ServerName {
62 self.origin
63 .as_deref()
64 .expect("server must be authenticated for this handler")
65 }
66
67 #[inline]
68 pub(crate) fn sender_device(&self) -> Result<&DeviceId> {
69 self.sender_device
70 .as_deref()
71 .ok_or(err!(Request(Forbidden("user must be authenticated and device identified"))))
72 }
73}
74
75impl<T> Deref for Args<T>
76where
77 T: Sync,
78{
79 type Target = T;
80
81 fn deref(&self) -> &Self::Target { &self.body }
82}
83
84impl<T> FromRequest<State, Body> for Args<T>
85where
86 T: IncomingRequest + Debug + Send + Sync + 'static,
87 T::Authentication: AuthDispatch,
88{
89 type Rejection = Error;
90
91 #[tracing::instrument(
92 name = "ar",
93 level = "debug",
94 skip(services),
95 err(level = "debug")
96 ret(level = "trace"),
97 )]
98 async fn from_request(
99 request: http::Request<Body>,
100 services: &State,
101 ) -> Result<Self, Self::Rejection> {
102 let mut request = request::from(services, request).await?;
103 let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&request.body).ok();
104 trace!(?request);
105
106 let json_endpoint = matches!(
109 request.parts.method,
110 Method::POST | Method::PUT | Method::DELETE | Method::PATCH
111 ) && !request.parts.uri.path().contains("/media/");
112
113 if json_body.is_none() && json_endpoint {
114 let empty = request.body.iter().all(u8::is_ascii_whitespace);
115
116 if !empty && serde_json::from_slice::<serde_json::Value>(&request.body).is_err() {
117 return Err!(Request(NotJson("Request body is not valid JSON.")));
118 }
119
120 if empty && matches!(request.parts.method, Method::POST | Method::DELETE) {
121 json_body = Some(CanonicalJsonValue::Object(CanonicalJsonObject::new()));
122 }
123 }
124
125 let auth = auth::auth::<T::Authentication>(
126 services,
127 &mut request,
128 json_body.as_ref(),
129 TypeId::of::<T>(),
130 )
131 .await?;
132
133 Ok(Self {
134 body: make_body::<T>(services, &mut request, json_body.as_mut(), &auth)?,
135 cookie: request.cookie,
136 origin: auth.origin,
137 sender_user: auth.sender_user,
138 sender_device: auth.sender_device,
139 appservice_info: auth.appservice_info,
140 json_body,
141 })
142 }
143}
144
145fn make_body<T>(
146 services: &Services,
147 request: &mut Request,
148 json_body: Option<&mut CanonicalJsonValue>,
149 auth: &Auth,
150) -> Result<T>
151where
152 T: IncomingRequest,
153{
154 let body = take_body(services, request, json_body, auth);
155 let http_request = into_http_request(request, body);
156 T::try_from_http_request(http_request, &request.path)
157 .map_err(|e| err!(Request(BadJson(debug_warn!("{e}")))))
158}
159
160fn into_http_request(request: &Request, body: Bytes) -> http::Request<Bytes> {
161 let mut http_request = http::Request::builder()
162 .uri(request.parts.uri.clone())
163 .method(request.parts.method.clone());
164
165 *http_request
166 .headers_mut()
167 .expect("mutable http headers") = request.parts.headers.clone();
168
169 http_request
170 .body(body)
171 .expect("http request body")
172}
173
174#[expect(clippy::needless_pass_by_value)]
175fn take_body(
176 services: &Services,
177 request: &mut Request,
178 json_body: Option<&mut CanonicalJsonValue>,
179 auth: &Auth,
180) -> Bytes {
181 let Some(CanonicalJsonValue::Object(json_body)) = json_body else {
182 return mem::take(&mut request.body);
183 };
184
185 let user_id = auth.sender_user.clone().unwrap_or_else(|| {
186 let server_name = services.globals.server_name();
187 UserId::parse_with_server_name(EMPTY, server_name).expect("valid user_id")
188 });
189
190 let uiaa_request = json_body
191 .get("auth")
192 .and_then(CanonicalJsonValue::as_object)
193 .and_then(|auth| auth.get("session"))
194 .and_then(CanonicalJsonValue::as_str)
195 .and_then(|session| {
196 services
197 .uiaa
198 .get_uiaa_request(&user_id, auth.sender_device.as_deref(), session)
199 });
200
201 if let Some(CanonicalJsonValue::Object(initial_request)) = uiaa_request {
202 for (key, value) in initial_request {
203 json_body.entry(key).or_insert(value);
204 }
205 }
206
207 let mut buf = BytesMut::new().writer();
208 serde_json::to_writer(&mut buf, &json_body).expect("value serialization can't fail");
209 buf.into_inner().freeze()
210}