Skip to main content

tuwunel_service/rooms/pdu_metadata/
mod.rs

1use std::sync::Arc;
2
3use futures::{Stream, StreamExt, TryFutureExt, future::Either};
4use ruma::{
5	EventId, RoomId, UserId,
6	api::Direction,
7	events::{reaction::ReactionEventContent, relation::RelationType},
8};
9use tuwunel_core::{
10	PduId, Result,
11	arrayvec::ArrayVec,
12	implement, is_equal_to,
13	matrix::{Event, Pdu, PduCount, RawPduId, event::RelationTypeEqual},
14	result::LogErr,
15	trace,
16	utils::{
17		stream::{ReadyExt, TryIgnore, WidebandExt},
18		u64_from_u8,
19	},
20};
21use tuwunel_database::{Interfix, Map};
22
23use crate::rooms::short::ShortRoomId;
24
25pub struct Service {
26	services: Arc<crate::services::OnceServices>,
27	db: Data,
28}
29
30struct Data {
31	tofrom_relation: Arc<Map>,
32	referencedevents: Arc<Map>,
33	softfailedeventids: Arc<Map>,
34}
35
36impl crate::Service for Service {
37	fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
38		Ok(Arc::new(Self {
39			services: args.services.clone(),
40			db: Data {
41				tofrom_relation: args.db["tofrom_relation"].clone(),
42				referencedevents: args.db["referencedevents"].clone(),
43				softfailedeventids: args.db["softfailedeventids"].clone(),
44			},
45		}))
46	}
47
48	fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
49}
50
51#[implement(Service)]
52#[tracing::instrument(skip(self, from, to), level = "debug")]
53pub fn add_relation(&self, from: PduCount, to: PduCount) {
54	const BUFSIZE: usize = size_of::<u64>() * 2;
55
56	match (from, to) {
57		| (PduCount::Normal(from), PduCount::Normal(to)) => {
58			let key: &[u64] = &[to, from];
59			self.db
60				.tofrom_relation
61				.aput_raw::<BUFSIZE, _, _>(key, []);
62		},
63		| _ => {}, // TODO: Relations with backfilled pdus
64	}
65}
66
67/// Query relations of an event to determine if matching any of the trailing
68/// arguments. When all criteria are None the mere presence of a relation causes
69/// this function to return true.
70#[implement(Service)]
71pub async fn event_has_relation(
72	&self,
73	event_id: &EventId,
74	user_id: Option<&UserId>,
75	rel_type: Option<&RelationType>,
76	key: Option<&str>,
77) -> bool {
78	let Ok(pdu_id) = self.services.timeline.get_pdu_id(event_id).await else {
79		return false;
80	};
81
82	self.has_relation(pdu_id.into(), user_id, rel_type, key)
83		.await
84}
85
86/// Query relations of an event by PduId to determine if matching any of the
87/// trailing arguments. When all criteria are None the mere presence of a
88/// relation causes this function to return true.
89#[implement(Service)]
90pub async fn has_relation(
91	&self,
92	target: PduId,
93	user_id: Option<&UserId>,
94	rel_type: Option<&RelationType>,
95	key: Option<&str>,
96) -> bool {
97	self.get_relations(target.shortroomid, target.count, None, Direction::Forward, None)
98		.ready_filter(|(_, pdu)| user_id.is_none_or(is_equal_to!(pdu.sender())))
99		.ready_filter(|(_, pdu)| {
100			debug_assert!(
101				key.is_none() || rel_type.is_none_or(is_equal_to!(&RelationType::Annotation)),
102				"key argument only applies to Annotation type relations."
103			);
104
105			// When key is supplied we don't need to double-parse the content here and below.
106			key.is_some() || rel_type
107				.is_none_or(|rel_type| rel_type.relation_type_equal(&pdu))
108		})
109		.ready_filter(|(_, pdu)| {
110			key.is_none_or(|key| {
111				pdu.get_content::<ReactionEventContent>()
112					.map(|content| content.relates_to.key == key)
113					.unwrap_or(false)
114			})
115		})
116		.ready_any(|_| true) // first match or false
117		.await
118}
119
120#[implement(Service)]
121pub fn get_relations<'a>(
122	&'a self,
123	shortroomid: ShortRoomId,
124	target: PduCount,
125	from: Option<PduCount>,
126	dir: Direction,
127	user_id: Option<&'a UserId>,
128) -> impl Stream<Item = (PduCount, Pdu)> + Send + '_ {
129	let target = target.to_be_bytes();
130	let from = from
131		.map(|from| from.saturating_inc(dir))
132		.unwrap_or_else(|| match dir {
133			| Direction::Backward => PduCount::max(),
134			| Direction::Forward => PduCount::default(),
135		})
136		.to_be_bytes();
137
138	let mut buf = ArrayVec::<u8, 16>::new();
139	let start = {
140		buf.extend(target);
141		buf.extend(from);
142		buf.as_slice()
143	};
144
145	match dir {
146		| Direction::Backward => Either::Left(self.db.tofrom_relation.rev_raw_keys_from(start)),
147		| Direction::Forward => Either::Right(self.db.tofrom_relation.raw_keys_from(start)),
148	}
149	.ignore_err()
150	.ready_take_while(move |key| key.starts_with(&target))
151	.map(|to_from| u64_from_u8(&to_from[8..16]))
152	.map(PduCount::from_unsigned)
153	.map(move |count| (user_id, shortroomid, count))
154	.wide_filter_map(async |(user_id, shortroomid, count)| {
155		let pdu_id: RawPduId = PduId { shortroomid, count }.into();
156		self.services
157			.timeline
158			.get_pdu_from_id(&pdu_id)
159			.map_ok(move |mut pdu| {
160				if user_id.is_none_or(|user_id| pdu.sender() != user_id) {
161					pdu.as_mut_pdu()
162						.remove_transaction_id()
163						.log_err()
164						.ok();
165				}
166
167				(count, pdu)
168			})
169			.await
170			.ok()
171	})
172}
173
174#[implement(Service)]
175#[tracing::instrument(skip_all, level = "debug")]
176pub fn mark_as_referenced<'a, I>(&self, room_id: &RoomId, event_ids: I)
177where
178	I: Iterator<Item = &'a EventId>,
179{
180	for prev in event_ids {
181		let key = (room_id, prev);
182		self.db.referencedevents.put_raw(key, []);
183	}
184}
185
186#[implement(Service)]
187#[tracing::instrument(skip(self), level = "debug", ret)]
188pub async fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> bool {
189	let key = (room_id, event_id);
190	self.db.referencedevents.qry(&key).await.is_ok()
191}
192
193#[implement(Service)]
194#[tracing::instrument(skip(self), level = "debug")]
195pub fn mark_event_soft_failed(&self, event_id: &EventId) {
196	self.db.softfailedeventids.insert(event_id, []);
197}
198
199#[implement(Service)]
200#[tracing::instrument(skip(self), level = "debug", ret)]
201pub async fn is_event_soft_failed(&self, event_id: &EventId) -> bool {
202	self.db
203		.softfailedeventids
204		.get(event_id)
205		.await
206		.is_ok()
207}
208
209#[implement(Service)]
210#[tracing::instrument(skip(self), level = "debug")]
211pub async fn delete_all_referenced_for_room(&self, room_id: &RoomId) -> Result {
212	let prefix = (room_id, Interfix);
213
214	self.db
215		.referencedevents
216		.keys_prefix_raw(&prefix)
217		.ignore_err()
218		.ready_for_each(|key| {
219			trace!(?key, "Removing key");
220			self.db.referencedevents.remove(key);
221		})
222		.await;
223
224	Ok(())
225}