tuwunel_api/router/
args.rs1use std::{any::TypeId, fmt::Debug, mem, ops::Deref};
2
3use axum::{body::Body, extract::FromRequest};
4use axum_extra::extract::cookie::CookieJar;
5use bytes::{BufMut, Bytes, BytesMut};
6use ruma::{
7 CanonicalJsonObject, CanonicalJsonValue, DeviceId, OwnedDeviceId, OwnedServerName,
8 OwnedUserId, ServerName, UserId, api::IncomingRequest,
9};
10use tuwunel_core::{Error, Result, debug_warn, err, trace, utils::string::EMPTY};
11use tuwunel_service::{Services, appservice::RegistrationInfo};
12
13use super::{
14 auth,
15 auth::{Auth, AuthDispatch},
16 request,
17 request::Request,
18};
19use crate::State;
20
21#[derive(Debug)]
23pub(crate) struct Args<T> {
24 pub(crate) body: T,
26
27 pub(crate) cookie: CookieJar,
29
30 pub(crate) origin: Option<OwnedServerName>,
33
34 pub(crate) sender_user: Option<OwnedUserId>,
37
38 pub(crate) sender_device: Option<OwnedDeviceId>,
41
42 pub(crate) appservice_info: Option<RegistrationInfo>,
45
46 pub(crate) json_body: Option<CanonicalJsonValue>,
49}
50
51impl<T> Args<T> {
52 #[inline]
53 pub(crate) fn sender_user(&self) -> &UserId {
54 self.sender_user
55 .as_deref()
56 .expect("user must be authenticated for this handler")
57 }
58
59 #[inline]
60 pub(crate) fn origin(&self) -> &ServerName {
61 self.origin
62 .as_deref()
63 .expect("server must be authenticated for this handler")
64 }
65
66 #[inline]
67 pub(crate) fn sender_device(&self) -> Result<&DeviceId> {
68 self.sender_device
69 .as_deref()
70 .ok_or(err!(Request(Forbidden("user must be authenticated and device identified"))))
71 }
72}
73
74impl<T> Deref for Args<T>
75where
76 T: Sync,
77{
78 type Target = T;
79
80 fn deref(&self) -> &Self::Target { &self.body }
81}
82
83impl<T> FromRequest<State, Body> for Args<T>
84where
85 T: IncomingRequest + Debug + Send + Sync + 'static,
86 T::Authentication: AuthDispatch,
87{
88 type Rejection = Error;
89
90 #[tracing::instrument(
91 name = "ar",
92 level = "debug",
93 skip(services),
94 err(level = "debug")
95 ret(level = "trace"),
96 )]
97 async fn from_request(
98 request: http::Request<Body>,
99 services: &State,
100 ) -> Result<Self, Self::Rejection> {
101 let mut request = request::from(services, request).await?;
102 let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&request.body).ok();
103 trace!(?request);
104
105 if json_body.is_none()
110 && request.parts.method == http::Method::POST
111 && !request.parts.uri.path().contains("/media/")
112 {
113 debug_warn!(
114 "received a POST request with an empty body, defaulting/assuming to {{}} like \
115 Synapse does"
116 );
117 json_body = Some(CanonicalJsonValue::Object(CanonicalJsonObject::new()));
118 }
119
120 let auth = auth::auth::<T::Authentication>(
121 services,
122 &mut request,
123 json_body.as_ref(),
124 TypeId::of::<T>(),
125 )
126 .await?;
127
128 Ok(Self {
129 body: make_body::<T>(services, &mut request, json_body.as_mut(), &auth)?,
130 cookie: request.cookie,
131 origin: auth.origin,
132 sender_user: auth.sender_user,
133 sender_device: auth.sender_device,
134 appservice_info: auth.appservice_info,
135 json_body,
136 })
137 }
138}
139
140fn make_body<T>(
141 services: &Services,
142 request: &mut Request,
143 json_body: Option<&mut CanonicalJsonValue>,
144 auth: &Auth,
145) -> Result<T>
146where
147 T: IncomingRequest,
148{
149 let body = take_body(services, request, json_body, auth);
150 let http_request = into_http_request(request, body);
151 T::try_from_http_request(http_request, &request.path)
152 .map_err(|e| err!(Request(BadJson(debug_warn!("{e}")))))
153}
154
155fn into_http_request(request: &Request, body: Bytes) -> http::Request<Bytes> {
156 let mut http_request = http::Request::builder()
157 .uri(request.parts.uri.clone())
158 .method(request.parts.method.clone());
159
160 *http_request
161 .headers_mut()
162 .expect("mutable http headers") = request.parts.headers.clone();
163
164 http_request
165 .body(body)
166 .expect("http request body")
167}
168
169#[expect(clippy::needless_pass_by_value)]
170fn take_body(
171 services: &Services,
172 request: &mut Request,
173 json_body: Option<&mut CanonicalJsonValue>,
174 auth: &Auth,
175) -> Bytes {
176 let Some(CanonicalJsonValue::Object(json_body)) = json_body else {
177 return mem::take(&mut request.body);
178 };
179
180 let user_id = auth.sender_user.clone().unwrap_or_else(|| {
181 let server_name = services.globals.server_name();
182 UserId::parse_with_server_name(EMPTY, server_name).expect("valid user_id")
183 });
184
185 let uiaa_request = json_body
186 .get("auth")
187 .and_then(CanonicalJsonValue::as_object)
188 .and_then(|auth| auth.get("session"))
189 .and_then(CanonicalJsonValue::as_str)
190 .and_then(|session| {
191 services
192 .uiaa
193 .get_uiaa_request(&user_id, auth.sender_device.as_deref(), session)
194 });
195
196 if let Some(CanonicalJsonValue::Object(initial_request)) = uiaa_request {
197 for (key, value) in initial_request {
198 json_body.entry(key).or_insert(value);
199 }
200 }
201
202 let mut buf = BytesMut::new().writer();
203 serde_json::to_writer(&mut buf, &json_body).expect("value serialization can't fail");
204 buf.into_inner().freeze()
205}