tuwunel_service/rooms/threads/
mod.rs1use std::{collections::BTreeMap, sync::Arc};
2
3use futures::{Stream, StreamExt, TryFutureExt};
4use ruma::{
5 CanonicalJsonValue, EventId, OwnedEventId, OwnedUserId, RoomId, UserId,
6 api::client::threads::get_threads::v1::IncludeThreads,
7 events::relation::{BundledThread, RelationType},
8 uint,
9};
10use serde::Deserialize;
11use serde_json::json;
12use tuwunel_core::{
13 Event, Result, err,
14 matrix::pdu::{PduCount, PduEvent, PduId, RawPduId},
15 trace,
16 utils::{
17 ReadyExt,
18 stream::{TryIgnore, WidebandExt},
19 },
20};
21use tuwunel_database::{Deserialized, Interfix, Map};
22
23const MAX_THREAD_HOPS: usize = 3;
26
27#[derive(Deserialize)]
28struct ExtractThreadRelation {
29 #[serde(rename = "m.relates_to")]
30 relates_to: ThreadRelation,
31}
32
33#[derive(Deserialize)]
34struct ThreadRelation {
35 rel_type: RelationType,
36 event_id: OwnedEventId,
37}
38
39pub struct Service {
40 db: Data,
41 services: Arc<crate::services::OnceServices>,
42}
43
44pub(super) struct Data {
45 threadid_userids: Arc<Map>,
46}
47
48impl crate::Service for Service {
49 fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
50 Ok(Arc::new(Self {
51 db: Data {
52 threadid_userids: args.db["threadid_userids"].clone(),
53 },
54 services: args.services.clone(),
55 }))
56 }
57
58 fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
59}
60
61impl Service {
62 pub async fn get_thread_id<E>(&self, event: &E) -> Option<OwnedEventId>
66 where
67 E: Event,
68 {
69 let mut relates_to = event
70 .get_content::<ExtractThreadRelation>()
71 .ok()?
72 .relates_to;
73
74 for _ in 0..MAX_THREAD_HOPS {
75 if relates_to.rel_type == RelationType::Thread {
76 return Some(relates_to.event_id);
77 }
78
79 relates_to = self
80 .services
81 .timeline
82 .get_pdu(&relates_to.event_id)
83 .await
84 .ok()?
85 .get_content::<ExtractThreadRelation>()
86 .ok()?
87 .relates_to;
88 }
89
90 None
91 }
92
93 pub async fn get_thread_id_for_event(&self, event_id: &EventId) -> Option<OwnedEventId> {
96 let pdu = self
97 .services
98 .timeline
99 .get_pdu(event_id)
100 .await
101 .ok()?;
102
103 self.get_thread_id(&pdu).await
104 }
105
106 pub async fn add_to_thread<E>(&self, root_event_id: &EventId, event: &E) -> Result
107 where
108 E: Event,
109 {
110 let root_id = self
111 .services
112 .timeline
113 .get_pdu_id(root_event_id)
114 .await
115 .map_err(|e| {
116 err!(Request(InvalidParam("Invalid event_id in thread message: {e:?}")))
117 })?;
118
119 let root_pdu = self
120 .services
121 .timeline
122 .get_pdu_from_id(&root_id)
123 .await
124 .map_err(|e| err!(Request(InvalidParam("Thread root not found: {e:?}"))))?;
125
126 let mut root_pdu_json = self
127 .services
128 .timeline
129 .get_pdu_json_from_id(&root_id)
130 .await
131 .map_err(|e| err!(Request(InvalidParam("Thread root pdu not found: {e:?}"))))?;
132
133 if let CanonicalJsonValue::Object(unsigned) = root_pdu_json
134 .entry("unsigned".into())
135 .or_insert_with(|| CanonicalJsonValue::Object(BTreeMap::default()))
136 {
137 if let Some(mut relations) = unsigned
138 .get("m.relations")
139 .and_then(|r| r.as_object())
140 .and_then(|r| r.get("m.thread"))
141 .and_then(|relations| {
142 serde_json::from_value::<BundledThread>(relations.clone().into()).ok()
143 }) {
144 relations.count = relations.count.saturating_add(uint!(1));
146 relations.latest_event = event.to_format();
147
148 let content = serde_json::to_value(relations).expect("to_value always works");
149
150 unsigned.insert(
151 "m.relations".into(),
152 json!({ "m.thread": content })
153 .try_into()
154 .expect("thread is valid json"),
155 );
156 } else {
157 let relations = BundledThread {
159 latest_event: event.to_format(),
160 count: uint!(1),
161 current_user_participated: true,
162 };
163
164 let content = serde_json::to_value(relations).expect("to_value always works");
165
166 unsigned.insert(
167 "m.relations".into(),
168 json!({ "m.thread": content })
169 .try_into()
170 .expect("thread is valid json"),
171 );
172 }
173
174 self.services
175 .timeline
176 .replace_pdu(&root_id, &root_pdu_json)
177 .await?;
178 }
179
180 let mut users = Vec::new();
181 match self.get_participants(&root_id).await {
182 | Ok(userids) => users.extend_from_slice(&userids),
183 | _ => users.push(root_pdu.sender().to_owned()),
184 }
185
186 users.push(event.sender().to_owned());
187 self.update_participants(&root_id, &users)
188 }
189
190 pub fn threads_until<'a>(
191 &'a self,
192 user_id: &'a UserId,
193 room_id: &'a RoomId,
194 count: PduCount,
195 _inc: &'a IncludeThreads,
196 ) -> impl Stream<Item = Result<(PduCount, PduEvent)>> + Send {
197 self.services
198 .short
199 .get_shortroomid(room_id)
200 .map_ok(move |shortroomid| PduId {
201 shortroomid,
202 count: count.saturating_sub(1),
203 })
204 .map_ok(Into::into)
205 .map_ok(move |current: RawPduId| {
206 self.db
207 .threadid_userids
208 .rev_raw_keys_from(¤t)
209 .ignore_err()
210 .map(RawPduId::from)
211 .map(move |pdu_id| (pdu_id, user_id))
212 .ready_take_while(move |(pdu_id, _)| {
213 pdu_id.shortroomid() == current.shortroomid()
214 })
215 .wide_filter_map(async |(raw_pdu_id, user_id)| {
216 let pdu_id: PduId = raw_pdu_id.into();
217 let mut pdu = self
218 .services
219 .timeline
220 .get_pdu_from_id(&raw_pdu_id)
221 .await
222 .ok()?;
223
224 if pdu.sender() != user_id {
225 pdu.as_mut_pdu().remove_transaction_id().ok();
226 }
227
228 Some((pdu_id.count, pdu))
229 })
230 .map(Ok)
231 })
232 .try_flatten_stream()
233 }
234
235 pub(super) fn update_participants(
236 &self,
237 root_id: &RawPduId,
238 participants: &[OwnedUserId],
239 ) -> Result {
240 let users = participants
241 .iter()
242 .map(|user| user.as_bytes())
243 .collect::<Vec<_>>()
244 .join(&[0xFF][..]);
245
246 self.db.threadid_userids.insert(root_id, &users);
247
248 Ok(())
249 }
250
251 pub(super) async fn get_participants(&self, root_id: &RawPduId) -> Result<Vec<OwnedUserId>> {
252 self.db
253 .threadid_userids
254 .get(root_id)
255 .await
256 .deserialized()
257 }
258
259 pub(super) async fn delete_all_rooms_threads(&self, room_id: &RoomId) -> Result {
260 let prefix = (room_id, Interfix);
261
262 self.db
263 .threadid_userids
264 .keys_prefix_raw(&prefix)
265 .ignore_err()
266 .ready_for_each(|key| {
267 trace!("Removing key: {key:?}");
268 self.db.threadid_userids.remove(key);
269 })
270 .await;
271
272 Ok(())
273 }
274}