1mod append;
2mod notification;
3mod request;
4mod send;
5mod suppressed;
6#[cfg(test)]
7mod tests;
8
9use std::sync::Arc;
10
11use futures::{Stream, StreamExt, TryFutureExt, future::join};
12use ipaddress::IPAddress;
13use ruma::{
14 DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId,
15 api::client::push::{Pusher, PusherKind, set_pusher},
16 events::{AnySyncTimelineEvent, room::power_levels::RoomPowerLevels},
17 push::{Action, PushConditionPowerLevelsCtx, PushConditionRoomCtx, Ruleset},
18 serde::Raw,
19 uint,
20};
21use tuwunel_core::{
22 Err, Result, err, implement,
23 utils::{
24 MutexMap,
25 future::TryExtExt,
26 stream::{BroadbandExt, ReadyExt, TryIgnore},
27 },
28};
29use tuwunel_database::{Database, Deserialized, Ignore, Interfix, Json, Map};
30
31pub use self::append::Notified;
32
33pub struct Service {
34 services: Arc<crate::services::OnceServices>,
35 notification_increment_mutex: MutexMap<(OwnedRoomId, OwnedUserId), ()>,
36 highlight_increment_mutex: MutexMap<(OwnedRoomId, OwnedUserId), ()>,
37 db: Data,
38 suppressed: suppressed::SuppressedQueue,
39}
40
41struct Data {
42 db: Arc<Database>,
43 senderkey_pusher: Arc<Map>,
44 pushkey_deviceid: Arc<Map>,
45 useridcount_notification: Arc<Map>,
46 userroomid_highlightcount: Arc<Map>,
47 userroomid_notificationcount: Arc<Map>,
48 roomuserid_lastnotificationread: Arc<Map>,
49}
50
51impl crate::Service for Service {
52 fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
53 Ok(Arc::new(Self {
54 services: args.services.clone(),
55 notification_increment_mutex: MutexMap::new(),
56 highlight_increment_mutex: MutexMap::new(),
57 db: Data {
58 db: args.db.clone(),
59 senderkey_pusher: args.db["senderkey_pusher"].clone(),
60 pushkey_deviceid: args.db["pushkey_deviceid"].clone(),
61 useridcount_notification: args.db["useridcount_notification"].clone(),
62 userroomid_highlightcount: args.db["userroomid_highlightcount"].clone(),
63 userroomid_notificationcount: args.db["userroomid_notificationcount"].clone(),
64 roomuserid_lastnotificationread: args.db["roomuserid_lastnotificationread"]
65 .clone(),
66 },
67 suppressed: suppressed::SuppressedQueue::default(),
68 }))
69 }
70
71 fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
72}
73
74#[implement(Service)]
75pub async fn set_pusher(
76 &self,
77 sender: &UserId,
78 sender_device: &DeviceId,
79 pusher: &set_pusher::v3::PusherAction,
80) -> Result {
81 match pusher {
82 | set_pusher::v3::PusherAction::Post(data) => {
83 let pushkey = data.pusher.ids.pushkey.as_str();
84
85 if pushkey.len() > 512 {
86 return Err!(Request(InvalidParam(
87 "Push key length cannot be greater than 512 bytes."
88 )));
89 }
90
91 if data.pusher.ids.app_id.as_str().len() > 64 {
92 return Err!(Request(InvalidParam(
93 "App ID length cannot be greater than 64 bytes."
94 )));
95 }
96
97 let pusher_kind = &data.pusher.kind;
99 if let PusherKind::Http(http) = pusher_kind {
100 let url = &http.url;
101 let url = url::Url::parse(&http.url).map_err(|e| {
102 err!(Request(InvalidParam(
103 warn!(%url, "HTTP pusher URL is not a valid URL: {e}")
104 )))
105 })?;
106
107 if ["http", "https"]
108 .iter()
109 .all(|&scheme| !scheme.eq_ignore_ascii_case(url.scheme()))
110 {
111 return Err!(Request(InvalidParam(
112 warn!(%url, "HTTP pusher URL is not a valid HTTP/HTTPS URL")
113 )));
114 }
115
116 if let Ok(ip) =
117 IPAddress::parse(url.host_str().expect("URL previously validated"))
118 && !self.services.client.valid_cidr_range(&ip)
119 {
120 return Err!(Request(InvalidParam(
121 warn!(%url, "HTTP pusher URL is a forbidden remote address")
122 )));
123 }
124 }
125
126 let pushkey = data.pusher.ids.pushkey.as_str();
127 let key = (sender, pushkey);
128 self.db.senderkey_pusher.put(key, Json(pusher));
129 self.db
130 .pushkey_deviceid
131 .insert(pushkey, sender_device);
132 },
133 | set_pusher::v3::PusherAction::Delete(ids) => {
134 self.delete_pusher(sender, ids.pushkey.as_str())
135 .await;
136 },
137 }
138
139 Ok(())
140}
141
142#[implement(Service)]
143pub async fn delete_pusher(&self, sender: &UserId, pushkey: &str) {
144 let key = (sender, pushkey);
145 self.db.senderkey_pusher.del(key);
146 self.db.pushkey_deviceid.remove(pushkey);
147 self.clear_suppressed_pushkey(sender, pushkey);
148
149 self.services
150 .sending
151 .cleanup_events(None, Some(sender), Some(pushkey))
152 .await
153 .ok();
154}
155
156#[implement(Service)]
157pub async fn get_device_pushkeys(&self, sender: &UserId, device_id: &DeviceId) -> Vec<String> {
158 self.get_pushkeys(sender)
159 .map(ToOwned::to_owned)
160 .broad_filter_map(async |pushkey| {
161 self.get_pusher_device(&pushkey)
162 .await
163 .ok()
164 .as_ref()
165 .is_some_and(|pusher_device| pusher_device == device_id)
166 .then_some(pushkey)
167 })
168 .collect()
169 .await
170}
171
172#[implement(Service)]
173pub async fn get_pusher_device(&self, pushkey: &str) -> Result<OwnedDeviceId> {
174 self.db
175 .pushkey_deviceid
176 .get(pushkey)
177 .await
178 .deserialized()
179}
180
181#[implement(Service)]
182pub async fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Pusher> {
183 let senderkey = (sender, pushkey);
184 self.db
185 .senderkey_pusher
186 .qry(&senderkey)
187 .await
188 .deserialized()
189}
190
191#[implement(Service)]
192pub async fn get_pushers(&self, sender: &UserId) -> Vec<Pusher> {
193 let prefix = (sender, Interfix);
194 self.db
195 .senderkey_pusher
196 .stream_prefix(&prefix)
197 .ignore_err()
198 .map(|(_, pusher): (Ignore, Pusher)| pusher)
199 .collect()
200 .await
201}
202
203#[implement(Service)]
204pub fn get_pushkeys<'a>(&'a self, sender: &'a UserId) -> impl Stream<Item = &str> + Send + 'a {
205 let prefix = (sender, Interfix);
206 self.db
207 .senderkey_pusher
208 .keys_prefix(&prefix)
209 .ignore_err()
210 .map(|(_, pushkey): (Ignore, &str)| pushkey)
211}
212
213#[implement(Service)]
214#[tracing::instrument(level = "debug", skip_all)]
215pub fn get_notifications<'a>(
216 &'a self,
217 sender: &'a UserId,
218 from: Option<u64>,
219) -> impl Stream<Item = (u64, Notified)> + Send + 'a {
220 let from = from
221 .map(|from| from.saturating_sub(1))
222 .unwrap_or(u64::MAX);
223
224 self.db
225 .useridcount_notification
226 .rev_stream_from(&(sender, from))
227 .ignore_err()
228 .map(|item: ((&UserId, u64), _)| (item.0, item.1))
229 .ready_take_while(move |((user_id, _count), _)| sender == *user_id)
230 .map(|((_, count), notified)| (count, notified))
231}
232
233#[implement(Service)]
234#[tracing::instrument(level = "debug", skip_all)]
235pub async fn get_actions<'a>(
236 &self,
237 user: &UserId,
238 ruleset: &'a Ruleset,
239 power_levels: Option<&RoomPowerLevels>,
240 pdu: &Raw<AnySyncTimelineEvent>,
241 room_id: &RoomId,
242) -> &'a [Action] {
243 let user_display_name = self
244 .services
245 .users
246 .displayname(user)
247 .unwrap_or_else(|_| user.localpart().to_owned());
248
249 let room_joined_count = self
250 .services
251 .state_cache
252 .room_joined_count(room_id)
253 .map_ok(TryInto::try_into)
254 .map_ok(|res| res.unwrap_or_else(|_| uint!(1)))
255 .unwrap_or_default();
256
257 let (room_joined_count, user_display_name) = join(room_joined_count, user_display_name).await;
258
259 let power_levels = power_levels.map(|power_levels| PushConditionPowerLevelsCtx {
260 users: power_levels.users.clone(),
261 users_default: power_levels.users_default,
262 notifications: power_levels.notifications.clone(),
263 rules: power_levels.rules.clone(),
264 });
265
266 let ctx = PushConditionRoomCtx::new(
267 room_id.to_owned(),
268 room_joined_count,
269 user.to_owned(),
270 user_display_name,
271 );
272 let ctx = match power_levels {
273 | Some(pl) => ctx.with_power_levels(pl),
274 | None => ctx,
275 };
276
277 ruleset.get_actions(pdu, &ctx).await
278}