Skip to main content

tuwunel_service/rooms/threads/
mod.rs

1use std::{collections::BTreeMap, sync::Arc};
2
3use futures::{Stream, StreamExt, TryFutureExt};
4use ruma::{
5	CanonicalJsonValue, EventId, OwnedEventId, OwnedUserId, RoomId, UserId,
6	api::client::threads::get_threads::v1::IncludeThreads,
7	events::{
8		TimelineEventType,
9		relation::{BundledThread, RelationType},
10	},
11	uint,
12};
13use serde::Deserialize;
14use serde_json::json;
15use tuwunel_core::{
16	Event, Result, err,
17	matrix::pdu::{PduCount, PduEvent, PduId, RawPduId},
18	trace,
19	utils::{
20		ReadyExt,
21		stream::{TryIgnore, WidebandExt},
22	},
23};
24use tuwunel_database::{Deserialized, Interfix, Map};
25
26/// Maximum relation hops walked when resolving thread membership, per
27/// the Matrix v1.4 spec recommendation (also MSC3771/MSC3773).
28const MAX_THREAD_HOPS: usize = 3;
29
30#[derive(Deserialize)]
31struct ExtractThreadRelation {
32	#[serde(rename = "m.relates_to")]
33	relates_to: ThreadRelation,
34}
35
36#[derive(Deserialize)]
37struct ThreadRelation {
38	rel_type: RelationType,
39	event_id: OwnedEventId,
40}
41
42pub struct Service {
43	db: Data,
44	services: Arc<crate::services::OnceServices>,
45}
46
47pub(super) struct Data {
48	threadid_userids: Arc<Map>,
49}
50
51impl crate::Service for Service {
52	fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
53		Ok(Arc::new(Self {
54			db: Data {
55				threadid_userids: args.db["threadid_userids"].clone(),
56			},
57			services: args.services.clone(),
58		}))
59	}
60
61	fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
62}
63
64impl Service {
65	/// Resolves the thread root for `event` by walking up `m.relates_to`
66	/// links, bounded at `MAX_THREAD_HOPS`. Returns `None` for events
67	/// that belong to the main timeline. Redaction events carry no
68	/// `m.relates_to` of their own; their thread is resolved from the
69	/// redacted target event per MSC3771/MSC3773.
70	pub async fn get_thread_id<E>(&self, event: &E) -> Option<OwnedEventId>
71	where
72		E: Event,
73	{
74		let initial = match event.get_content::<ExtractThreadRelation>() {
75			| Ok(t) => Some(t.relates_to),
76			| Err(_) => self.relates_to_via_redaction_target(event).await,
77		};
78
79		let mut relates_to = initial?;
80
81		for _ in 0..MAX_THREAD_HOPS {
82			if relates_to.rel_type == RelationType::Thread {
83				return Some(relates_to.event_id);
84			}
85
86			relates_to = self
87				.services
88				.timeline
89				.get_pdu(&relates_to.event_id)
90				.await
91				.ok()?
92				.get_content::<ExtractThreadRelation>()
93				.ok()?
94				.relates_to;
95		}
96
97		None
98	}
99
100	/// Resolve a redaction event's thread by looking through to the
101	/// redacted target. Returns `None` for non-redaction events and for
102	/// redactions whose target is unknown or carries no thread relation.
103	async fn relates_to_via_redaction_target<E>(&self, event: &E) -> Option<ThreadRelation>
104	where
105		E: Event,
106	{
107		if *event.kind() != TimelineEventType::RoomRedaction {
108			return None;
109		}
110
111		let room_rules = self
112			.services
113			.state
114			.get_room_version_rules(event.room_id())
115			.await
116			.ok()?;
117
118		let target_id = event.redacts_id(&room_rules)?;
119
120		self.services
121			.timeline
122			.get_pdu(&target_id)
123			.await
124			.ok()?
125			.get_content::<ExtractThreadRelation>()
126			.ok()
127			.map(|t| t.relates_to)
128	}
129
130	/// `get_thread_id` for an event referenced by id; events missing
131	/// locally resolve to `None` (the main timeline).
132	pub async fn get_thread_id_for_event(&self, event_id: &EventId) -> Option<OwnedEventId> {
133		let pdu = self
134			.services
135			.timeline
136			.get_pdu(event_id)
137			.await
138			.ok()?;
139
140		self.get_thread_id(&pdu).await
141	}
142
143	pub async fn add_to_thread<E>(&self, root_event_id: &EventId, event: &E) -> Result
144	where
145		E: Event,
146	{
147		let root_id = self
148			.services
149			.timeline
150			.get_pdu_id(root_event_id)
151			.await
152			.map_err(|e| {
153				err!(Request(InvalidParam("Invalid event_id in thread message: {e:?}")))
154			})?;
155
156		let root_pdu = self
157			.services
158			.timeline
159			.get_pdu_from_id(&root_id)
160			.await
161			.map_err(|e| err!(Request(InvalidParam("Thread root not found: {e:?}"))))?;
162
163		let mut root_pdu_json = self
164			.services
165			.timeline
166			.get_pdu_json_from_id(&root_id)
167			.await
168			.map_err(|e| err!(Request(InvalidParam("Thread root pdu not found: {e:?}"))))?;
169
170		let mut users = self
171			.get_participants(&root_id)
172			.await
173			.unwrap_or_else(|_| vec![root_pdu.sender().to_owned()]);
174
175		users.push(event.sender().to_owned());
176
177		// Record participants before the bundle so a concurrent read never sees the
178		// bundle with a stale participant set (MSC3816 current_user_participated).
179		self.update_participants(&root_id, &users)?;
180
181		if let CanonicalJsonValue::Object(unsigned) = root_pdu_json
182			.entry("unsigned".into())
183			.or_insert_with(|| CanonicalJsonValue::Object(BTreeMap::default()))
184		{
185			if let Some(mut relations) = unsigned
186				.get("m.relations")
187				.and_then(|r| r.as_object())
188				.and_then(|r| r.get("m.thread"))
189				.and_then(|relations| {
190					serde_json::from_value::<BundledThread>(relations.clone().into()).ok()
191				}) {
192				// Thread already existed
193				relations.count = relations.count.saturating_add(uint!(1));
194				relations.latest_event = event.to_format();
195
196				let content = serde_json::to_value(relations).expect("to_value always works");
197
198				unsigned.insert(
199					"m.relations".into(),
200					json!({ "m.thread": content })
201						.try_into()
202						.expect("thread is valid json"),
203				);
204			} else {
205				// New thread
206				let relations = BundledThread {
207					latest_event: event.to_format(),
208					count: uint!(1),
209					current_user_participated: true,
210				};
211
212				let content = serde_json::to_value(relations).expect("to_value always works");
213
214				unsigned.insert(
215					"m.relations".into(),
216					json!({ "m.thread": content })
217						.try_into()
218						.expect("thread is valid json"),
219				);
220			}
221
222			self.services
223				.timeline
224				.replace_pdu(&root_id, &root_pdu_json)
225				.await?;
226		}
227
228		Ok(())
229	}
230
231	pub fn threads_until<'a>(
232		&'a self,
233		user_id: &'a UserId,
234		room_id: &'a RoomId,
235		count: PduCount,
236		_inc: &'a IncludeThreads,
237	) -> impl Stream<Item = Result<(PduCount, PduEvent)>> + Send {
238		self.services
239			.short
240			.get_shortroomid(room_id)
241			.map_ok(move |shortroomid| PduId {
242				shortroomid,
243				count: count.saturating_sub(1),
244			})
245			.map_ok(Into::into)
246			.map_ok(move |current: RawPduId| {
247				self.db
248					.threadid_userids
249					.rev_raw_keys_from(&current)
250					.ignore_err()
251					.map(RawPduId::from)
252					.map(move |pdu_id| (pdu_id, user_id))
253					.ready_take_while(move |(pdu_id, _)| {
254						pdu_id.shortroomid() == current.shortroomid()
255					})
256					.wide_filter_map(async |(raw_pdu_id, user_id)| {
257						let pdu_id: PduId = raw_pdu_id.into();
258						let mut pdu = self
259							.services
260							.timeline
261							.get_pdu_from_id(&raw_pdu_id)
262							.await
263							.ok()?;
264
265						if pdu.sender() != user_id {
266							pdu.as_mut_pdu().remove_transaction_id().ok();
267						}
268
269						Some((pdu_id.count, pdu))
270					})
271					.map(Ok)
272			})
273			.try_flatten_stream()
274	}
275
276	pub(super) fn update_participants(
277		&self,
278		root_id: &RawPduId,
279		participants: &[OwnedUserId],
280	) -> Result {
281		let users = participants
282			.iter()
283			.map(|user| user.as_bytes())
284			.collect::<Vec<_>>()
285			.join(&[0xFF][..]);
286
287		self.db.threadid_userids.insert(root_id, &users);
288
289		Ok(())
290	}
291
292	pub(super) async fn get_participants(&self, root_id: &RawPduId) -> Result<Vec<OwnedUserId>> {
293		self.db
294			.threadid_userids
295			.get(root_id)
296			.await
297			.deserialized()
298	}
299
300	/// MSC3816: whether `user_id` has participated in the thread rooted at
301	/// `root_event_id`, having sent the root event or a threaded reply to it.
302	pub async fn user_participated(&self, root_event_id: &EventId, user_id: &UserId) -> bool {
303		let Ok(root_id) = self
304			.services
305			.timeline
306			.get_pdu_id(root_event_id)
307			.await
308		else {
309			return false;
310		};
311
312		self.get_participants(&root_id)
313			.await
314			.is_ok_and(|users| users.iter().any(|user| user == user_id))
315	}
316
317	pub(super) async fn delete_all_rooms_threads(&self, room_id: &RoomId) -> Result {
318		let prefix = (room_id, Interfix);
319
320		self.db
321			.threadid_userids
322			.keys_prefix_raw(&prefix)
323			.ignore_err()
324			.ready_for_each(|key| {
325				trace!("Removing key: {key:?}");
326				self.db.threadid_userids.remove(key);
327			})
328			.await;
329
330		Ok(())
331	}
332}