1use std::collections::BTreeMap;
2
3use axum::extract::State;
4use ruma::{
5 MilliSecondsSinceUnixEpoch,
6 api::client::{read_marker::set_read_marker, receipt::create_receipt},
7 events::{
8 RoomAccountDataEventType,
9 fully_read::{FullyReadEvent, FullyReadEventContent},
10 receipt::{Receipt, ReceiptEvent, ReceiptEventContent, ReceiptThread, ReceiptType},
11 },
12 presence::PresenceState,
13};
14use tuwunel_core::{Err, PduCount, Result, err};
15
16use crate::{ClientIp, Ruma};
17
18pub(crate) async fn set_read_marker_route(
26 State(services): State<crate::State>,
27 ClientIp(client): ClientIp,
28 body: Ruma<set_read_marker::v3::Request>,
29) -> Result<set_read_marker::v3::Response> {
30 let sender_user = body.sender_user();
31
32 if body.private_read_receipt.is_some() || body.read_receipt.is_some() {
33 services
36 .pusher
37 .reset_notification_counts_for_thread(
38 sender_user,
39 &body.room_id,
40 &ReceiptThread::Unthreaded,
41 )
42 .await;
43 }
44
45 if let Some(event) = &body.fully_read {
46 let fully_read_event = FullyReadEvent {
47 content: FullyReadEventContent { event_id: event.clone() },
48 };
49
50 services
51 .account_data
52 .update(
53 Some(&body.room_id),
54 sender_user,
55 RoomAccountDataEventType::FullyRead,
56 &serde_json::to_value(fully_read_event)?,
57 )
58 .await
59 .ok();
60 }
61
62 if let Some(event) = &body.private_read_receipt {
63 let count = services
64 .timeline
65 .get_pdu_count(event)
66 .await
67 .map_err(|_| err!(Request(NotFound("Event not found."))))?;
68
69 let PduCount::Normal(count) = count else {
70 return Err!(Request(InvalidParam(
71 "Event is a backfilled PDU and cannot be marked as read."
72 )));
73 };
74
75 services
76 .read_receipt
77 .private_read_set(&body.room_id, sender_user, count, &ReceiptThread::Unthreaded)
78 .await;
79 }
80
81 if let Some(event) = &body.read_receipt {
82 let receipt_content = BTreeMap::from_iter([(
83 event.to_owned(),
84 BTreeMap::from_iter([(
85 ReceiptType::Read,
86 BTreeMap::from_iter([(sender_user.to_owned(), Receipt {
87 ts: Some(MilliSecondsSinceUnixEpoch::now()),
88 thread: ReceiptThread::Unthreaded,
89 })]),
90 )]),
91 )]);
92
93 services
94 .read_receipt
95 .readreceipt_update(sender_user, &body.room_id, &ReceiptEvent {
96 content: ReceiptEventContent(receipt_content),
97 room_id: body.room_id.clone(),
98 })
99 .await;
100
101 services
102 .presence
103 .maybe_ping_presence(
104 sender_user,
105 body.sender_device.as_deref(),
106 Some(client),
107 &PresenceState::Online,
108 )
109 .await
110 .ok();
111 }
112
113 Ok(set_read_marker::v3::Response {})
114}
115
116pub(crate) async fn create_receipt_route(
120 State(services): State<crate::State>,
121 ClientIp(client): ClientIp,
122 body: Ruma<create_receipt::v3::Request>,
123) -> Result<create_receipt::v3::Response> {
124 let sender_user = body.sender_user();
125
126 if matches!(&body.receipt_type, create_receipt::v3::ReceiptType::FullyRead)
128 && !matches!(body.thread, ReceiptThread::Unthreaded)
129 {
130 return Err!(Request(InvalidParam(
131 "thread_id must not be set for m.fully_read receipts"
132 )));
133 }
134
135 if body.thread.as_str() == Some("") {
137 return Err!(Request(InvalidParam("thread_id must be a non-empty string")));
138 }
139
140 if !matches!(
143 &body.thread,
144 ReceiptThread::Unthreaded | ReceiptThread::Main | ReceiptThread::Thread(_)
145 ) {
146 return Err!(Request(InvalidParam(
147 "thread_id must be either \"main\" or a thread root event id"
148 )));
149 }
150
151 if matches!(&body.thread, ReceiptThread::Main | ReceiptThread::Thread(_)) {
153 let resolved = services
154 .threads
155 .get_thread_id_for_event(&body.event_id)
156 .await;
157
158 let in_thread = match (&body.thread, resolved.as_deref()) {
159 | (ReceiptThread::Main, None) => true,
160 | (ReceiptThread::Thread(root), Some(parent)) => &**root == parent,
161 | (ReceiptThread::Thread(root), None) => **root == *body.event_id,
162 | _ => false,
163 };
164
165 if !in_thread {
166 return Err!(Request(InvalidParam("event_id is not related to the given thread_id")));
167 }
168 }
169
170 if matches!(
171 &body.receipt_type,
172 create_receipt::v3::ReceiptType::Read | create_receipt::v3::ReceiptType::ReadPrivate
173 ) {
174 services
175 .pusher
176 .reset_notification_counts_for_thread(sender_user, &body.room_id, &body.thread)
177 .await;
178 }
179
180 match body.receipt_type {
181 | create_receipt::v3::ReceiptType::FullyRead => {
182 let fully_read_event = FullyReadEvent {
183 content: FullyReadEventContent { event_id: body.event_id.clone() },
184 };
185 services
186 .account_data
187 .update(
188 Some(&body.room_id),
189 sender_user,
190 RoomAccountDataEventType::FullyRead,
191 &serde_json::to_value(fully_read_event)?,
192 )
193 .await?;
194 },
195 | create_receipt::v3::ReceiptType::Read => {
196 let receipt_content = BTreeMap::from_iter([(
197 body.event_id.clone(),
198 BTreeMap::from_iter([(
199 ReceiptType::Read,
200 BTreeMap::from_iter([(sender_user.to_owned(), Receipt {
201 ts: Some(MilliSecondsSinceUnixEpoch::now()),
202 thread: body.thread.clone(),
203 })]),
204 )]),
205 )]);
206
207 services
208 .read_receipt
209 .readreceipt_update(sender_user, &body.room_id, &ReceiptEvent {
210 content: ReceiptEventContent(receipt_content),
211 room_id: body.room_id.clone(),
212 })
213 .await;
214
215 services
216 .presence
217 .maybe_ping_presence(
218 sender_user,
219 body.sender_device.as_deref(),
220 Some(client),
221 &PresenceState::Online,
222 )
223 .await
224 .ok();
225 },
226 | create_receipt::v3::ReceiptType::ReadPrivate => {
227 let count = services
228 .timeline
229 .get_pdu_count(&body.event_id)
230 .await
231 .map_err(|_| err!(Request(NotFound("Event not found."))))?;
232
233 let PduCount::Normal(count) = count else {
234 return Err!(Request(InvalidParam(
235 "Event is a backfilled PDU and cannot be marked as read."
236 )));
237 };
238
239 services
240 .read_receipt
241 .private_read_set(&body.room_id, sender_user, count, &body.thread)
242 .await;
243 },
244 | _ => {
245 return Err!(Request(InvalidParam(warn!(
246 "Received unknown read receipt type: {}",
247 &body.receipt_type
248 ))));
249 },
250 }
251
252 Ok(create_receipt::v3::Response {})
253}