1use std::iter::once;
2
3use axum::extract::State;
4use futures::{
5 FutureExt, StreamExt, TryFutureExt,
6 future::try_join3,
7 stream::{select_all, unfold},
8};
9use ruma::{
10 EventId, RoomId, UInt, UserId,
11 api::{
12 Direction,
13 client::relations::{
14 get_relating_events, get_relating_events_with_rel_type,
15 get_relating_events_with_rel_type_and_event_type,
16 },
17 },
18 events::{TimelineEventType, relation::RelationType},
19};
20use tuwunel_core::{
21 Err, Error, Result, at, err,
22 matrix::{
23 event::{Event, RelationTypeEqual},
24 pdu::{PduCount, PduId},
25 },
26 utils::{
27 BoolExt,
28 result::FlatOk,
29 stream::{ReadyExt, WidebandExt},
30 },
31};
32use tuwunel_service::Services;
33
34use crate::{Ruma, client::is_ignored_pdu};
35
36pub(crate) async fn get_relating_events_with_rel_type_and_event_type_route(
38 State(services): State<crate::State>,
39 body: Ruma<get_relating_events_with_rel_type_and_event_type::v1::Request>,
40) -> Result<get_relating_events_with_rel_type_and_event_type::v1::Response> {
41 paginate_relations_with_filter(
42 &services,
43 body.sender_user(),
44 &body.room_id,
45 &body.event_id,
46 body.event_type.clone().into(),
47 body.rel_type.clone().into(),
48 body.from.as_deref(),
49 body.to.as_deref(),
50 body.limit,
51 body.recurse,
52 body.dir,
53 )
54 .await
55 .map(|res| get_relating_events_with_rel_type_and_event_type::v1::Response {
56 chunk: res.chunk,
57 next_batch: res.next_batch,
58 prev_batch: res.prev_batch,
59 recursion_depth: res.recursion_depth,
60 })
61}
62
63pub(crate) async fn get_relating_events_with_rel_type_route(
65 State(services): State<crate::State>,
66 body: Ruma<get_relating_events_with_rel_type::v1::Request>,
67) -> Result<get_relating_events_with_rel_type::v1::Response> {
68 paginate_relations_with_filter(
69 &services,
70 body.sender_user(),
71 &body.room_id,
72 &body.event_id,
73 None,
74 body.rel_type.clone().into(),
75 body.from.as_deref(),
76 body.to.as_deref(),
77 body.limit,
78 body.recurse,
79 body.dir,
80 )
81 .await
82 .map(|res| get_relating_events_with_rel_type::v1::Response {
83 chunk: res.chunk,
84 next_batch: res.next_batch,
85 prev_batch: res.prev_batch,
86 recursion_depth: res.recursion_depth,
87 })
88}
89
90pub(crate) async fn get_relating_events_route(
92 State(services): State<crate::State>,
93 body: Ruma<get_relating_events::v1::Request>,
94) -> Result<get_relating_events::v1::Response> {
95 paginate_relations_with_filter(
96 &services,
97 body.sender_user(),
98 &body.room_id,
99 &body.event_id,
100 None,
101 None,
102 body.from.as_deref(),
103 body.to.as_deref(),
104 body.limit,
105 body.recurse,
106 body.dir,
107 )
108 .await
109}
110
111#[expect(clippy::too_many_arguments)]
112#[tracing::instrument(
113 name = "relations",
114 level = "debug",
115 skip_all,
116 fields(room_id, target, from, to, dir, limit, recurse)
117)]
118async fn paginate_relations_with_filter(
119 services: &Services,
120 sender_user: &UserId,
121 room_id: &RoomId,
122 target: &EventId,
123 filter_event_type: Option<TimelineEventType>,
124 filter_rel_type: Option<RelationType>,
125 from: Option<&str>,
126 to: Option<&str>,
127 limit: Option<UInt>,
128 recurse: bool,
129 dir: Direction,
130) -> Result<get_relating_events::v1::Response> {
131 let from: Option<PduCount> = from.map(str::parse).transpose()?;
132
133 let to: Option<PduCount> = to.map(str::parse).flat_ok();
134
135 let max_depth: usize = if recurse { 3 } else { 0 };
137
138 let limit: usize = limit
139 .map(TryInto::try_into)
140 .flat_ok()
141 .unwrap_or(30)
142 .min(100);
143
144 let target_event_id: &EventId = target;
145
146 let target = services
147 .timeline
148 .get_pdu_id(target)
149 .map_ok(PduId::from)
150 .map_ok(Ok::<_, Error>);
151
152 let visible = services
153 .state_accessor
154 .user_can_see_state_events(sender_user, room_id)
155 .map(|visible| {
156 visible.ok_or_else(|| err!(Request(Forbidden("You cannot view this room."))))
157 });
158
159 let shortroomid = services.short.get_shortroomid(room_id);
160
161 let (shortroomid, target, ()) = try_join3(shortroomid, target, visible).await?;
162
163 let Ok(target) = target else {
164 return Ok(get_relating_events::v1::Response::new(Vec::new()));
165 };
166
167 if shortroomid != target.shortroomid {
168 return Err!(Request(NotFound("Event not found in room.")));
169 }
170
171 if let PduCount::Backfilled(_) = target.count {
172 return Ok(get_relating_events::v1::Response::new(Vec::new()));
173 }
174
175 if let Ok(target_pdu) = services.timeline.get_pdu(target_event_id).await
176 && is_ignored_pdu(services, &target_pdu, sender_user).await
177 {
178 return Err!(HttpJson(NOT_FOUND, {
179 "errcode": "M_SENDER_IGNORED",
180 "error": "You have ignored the user that sent this event",
181 "sender": target_pdu.sender().as_str(),
182 }));
183 }
184
185 let fetch = |depth: usize, count: PduCount| {
186 services
187 .pdu_metadata
188 .get_relations(shortroomid, count, from, dir, Some(sender_user))
189 .map(move |(count, pdu)| (depth, count, pdu))
190 .ready_filter(|(_, count, _)| matches!(count, PduCount::Normal(_)))
191 .boxed()
192 };
193
194 let events = unfold(select_all(once(fetch(0, target.count))), async |mut relations| {
195 let (depth, count, pdu) = relations.next().await?;
196
197 if depth < max_depth {
198 relations.push(fetch(depth.saturating_add(1), count));
199 }
200
201 Some(((depth, count, pdu), relations))
202 })
203 .ready_take_while(|&(_, count, _)| Some(count) != to)
204 .ready_filter(|(_, _, pdu)| {
205 filter_event_type
206 .as_ref()
207 .is_none_or(|kind| kind == pdu.kind())
208 })
209 .ready_filter(|(_, _, pdu)| {
210 filter_rel_type
211 .as_ref()
212 .is_none_or(|rel_type| rel_type.relation_type_equal(pdu))
213 })
214 .wide_filter_map(async |(depth, count, pdu)| {
215 services
216 .state_accessor
217 .user_can_see_event(sender_user, pdu.room_id(), pdu.event_id())
218 .await
219 .then_some((depth, count, pdu))
220 })
221 .take(limit)
222 .collect::<Vec<_>>()
223 .await;
224
225 Ok(get_relating_events::v1::Response {
226 recursion_depth: max_depth
227 .gt(&0)
228 .then(|| events.iter().map(at!(0)))
229 .into_iter()
230 .flatten()
231 .max()
232 .map(TryInto::try_into)
233 .transpose()?,
234
235 next_batch: events
236 .last()
237 .map(at!(1))
238 .as_ref()
239 .map(ToString::to_string),
240
241 prev_batch: events
242 .first()
243 .map(at!(1))
244 .or(from)
245 .as_ref()
246 .map(ToString::to_string),
247
248 chunk: events
249 .into_iter()
250 .map(at!(2))
251 .map(Event::into_format)
252 .collect(),
253 })
254}