Skip to main content

tuwunel_api/server/
send.rs

1use std::{
2	collections::BTreeMap,
3	net::IpAddr,
4	sync::atomic::{AtomicBool, Ordering},
5	time::{Duration, Instant},
6};
7
8use axum::extract::State;
9use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
10use ruma::{
11	CanonicalJsonObject, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, ServerName,
12	TransactionId, UserId,
13	api::{
14		error::ErrorKind,
15		federation::transactions::{
16			edu::{
17				DeviceListUpdateContent, DirectDeviceContent, Edu, PresenceContent,
18				PresenceUpdate, ReceiptContent, ReceiptData, ReceiptMap, SigningKeyUpdateContent,
19				TypingContent,
20			},
21			send_transaction_message,
22		},
23	},
24	events::receipt::{ReceiptEvent, ReceiptEventContent, ReceiptType},
25	serde::Raw,
26	to_device::DeviceIdOrAllDevices,
27};
28use tuwunel_core::{
29	Err, Error, Result, debug,
30	debug::INFO_SPAN_LEVEL,
31	debug_warn, defer, err, error,
32	itertools::Itertools,
33	result::LogErr,
34	smallvec::SmallVec,
35	trace,
36	utils::{
37		debug::str_truncated,
38		future::TryExtExt,
39		millis_since_unix_epoch,
40		stream::{BroadbandExt, IterStream, ReadyExt, TryBroadbandExt, automatic_width},
41	},
42	warn,
43};
44use tuwunel_service::{
45	Services,
46	sending::{EDU_LIMIT, PDU_LIMIT},
47};
48
49use crate::{ClientIp, Ruma};
50
51type ResolvedMap = BTreeMap<OwnedEventId, Result>;
52type RoomsPdus = SmallVec<[RoomPdus; 1]>;
53type RoomPdus = (OwnedRoomId, TxnPdus);
54type TxnPdus = SmallVec<[(usize, Pdu); 1]>;
55type Pdu = (OwnedRoomId, OwnedEventId, CanonicalJsonObject);
56
57/// # `PUT /_matrix/federation/v1/send/{txnId}`
58///
59/// Push EDUs and PDUs to this server.
60#[tracing::instrument(
61	name = "txn",
62	level = INFO_SPAN_LEVEL,
63	skip_all,
64	fields(
65		txn = str_truncated(body.transaction_id.as_str(), 20),
66		origin = body.origin().as_str(),
67		%client,
68	),
69)]
70pub(crate) async fn send_transaction_message_route(
71	State(services): State<crate::State>,
72	ClientIp(client): ClientIp,
73	body: Ruma<send_transaction_message::v1::Request>,
74) -> Result<send_transaction_message::v1::Response> {
75	if body.origin() != body.body.origin {
76		return Err!(Request(Forbidden(
77			"Not allowed to send transactions on behalf of other servers"
78		)));
79	}
80
81	if body.pdus.len() > PDU_LIMIT {
82		return Err!(Request(Forbidden(
83			"Not allowed to send more than {PDU_LIMIT} PDUs in one transaction"
84		)));
85	}
86
87	if body.edus.len() > EDU_LIMIT {
88		return Err!(Request(Forbidden(
89			"Not allowed to send more than {EDU_LIMIT} EDUs in one transaction"
90		)));
91	}
92
93	let txn_start_time = Instant::now();
94	trace!(
95		pdus = body.pdus.len(),
96		edus = body.edus.len(),
97		elapsed = ?txn_start_time.elapsed(),
98		"Starting txn",
99	);
100
101	let pdus = body
102		.pdus
103		.iter()
104		.stream()
105		.enumerate()
106		.broad_filter_map(|(i, pdu)| {
107			services
108				.event_handler
109				.parse_incoming_pdu(pdu)
110				.inspect_err(move |e| debug_warn!("Could not parse PDU[{i}]: {e}"))
111				.map_ok(move |pdu| (i, pdu))
112				.ok()
113		});
114
115	let edus = body
116		.edus
117		.iter()
118		.stream()
119		.enumerate()
120		.ready_filter_map(|(i, edu)| {
121			serde_json::from_str(edu.json().get())
122				.inspect_err(|e| debug_warn!("Could not parse EDU[{i}]: {e}"))
123				.map(|edu| (i, edu))
124				.ok()
125		});
126
127	let results = handle(
128		&services,
129		&client,
130		body.origin(),
131		&body.transaction_id,
132		txn_start_time,
133		pdus,
134		edus,
135	)
136	.await?;
137
138	debug!(
139		pdus = body.pdus.len(),
140		edus = body.edus.len(),
141		elapsed = ?txn_start_time.elapsed(),
142		"Finished txn",
143	);
144
145	for (id, result) in &results {
146		if let Err(e) = result
147			&& matches!(e, Error::BadRequest(ErrorKind::NotFound, _))
148		{
149			warn!("Incoming PDU failed {id}: {e:?}");
150		}
151	}
152
153	Ok(send_transaction_message::v1::Response {
154		pdus: results
155			.into_iter()
156			.map(|(e, r)| (e, r.map_err(error::sanitized_message)))
157			.collect(),
158	})
159}
160
161async fn handle(
162	services: &Services,
163	client: &IpAddr,
164	origin: &ServerName,
165	txn_id: &TransactionId,
166	started: Instant,
167	pdus: impl Stream<Item = (usize, Pdu)> + Send,
168	edus: impl Stream<Item = (usize, Edu)> + Send,
169) -> Result<ResolvedMap> {
170	let results = handle_pdus(services, client, origin, txn_id, started, pdus).await?;
171
172	handle_edus(services, client, origin, txn_id, edus).await?;
173
174	Ok(results)
175}
176
177async fn handle_pdus(
178	services: &Services,
179	client: &IpAddr,
180	origin: &ServerName,
181	txn_id: &TransactionId,
182	started: Instant,
183	pdus: impl Stream<Item = (usize, Pdu)> + Send,
184) -> Result<ResolvedMap> {
185	pdus.collect()
186		.map(Ok)
187		.map_ok(|pdus: TxnPdus| {
188			pdus.into_iter()
189				.sorted_by(|(_, (room_a, ..)), (_, (room_b, ..))| room_a.cmp(room_b))
190				.into_grouping_map_by(|(_, (room_id, ..))| room_id.clone())
191				.collect()
192				.into_iter()
193				.try_stream()
194		})
195		.try_flatten_stream()
196		.try_collect::<RoomsPdus>()
197		.map_ok(IntoIterator::into_iter)
198		.map_ok(IterStream::try_stream)
199		.try_flatten_stream()
200		.broad_and_then(async |(room_id, pdus)| {
201			handle_room(services, client, origin, txn_id, started, room_id, pdus.into_iter())
202				.map_ok(ResolvedMap::into_iter)
203				.map_ok(IterStream::try_stream)
204				.await
205		})
206		.try_flatten()
207		.try_collect()
208		.await
209}
210
211#[tracing::instrument(
212	name = "room",
213	level = INFO_SPAN_LEVEL,
214	skip_all,
215	fields(%room_id)
216)]
217async fn handle_room(
218	services: &Services,
219	_client: &IpAddr,
220	origin: &ServerName,
221	txn_id: &TransactionId,
222	txn_start_time: Instant,
223	ref room_id: OwnedRoomId,
224	pdus: impl Iterator<Item = (usize, Pdu)> + Send,
225) -> Result<ResolvedMap> {
226	services
227		.event_handler
228		.mutex_federation
229		.lock(room_id)
230		.then(async |_lock| {
231			pdus.enumerate()
232				.try_stream()
233				.and_then(async |pdu| {
234					services.server.check_running().map(|()| pdu) // interruption point
235				})
236				.and_then(|(ri, (ti, (room_id, event_id, value)))| {
237					let meta = (origin, txn_id, txn_start_time, ti);
238					let pdu = (ri, (room_id, event_id, value));
239					handle_pdu(services, meta, pdu).map(Ok)
240				})
241				.try_collect()
242				.await
243		})
244		.await
245}
246
247#[tracing::instrument(
248	name = "pdu",
249	level = INFO_SPAN_LEVEL,
250	skip_all,
251	fields(%event_id, %ti, %ri)
252)]
253async fn handle_pdu(
254	services: &Services,
255	(origin, txn_id, txn_start_time, ti): (&ServerName, &TransactionId, Instant, usize),
256	(ri, (ref room_id, event_id, value)): (usize, Pdu),
257) -> (OwnedEventId, Result) {
258	let pdu_start_time = Instant::now();
259	let completed: AtomicBool = Default::default();
260	defer! {{
261		if completed.load(Ordering::Acquire) {
262			return;
263		}
264
265		if pdu_start_time.elapsed() >= Duration::from_secs(services.config.client_request_timeout) {
266			error!(
267				%origin, %txn_id, %room_id, %event_id, %ri, %ti,
268				elapsed = ?pdu_start_time.elapsed(),
269				"Incoming transaction processing timed out.",
270			);
271		} else {
272			debug_warn!(
273				%origin, %txn_id, %room_id, %event_id, %ri, %ti,
274				elapsed = ?pdu_start_time.elapsed(),
275				"Incoming transaction processing interrupted.",
276			);
277		}
278	}}
279
280	let result = services
281		.event_handler
282		.handle_incoming_pdu(origin, room_id, &event_id, value, true)
283		.map_ok(|_| ())
284		.boxed()
285		.await;
286
287	completed.store(true, Ordering::Release);
288	debug!(
289		%event_id, ri, ti,
290		pdu_elapsed = ?pdu_start_time.elapsed(),
291		txn_elapsed = ?txn_start_time.elapsed(),
292		"Finished PDU",
293	);
294
295	(event_id.clone(), result)
296}
297
298#[tracing::instrument(name = "edus", level = "debug", skip_all)]
299async fn handle_edus(
300	services: &Services,
301	client: &IpAddr,
302	origin: &ServerName,
303	txn_id: &TransactionId,
304	edus: impl Stream<Item = (usize, Edu)> + Send,
305) -> Result {
306	edus.for_each_concurrent(automatic_width(), |(i, edu)| {
307		handle_edu(services, client, origin, txn_id, i, edu)
308	})
309	.await;
310
311	Ok(())
312}
313
314#[tracing::instrument(
315	name = "edu",
316	level = "debug",
317	skip_all,
318	fields(%i),
319)]
320async fn handle_edu(
321	services: &Services,
322	client: &IpAddr,
323	origin: &ServerName,
324	_txn_id: &TransactionId,
325	i: usize,
326	edu: Edu,
327) {
328	match edu {
329		| Edu::Presence(presence) if services.server.config.allow_incoming_presence =>
330			handle_edu_presence(services, client, origin, presence).await,
331
332		| Edu::Receipt(receipt)
333			if services
334				.server
335				.config
336				.allow_incoming_read_receipts =>
337			handle_edu_receipt(services, client, origin, receipt).await,
338
339		| Edu::Typing(typing) if services.server.config.allow_incoming_typing =>
340			handle_edu_typing(services, client, origin, typing).await,
341
342		| Edu::DeviceListUpdate(content) =>
343			handle_edu_device_list_update(services, client, origin, content).await,
344
345		| Edu::DirectToDevice(content) =>
346			handle_edu_direct_to_device(services, client, origin, content).await,
347
348		| Edu::SigningKeyUpdate(content) =>
349			handle_edu_signing_key_update(services, client, origin, content).await,
350
351		| Edu::_Custom(ref _custom) => debug_warn!(?i, ?edu, "received custom/unknown EDU"),
352
353		| _ => trace!(?i, ?edu, "skipped"),
354	}
355}
356
357async fn handle_edu_presence(
358	services: &Services,
359	_client: &IpAddr,
360	origin: &ServerName,
361	presence: PresenceContent,
362) {
363	presence
364		.push
365		.into_iter()
366		.stream()
367		.for_each_concurrent(automatic_width(), |update| {
368			handle_edu_presence_update(services, origin, update)
369		})
370		.await;
371}
372
373async fn handle_edu_presence_update(
374	services: &Services,
375	origin: &ServerName,
376	update: PresenceUpdate,
377) {
378	if update.user_id.server_name() != origin {
379		debug_warn!(
380			%update.user_id, %origin,
381			"received presence EDU for user not belonging to origin"
382		);
383		return;
384	}
385
386	services
387		.presence
388		.set_presence_from_federation(
389			&update.user_id,
390			&update.presence,
391			update.currently_active,
392			update.last_active_ago,
393			update.status_msg.clone(),
394		)
395		.await
396		.log_err()
397		.ok();
398}
399
400async fn handle_edu_receipt(
401	services: &Services,
402	_client: &IpAddr,
403	origin: &ServerName,
404	receipt: ReceiptContent,
405) {
406	receipt
407		.receipts
408		.into_iter()
409		.stream()
410		.for_each_concurrent(automatic_width(), |(room_id, room_updates)| {
411			handle_edu_receipt_room(services, origin, room_id, room_updates)
412		})
413		.await;
414}
415
416async fn handle_edu_receipt_room(
417	services: &Services,
418	origin: &ServerName,
419	room_id: OwnedRoomId,
420	room_updates: ReceiptMap,
421) {
422	if services
423		.event_handler
424		.acl_check(origin, &room_id)
425		.await
426		.is_err()
427	{
428		debug_warn!(
429			%origin, %room_id,
430			"received read receipt EDU from ACL'd server"
431		);
432		return;
433	}
434
435	let room_id = &room_id;
436	room_updates
437		.read
438		.into_iter()
439		.stream()
440		.for_each_concurrent(automatic_width(), async |(user_id, user_updates)| {
441			handle_edu_receipt_room_user(services, origin, room_id, &user_id, user_updates).await;
442		})
443		.await;
444}
445
446async fn handle_edu_receipt_room_user(
447	services: &Services,
448	origin: &ServerName,
449	room_id: &RoomId,
450	user_id: &UserId,
451	user_updates: ReceiptData,
452) {
453	if user_id.server_name() != origin {
454		debug_warn!(
455			%user_id, %origin,
456			"received read receipt EDU for user not belonging to origin"
457		);
458		return;
459	}
460
461	if !services
462		.state_cache
463		.server_in_room(origin, room_id)
464		.await
465	{
466		debug_warn!(
467			%user_id, %room_id, %origin,
468			"received read receipt EDU from server who does not have a member in the room",
469		);
470		return;
471	}
472
473	let data = &user_updates.data;
474	user_updates
475		.event_ids
476		.into_iter()
477		.stream()
478		.for_each_concurrent(automatic_width(), async |event_id| {
479			let user_data = [(user_id.to_owned(), data.clone())];
480			let receipts = [(ReceiptType::Read, BTreeMap::from(user_data))];
481			let content = [(event_id.clone(), BTreeMap::from(receipts))];
482			services
483				.read_receipt
484				.readreceipt_update(user_id, room_id, &ReceiptEvent {
485					content: ReceiptEventContent(content.into()),
486					room_id: room_id.to_owned(),
487				})
488				.await;
489		})
490		.await;
491}
492
493async fn handle_edu_typing(
494	services: &Services,
495	_client: &IpAddr,
496	origin: &ServerName,
497	typing: TypingContent,
498) {
499	if typing.user_id.server_name() != origin {
500		debug_warn!(
501			%typing.user_id, %origin,
502			"received typing EDU for user not belonging to origin"
503		);
504		return;
505	}
506
507	if services
508		.event_handler
509		.acl_check(typing.user_id.server_name(), &typing.room_id)
510		.await
511		.is_err()
512	{
513		debug_warn!(
514			%typing.user_id, %typing.room_id, %origin,
515			"received typing EDU for ACL'd user's server"
516		);
517		return;
518	}
519
520	if !services
521		.state_cache
522		.is_joined(&typing.user_id, &typing.room_id)
523		.await
524	{
525		debug_warn!(
526			%typing.user_id, %typing.room_id, %origin,
527			"received typing EDU for user not in room"
528		);
529		return;
530	}
531
532	if typing.typing {
533		let secs = services.server.config.typing_federation_timeout_s;
534		let timeout = millis_since_unix_epoch().saturating_add(secs.saturating_mul(1000));
535
536		services
537			.typing
538			.typing_add(&typing.user_id, &typing.room_id, timeout)
539			.await
540			.log_err()
541			.ok();
542	} else {
543		services
544			.typing
545			.typing_remove(&typing.user_id, &typing.room_id)
546			.await
547			.log_err()
548			.ok();
549	}
550}
551
552async fn handle_edu_device_list_update(
553	services: &Services,
554	_client: &IpAddr,
555	origin: &ServerName,
556	content: DeviceListUpdateContent,
557) {
558	let DeviceListUpdateContent { user_id, .. } = content;
559
560	if user_id.server_name() != origin {
561		debug_warn!(
562			%user_id, %origin,
563			"received device list update EDU for user not belonging to origin"
564		);
565		return;
566	}
567
568	services
569		.users
570		.mark_device_key_update(&user_id)
571		.await;
572}
573
574async fn handle_edu_direct_to_device(
575	services: &Services,
576	_client: &IpAddr,
577	origin: &ServerName,
578	content: DirectDeviceContent,
579) {
580	let DirectDeviceContent {
581		ref sender,
582		ref ev_type,
583		ref message_id,
584		messages,
585	} = content;
586
587	if sender.server_name() != origin {
588		debug_warn!(
589			%sender, %origin,
590			"received direct to device EDU for user not belonging to origin"
591		);
592		return;
593	}
594
595	// Check if this is a new transaction id
596	if services
597		.transaction_ids
598		.existing_txnid(sender, None, message_id)
599		.await
600		.is_ok()
601	{
602		return;
603	}
604
605	// process messages concurrently for different users
606	let ev_type = ev_type.to_string();
607	messages
608		.into_iter()
609		.stream()
610		.for_each_concurrent(automatic_width(), |(target_user_id, map)| {
611			handle_edu_direct_to_device_user(services, target_user_id, sender, &ev_type, map)
612		})
613		.await;
614
615	// Save transaction id with empty data
616	services
617		.transaction_ids
618		.add_txnid(sender, None, message_id, &[]);
619}
620
621async fn handle_edu_direct_to_device_user<Event: Send + Sync>(
622	services: &Services,
623	target_user_id: OwnedUserId,
624	sender: &UserId,
625	ev_type: &str,
626	map: BTreeMap<DeviceIdOrAllDevices, Raw<Event>>,
627) {
628	for (target_device_id_maybe, event) in map {
629		let Ok(event) = event
630			.deserialize_as()
631			.map_err(|e| err!(Request(InvalidParam(error!("To-Device event is invalid: {e}")))))
632		else {
633			continue;
634		};
635
636		handle_edu_direct_to_device_event(
637			services,
638			&target_user_id,
639			sender,
640			target_device_id_maybe,
641			ev_type,
642			event,
643		)
644		.await;
645	}
646}
647
648async fn handle_edu_direct_to_device_event(
649	services: &Services,
650	target_user_id: &UserId,
651	sender: &UserId,
652	target_device_id_maybe: DeviceIdOrAllDevices,
653	ev_type: &str,
654	event: serde_json::Value,
655) {
656	match target_device_id_maybe {
657		| DeviceIdOrAllDevices::DeviceId(ref target_device_id) => {
658			services.users.add_to_device_event(
659				sender,
660				target_user_id,
661				target_device_id,
662				ev_type,
663				&event,
664			);
665		},
666
667		| DeviceIdOrAllDevices::AllDevices => {
668			services
669				.users
670				.all_device_ids(target_user_id)
671				.ready_for_each(|target_device_id| {
672					services.users.add_to_device_event(
673						sender,
674						target_user_id,
675						target_device_id,
676						ev_type,
677						&event,
678					);
679				})
680				.await;
681		},
682	}
683}
684
685async fn handle_edu_signing_key_update(
686	services: &Services,
687	_client: &IpAddr,
688	origin: &ServerName,
689	content: SigningKeyUpdateContent,
690) {
691	let SigningKeyUpdateContent { user_id, master_key, self_signing_key } = content;
692
693	if user_id.server_name() != origin {
694		debug_warn!(
695			%user_id, %origin,
696			"received signing key update EDU from server that does not belong to user's server"
697		);
698		return;
699	}
700
701	services
702		.users
703		.add_cross_signing_keys(&user_id, &master_key, &self_signing_key, &None, true)
704		.await
705		.log_err()
706		.ok();
707}