1use std::{
2 collections::BTreeMap,
3 net::IpAddr,
4 sync::atomic::{AtomicBool, Ordering},
5 time::{Duration, Instant},
6};
7
8use axum::extract::State;
9use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
10use ruma::{
11 CanonicalJsonObject, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, ServerName,
12 TransactionId, UserId,
13 api::{
14 error::ErrorKind,
15 federation::transactions::{
16 edu::{
17 DeviceListUpdateContent, DirectDeviceContent, Edu, PresenceContent,
18 PresenceUpdate, ReceiptContent, ReceiptData, ReceiptMap, SigningKeyUpdateContent,
19 TypingContent,
20 },
21 send_transaction_message,
22 },
23 },
24 events::receipt::{ReceiptEvent, ReceiptEventContent, ReceiptType},
25 serde::Raw,
26 to_device::DeviceIdOrAllDevices,
27};
28use tuwunel_core::{
29 Err, Error, Result, debug,
30 debug::INFO_SPAN_LEVEL,
31 debug_warn, defer, err, error,
32 itertools::Itertools,
33 result::LogErr,
34 smallvec::SmallVec,
35 trace,
36 utils::{
37 debug::str_truncated,
38 future::TryExtExt,
39 millis_since_unix_epoch,
40 stream::{BroadbandExt, IterStream, ReadyExt, TryBroadbandExt, automatic_width},
41 },
42 warn,
43};
44use tuwunel_service::{
45 Services,
46 sending::{EDU_LIMIT, PDU_LIMIT},
47};
48
49use crate::{ClientIp, Ruma};
50
51type ResolvedMap = BTreeMap<OwnedEventId, Result>;
52type RoomsPdus = SmallVec<[RoomPdus; 1]>;
53type RoomPdus = (OwnedRoomId, TxnPdus);
54type TxnPdus = SmallVec<[(usize, Pdu); 1]>;
55type Pdu = (OwnedRoomId, OwnedEventId, CanonicalJsonObject);
56
57#[tracing::instrument(
61 name = "txn",
62 level = INFO_SPAN_LEVEL,
63 skip_all,
64 fields(
65 txn = str_truncated(body.transaction_id.as_str(), 20),
66 origin = body.origin().as_str(),
67 %client,
68 ),
69)]
70pub(crate) async fn send_transaction_message_route(
71 State(services): State<crate::State>,
72 ClientIp(client): ClientIp,
73 body: Ruma<send_transaction_message::v1::Request>,
74) -> Result<send_transaction_message::v1::Response> {
75 if body.origin() != body.body.origin {
76 return Err!(Request(Forbidden(
77 "Not allowed to send transactions on behalf of other servers"
78 )));
79 }
80
81 if body.pdus.len() > PDU_LIMIT {
82 return Err!(Request(Forbidden(
83 "Not allowed to send more than {PDU_LIMIT} PDUs in one transaction"
84 )));
85 }
86
87 if body.edus.len() > EDU_LIMIT {
88 return Err!(Request(Forbidden(
89 "Not allowed to send more than {EDU_LIMIT} EDUs in one transaction"
90 )));
91 }
92
93 let txn_start_time = Instant::now();
94 trace!(
95 pdus = body.pdus.len(),
96 edus = body.edus.len(),
97 elapsed = ?txn_start_time.elapsed(),
98 "Starting txn",
99 );
100
101 let pdus = body
102 .pdus
103 .iter()
104 .stream()
105 .enumerate()
106 .broad_filter_map(|(i, pdu)| {
107 services
108 .event_handler
109 .parse_incoming_pdu(pdu)
110 .inspect_err(move |e| debug_warn!("Could not parse PDU[{i}]: {e}"))
111 .map_ok(move |pdu| (i, pdu))
112 .ok()
113 });
114
115 let edus = body
116 .edus
117 .iter()
118 .stream()
119 .enumerate()
120 .ready_filter_map(|(i, edu)| {
121 serde_json::from_str(edu.json().get())
122 .inspect_err(|e| debug_warn!("Could not parse EDU[{i}]: {e}"))
123 .map(|edu| (i, edu))
124 .ok()
125 });
126
127 let results = handle(
128 &services,
129 &client,
130 body.origin(),
131 &body.transaction_id,
132 txn_start_time,
133 pdus,
134 edus,
135 )
136 .await?;
137
138 debug!(
139 pdus = body.pdus.len(),
140 edus = body.edus.len(),
141 elapsed = ?txn_start_time.elapsed(),
142 "Finished txn",
143 );
144
145 for (id, result) in &results {
146 if let Err(e) = result
147 && matches!(e, Error::BadRequest(ErrorKind::NotFound, _))
148 {
149 warn!("Incoming PDU failed {id}: {e:?}");
150 }
151 }
152
153 Ok(send_transaction_message::v1::Response {
154 pdus: results
155 .into_iter()
156 .map(|(e, r)| (e, r.map_err(error::sanitized_message)))
157 .collect(),
158 })
159}
160
161async fn handle(
162 services: &Services,
163 client: &IpAddr,
164 origin: &ServerName,
165 txn_id: &TransactionId,
166 started: Instant,
167 pdus: impl Stream<Item = (usize, Pdu)> + Send,
168 edus: impl Stream<Item = (usize, Edu)> + Send,
169) -> Result<ResolvedMap> {
170 let results = handle_pdus(services, client, origin, txn_id, started, pdus).await?;
171
172 handle_edus(services, client, origin, txn_id, edus).await?;
173
174 Ok(results)
175}
176
177async fn handle_pdus(
178 services: &Services,
179 client: &IpAddr,
180 origin: &ServerName,
181 txn_id: &TransactionId,
182 started: Instant,
183 pdus: impl Stream<Item = (usize, Pdu)> + Send,
184) -> Result<ResolvedMap> {
185 pdus.collect()
186 .map(Ok)
187 .map_ok(|pdus: TxnPdus| {
188 pdus.into_iter()
189 .sorted_by(|(_, (room_a, ..)), (_, (room_b, ..))| room_a.cmp(room_b))
190 .into_grouping_map_by(|(_, (room_id, ..))| room_id.clone())
191 .collect()
192 .into_iter()
193 .try_stream()
194 })
195 .try_flatten_stream()
196 .try_collect::<RoomsPdus>()
197 .map_ok(IntoIterator::into_iter)
198 .map_ok(IterStream::try_stream)
199 .try_flatten_stream()
200 .broad_and_then(async |(room_id, pdus)| {
201 handle_room(services, client, origin, txn_id, started, room_id, pdus.into_iter())
202 .map_ok(ResolvedMap::into_iter)
203 .map_ok(IterStream::try_stream)
204 .await
205 })
206 .try_flatten()
207 .try_collect()
208 .await
209}
210
211#[tracing::instrument(
212 name = "room",
213 level = INFO_SPAN_LEVEL,
214 skip_all,
215 fields(%room_id)
216)]
217async fn handle_room(
218 services: &Services,
219 _client: &IpAddr,
220 origin: &ServerName,
221 txn_id: &TransactionId,
222 txn_start_time: Instant,
223 ref room_id: OwnedRoomId,
224 pdus: impl Iterator<Item = (usize, Pdu)> + Send,
225) -> Result<ResolvedMap> {
226 services
227 .event_handler
228 .mutex_federation
229 .lock(room_id)
230 .then(async |_lock| {
231 pdus.enumerate()
232 .try_stream()
233 .and_then(async |pdu| {
234 services.server.check_running().map(|()| pdu) })
236 .and_then(|(ri, (ti, (room_id, event_id, value)))| {
237 let meta = (origin, txn_id, txn_start_time, ti);
238 let pdu = (ri, (room_id, event_id, value));
239 handle_pdu(services, meta, pdu).map(Ok)
240 })
241 .try_collect()
242 .await
243 })
244 .await
245}
246
247#[tracing::instrument(
248 name = "pdu",
249 level = INFO_SPAN_LEVEL,
250 skip_all,
251 fields(%event_id, %ti, %ri)
252)]
253async fn handle_pdu(
254 services: &Services,
255 (origin, txn_id, txn_start_time, ti): (&ServerName, &TransactionId, Instant, usize),
256 (ri, (ref room_id, event_id, value)): (usize, Pdu),
257) -> (OwnedEventId, Result) {
258 let pdu_start_time = Instant::now();
259 let completed: AtomicBool = Default::default();
260 defer! {{
261 if completed.load(Ordering::Acquire) {
262 return;
263 }
264
265 if pdu_start_time.elapsed() >= Duration::from_secs(services.config.client_request_timeout) {
266 error!(
267 %origin, %txn_id, %room_id, %event_id, %ri, %ti,
268 elapsed = ?pdu_start_time.elapsed(),
269 "Incoming transaction processing timed out.",
270 );
271 } else {
272 debug_warn!(
273 %origin, %txn_id, %room_id, %event_id, %ri, %ti,
274 elapsed = ?pdu_start_time.elapsed(),
275 "Incoming transaction processing interrupted.",
276 );
277 }
278 }}
279
280 let result = services
281 .event_handler
282 .handle_incoming_pdu(origin, room_id, &event_id, value, true)
283 .map_ok(|_| ())
284 .boxed()
285 .await;
286
287 completed.store(true, Ordering::Release);
288 debug!(
289 %event_id, ri, ti,
290 pdu_elapsed = ?pdu_start_time.elapsed(),
291 txn_elapsed = ?txn_start_time.elapsed(),
292 "Finished PDU",
293 );
294
295 (event_id.clone(), result)
296}
297
298#[tracing::instrument(name = "edus", level = "debug", skip_all)]
299async fn handle_edus(
300 services: &Services,
301 client: &IpAddr,
302 origin: &ServerName,
303 txn_id: &TransactionId,
304 edus: impl Stream<Item = (usize, Edu)> + Send,
305) -> Result {
306 edus.for_each_concurrent(automatic_width(), |(i, edu)| {
307 handle_edu(services, client, origin, txn_id, i, edu)
308 })
309 .await;
310
311 Ok(())
312}
313
314#[tracing::instrument(
315 name = "edu",
316 level = "debug",
317 skip_all,
318 fields(%i),
319)]
320async fn handle_edu(
321 services: &Services,
322 client: &IpAddr,
323 origin: &ServerName,
324 _txn_id: &TransactionId,
325 i: usize,
326 edu: Edu,
327) {
328 match edu {
329 | Edu::Presence(presence) if services.server.config.allow_incoming_presence =>
330 handle_edu_presence(services, client, origin, presence).await,
331
332 | Edu::Receipt(receipt)
333 if services
334 .server
335 .config
336 .allow_incoming_read_receipts =>
337 handle_edu_receipt(services, client, origin, receipt).await,
338
339 | Edu::Typing(typing) if services.server.config.allow_incoming_typing =>
340 handle_edu_typing(services, client, origin, typing).await,
341
342 | Edu::DeviceListUpdate(content) =>
343 handle_edu_device_list_update(services, client, origin, content).await,
344
345 | Edu::DirectToDevice(content) =>
346 handle_edu_direct_to_device(services, client, origin, content).await,
347
348 | Edu::SigningKeyUpdate(content) =>
349 handle_edu_signing_key_update(services, client, origin, content).await,
350
351 | Edu::_Custom(ref _custom) => debug_warn!(?i, ?edu, "received custom/unknown EDU"),
352
353 | _ => trace!(?i, ?edu, "skipped"),
354 }
355}
356
357async fn handle_edu_presence(
358 services: &Services,
359 _client: &IpAddr,
360 origin: &ServerName,
361 presence: PresenceContent,
362) {
363 presence
364 .push
365 .into_iter()
366 .stream()
367 .for_each_concurrent(automatic_width(), |update| {
368 handle_edu_presence_update(services, origin, update)
369 })
370 .await;
371}
372
373async fn handle_edu_presence_update(
374 services: &Services,
375 origin: &ServerName,
376 update: PresenceUpdate,
377) {
378 if update.user_id.server_name() != origin {
379 debug_warn!(
380 %update.user_id, %origin,
381 "received presence EDU for user not belonging to origin"
382 );
383 return;
384 }
385
386 services
387 .presence
388 .set_presence_from_federation(
389 &update.user_id,
390 &update.presence,
391 update.currently_active,
392 update.last_active_ago,
393 update.status_msg.clone(),
394 )
395 .await
396 .log_err()
397 .ok();
398}
399
400async fn handle_edu_receipt(
401 services: &Services,
402 _client: &IpAddr,
403 origin: &ServerName,
404 receipt: ReceiptContent,
405) {
406 receipt
407 .receipts
408 .into_iter()
409 .stream()
410 .for_each_concurrent(automatic_width(), |(room_id, room_updates)| {
411 handle_edu_receipt_room(services, origin, room_id, room_updates)
412 })
413 .await;
414}
415
416async fn handle_edu_receipt_room(
417 services: &Services,
418 origin: &ServerName,
419 room_id: OwnedRoomId,
420 room_updates: ReceiptMap,
421) {
422 if services
423 .event_handler
424 .acl_check(origin, &room_id)
425 .await
426 .is_err()
427 {
428 debug_warn!(
429 %origin, %room_id,
430 "received read receipt EDU from ACL'd server"
431 );
432 return;
433 }
434
435 let room_id = &room_id;
436 room_updates
437 .read
438 .into_iter()
439 .stream()
440 .for_each_concurrent(automatic_width(), async |(user_id, user_updates)| {
441 handle_edu_receipt_room_user(services, origin, room_id, &user_id, user_updates).await;
442 })
443 .await;
444}
445
446async fn handle_edu_receipt_room_user(
447 services: &Services,
448 origin: &ServerName,
449 room_id: &RoomId,
450 user_id: &UserId,
451 user_updates: ReceiptData,
452) {
453 if user_id.server_name() != origin {
454 debug_warn!(
455 %user_id, %origin,
456 "received read receipt EDU for user not belonging to origin"
457 );
458 return;
459 }
460
461 if !services
462 .state_cache
463 .server_in_room(origin, room_id)
464 .await
465 {
466 debug_warn!(
467 %user_id, %room_id, %origin,
468 "received read receipt EDU from server who does not have a member in the room",
469 );
470 return;
471 }
472
473 let data = &user_updates.data;
474 user_updates
475 .event_ids
476 .into_iter()
477 .stream()
478 .for_each_concurrent(automatic_width(), async |event_id| {
479 let user_data = [(user_id.to_owned(), data.clone())];
480 let receipts = [(ReceiptType::Read, BTreeMap::from(user_data))];
481 let content = [(event_id.clone(), BTreeMap::from(receipts))];
482 services
483 .read_receipt
484 .readreceipt_update(user_id, room_id, &ReceiptEvent {
485 content: ReceiptEventContent(content.into()),
486 room_id: room_id.to_owned(),
487 })
488 .await;
489 })
490 .await;
491}
492
493async fn handle_edu_typing(
494 services: &Services,
495 _client: &IpAddr,
496 origin: &ServerName,
497 typing: TypingContent,
498) {
499 if typing.user_id.server_name() != origin {
500 debug_warn!(
501 %typing.user_id, %origin,
502 "received typing EDU for user not belonging to origin"
503 );
504 return;
505 }
506
507 if services
508 .event_handler
509 .acl_check(typing.user_id.server_name(), &typing.room_id)
510 .await
511 .is_err()
512 {
513 debug_warn!(
514 %typing.user_id, %typing.room_id, %origin,
515 "received typing EDU for ACL'd user's server"
516 );
517 return;
518 }
519
520 if !services
521 .state_cache
522 .is_joined(&typing.user_id, &typing.room_id)
523 .await
524 {
525 debug_warn!(
526 %typing.user_id, %typing.room_id, %origin,
527 "received typing EDU for user not in room"
528 );
529 return;
530 }
531
532 if typing.typing {
533 let secs = services.server.config.typing_federation_timeout_s;
534 let timeout = millis_since_unix_epoch().saturating_add(secs.saturating_mul(1000));
535
536 services
537 .typing
538 .typing_add(&typing.user_id, &typing.room_id, timeout)
539 .await
540 .log_err()
541 .ok();
542 } else {
543 services
544 .typing
545 .typing_remove(&typing.user_id, &typing.room_id)
546 .await
547 .log_err()
548 .ok();
549 }
550}
551
552async fn handle_edu_device_list_update(
553 services: &Services,
554 _client: &IpAddr,
555 origin: &ServerName,
556 content: DeviceListUpdateContent,
557) {
558 let DeviceListUpdateContent { user_id, .. } = content;
559
560 if user_id.server_name() != origin {
561 debug_warn!(
562 %user_id, %origin,
563 "received device list update EDU for user not belonging to origin"
564 );
565 return;
566 }
567
568 services
569 .users
570 .mark_device_key_update(&user_id)
571 .await;
572}
573
574async fn handle_edu_direct_to_device(
575 services: &Services,
576 _client: &IpAddr,
577 origin: &ServerName,
578 content: DirectDeviceContent,
579) {
580 let DirectDeviceContent {
581 ref sender,
582 ref ev_type,
583 ref message_id,
584 messages,
585 } = content;
586
587 if sender.server_name() != origin {
588 debug_warn!(
589 %sender, %origin,
590 "received direct to device EDU for user not belonging to origin"
591 );
592 return;
593 }
594
595 if services
597 .transaction_ids
598 .existing_txnid(sender, None, message_id)
599 .await
600 .is_ok()
601 {
602 return;
603 }
604
605 let ev_type = ev_type.to_string();
607 messages
608 .into_iter()
609 .stream()
610 .for_each_concurrent(automatic_width(), |(target_user_id, map)| {
611 handle_edu_direct_to_device_user(services, target_user_id, sender, &ev_type, map)
612 })
613 .await;
614
615 services
617 .transaction_ids
618 .add_txnid(sender, None, message_id, &[]);
619}
620
621async fn handle_edu_direct_to_device_user<Event: Send + Sync>(
622 services: &Services,
623 target_user_id: OwnedUserId,
624 sender: &UserId,
625 ev_type: &str,
626 map: BTreeMap<DeviceIdOrAllDevices, Raw<Event>>,
627) {
628 for (target_device_id_maybe, event) in map {
629 let Ok(event) = event
630 .deserialize_as()
631 .map_err(|e| err!(Request(InvalidParam(error!("To-Device event is invalid: {e}")))))
632 else {
633 continue;
634 };
635
636 handle_edu_direct_to_device_event(
637 services,
638 &target_user_id,
639 sender,
640 target_device_id_maybe,
641 ev_type,
642 event,
643 )
644 .await;
645 }
646}
647
648async fn handle_edu_direct_to_device_event(
649 services: &Services,
650 target_user_id: &UserId,
651 sender: &UserId,
652 target_device_id_maybe: DeviceIdOrAllDevices,
653 ev_type: &str,
654 event: serde_json::Value,
655) {
656 match target_device_id_maybe {
657 | DeviceIdOrAllDevices::DeviceId(ref target_device_id) => {
658 services.users.add_to_device_event(
659 sender,
660 target_user_id,
661 target_device_id,
662 ev_type,
663 &event,
664 );
665 },
666
667 | DeviceIdOrAllDevices::AllDevices => {
668 services
669 .users
670 .all_device_ids(target_user_id)
671 .ready_for_each(|target_device_id| {
672 services.users.add_to_device_event(
673 sender,
674 target_user_id,
675 target_device_id,
676 ev_type,
677 &event,
678 );
679 })
680 .await;
681 },
682 }
683}
684
685async fn handle_edu_signing_key_update(
686 services: &Services,
687 _client: &IpAddr,
688 origin: &ServerName,
689 content: SigningKeyUpdateContent,
690) {
691 let SigningKeyUpdateContent { user_id, master_key, self_signing_key } = content;
692
693 if user_id.server_name() != origin {
694 debug_warn!(
695 %user_id, %origin,
696 "received signing key update EDU from server that does not belong to user's server"
697 );
698 return;
699 }
700
701 services
702 .users
703 .add_cross_signing_keys(&user_id, &master_key, &self_signing_key, &None, true)
704 .await
705 .log_err()
706 .ok();
707}