Skip to main content

tuwunel_api/router/
handler.rs

1use std::fmt::Debug;
2
3use axum::{
4	Router,
5	extract::FromRequestParts,
6	response::IntoResponse,
7	routing::{MethodFilter, on},
8};
9use futures::{Future, TryFutureExt};
10use http::Method;
11use ruma::api::{IncomingRequest, path_builder::PathBuilder};
12use tuwunel_core::{Result, trace};
13
14use super::{Ruma, RumaResponse, State, auth::AuthDispatch};
15
16pub(in super::super) trait RumaHandler<T> {
17	fn add_route(&'static self, router: Router<State>, path: &str) -> Router<State>;
18	fn add_routes(&'static self, router: Router<State>) -> Router<State>;
19}
20
21pub(in super::super) trait RouterExt {
22	fn ruma_route<H, T>(self, handler: &'static H) -> Self
23	where
24		H: RumaHandler<T>;
25}
26
27impl RouterExt for Router<State> {
28	fn ruma_route<H, T>(self, handler: &'static H) -> Self
29	where
30		H: RumaHandler<T>,
31	{
32		handler.add_routes(self)
33	}
34}
35
36macro_rules! ruma_handler {
37	( $($tx:ident),* $(,)? ) => {
38		#[allow(clippy::allow_attributes)]
39		#[allow(non_snake_case)]
40		impl<Err, Req, Fut, Fun, $($tx,)*> RumaHandler<($($tx,)* Ruma<Req>,)> for Fun
41		where
42			Fun: Fn($($tx,)* Ruma<Req>,) -> Fut + Send + Sync + 'static,
43			Fut: Future<Output = Result<Req::OutgoingResponse, Err>> + Send,
44			Req: IncomingRequest + Debug + Send + Sync + 'static,
45			Req::Authentication: AuthDispatch,
46			Err: IntoResponse + Debug + Send,
47			<Req as IncomingRequest>::OutgoingResponse: Debug + Send,
48			$( $tx: FromRequestParts<State> + Send + Sync + 'static, )*
49		{
50			fn add_routes(&'static self, router: Router<State>) -> Router<State> {
51				Req::PATH_BUILDER
52					.all_paths()
53					.fold(router, |router, path| self.add_route(router, path))
54			}
55
56			fn add_route(&'static self, router: Router<State>, path: &str) -> Router<State> {
57				let method = method_to_filter(&Req::METHOD);
58				let action = |$($tx,)* req| {
59					self($($tx,)* req)
60						.inspect_ok(|result| trace!(?result))
61						.map_ok(RumaResponse)
62				};
63
64				router.route(path, on(method, action))
65			}
66		}
67	}
68}
69ruma_handler!();
70ruma_handler!(T1);
71ruma_handler!(T1, T2);
72ruma_handler!(T1, T2, T3);
73ruma_handler!(T1, T2, T3, T4);
74
75fn method_to_filter(method: &Method) -> MethodFilter {
76	match method {
77		| &Method::DELETE => MethodFilter::DELETE,
78		| &Method::GET => MethodFilter::GET,
79		| &Method::HEAD => MethodFilter::HEAD,
80		| &Method::OPTIONS => MethodFilter::OPTIONS,
81		| &Method::PATCH => MethodFilter::PATCH,
82		| &Method::POST => MethodFilter::POST,
83		| &Method::PUT => MethodFilter::PUT,
84		| &Method::TRACE => MethodFilter::TRACE,
85		| _ => panic!("Unsupported HTTP method"),
86	}
87}