Skip to main content

tuwunel_router/
request.rs

1use std::{
2	fmt::Debug,
3	sync::{Arc, atomic::Ordering},
4	time::Duration,
5};
6
7use axum::{
8	extract::State,
9	response::{IntoResponse, Response},
10};
11use futures::FutureExt;
12use http::{Method, StatusCode, Uri};
13use ruma::api::error::ErrorKind;
14use tokio::{sync::Notify, task, time::sleep};
15use tracing::Span;
16use tuwunel_core::{Error, Result, debug, debug_error, debug_warn, defer, error, trace};
17use tuwunel_service::Services;
18
19#[tracing::instrument(
20	name = "request",
21	level = "debug",
22	skip_all,
23	err(Debug, level = "debug")
24	fields(
25		task = %task::id(),
26		id = %services
27			.server
28			.metrics
29			.requests_count
30			.fetch_add(1, Ordering::Relaxed)
31	)
32)]
33pub(crate) async fn handle(
34	State(services): State<Arc<Services>>,
35	req: http::Request<axum::body::Body>,
36	next: axum::middleware::Next,
37) -> Result<Response, StatusCode> {
38	if !services.server.is_running() {
39		debug_warn!(
40			method = %req.method(),
41			uri = %req.uri(),
42			"unavailable pending shutdown"
43		);
44
45		return Err(StatusCode::SERVICE_UNAVAILABLE);
46	}
47
48	let uri = req.uri().clone();
49	let method = req.method().clone();
50	let parent = Span::current();
51	let response = match method {
52		| Method::PUT | Method::POST | Method::DELETE | Method::PATCH =>
53			spawn_execute(services, req, next, parent).await?,
54		| _ => execute(&services, req, next, &parent).await,
55	};
56
57	handle_result(&method, &uri, response)
58}
59
60async fn spawn_execute(
61	services: Arc<Services>,
62	mut req: http::Request<axum::body::Body>,
63	next: axum::middleware::Next,
64	parent: Span,
65) -> Result<Response, StatusCode> {
66	let detached = Arc::new(Notify::new());
67	req.extensions_mut().insert(detached.clone());
68
69	let task = services
70		.clone()
71		.server
72		.runtime()
73		.spawn(async move {
74			tokio::select! {
75				response = execute(&services, req, next, &parent) => response,
76				response = services.server.until_shutdown()
77					.then(|()| {
78						let timeout = services.config.client_shutdown_timeout;
79						sleep(Duration::from_secs(timeout))
80					})
81					.map(|()| StatusCode::SERVICE_UNAVAILABLE)
82					.map(IntoResponse::into_response) => response,
83			}
84		});
85
86	let abort = task.abort_handle();
87	defer! {{
88		if !abort.is_finished() {
89			debug_warn!(
90				task = ?abort.id(),
91				"Client disconnected; detached request."
92			);
93
94			detached.notify_one();
95		}
96	}};
97
98	task.await.map_err(unhandled)
99}
100
101#[tracing::instrument(
102	name = "handle",
103	level = "debug",
104	parent = parent,
105	skip_all,
106	ret(level = "trace"),
107	fields(
108		task = %task::id(),
109	)
110)]
111#[cfg_attr(not(debug_assertions), expect(unused_variables))]
112async fn execute(
113	// we made a safety contract that Services will not go out of scope
114	// during the request; this ensures a reference is accounted for at
115	// the base frame of the task regardless of its detachment.
116	services: &Arc<Services>,
117	req: http::Request<axum::body::Body>,
118	next: axum::middleware::Next,
119	parent: &Span,
120) -> Response {
121	#[cfg(debug_assertions)]
122	services
123		.server
124		.metrics
125		.requests_handle_active
126		.fetch_add(1, Ordering::Relaxed);
127
128	#[cfg(debug_assertions)]
129	defer! {{
130		_ = services.server
131			.metrics
132			.requests_handle_finished
133			.fetch_add(1, Ordering::Relaxed);
134		_ = services.server
135			.metrics
136			.requests_handle_active
137			.fetch_sub(1, Ordering::Relaxed);
138	}};
139
140	next.run(req).await
141}
142
143fn handle_result(method: &Method, uri: &Uri, result: Response) -> Result<Response, StatusCode> {
144	let status = result.status();
145	let code = status.as_u16();
146	let reason = status
147		.canonical_reason()
148		.unwrap_or("Unknown Reason");
149
150	if status.is_server_error() {
151		error!(method = ?method, uri = ?uri, "{code} {reason}");
152	} else if status.is_client_error() {
153		debug_error!(method = ?method, uri = ?uri, "{code} {reason}");
154	} else if status.is_redirection() {
155		debug!(method = ?method, uri = ?uri, "{code} {reason}");
156	} else {
157		trace!(method = ?method, uri = ?uri, "{code} {reason}");
158	}
159
160	if status == StatusCode::METHOD_NOT_ALLOWED {
161		return Ok(Error::Request(
162			ErrorKind::Unrecognized,
163			"Method Not Allowed".into(),
164			StatusCode::METHOD_NOT_ALLOWED,
165		)
166		.into_response());
167	}
168
169	Ok(result)
170}
171
172#[cold]
173fn unhandled<Error: Debug>(e: Error) -> StatusCode {
174	error!("unhandled error or panic during request: {e:?}");
175
176	StatusCode::INTERNAL_SERVER_ERROR
177}