Skip to main content

tuwunel_api/router/
args.rs

1use std::{any::TypeId, fmt::Debug, mem, ops::Deref};
2
3use axum::{body::Body, extract::FromRequest};
4use axum_extra::extract::cookie::CookieJar;
5use bytes::{BufMut, Bytes, BytesMut};
6use ruma::{
7	CanonicalJsonObject, CanonicalJsonValue, DeviceId, OwnedDeviceId, OwnedServerName,
8	OwnedUserId, ServerName, UserId, api::IncomingRequest,
9};
10use tuwunel_core::{Error, Result, debug_warn, err, trace, utils::string::EMPTY};
11use tuwunel_service::{Services, appservice::RegistrationInfo};
12
13use super::{
14	auth,
15	auth::{Auth, AuthDispatch},
16	request,
17	request::Request,
18};
19use crate::State;
20
21/// Extractor for Ruma request structs
22#[derive(Debug)]
23pub(crate) struct Args<T> {
24	/// Request struct body
25	pub(crate) body: T,
26
27	/// Cookies received from the useragent.
28	pub(crate) cookie: CookieJar,
29
30	/// Federation server authentication: X-Matrix origin
31	/// None when not a federation server.
32	pub(crate) origin: Option<OwnedServerName>,
33
34	/// Local user authentication: user_id.
35	/// None when not an authenticated local user.
36	pub(crate) sender_user: Option<OwnedUserId>,
37
38	/// Local user authentication: device_id.
39	/// None when not an authenticated local user or no device.
40	pub(crate) sender_device: Option<OwnedDeviceId>,
41
42	/// Appservice authentication; registration info.
43	/// None when not an appservice.
44	pub(crate) appservice_info: Option<RegistrationInfo>,
45
46	/// Parsed JSON content.
47	/// None when body is not a valid string
48	pub(crate) json_body: Option<CanonicalJsonValue>,
49}
50
51impl<T> Args<T> {
52	#[inline]
53	pub(crate) fn sender_user(&self) -> &UserId {
54		self.sender_user
55			.as_deref()
56			.expect("user must be authenticated for this handler")
57	}
58
59	#[inline]
60	pub(crate) fn origin(&self) -> &ServerName {
61		self.origin
62			.as_deref()
63			.expect("server must be authenticated for this handler")
64	}
65
66	#[inline]
67	pub(crate) fn sender_device(&self) -> Result<&DeviceId> {
68		self.sender_device
69			.as_deref()
70			.ok_or(err!(Request(Forbidden("user must be authenticated and device identified"))))
71	}
72}
73
74impl<T> Deref for Args<T>
75where
76	T: Sync,
77{
78	type Target = T;
79
80	fn deref(&self) -> &Self::Target { &self.body }
81}
82
83impl<T> FromRequest<State, Body> for Args<T>
84where
85	T: IncomingRequest + Debug + Send + Sync + 'static,
86	T::Authentication: AuthDispatch,
87{
88	type Rejection = Error;
89
90	#[tracing::instrument(
91		name = "ar",
92		level = "debug",
93		skip(services),
94		err(level = "debug")
95		ret(level = "trace"),
96	)]
97	async fn from_request(
98		request: http::Request<Body>,
99		services: &State,
100	) -> Result<Self, Self::Rejection> {
101		let mut request = request::from(services, request).await?;
102		let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&request.body).ok();
103		trace!(?request);
104
105		// while very unusual and really shouldn't be recommended, Synapse accepts POST
106		// requests with a completely empty body. very old clients, libraries, and some
107		// appservices still call APIs like /join like this. so let's just default to
108		// empty object `{}` to copy synapse's behaviour
109		if json_body.is_none()
110			&& request.parts.method == http::Method::POST
111			&& !request.parts.uri.path().contains("/media/")
112		{
113			debug_warn!(
114				"received a POST request with an empty body, defaulting/assuming to {{}} like \
115				 Synapse does"
116			);
117			json_body = Some(CanonicalJsonValue::Object(CanonicalJsonObject::new()));
118		}
119
120		let auth = auth::auth::<T::Authentication>(
121			services,
122			&mut request,
123			json_body.as_ref(),
124			TypeId::of::<T>(),
125		)
126		.await?;
127
128		Ok(Self {
129			body: make_body::<T>(services, &mut request, json_body.as_mut(), &auth)?,
130			cookie: request.cookie,
131			origin: auth.origin,
132			sender_user: auth.sender_user,
133			sender_device: auth.sender_device,
134			appservice_info: auth.appservice_info,
135			json_body,
136		})
137	}
138}
139
140fn make_body<T>(
141	services: &Services,
142	request: &mut Request,
143	json_body: Option<&mut CanonicalJsonValue>,
144	auth: &Auth,
145) -> Result<T>
146where
147	T: IncomingRequest,
148{
149	let body = take_body(services, request, json_body, auth);
150	let http_request = into_http_request(request, body);
151	T::try_from_http_request(http_request, &request.path)
152		.map_err(|e| err!(Request(BadJson(debug_warn!("{e}")))))
153}
154
155fn into_http_request(request: &Request, body: Bytes) -> http::Request<Bytes> {
156	let mut http_request = http::Request::builder()
157		.uri(request.parts.uri.clone())
158		.method(request.parts.method.clone());
159
160	*http_request
161		.headers_mut()
162		.expect("mutable http headers") = request.parts.headers.clone();
163
164	http_request
165		.body(body)
166		.expect("http request body")
167}
168
169#[expect(clippy::needless_pass_by_value)]
170fn take_body(
171	services: &Services,
172	request: &mut Request,
173	json_body: Option<&mut CanonicalJsonValue>,
174	auth: &Auth,
175) -> Bytes {
176	let Some(CanonicalJsonValue::Object(json_body)) = json_body else {
177		return mem::take(&mut request.body);
178	};
179
180	let user_id = auth.sender_user.clone().unwrap_or_else(|| {
181		let server_name = services.globals.server_name();
182		UserId::parse_with_server_name(EMPTY, server_name).expect("valid user_id")
183	});
184
185	let uiaa_request = json_body
186		.get("auth")
187		.and_then(CanonicalJsonValue::as_object)
188		.and_then(|auth| auth.get("session"))
189		.and_then(CanonicalJsonValue::as_str)
190		.and_then(|session| {
191			services
192				.uiaa
193				.get_uiaa_request(&user_id, auth.sender_device.as_deref(), session)
194		});
195
196	if let Some(CanonicalJsonValue::Object(initial_request)) = uiaa_request {
197		for (key, value) in initial_request {
198			json_body.entry(key).or_insert(value);
199		}
200	}
201
202	let mut buf = BytesMut::new().writer();
203	serde_json::to_writer(&mut buf, &json_body).expect("value serialization can't fail");
204	buf.into_inner().freeze()
205}