Skip to main content

tuwunel_service/pusher/
mod.rs

1mod append;
2mod notification;
3mod request;
4mod send;
5mod suppressed;
6#[cfg(test)]
7mod tests;
8
9use std::sync::Arc;
10
11use futures::{Stream, StreamExt, TryFutureExt, future::join};
12use ipaddress::IPAddress;
13use ruma::{
14	DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId,
15	api::client::push::{Pusher, PusherKind, set_pusher},
16	events::{AnySyncTimelineEvent, room::power_levels::RoomPowerLevels},
17	push::{Action, PushConditionPowerLevelsCtx, PushConditionRoomCtx, Ruleset},
18	serde::Raw,
19	uint,
20};
21use tuwunel_core::{
22	Err, Result, err, implement,
23	utils::{
24		MutexMap,
25		future::TryExtExt,
26		stream::{BroadbandExt, ReadyExt, TryIgnore},
27	},
28};
29use tuwunel_database::{Database, Deserialized, Ignore, Interfix, Json, Map};
30
31pub use self::append::Notified;
32
33pub struct Service {
34	services: Arc<crate::services::OnceServices>,
35	notification_increment_mutex: MutexMap<(OwnedRoomId, OwnedUserId), ()>,
36	highlight_increment_mutex: MutexMap<(OwnedRoomId, OwnedUserId), ()>,
37	db: Data,
38	suppressed: suppressed::SuppressedQueue,
39}
40
41struct Data {
42	db: Arc<Database>,
43	senderkey_pusher: Arc<Map>,
44	pushkey_deviceid: Arc<Map>,
45	useridcount_notification: Arc<Map>,
46	userroomid_highlightcount: Arc<Map>,
47	userroomid_notificationcount: Arc<Map>,
48	roomuserid_lastnotificationread: Arc<Map>,
49}
50
51impl crate::Service for Service {
52	fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
53		Ok(Arc::new(Self {
54			services: args.services.clone(),
55			notification_increment_mutex: MutexMap::new(),
56			highlight_increment_mutex: MutexMap::new(),
57			db: Data {
58				db: args.db.clone(),
59				senderkey_pusher: args.db["senderkey_pusher"].clone(),
60				pushkey_deviceid: args.db["pushkey_deviceid"].clone(),
61				useridcount_notification: args.db["useridcount_notification"].clone(),
62				userroomid_highlightcount: args.db["userroomid_highlightcount"].clone(),
63				userroomid_notificationcount: args.db["userroomid_notificationcount"].clone(),
64				roomuserid_lastnotificationread: args.db["roomuserid_lastnotificationread"]
65					.clone(),
66			},
67			suppressed: suppressed::SuppressedQueue::default(),
68		}))
69	}
70
71	fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
72}
73
74#[implement(Service)]
75pub async fn set_pusher(
76	&self,
77	sender: &UserId,
78	sender_device: &DeviceId,
79	pusher: &set_pusher::v3::PusherAction,
80) -> Result {
81	match pusher {
82		| set_pusher::v3::PusherAction::Post(data) => {
83			let pushkey = data.pusher.ids.pushkey.as_str();
84
85			if pushkey.len() > 512 {
86				return Err!(Request(InvalidParam(
87					"Push key length cannot be greater than 512 bytes."
88				)));
89			}
90
91			if data.pusher.ids.app_id.as_str().len() > 64 {
92				return Err!(Request(InvalidParam(
93					"App ID length cannot be greater than 64 bytes."
94				)));
95			}
96
97			// add some validation to the pusher URL
98			let pusher_kind = &data.pusher.kind;
99			if let PusherKind::Http(http) = pusher_kind {
100				let url = &http.url;
101				let url = url::Url::parse(&http.url).map_err(|e| {
102					err!(Request(InvalidParam(
103						warn!(%url, "HTTP pusher URL is not a valid URL: {e}")
104					)))
105				})?;
106
107				if ["http", "https"]
108					.iter()
109					.all(|&scheme| !scheme.eq_ignore_ascii_case(url.scheme()))
110				{
111					return Err!(Request(InvalidParam(
112						warn!(%url, "HTTP pusher URL is not a valid HTTP/HTTPS URL")
113					)));
114				}
115
116				if let Ok(ip) =
117					IPAddress::parse(url.host_str().expect("URL previously validated"))
118					&& !self.services.client.valid_cidr_range(&ip)
119				{
120					return Err!(Request(InvalidParam(
121						warn!(%url, "HTTP pusher URL is a forbidden remote address")
122					)));
123				}
124			}
125
126			let pushkey = data.pusher.ids.pushkey.as_str();
127			let key = (sender, pushkey);
128			self.db.senderkey_pusher.put(key, Json(pusher));
129			self.db
130				.pushkey_deviceid
131				.insert(pushkey, sender_device);
132		},
133		| set_pusher::v3::PusherAction::Delete(ids) => {
134			self.delete_pusher(sender, ids.pushkey.as_str())
135				.await;
136		},
137	}
138
139	Ok(())
140}
141
142#[implement(Service)]
143pub async fn delete_pusher(&self, sender: &UserId, pushkey: &str) {
144	let key = (sender, pushkey);
145	self.db.senderkey_pusher.del(key);
146	self.db.pushkey_deviceid.remove(pushkey);
147	self.clear_suppressed_pushkey(sender, pushkey);
148
149	self.services
150		.sending
151		.cleanup_events(None, Some(sender), Some(pushkey))
152		.await
153		.ok();
154}
155
156#[implement(Service)]
157pub async fn get_device_pushkeys(&self, sender: &UserId, device_id: &DeviceId) -> Vec<String> {
158	self.get_pushkeys(sender)
159		.map(ToOwned::to_owned)
160		.broad_filter_map(async |pushkey| {
161			self.get_pusher_device(&pushkey)
162				.await
163				.ok()
164				.as_ref()
165				.is_some_and(|pusher_device| pusher_device == device_id)
166				.then_some(pushkey)
167		})
168		.collect()
169		.await
170}
171
172#[implement(Service)]
173pub async fn get_pusher_device(&self, pushkey: &str) -> Result<OwnedDeviceId> {
174	self.db
175		.pushkey_deviceid
176		.get(pushkey)
177		.await
178		.deserialized()
179}
180
181#[implement(Service)]
182pub async fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Pusher> {
183	let senderkey = (sender, pushkey);
184	self.db
185		.senderkey_pusher
186		.qry(&senderkey)
187		.await
188		.deserialized()
189}
190
191#[implement(Service)]
192pub async fn get_pushers(&self, sender: &UserId) -> Vec<Pusher> {
193	let prefix = (sender, Interfix);
194	self.db
195		.senderkey_pusher
196		.stream_prefix(&prefix)
197		.ignore_err()
198		.map(|(_, pusher): (Ignore, Pusher)| pusher)
199		.collect()
200		.await
201}
202
203#[implement(Service)]
204pub fn get_pushkeys<'a>(&'a self, sender: &'a UserId) -> impl Stream<Item = &str> + Send + 'a {
205	let prefix = (sender, Interfix);
206	self.db
207		.senderkey_pusher
208		.keys_prefix(&prefix)
209		.ignore_err()
210		.map(|(_, pushkey): (Ignore, &str)| pushkey)
211}
212
213#[implement(Service)]
214#[tracing::instrument(level = "debug", skip_all)]
215pub fn get_notifications<'a>(
216	&'a self,
217	sender: &'a UserId,
218	from: Option<u64>,
219) -> impl Stream<Item = (u64, Notified)> + Send + 'a {
220	let from = from
221		.map(|from| from.saturating_sub(1))
222		.unwrap_or(u64::MAX);
223
224	self.db
225		.useridcount_notification
226		.rev_stream_from(&(sender, from))
227		.ignore_err()
228		.map(|item: ((&UserId, u64), _)| (item.0, item.1))
229		.ready_take_while(move |((user_id, _count), _)| sender == *user_id)
230		.map(|((_, count), notified)| (count, notified))
231}
232
233#[implement(Service)]
234#[tracing::instrument(level = "debug", skip_all)]
235pub async fn get_actions<'a>(
236	&self,
237	user: &UserId,
238	ruleset: &'a Ruleset,
239	power_levels: Option<&RoomPowerLevels>,
240	pdu: &Raw<AnySyncTimelineEvent>,
241	room_id: &RoomId,
242) -> &'a [Action] {
243	let user_display_name = self
244		.services
245		.users
246		.displayname(user)
247		.unwrap_or_else(|_| user.localpart().to_owned());
248
249	let room_joined_count = self
250		.services
251		.state_cache
252		.room_joined_count(room_id)
253		.map_ok(TryInto::try_into)
254		.map_ok(|res| res.unwrap_or_else(|_| uint!(1)))
255		.unwrap_or_default();
256
257	let (room_joined_count, user_display_name) = join(room_joined_count, user_display_name).await;
258
259	let power_levels = power_levels.map(|power_levels| PushConditionPowerLevelsCtx {
260		users: power_levels.users.clone(),
261		users_default: power_levels.users_default,
262		notifications: power_levels.notifications.clone(),
263		rules: power_levels.rules.clone(),
264	});
265
266	let ctx = PushConditionRoomCtx::new(
267		room_id.to_owned(),
268		room_joined_count,
269		user.to_owned(),
270		user_display_name,
271	);
272	let ctx = match power_levels {
273		| Some(pl) => ctx.with_power_levels(pl),
274		| None => ctx,
275	};
276
277	ruleset.get_actions(pdu, &ctx).await
278}