1use std::sync::Arc;
2
3use futures::{Stream, StreamExt, TryFutureExt, future::Either, pin_mut};
4use ruma::{
5 CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedEventId, RoomId, UserId,
6 api::Direction,
7 events::{reaction::ReactionEventContent, relation::RelationType, room::encrypted::Relation},
8};
9use serde::Deserialize;
10use tuwunel_core::{
11 PduId, Result,
12 arrayvec::ArrayVec,
13 implement, is_equal_to,
14 matrix::{Event, Pdu, PduCount, RawPduId, event::RelationTypeEqual},
15 result::LogErr,
16 trace,
17 utils::{
18 BoolExt,
19 stream::{ReadyExt, TryIgnore, WidebandExt, automatic_width},
20 u64_from_u8,
21 },
22};
23use tuwunel_database::{Interfix, Map};
24
25use crate::rooms::short::ShortRoomId;
26
27#[cfg(test)]
28mod tests;
29
30pub struct Service {
31 services: Arc<crate::services::OnceServices>,
32 db: Data,
33}
34
35struct Data {
36 tofrom_relation: Arc<Map>,
37 relatesto_typed: Arc<Map>,
38 referencedevents: Arc<Map>,
39 softfailedeventids: Arc<Map>,
40}
41
42#[derive(Clone, Copy)]
46enum RelTag {
47 Replace = 0x01,
48 Reference = 0x02,
49}
50
51impl From<RelTag> for u8 {
52 #[inline]
53 fn from(tag: RelTag) -> Self {
54 match tag {
55 | RelTag::Replace => 0x01,
56 | RelTag::Reference => 0x02,
57 }
58 }
59}
60
61const TYPED_PREFIX_LEN: usize = size_of::<u64>() * 2 + size_of::<u8>();
63
64const TYPED_KEY_LEN: usize = TYPED_PREFIX_LEN + size_of::<u64>() * 2;
66
67const TYPED_CHILD_COUNT_OFFSET: usize = TYPED_KEY_LEN - size_of::<u64>();
69
70const REFERENCE_BUNDLE_MAX: usize = 100;
72
73#[derive(Deserialize)]
74struct ExtractRelatesTo {
75 #[serde(rename = "m.relates_to")]
76 relates_to: Relation,
77}
78
79impl crate::Service for Service {
80 fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
81 Ok(Arc::new(Self {
82 services: args.services.clone(),
83 db: Data {
84 tofrom_relation: args.db["tofrom_relation"].clone(),
85 relatesto_typed: args.db["relatesto_typed"].clone(),
86 referencedevents: args.db["referencedevents"].clone(),
87 softfailedeventids: args.db["softfailedeventids"].clone(),
88 },
89 }))
90 }
91
92 fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
93}
94
95#[implement(Service)]
96#[tracing::instrument(skip(self, from, to), level = "debug")]
97pub fn add_relation(&self, from: PduCount, to: PduCount) {
98 const BUFSIZE: usize = size_of::<u64>() * 2;
99
100 match (from, to) {
101 | (PduCount::Normal(from), PduCount::Normal(to)) => {
102 let key: &[u64] = &[to, from];
103 self.db
104 .tofrom_relation
105 .aput_raw::<BUFSIZE, _, _>(key, []);
106 },
107 | _ => {}, }
109}
110
111#[implement(Service)]
116#[tracing::instrument(skip(self, child), level = "debug")]
117pub async fn add_typed_relation<E: Event>(
118 &self,
119 shortroomid: ShortRoomId,
120 child_count: PduCount,
121 parent: &EventId,
122 child: &E,
123 rel_type: RelationType,
124) {
125 let Some(tag) = rel_type_tag(&rel_type) else {
126 return;
127 };
128
129 let Ok(parent_count) = self.services.timeline.get_pdu_count(parent).await else {
130 return;
131 };
132
133 let (PduCount::Normal(_), PduCount::Normal(_)) = (parent_count, child_count) else {
134 return; };
136
137 let child_short = self
138 .services
139 .short
140 .get_or_create_shorteventid(child.event_id())
141 .await;
142
143 let child_ts = u64::from(child.origin_server_ts().get());
144 let key = typed_relation_key(shortroomid, parent_count, tag, child_ts, child_count);
145
146 self.db
147 .relatesto_typed
148 .aput_raw::<TYPED_KEY_LEN, _, _>(key.as_slice(), child_short.to_be_bytes());
149}
150
151#[implement(Service)]
155pub async fn event_has_relation(
156 &self,
157 event_id: &EventId,
158 user_id: Option<&UserId>,
159 rel_type: Option<&RelationType>,
160 key: Option<&str>,
161) -> bool {
162 let Ok(pdu_id) = self.services.timeline.get_pdu_id(event_id).await else {
163 return false;
164 };
165
166 self.has_relation(pdu_id.into(), user_id, rel_type, key)
167 .await
168}
169
170#[implement(Service)]
174pub async fn has_relation(
175 &self,
176 target: PduId,
177 user_id: Option<&UserId>,
178 rel_type: Option<&RelationType>,
179 key: Option<&str>,
180) -> bool {
181 self.get_relations(target.shortroomid, target.count, None, Direction::Forward, None)
182 .ready_filter(|(_, pdu)| user_id.is_none_or(is_equal_to!(pdu.sender())))
183 .ready_filter(|(_, pdu)| {
184 debug_assert!(
185 key.is_none() || rel_type.is_none_or(is_equal_to!(&RelationType::Annotation)),
186 "key argument only applies to Annotation type relations."
187 );
188
189 key.is_some() || rel_type
191 .is_none_or(|rel_type| rel_type.relation_type_equal(&pdu))
192 })
193 .ready_filter(|(_, pdu)| {
194 key.is_none_or(|key| {
195 pdu.get_content::<ReactionEventContent>()
196 .map(|content| content.relates_to.key == key)
197 .unwrap_or(false)
198 })
199 })
200 .ready_any(|_| true) .await
202}
203
204#[implement(Service)]
205pub fn get_relations<'a>(
206 &'a self,
207 shortroomid: ShortRoomId,
208 target: PduCount,
209 from: Option<PduCount>,
210 dir: Direction,
211 user_id: Option<&'a UserId>,
212) -> impl Stream<Item = (PduCount, Pdu)> + Send + '_ {
213 let target = target.to_be_bytes();
214 let from = from
215 .map(|from| from.saturating_inc(dir))
216 .unwrap_or_else(|| match dir {
217 | Direction::Backward => PduCount::max(),
218 | Direction::Forward => PduCount::default(),
219 })
220 .to_be_bytes();
221
222 let mut buf = ArrayVec::<u8, 16>::new();
223 let start = {
224 buf.extend(target);
225 buf.extend(from);
226 buf.as_slice()
227 };
228
229 match dir {
230 | Direction::Backward => Either::Left(self.db.tofrom_relation.rev_raw_keys_from(start)),
231 | Direction::Forward => Either::Right(self.db.tofrom_relation.raw_keys_from(start)),
232 }
233 .ignore_err()
234 .ready_take_while(move |key| key.starts_with(&target))
235 .map(|to_from| u64_from_u8(&to_from[8..16]))
236 .map(PduCount::from_unsigned)
237 .map(move |count| (user_id, shortroomid, count))
238 .wide_filter_map(async |(user_id, shortroomid, count)| {
239 let pdu_id: RawPduId = PduId { shortroomid, count }.into();
240 self.services
241 .timeline
242 .get_pdu_from_id(&pdu_id)
243 .map_ok(move |mut pdu| {
244 if user_id.is_none_or(|user_id| pdu.sender() != user_id) {
245 pdu.as_mut_pdu()
246 .remove_transaction_id()
247 .log_err()
248 .ok();
249 }
250
251 (count, pdu)
252 })
253 .await
254 .ok()
255 })
256}
257
258#[implement(Service)]
268pub async fn bundle_aggregations(&self, sender_user: &UserId, mut pdu: Pdu) -> Pdu {
269 let has_thread = pdu
270 .unsigned()
271 .is_some_and(|unsigned| unsigned.get().contains("m.thread"));
272
273 if has_thread {
274 let participated = self
275 .services
276 .threads
277 .user_participated(pdu.event_id(), sender_user)
278 .await;
279
280 pdu.set_thread_participated(participated)
281 .log_err()
282 .ok();
283 }
284
285 let replacement = self
286 .services
287 .server
288 .config
289 .bundle_edit_relations
290 .then_async(|| self.newest_replacement(&pdu))
291 .await
292 .flatten();
293
294 if let Some(replacement) = replacement {
295 pdu.set_replacement_bundle(&replacement.to_format())
296 .log_err()
297 .ok();
298 }
299
300 let references = self
301 .services
302 .server
303 .config
304 .bundle_reference_relations
305 .then_async(|| self.references(&pdu))
306 .await
307 .unwrap_or_default();
308
309 if !references.is_empty() {
310 pdu.set_reference_bundle(&references)
311 .log_err()
312 .ok();
313 }
314
315 pdu
316}
317
318#[implement(Service)]
323async fn newest_replacement(&self, parent: &Pdu) -> Option<Pdu> {
324 if parent.is_redacted() {
325 return None;
326 }
327
328 let parent_id: PduId = self
329 .services
330 .timeline
331 .get_pdu_id(parent.event_id())
332 .map_ok(Into::into)
333 .await
334 .ok()?;
335
336 let replacements = self.replacement_children(parent, parent_id);
337
338 pin_mut!(replacements);
339 replacements.next().await
340}
341
342#[implement(Service)]
346fn replacement_children<'a>(
347 &'a self,
348 parent: &'a Pdu,
349 parent_id: PduId,
350) -> impl Stream<Item = Pdu> + Send + 'a {
351 let shortroomid = parent_id.shortroomid;
352 let prefix = typed_relation_prefix(shortroomid, parent_id.count, RelTag::Replace);
353
354 let mut seek = ArrayVec::<u8, TYPED_KEY_LEN>::new();
355 seek.extend(prefix.iter().copied());
356 seek.extend([u8::MAX; size_of::<u64>() * 2]);
357
358 self.db
359 .relatesto_typed
360 .rev_raw_keys_from(seek.as_slice())
361 .ignore_err()
362 .ready_take_while(move |key| key.starts_with(&prefix))
363 .map(|key| u64_from_u8(&key[TYPED_CHILD_COUNT_OFFSET..TYPED_KEY_LEN]))
364 .map(PduCount::from_unsigned)
365 .map(move |count| (shortroomid, count))
366 .filter_map(async |(shortroomid, count)| {
367 let child_id: RawPduId = PduId { shortroomid, count }.into();
368 self.services
369 .timeline
370 .get_pdu_from_id(&child_id)
371 .await
372 .ok()
373 .filter(|child| !child.is_redacted())
374 .filter(|child| child.sender() == parent.sender())
375 .filter(|child| child.kind() == parent.kind())
376 })
377}
378
379#[implement(Service)]
389async fn references(&self, parent: &Pdu) -> Vec<OwnedEventId> {
390 if parent.is_redacted() {
391 return Vec::new();
392 }
393
394 let Ok(parent_id) = self
395 .services
396 .timeline
397 .get_pdu_id(parent.event_id())
398 .map_ok(PduId::from)
399 .await
400 else {
401 return Vec::new();
402 };
403
404 self.referenced_children(parent_id)
405 .take(REFERENCE_BUNDLE_MAX)
406 .collect()
407 .await
408}
409
410#[implement(Service)]
414fn referenced_children<'a>(
415 &'a self,
416 parent_id: PduId,
417) -> impl Stream<Item = OwnedEventId> + Send + 'a {
418 let prefix = typed_relation_prefix(parent_id.shortroomid, parent_id.count, RelTag::Reference);
419 let seek = prefix.clone();
420
421 self.db
422 .relatesto_typed
423 .raw_stream_from(seek.as_slice())
424 .ignore_err()
425 .ready_take_while(move |(key, _)| key.starts_with(&prefix))
426 .map(|(_, val)| u64_from_u8(val))
427 .wide_filter_map(async |short| {
428 self.services
429 .short
430 .get_eventid_from_short(short)
431 .await
432 .ok()
433 })
434}
435
436#[implement(Service)]
437#[tracing::instrument(skip_all, level = "debug")]
438pub fn mark_as_referenced<'a, I>(&self, room_id: &RoomId, event_ids: I)
439where
440 I: Iterator<Item = &'a EventId>,
441{
442 for prev in event_ids {
443 let key = (room_id, prev);
444 self.db.referencedevents.put_raw(key, []);
445 }
446}
447
448#[implement(Service)]
449#[tracing::instrument(skip(self), level = "debug", ret)]
450pub async fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> bool {
451 let key = (room_id, event_id);
452 self.db.referencedevents.qry(&key).await.is_ok()
453}
454
455#[implement(Service)]
456#[tracing::instrument(skip(self), level = "debug")]
457pub fn mark_event_soft_failed(&self, event_id: &EventId) {
458 self.db.softfailedeventids.insert(event_id, []);
459}
460
461#[implement(Service)]
462#[tracing::instrument(skip(self), level = "debug", ret)]
463pub async fn is_event_soft_failed(&self, event_id: &EventId) -> bool {
464 self.db
465 .softfailedeventids
466 .get(event_id)
467 .await
468 .is_ok()
469}
470
471#[implement(Service)]
472#[tracing::instrument(skip(self), level = "debug")]
473pub async fn delete_all_referenced_for_room(&self, room_id: &RoomId) -> Result {
474 let prefix = (room_id, Interfix);
475
476 self.db
477 .referencedevents
478 .keys_prefix_raw(&prefix)
479 .ignore_err()
480 .ready_for_each(|key| {
481 trace!(?key, "Removing key");
482 self.db.referencedevents.remove(key);
483 })
484 .await;
485
486 Ok(())
487}
488
489#[implement(Service)]
494#[tracing::instrument(skip_all, level = "debug")]
495pub async fn delete_typed_relation(&self, child_id: &RawPduId, child: &CanonicalJsonObject) {
496 let Some(relates_to) = child
497 .get("content")
498 .and_then(CanonicalJsonValue::as_object)
499 .and_then(|content| content.get("m.relates_to"))
500 .and_then(CanonicalJsonValue::as_object)
501 else {
502 return;
503 };
504
505 let tag = match relates_to
506 .get("rel_type")
507 .and_then(CanonicalJsonValue::as_str)
508 {
509 | Some("m.replace") => RelTag::Replace,
510 | Some("m.reference") => RelTag::Reference,
511 | _ => return,
512 };
513
514 let Some(parent) = relates_to
515 .get("event_id")
516 .and_then(CanonicalJsonValue::as_str)
517 .and_then(|parent| EventId::parse(parent).ok())
518 else {
519 return;
520 };
521
522 let Some(child_ts) = child
523 .get("origin_server_ts")
524 .and_then(CanonicalJsonValue::as_integer)
525 .and_then(|ts| u64::try_from(i64::from(ts)).ok())
526 else {
527 return;
528 };
529
530 let child_count = child_id.pdu_count();
531 let shortroomid = u64_from_u8(&child_id.shortroomid());
532
533 let Ok(parent_count) = self
534 .services
535 .timeline
536 .get_pdu_count(&parent)
537 .await
538 else {
539 return;
540 };
541
542 let (PduCount::Normal(_), PduCount::Normal(_)) = (parent_count, child_count) else {
543 return;
544 };
545
546 let key = typed_relation_key(shortroomid, parent_count, tag, child_ts, child_count);
547
548 self.db.relatesto_typed.remove(key.as_slice());
549}
550
551#[implement(Service)]
552#[tracing::instrument(skip(self), level = "debug")]
553pub async fn delete_all_relatesto_typed_for_room(&self, room_id: &RoomId) -> Result {
554 let Ok(shortroomid) = self.services.short.get_shortroomid(room_id).await else {
555 return Ok(());
556 };
557
558 self.db
559 .relatesto_typed
560 .keys_prefix_raw(&shortroomid)
561 .ignore_err()
562 .ready_for_each(|key| {
563 self.db.relatesto_typed.remove(key);
564 })
565 .await;
566
567 Ok(())
568}
569
570#[implement(Service)]
574pub async fn rebuild_typed_relations(&self) -> Result {
575 self.db.relatesto_typed.clear().await;
576
577 let pduid_pdu = self.services.db["pduid_pdu"].clone();
578
579 pduid_pdu
580 .raw_stream()
581 .ignore_err()
582 .ready_filter_map(|(key, value)| {
583 let pdu_id = RawPduId::from(key);
584 let pdu = serde_json::from_slice::<Pdu>(value).ok()?;
585
586 Some((pdu_id, pdu))
587 })
588 .for_each_concurrent(automatic_width(), async |(pdu_id, pdu)| {
589 self.index_pdu_relations(pdu_id, &pdu).await;
590 })
591 .await;
592
593 Ok(())
594}
595
596#[implement(Service)]
597async fn index_pdu_relations(&self, pdu_id: RawPduId, pdu: &Pdu) {
598 let Ok(content) = pdu.get_content::<ExtractRelatesTo>() else {
599 return;
600 };
601
602 let (rel_type, parent) = match content.relates_to {
603 | Relation::Replacement(replacement) => (RelationType::Replacement, replacement.event_id),
604 | Relation::Reference(reference) => (RelationType::Reference, reference.event_id),
605 | _ => return,
606 };
607
608 let shortroomid = u64_from_u8(&pdu_id.shortroomid());
609
610 self.add_typed_relation(shortroomid, pdu_id.pdu_count(), &parent, pdu, rel_type)
611 .await;
612}
613
614fn rel_type_tag(rel_type: &RelationType) -> Option<RelTag> {
615 match rel_type {
616 | RelationType::Replacement => Some(RelTag::Replace),
617 | RelationType::Reference => Some(RelTag::Reference),
618 | _ => None,
619 }
620}
621
622fn typed_relation_prefix(
623 shortroomid: ShortRoomId,
624 parent: PduCount,
625 tag: RelTag,
626) -> ArrayVec<u8, TYPED_PREFIX_LEN> {
627 let mut buf = ArrayVec::new();
628 buf.extend(shortroomid.to_be_bytes());
629 buf.extend(parent.to_be_bytes());
630 buf.push(u8::from(tag));
631 buf
632}
633
634fn typed_relation_key(
635 shortroomid: ShortRoomId,
636 parent: PduCount,
637 tag: RelTag,
638 child_ts: u64,
639 child: PduCount,
640) -> ArrayVec<u8, TYPED_KEY_LEN> {
641 let mut buf = ArrayVec::new();
642 buf.extend(shortroomid.to_be_bytes());
643 buf.extend(parent.to_be_bytes());
644 buf.push(u8::from(tag));
645 buf.extend(child_ts.to_be_bytes());
646 buf.extend(child.to_be_bytes());
647 buf
648}