Skip to main content

tuwunel_service/rooms/threads/
mod.rs

1use std::{collections::BTreeMap, sync::Arc};
2
3use futures::{Stream, StreamExt, TryFutureExt};
4use ruma::{
5	CanonicalJsonValue, EventId, OwnedEventId, OwnedUserId, RoomId, UserId,
6	api::client::threads::get_threads::v1::IncludeThreads,
7	events::relation::{BundledThread, RelationType},
8	uint,
9};
10use serde::Deserialize;
11use serde_json::json;
12use tuwunel_core::{
13	Event, Result, err,
14	matrix::pdu::{PduCount, PduEvent, PduId, RawPduId},
15	trace,
16	utils::{
17		ReadyExt,
18		stream::{TryIgnore, WidebandExt},
19	},
20};
21use tuwunel_database::{Deserialized, Interfix, Map};
22
23/// Maximum relation hops walked when resolving thread membership, per
24/// the Matrix v1.4 spec recommendation (also MSC3771/MSC3773).
25const MAX_THREAD_HOPS: usize = 3;
26
27#[derive(Deserialize)]
28struct ExtractThreadRelation {
29	#[serde(rename = "m.relates_to")]
30	relates_to: ThreadRelation,
31}
32
33#[derive(Deserialize)]
34struct ThreadRelation {
35	rel_type: RelationType,
36	event_id: OwnedEventId,
37}
38
39pub struct Service {
40	db: Data,
41	services: Arc<crate::services::OnceServices>,
42}
43
44pub(super) struct Data {
45	threadid_userids: Arc<Map>,
46}
47
48impl crate::Service for Service {
49	fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
50		Ok(Arc::new(Self {
51			db: Data {
52				threadid_userids: args.db["threadid_userids"].clone(),
53			},
54			services: args.services.clone(),
55		}))
56	}
57
58	fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
59}
60
61impl Service {
62	/// Resolves the thread root for `event` by walking up `m.relates_to`
63	/// links, bounded at `MAX_THREAD_HOPS`. Returns `None` for events
64	/// that belong to the main timeline.
65	pub async fn get_thread_id<E>(&self, event: &E) -> Option<OwnedEventId>
66	where
67		E: Event,
68	{
69		let mut relates_to = event
70			.get_content::<ExtractThreadRelation>()
71			.ok()?
72			.relates_to;
73
74		for _ in 0..MAX_THREAD_HOPS {
75			if relates_to.rel_type == RelationType::Thread {
76				return Some(relates_to.event_id);
77			}
78
79			relates_to = self
80				.services
81				.timeline
82				.get_pdu(&relates_to.event_id)
83				.await
84				.ok()?
85				.get_content::<ExtractThreadRelation>()
86				.ok()?
87				.relates_to;
88		}
89
90		None
91	}
92
93	/// `get_thread_id` for an event referenced by id; events missing
94	/// locally resolve to `None` (the main timeline).
95	pub async fn get_thread_id_for_event(&self, event_id: &EventId) -> Option<OwnedEventId> {
96		let pdu = self
97			.services
98			.timeline
99			.get_pdu(event_id)
100			.await
101			.ok()?;
102
103		self.get_thread_id(&pdu).await
104	}
105
106	pub async fn add_to_thread<E>(&self, root_event_id: &EventId, event: &E) -> Result
107	where
108		E: Event,
109	{
110		let root_id = self
111			.services
112			.timeline
113			.get_pdu_id(root_event_id)
114			.await
115			.map_err(|e| {
116				err!(Request(InvalidParam("Invalid event_id in thread message: {e:?}")))
117			})?;
118
119		let root_pdu = self
120			.services
121			.timeline
122			.get_pdu_from_id(&root_id)
123			.await
124			.map_err(|e| err!(Request(InvalidParam("Thread root not found: {e:?}"))))?;
125
126		let mut root_pdu_json = self
127			.services
128			.timeline
129			.get_pdu_json_from_id(&root_id)
130			.await
131			.map_err(|e| err!(Request(InvalidParam("Thread root pdu not found: {e:?}"))))?;
132
133		if let CanonicalJsonValue::Object(unsigned) = root_pdu_json
134			.entry("unsigned".into())
135			.or_insert_with(|| CanonicalJsonValue::Object(BTreeMap::default()))
136		{
137			if let Some(mut relations) = unsigned
138				.get("m.relations")
139				.and_then(|r| r.as_object())
140				.and_then(|r| r.get("m.thread"))
141				.and_then(|relations| {
142					serde_json::from_value::<BundledThread>(relations.clone().into()).ok()
143				}) {
144				// Thread already existed
145				relations.count = relations.count.saturating_add(uint!(1));
146				relations.latest_event = event.to_format();
147
148				let content = serde_json::to_value(relations).expect("to_value always works");
149
150				unsigned.insert(
151					"m.relations".into(),
152					json!({ "m.thread": content })
153						.try_into()
154						.expect("thread is valid json"),
155				);
156			} else {
157				// New thread
158				let relations = BundledThread {
159					latest_event: event.to_format(),
160					count: uint!(1),
161					current_user_participated: true,
162				};
163
164				let content = serde_json::to_value(relations).expect("to_value always works");
165
166				unsigned.insert(
167					"m.relations".into(),
168					json!({ "m.thread": content })
169						.try_into()
170						.expect("thread is valid json"),
171				);
172			}
173
174			self.services
175				.timeline
176				.replace_pdu(&root_id, &root_pdu_json)
177				.await?;
178		}
179
180		let mut users = Vec::new();
181		match self.get_participants(&root_id).await {
182			| Ok(userids) => users.extend_from_slice(&userids),
183			| _ => users.push(root_pdu.sender().to_owned()),
184		}
185
186		users.push(event.sender().to_owned());
187		self.update_participants(&root_id, &users)
188	}
189
190	pub fn threads_until<'a>(
191		&'a self,
192		user_id: &'a UserId,
193		room_id: &'a RoomId,
194		count: PduCount,
195		_inc: &'a IncludeThreads,
196	) -> impl Stream<Item = Result<(PduCount, PduEvent)>> + Send {
197		self.services
198			.short
199			.get_shortroomid(room_id)
200			.map_ok(move |shortroomid| PduId {
201				shortroomid,
202				count: count.saturating_sub(1),
203			})
204			.map_ok(Into::into)
205			.map_ok(move |current: RawPduId| {
206				self.db
207					.threadid_userids
208					.rev_raw_keys_from(&current)
209					.ignore_err()
210					.map(RawPduId::from)
211					.map(move |pdu_id| (pdu_id, user_id))
212					.ready_take_while(move |(pdu_id, _)| {
213						pdu_id.shortroomid() == current.shortroomid()
214					})
215					.wide_filter_map(async |(raw_pdu_id, user_id)| {
216						let pdu_id: PduId = raw_pdu_id.into();
217						let mut pdu = self
218							.services
219							.timeline
220							.get_pdu_from_id(&raw_pdu_id)
221							.await
222							.ok()?;
223
224						if pdu.sender() != user_id {
225							pdu.as_mut_pdu().remove_transaction_id().ok();
226						}
227
228						Some((pdu_id.count, pdu))
229					})
230					.map(Ok)
231			})
232			.try_flatten_stream()
233	}
234
235	pub(super) fn update_participants(
236		&self,
237		root_id: &RawPduId,
238		participants: &[OwnedUserId],
239	) -> Result {
240		let users = participants
241			.iter()
242			.map(|user| user.as_bytes())
243			.collect::<Vec<_>>()
244			.join(&[0xFF][..]);
245
246		self.db.threadid_userids.insert(root_id, &users);
247
248		Ok(())
249	}
250
251	pub(super) async fn get_participants(&self, root_id: &RawPduId) -> Result<Vec<OwnedUserId>> {
252		self.db
253			.threadid_userids
254			.get(root_id)
255			.await
256			.deserialized()
257	}
258
259	pub(super) async fn delete_all_rooms_threads(&self, room_id: &RoomId) -> Result {
260		let prefix = (room_id, Interfix);
261
262		self.db
263			.threadid_userids
264			.keys_prefix_raw(&prefix)
265			.ignore_err()
266			.ready_for_each(|key| {
267				trace!("Removing key: {key:?}");
268				self.db.threadid_userids.remove(key);
269			})
270			.await;
271
272		Ok(())
273	}
274}