1mod append;
2mod backfill;
3mod build;
4mod create;
5mod pdus;
6mod redact;
7
8use std::{fmt::Write, sync::Arc};
9
10use async_trait::async_trait;
11use futures::{
12 TryFutureExt, TryStreamExt,
13 future::{
14 Either::{Left, Right},
15 select_ok,
16 },
17 pin_mut,
18};
19use ruma::{
20 CanonicalJsonObject, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId, RoomId,
21 UserId, api::Direction, events::room::encrypted::Relation,
22};
23use serde::Deserialize;
24pub use tuwunel_core::matrix::pdu::{PduId, RawPduId};
25use tuwunel_core::{
26 Err, Result, at, err, implement,
27 matrix::{
28 ShortEventId,
29 pdu::{PduCount, PduEvent},
30 },
31 utils::{
32 MutexMap, MutexMapGuard,
33 result::{LogErr, NotFound},
34 stream::TryReadyExt,
35 },
36 warn,
37};
38use tuwunel_database::{Database, Deserialized, Json, Map};
39
40pub use self::pdus::PdusIterItem;
41use crate::rooms::short::{ShortRoomId, ShortStateHash};
42
43pub struct Service {
44 services: Arc<crate::services::OnceServices>,
45 db: Data,
46 pub mutex_insert: RoomMutexMap,
47}
48
49struct Data {
50 eventid_outlierpdu: Arc<Map>,
51 eventid_pduid: Arc<Map>,
52 pduid_pdu: Arc<Map>,
53 roomid_ts_pducount: Arc<Map>,
54 db: Arc<Database>,
55}
56
57#[derive(Deserialize)]
59struct ExtractRelatesTo {
60 #[serde(rename = "m.relates_to")]
61 relates_to: Relation,
62}
63
64#[derive(Clone, Debug, Deserialize)]
65struct ExtractEventId {
66 event_id: OwnedEventId,
67}
68#[derive(Clone, Debug, Deserialize)]
69struct ExtractRelatesToEventId {
70 #[serde(rename = "m.relates_to")]
71 relates_to: ExtractEventId,
72}
73
74#[derive(Deserialize)]
75struct ExtractBody {
76 body: Option<String>,
77}
78
79type RoomMutexMap = MutexMap<OwnedRoomId, ()>;
80pub type RoomMutexGuard = MutexMapGuard<OwnedRoomId, ()>;
81
82#[async_trait]
83impl crate::Service for Service {
84 fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
85 Ok(Arc::new(Self {
86 services: args.services.clone(),
87 db: Data {
88 eventid_outlierpdu: args.db["eventid_outlierpdu"].clone(),
89 eventid_pduid: args.db["eventid_pduid"].clone(),
90 pduid_pdu: args.db["pduid_pdu"].clone(),
91 roomid_ts_pducount: args.db["roomid_ts_pducount"].clone(),
92 db: args.db.clone(),
93 },
94 mutex_insert: RoomMutexMap::new(),
95 }))
96 }
97
98 async fn memory_usage(&self, out: &mut (dyn Write + Send)) -> Result {
99 let mutex_insert = self.mutex_insert.len();
100 writeln!(out, "insert_mutex: {mutex_insert}")?;
101
102 Ok(())
103 }
104
105 fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
106}
107
108#[implement(Service)]
110#[tracing::instrument(skip(self), level = "debug")]
111pub async fn replace_pdu(&self, pdu_id: &RawPduId, pdu_json: &CanonicalJsonObject) -> Result {
112 if self.db.pduid_pdu.get(pdu_id).await.is_not_found() {
113 return Err!(Request(NotFound("PDU does not exist.")));
114 }
115
116 self.db.pduid_pdu.raw_put(pdu_id, Json(pdu_json));
117
118 Ok(())
119}
120
121#[implement(Service)]
122#[tracing::instrument(skip(self, pdu), level = "debug")]
123pub fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) {
124 self.db
125 .eventid_outlierpdu
126 .raw_put(event_id, Json(pdu));
127}
128
129#[implement(Service)]
130#[tracing::instrument(skip(self), level = "debug")]
131pub async fn first_pdu_in_room(&self, room_id: &RoomId) -> Result<PduEvent> {
132 self.first_item_in_room(room_id).await.map(at!(1))
133}
134
135#[implement(Service)]
136#[tracing::instrument(skip(self), level = "debug")]
137#[inline]
138pub async fn latest_pdu_in_room(&self, room_id: &RoomId) -> Result<PduEvent> {
139 self.latest_item_in_room(None, room_id).await
140}
141
142#[implement(Service)]
143#[tracing::instrument(skip(self), level = "debug")]
144pub async fn first_item_in_room(&self, room_id: &RoomId) -> Result<(PduCount, PduEvent)> {
145 let pdus = self.pdus(None, room_id, None);
146
147 pin_mut!(pdus);
148 pdus.try_next()
149 .await?
150 .ok_or_else(|| err!(Request(NotFound("No PDU found in room"))))
151}
152
153#[implement(Service)]
154#[tracing::instrument(skip(self), level = "debug")]
155pub async fn latest_item_in_room(
156 &self,
157 sender_user: Option<&UserId>,
158 room_id: &RoomId,
159) -> Result<PduEvent> {
160 let pdus_rev = self.pdus_rev(sender_user, room_id, None);
161
162 pin_mut!(pdus_rev);
163 pdus_rev
164 .try_next()
165 .await?
166 .map(at!(1))
167 .ok_or_else(|| err!(Request(NotFound("No PDU's found in room"))))
168}
169
170#[implement(Service)]
174#[tracing::instrument(skip(self), level = "debug")]
175pub async fn prev_shortstatehash(
176 &self,
177 room_id: &RoomId,
178 before: PduCount,
179) -> Result<ShortStateHash> {
180 let shortroomid: ShortRoomId = self
181 .services
182 .short
183 .get_shortroomid(room_id)
184 .await
185 .map_err(|e| err!(Request(NotFound("Room {room_id:?} not found: {e:?}"))))?;
186
187 let before = PduId { shortroomid, count: before };
188
189 let prev = PduId {
190 shortroomid,
191 count: self.prev_timeline_count(&before).await?,
192 };
193
194 let shorteventid = self.get_shorteventid_from_pdu_id(&prev).await?;
195
196 self.services
197 .state
198 .get_shortstatehash(shorteventid)
199 .await
200}
201
202#[implement(Service)]
206#[tracing::instrument(skip(self), level = "debug")]
207pub async fn next_shortstatehash(
208 &self,
209 room_id: &RoomId,
210 after: PduCount,
211) -> Result<ShortStateHash> {
212 let shortroomid: ShortRoomId = self
213 .services
214 .short
215 .get_shortroomid(room_id)
216 .await
217 .map_err(|e| err!(Request(NotFound("Room {room_id:?} not found: {e:?}"))))?;
218
219 let after = PduId { shortroomid, count: after };
220
221 let next = PduId {
222 shortroomid,
223 count: self.next_timeline_count(&after).await?,
224 };
225
226 let shorteventid = self.get_shorteventid_from_pdu_id(&next).await?;
227
228 self.services
229 .state
230 .get_shortstatehash(shorteventid)
231 .await
232}
233
234#[implement(Service)]
236#[tracing::instrument(skip(self), level = "debug")]
237pub async fn get_shortstatehash(
238 &self,
239 room_id: &RoomId,
240 count: PduCount,
241) -> Result<ShortStateHash> {
242 let shortroomid: ShortRoomId = self
243 .services
244 .short
245 .get_shortroomid(room_id)
246 .await
247 .map_err(|e| err!(Request(NotFound("Room {room_id:?} not found: {e:?}"))))?;
248
249 let pdu_id = PduId { shortroomid, count };
250
251 let shorteventid = self.get_shorteventid_from_pdu_id(&pdu_id).await?;
252
253 self.services
254 .state
255 .get_shortstatehash(shorteventid)
256 .await
257}
258
259#[implement(Service)]
262#[tracing::instrument(skip(self), level = "debug")]
263pub async fn prev_timeline_count(&self, before: &PduId) -> Result<PduCount> {
264 let before = Self::pdu_count_to_id(before.shortroomid, before.count, Direction::Backward);
265
266 let pdu_ids = self
267 .db
268 .pduid_pdu
269 .rev_keys_raw_from(&before)
270 .ready_try_take_while(|pdu_id: &RawPduId| Ok(pdu_id.is_room_eq(before)))
271 .ready_and_then(|pdu_id: RawPduId| Ok(pdu_id.pdu_count()));
272
273 pin_mut!(pdu_ids);
274 pdu_ids
275 .try_next()
276 .await
277 .log_err()?
278 .ok_or_else(|| err!(Request(NotFound("No earlier PDU's found in room"))))
279}
280
281#[implement(Service)]
284#[tracing::instrument(skip(self), level = "debug")]
285pub async fn next_timeline_count(&self, after: &PduId) -> Result<PduCount> {
286 let after = Self::pdu_count_to_id(after.shortroomid, after.count, Direction::Forward);
287
288 let pdu_ids = self
289 .db
290 .pduid_pdu
291 .keys_raw_from(&after)
292 .ready_try_take_while(|pdu_id: &RawPduId| Ok(pdu_id.is_room_eq(after)))
293 .ready_and_then(|pdu_id: RawPduId| Ok(pdu_id.pdu_count()));
294
295 pin_mut!(pdu_ids);
296 pdu_ids
297 .try_next()
298 .await
299 .log_err()?
300 .ok_or(err!(Request(NotFound("No more PDU's found in room"))))
301}
302
303#[implement(Service)]
304#[tracing::instrument(skip(self), level = "debug")]
305pub async fn last_timeline_count(
306 &self,
307 sender_user: Option<&UserId>,
308 room_id: &RoomId,
309 upper_bound: Option<PduCount>,
310) -> Result<PduCount> {
311 let upper_bound = upper_bound.unwrap_or_else(PduCount::max);
312 let pdus_rev = self.pdus_rev(sender_user, room_id, None);
313
314 pin_mut!(pdus_rev);
315 let last_count = pdus_rev
316 .ready_try_skip_while(|&(pducount, _)| Ok(pducount > upper_bound))
317 .try_next()
318 .await?
319 .map(at!(0))
320 .filter(|&count| matches!(count, PduCount::Normal(_)))
321 .unwrap_or_else(PduCount::max);
322
323 Ok(last_count)
324}
325
326#[implement(Service)]
327pub async fn get_event_id_near_ts(
328 &self,
329 room_id: &RoomId,
330 ts: MilliSecondsSinceUnixEpoch,
331 dir: Direction,
332) -> Result<(MilliSecondsSinceUnixEpoch, OwnedEventId)> {
333 self.get_pdu_id_near_ts(room_id, ts, dir)
334 .and_then(async |(ts, pdu_id)| {
335 self.get_event_id_from_pdu_id(&pdu_id)
336 .map_ok(|event_id| (ts, event_id))
337 .await
338 })
339 .await
340}
341
342#[implement(Service)]
343pub async fn get_pdu_id_near_ts(
344 &self,
345 room_id: &RoomId,
346 ts: MilliSecondsSinceUnixEpoch,
347 dir: Direction,
348) -> Result<(MilliSecondsSinceUnixEpoch, PduId)> {
349 let pdu_ids = self.pdu_ids_near_ts(room_id, ts, dir);
350
351 pin_mut!(pdu_ids);
352 pdu_ids
353 .try_next()
354 .await?
355 .ok_or_else(|| err!(Request(NotFound("No event found near this timestamp."))))
356}
357
358#[implement(Service)]
359pub async fn get_pdu_near_ts(
360 &self,
361 _user_id: Option<&UserId>,
362 room_id: &RoomId,
363 ts: MilliSecondsSinceUnixEpoch,
364 dir: Direction,
365) -> Result<PdusIterItem> {
366 let pdus = self
367 .pdu_ids_near_ts(room_id, ts, dir)
368 .map_ok(|(ts, pdu_id)| (ts, pdu_id.into()))
369 .and_then(async |(_, pdu_id): (_, RawPduId)| {
370 self.get_pdu_from_id(&pdu_id)
371 .map_ok(|pdu| (pdu_id.pdu_count(), pdu))
372 .await
373 });
374
375 pin_mut!(pdus);
376 pdus.try_next()
377 .await?
378 .ok_or_else(|| err!(Request(NotFound("No event found near this timestamp."))))
379}
380
381#[implement(Service)]
382async fn count_to_id(
383 &self,
384 room_id: &RoomId,
385 count: PduCount,
386 dir: Direction,
387) -> Result<RawPduId> {
388 let shortroomid: ShortRoomId = self
389 .services
390 .short
391 .get_shortroomid(room_id)
392 .await
393 .map_err(|e| err!(Request(NotFound("Room {room_id:?} not found: {e:?}"))))?;
394
395 Ok(Self::pdu_count_to_id(shortroomid, count, dir))
396}
397
398#[implement(Service)]
399fn pdu_count_to_id(shortroomid: ShortRoomId, count: PduCount, dir: Direction) -> RawPduId {
400 let pdu_id = PduId {
402 shortroomid,
403 count: count.saturating_inc(dir),
404 };
405
406 pdu_id.into()
407}
408
409#[implement(Service)]
412pub async fn get_pdu_from_shorteventid(&self, shorteventid: ShortEventId) -> Result<PduEvent> {
413 let event_id: OwnedEventId = self
414 .services
415 .short
416 .get_eventid_from_short(shorteventid)
417 .await?;
418
419 self.get_pdu(&event_id).await
420}
421
422#[implement(Service)]
425pub async fn get_pdu(&self, event_id: &EventId) -> Result<PduEvent> { self.get(event_id).await }
426
427#[implement(Service)]
430pub async fn get_outlier_pdu(&self, event_id: &EventId) -> Result<PduEvent> {
431 self.get_outlier(event_id).await
432}
433
434#[implement(Service)]
437pub async fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result<PduEvent> {
438 self.get_non_outlier(event_id).await
439}
440
441#[implement(Service)]
444pub async fn get_pdu_from_id(&self, pdu_id: &RawPduId) -> Result<PduEvent> {
445 self.get_from_id(pdu_id).await
446}
447
448#[implement(Service)]
451pub async fn get_pdu_json(&self, event_id: &EventId) -> Result<CanonicalJsonObject> {
452 self.get(event_id).await
453}
454
455#[implement(Service)]
458pub async fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<CanonicalJsonObject> {
459 self.get_outlier(event_id).await
460}
461
462#[implement(Service)]
465pub async fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result<CanonicalJsonObject> {
466 self.get_non_outlier(event_id).await
467}
468
469#[implement(Service)]
472pub async fn get_pdu_json_from_id(&self, pdu_id: &RawPduId) -> Result<CanonicalJsonObject> {
473 self.get_from_id(pdu_id).await
474}
475
476#[implement(Service)]
479#[inline]
480pub async fn get<T>(&self, event_id: &EventId) -> Result<T>
481where
482 T: for<'de> Deserialize<'de>,
483{
484 let accepted = self.get_non_outlier(event_id);
485 let outlier = self.get_outlier(event_id);
486
487 pin_mut!(accepted, outlier);
488 select_ok([Left(accepted), Right(outlier)])
489 .await
490 .map(at!(0))
491}
492
493#[implement(Service)]
496#[inline]
497pub async fn get_outlier<T>(&self, event_id: &EventId) -> Result<T>
498where
499 T: for<'de> Deserialize<'de>,
500{
501 self.db
502 .eventid_outlierpdu
503 .get(event_id)
504 .await
505 .deserialized()
506}
507
508#[implement(Service)]
511#[inline]
512pub async fn get_non_outlier<T>(&self, event_id: &EventId) -> Result<T>
513where
514 T: for<'de> Deserialize<'de>,
515{
516 let pdu_id = self.get_pdu_id(event_id).await?;
517
518 self.get_from_id(&pdu_id).await
519}
520
521#[implement(Service)]
524#[inline]
525pub async fn get_from_id<T>(&self, pdu_id: &RawPduId) -> Result<T>
526where
527 T: for<'de> Deserialize<'de>,
528{
529 self.db.pduid_pdu.get(pdu_id).await.deserialized()
530}
531
532#[implement(Service)]
535pub async fn pdu_exists<'a>(&'a self, event_id: &'a EventId) -> bool {
536 let non_outlier = self.non_outlier_pdu_exists(event_id);
537 let outlier = self.outlier_pdu_exists(event_id);
538
539 pin_mut!(non_outlier, outlier);
540 select_ok([Left(non_outlier), Right(outlier)])
541 .await
542 .map(at!(0))
543 .is_ok()
544}
545
546#[implement(Service)]
549pub async fn non_outlier_pdu_exists(&self, event_id: &EventId) -> Result {
550 let pduid = self.get_pdu_id(event_id).await?;
551
552 self.db.pduid_pdu.exists(&pduid).await
553}
554
555#[implement(Service)]
558#[inline]
559pub async fn outlier_pdu_exists(&self, event_id: &EventId) -> Result {
560 self.db.eventid_outlierpdu.exists(event_id).await
561}
562
563#[implement(Service)]
565pub async fn get_pdu_count(&self, event_id: &EventId) -> Result<PduCount> {
566 self.get_pdu_id(event_id)
567 .await
568 .map(RawPduId::pdu_count)
569}
570
571#[implement(Service)]
573pub async fn get_shorteventid_from_pdu_id(&self, pdu_id: &PduId) -> Result<ShortEventId> {
574 let event_id = self.get_event_id_from_pdu_id(pdu_id).await?;
575
576 self.services
577 .short
578 .get_shorteventid(&event_id)
579 .await
580}
581
582#[implement(Service)]
584pub async fn get_event_id_from_pdu_id(&self, pdu_id: &PduId) -> Result<OwnedEventId> {
585 let pdu_id: RawPduId = (*pdu_id).into();
586
587 self.get_pdu_from_id(&pdu_id)
588 .map_ok(|pdu| pdu.event_id)
589 .await
590}
591
592#[implement(Service)]
594pub async fn get_pdu_id_from_shorteventid(&self, shorteventid: ShortEventId) -> Result<RawPduId> {
595 let event_id: OwnedEventId = self
596 .services
597 .short
598 .get_eventid_from_short(shorteventid)
599 .await?;
600
601 self.get_pdu_id(&event_id).await
602}
603
604#[implement(Service)]
606pub async fn get_pdu_id(&self, event_id: &EventId) -> Result<RawPduId> {
607 self.db
608 .eventid_pduid
609 .get(event_id)
610 .await
611 .map(|handle| RawPduId::from(&*handle))
612}