Skip to main content

tuwunel_api/router/
args.rs

1use std::{any::TypeId, fmt::Debug, mem, ops::Deref};
2
3use axum::{body::Body, extract::FromRequest};
4use axum_extra::extract::cookie::CookieJar;
5use bytes::{BufMut, Bytes, BytesMut};
6use http::Method;
7use ruma::{
8	CanonicalJsonObject, CanonicalJsonValue, DeviceId, OwnedDeviceId, OwnedServerName,
9	OwnedUserId, ServerName, UserId, api::IncomingRequest,
10};
11use tuwunel_core::{Err, Error, Result, err, trace, utils::string::EMPTY};
12use tuwunel_service::{Services, appservice::RegistrationInfo};
13
14use super::{
15	auth,
16	auth::{Auth, AuthDispatch},
17	request,
18	request::Request,
19};
20use crate::State;
21
22/// Extractor for Ruma request structs
23#[derive(Debug)]
24pub(crate) struct Args<T> {
25	/// Request struct body
26	pub(crate) body: T,
27
28	/// Cookies received from the useragent.
29	pub(crate) cookie: CookieJar,
30
31	/// Federation server authentication: X-Matrix origin
32	/// None when not a federation server.
33	pub(crate) origin: Option<OwnedServerName>,
34
35	/// Local user authentication: user_id.
36	/// None when not an authenticated local user.
37	pub(crate) sender_user: Option<OwnedUserId>,
38
39	/// Local user authentication: device_id.
40	/// None when not an authenticated local user or no device.
41	pub(crate) sender_device: Option<OwnedDeviceId>,
42
43	/// Appservice authentication; registration info.
44	/// None when not an appservice.
45	pub(crate) appservice_info: Option<RegistrationInfo>,
46
47	/// Parsed JSON content.
48	/// None when body is not a valid string
49	pub(crate) json_body: Option<CanonicalJsonValue>,
50}
51
52impl<T> Args<T> {
53	#[inline]
54	pub(crate) fn sender_user(&self) -> &UserId {
55		self.sender_user
56			.as_deref()
57			.expect("user must be authenticated for this handler")
58	}
59
60	#[inline]
61	pub(crate) fn origin(&self) -> &ServerName {
62		self.origin
63			.as_deref()
64			.expect("server must be authenticated for this handler")
65	}
66
67	#[inline]
68	pub(crate) fn sender_device(&self) -> Result<&DeviceId> {
69		self.sender_device
70			.as_deref()
71			.ok_or(err!(Request(Forbidden("user must be authenticated and device identified"))))
72	}
73}
74
75impl<T> Deref for Args<T>
76where
77	T: Sync,
78{
79	type Target = T;
80
81	fn deref(&self) -> &Self::Target { &self.body }
82}
83
84impl<T> FromRequest<State, Body> for Args<T>
85where
86	T: IncomingRequest + Debug + Send + Sync + 'static,
87	T::Authentication: AuthDispatch,
88{
89	type Rejection = Error;
90
91	#[tracing::instrument(
92		name = "ar",
93		level = "debug",
94		skip(services),
95		err(level = "debug")
96		ret(level = "trace"),
97	)]
98	async fn from_request(
99		request: http::Request<Body>,
100		services: &State,
101	) -> Result<Self, Self::Rejection> {
102		let mut request = request::from(services, request).await?;
103		let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&request.body).ok();
104		trace!(?request);
105
106		// An empty body defaults to `{}` like Synapse so a UIA flow can start;
107		// a present body that is not valid JSON is rejected as M_NOT_JSON.
108		let json_endpoint = matches!(
109			request.parts.method,
110			Method::POST | Method::PUT | Method::DELETE | Method::PATCH
111		) && !request.parts.uri.path().contains("/media/");
112
113		if json_body.is_none() && json_endpoint {
114			let empty = request.body.iter().all(u8::is_ascii_whitespace);
115
116			if !empty && serde_json::from_slice::<serde_json::Value>(&request.body).is_err() {
117				return Err!(Request(NotJson("Request body is not valid JSON.")));
118			}
119
120			if empty && matches!(request.parts.method, Method::POST | Method::DELETE) {
121				json_body = Some(CanonicalJsonValue::Object(CanonicalJsonObject::new()));
122			}
123		}
124
125		let auth = auth::auth::<T::Authentication>(
126			services,
127			&mut request,
128			json_body.as_ref(),
129			TypeId::of::<T>(),
130		)
131		.await?;
132
133		Ok(Self {
134			body: make_body::<T>(services, &mut request, json_body.as_mut(), &auth)?,
135			cookie: request.cookie,
136			origin: auth.origin,
137			sender_user: auth.sender_user,
138			sender_device: auth.sender_device,
139			appservice_info: auth.appservice_info,
140			json_body,
141		})
142	}
143}
144
145fn make_body<T>(
146	services: &Services,
147	request: &mut Request,
148	json_body: Option<&mut CanonicalJsonValue>,
149	auth: &Auth,
150) -> Result<T>
151where
152	T: IncomingRequest,
153{
154	let body = take_body(services, request, json_body, auth);
155	let http_request = into_http_request(request, body);
156	T::try_from_http_request(http_request, &request.path)
157		.map_err(|e| err!(Request(BadJson(debug_warn!("{e}")))))
158}
159
160fn into_http_request(request: &Request, body: Bytes) -> http::Request<Bytes> {
161	let mut http_request = http::Request::builder()
162		.uri(request.parts.uri.clone())
163		.method(request.parts.method.clone());
164
165	*http_request
166		.headers_mut()
167		.expect("mutable http headers") = request.parts.headers.clone();
168
169	http_request
170		.body(body)
171		.expect("http request body")
172}
173
174#[expect(clippy::needless_pass_by_value)]
175fn take_body(
176	services: &Services,
177	request: &mut Request,
178	json_body: Option<&mut CanonicalJsonValue>,
179	auth: &Auth,
180) -> Bytes {
181	let Some(CanonicalJsonValue::Object(json_body)) = json_body else {
182		return mem::take(&mut request.body);
183	};
184
185	let user_id = auth.sender_user.clone().unwrap_or_else(|| {
186		let server_name = services.globals.server_name();
187		UserId::parse_with_server_name(EMPTY, server_name).expect("valid user_id")
188	});
189
190	let uiaa_request = json_body
191		.get("auth")
192		.and_then(CanonicalJsonValue::as_object)
193		.and_then(|auth| auth.get("session"))
194		.and_then(CanonicalJsonValue::as_str)
195		.and_then(|session| {
196			services
197				.uiaa
198				.get_uiaa_request(&user_id, auth.sender_device.as_deref(), session)
199		});
200
201	if let Some(CanonicalJsonValue::Object(initial_request)) = uiaa_request {
202		for (key, value) in initial_request {
203			json_body.entry(key).or_insert(value);
204		}
205	}
206
207	let mut buf = BytesMut::new().writer();
208	serde_json::to_writer(&mut buf, &json_body).expect("value serialization can't fail");
209	buf.into_inner().freeze()
210}