Skip to main content

tuwunel_service/rooms/read_receipt/
mod.rs

1mod data;
2#[cfg(test)]
3mod tests;
4
5use std::{collections::BTreeMap, sync::Arc};
6
7use futures::{Stream, StreamExt};
8use ruma::{
9	OwnedEventId, OwnedUserId, RoomId, UserId,
10	api::appservice::event::push_events::v1::EphemeralData,
11	events::{
12		AnySyncEphemeralRoomEvent, SyncEphemeralRoomEvent,
13		receipt::{
14			Receipt, ReceiptEvent, ReceiptEventContent, ReceiptThread, ReceiptType, Receipts,
15		},
16	},
17	serde::Raw,
18};
19use tuwunel_core::{
20	Result, debug, err,
21	matrix::{
22		Event,
23		pdu::{PduCount, PduId, RawPduId},
24	},
25	smallstr::SmallString,
26	smallvec::SmallVec,
27	trace,
28	utils::IterStream,
29	warn,
30};
31
32use self::data::{Data, ReceiptItem};
33
34/// Private read receipts surfaced by `private_read_get`. One legacy
35/// unthreaded row plus zero or more per-thread rows; inline-1 catches the
36/// dominant case (a single unthreaded marker) without a heap alloc.
37pub type PrivateReadEvents = SmallVec<[Raw<AnySyncEphemeralRoomEvent>; 1]>;
38
39/// Stored thread-kind tag: `""` for `Unthreaded`, `"main"` for `Main`, or
40/// the event-id string for `Thread(...)`. v3+ event ids are 44 bytes
41/// including the leading `$`; 48 bytes inline matches the project's
42/// `StateKey` budget and stays inline for every realistic thread root.
43type ThreadKind = SmallString<[u8; 48]>;
44
45pub struct Service {
46	services: Arc<crate::services::OnceServices>,
47	db: Data,
48}
49
50impl crate::Service for Service {
51	fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
52		Ok(Arc::new(Self {
53			services: args.services.clone(),
54			db: Data::new(args),
55		}))
56	}
57
58	fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
59}
60
61impl Service {
62	/// Replaces the previous read receipt.
63	#[tracing::instrument(skip(self), level = "debug", name = "set_receipt")]
64	pub async fn readreceipt_update(
65		&self,
66		user_id: &UserId,
67		room_id: &RoomId,
68		event: &ReceiptEvent,
69	) {
70		// update local
71		self.db
72			.readreceipt_update(user_id, room_id, event)
73			.await;
74
75		// update appservices
76		self.services
77			.sending
78			.send_edu_room_appservices(room_id, |buf| {
79				let edu = EphemeralData::Receipt(ReceiptEvent {
80					content: event.content.clone(),
81					room_id: room_id.to_owned(),
82				});
83
84				Ok(serde_json::to_writer(buf, &edu)?)
85			})
86			.await
87			.expect("edu serialization or flush failed");
88
89		// update federation
90		if self.services.globals.user_is_local(user_id) {
91			self.services
92				.sending
93				.flush_room(room_id)
94				.await
95				.expect("room flush failed");
96		}
97	}
98
99	/// Gets every stored private read receipt for `(room, user)`. Returns
100	/// one ephemeral event per stored row (legacy unthreaded plus per-thread
101	/// rows). An empty result means no marker is set.
102	#[tracing::instrument(skip(self), level = "debug", name = "get_private")]
103	pub async fn private_read_get(
104		&self,
105		room_id: &RoomId,
106		user_id: &UserId,
107	) -> Result<PrivateReadEvents> {
108		let shortroomid = self
109			.services
110			.short
111			.get_shortroomid(room_id)
112			.await
113			.map_err(|e| {
114				err!(Database(warn!(
115					"Short room ID does not exist in database for {room_id}: {e}"
116				)))
117			})?;
118
119		let legacy = self
120			.private_read_get_count(room_id, user_id)
121			.await
122			.ok()
123			.map(|count| (ThreadKind::new(), count));
124
125		let threaded: SmallVec<[(ThreadKind, u64); 1]> = self
126			.db
127			.private_read_threaded_stream(room_id, user_id)
128			.collect()
129			.await;
130
131		let events = legacy
132			.into_iter()
133			.chain(threaded)
134			.stream()
135			.filter_map(|(kind, count)| async move {
136				self.build_private_read_event(shortroomid, count, user_id, &kind)
137					.await
138			})
139			.collect()
140			.await;
141
142		Ok(events)
143	}
144
145	async fn build_private_read_event(
146		&self,
147		shortroomid: u64,
148		count: u64,
149		user_id: &UserId,
150		thread_kind: &str,
151	) -> Option<Raw<AnySyncEphemeralRoomEvent>> {
152		let pdu_id: RawPduId = PduId {
153			shortroomid,
154			count: PduCount::Normal(count),
155		}
156		.into();
157		let pdu = self
158			.services
159			.timeline
160			.get_pdu_from_id(&pdu_id)
161			.await
162			.ok()?;
163
164		let thread = thread_kind_to_receipt(thread_kind);
165		let event_id: OwnedEventId = pdu.event_id().to_owned();
166		let user_id: OwnedUserId = user_id.to_owned();
167		let content: BTreeMap<OwnedEventId, Receipts> = BTreeMap::from_iter([(
168			event_id,
169			BTreeMap::from_iter([(
170				ReceiptType::ReadPrivate,
171				BTreeMap::from_iter([(user_id, Receipt { ts: None, thread })]),
172			)]),
173		)]);
174
175		let receipt_event_content = ReceiptEventContent(content);
176		let receipt_sync_event = SyncEphemeralRoomEvent { content: receipt_event_content };
177		let event = serde_json::value::to_raw_value(&receipt_sync_event)
178			.expect("receipt created manually");
179
180		Some(Raw::from_json(event))
181	}
182
183	/// Returns an iterator over the most recent read_receipts in a room that
184	/// happened after the event with id `since`.
185	#[tracing::instrument(skip(self), level = "debug")]
186	pub fn readreceipts_since<'a>(
187		&'a self,
188		room_id: &'a RoomId,
189		since: u64,
190		to: Option<u64>,
191	) -> impl Stream<Item = ReceiptItem<'_>> + Send + 'a {
192		self.db.readreceipts_since(room_id, since, to)
193	}
194
195	/// Sets a private read marker at PDU `count` for the given thread.
196	/// Unthreaded writes supersede prior per-thread rows so the room-wide
197	/// receipt subsumes thread state.
198	#[tracing::instrument(skip(self), level = "debug", name = "set_private")]
199	pub async fn private_read_set(
200		&self,
201		room_id: &RoomId,
202		user_id: &UserId,
203		count: u64,
204		thread: &ReceiptThread,
205	) {
206		self.db
207			.private_read_set(room_id, user_id, count, thread)
208			.await;
209	}
210
211	/// Returns the private read marker PDU count.
212	#[tracing::instrument(
213		name = "get_private_count",
214		level = "debug",
215		skip(self),
216		ret(level = "trace")
217	)]
218	pub async fn private_read_get_count(
219		&self,
220		room_id: &RoomId,
221		user_id: &UserId,
222	) -> Result<u64> {
223		self.db
224			.private_read_get_count(room_id, user_id)
225			.await
226	}
227
228	/// Returns the PDU count of the last typing update in this room.
229	#[tracing::instrument(
230		name = "get_private_last",
231		level = "debug",
232		skip(self),
233		ret(level = "trace")
234	)]
235	pub async fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> u64 {
236		self.db
237			.last_privateread_update(user_id, room_id)
238			.await
239	}
240
241	#[tracing::instrument(
242		name = "get_receipt_last",
243		level = "debug",
244		skip(self),
245		ret(level = "trace")
246	)]
247	pub async fn last_receipt_count(
248		&self,
249		room_id: &RoomId,
250		user_id: Option<&UserId>,
251		since: Option<u64>,
252	) -> Result<u64> {
253		self.db
254			.last_receipt_count(room_id, since, user_id)
255			.await
256	}
257
258	pub async fn delete_all_read_receipts(&self, room_id: &RoomId) -> Result {
259		self.db.delete_all_read_receipts(room_id).await
260	}
261}
262
263/// Reverse of `ReceiptThread::as_str`: parse a stored thread tag into the
264/// enum. Empty string maps to `Unthreaded`; `"main"` to `Main`; anything
265/// starting with `$` to `Thread(event_id)` if parseable.
266fn thread_kind_to_receipt(thread_kind: &str) -> ReceiptThread {
267	match thread_kind {
268		| "" => ReceiptThread::Unthreaded,
269		| "main" => ReceiptThread::Main,
270		| _ => OwnedEventId::try_from(thread_kind)
271			.map(ReceiptThread::Thread)
272			.unwrap_or(ReceiptThread::Unthreaded),
273	}
274}
275
276#[must_use]
277pub fn pack_receipts<I>(receipts: I) -> Raw<SyncEphemeralRoomEvent<ReceiptEventContent>>
278where
279	I: Iterator<Item = Raw<AnySyncEphemeralRoomEvent>>,
280{
281	let mut json = BTreeMap::new();
282	for value in receipts {
283		let receipt = serde_json::from_str::<SyncEphemeralRoomEvent<ReceiptEventContent>>(
284			value.json().get(),
285		);
286		match receipt {
287			| Ok(value) =>
288				for (event, receipt) in value.content {
289					json.insert(event, receipt);
290				},
291			| _ => {
292				debug!("failed to parse receipt: {:?}", receipt);
293			},
294		}
295	}
296
297	let content = ReceiptEventContent::from_iter(json);
298
299	trace!(?content);
300	Raw::from_json(
301		serde_json::value::to_raw_value(&SyncEphemeralRoomEvent { content })
302			.expect("received valid json"),
303	)
304}