Skip to main content

tuwunel_service/pusher/
send.rs

1use futures::future::join;
2use ipaddress::IPAddress;
3use ruma::{
4	UInt, UserId,
5	api::{
6		client::push::{Pusher, PusherKind},
7		push_gateway::send_event_notification::{
8			self,
9			v1::{Device, Notification, NotificationCounts, NotificationPriority},
10		},
11	},
12	events::TimelineEventType,
13	push::{Action, PushFormat, Ruleset, Tweak},
14	uint,
15};
16use tuwunel_core::{Err, Result, err, implement, matrix::Event};
17
18#[implement(super::Service)]
19#[tracing::instrument(level = "debug", skip_all)]
20pub async fn send_push_notice<E>(
21	&self,
22	user_id: &UserId,
23	pusher: &Pusher,
24	ruleset: &Ruleset,
25	event: &E,
26) -> Result
27where
28	E: Event,
29{
30	let mut notify = None;
31	let mut tweaks = Vec::new();
32
33	let power_levels = self
34		.services
35		.state_accessor
36		.get_power_levels(event.room_id())
37		.await
38		.ok();
39
40	let serialized = event.to_format();
41	let actions = self
42		.get_actions(user_id, ruleset, power_levels.as_ref(), &serialized, event.room_id())
43		.await;
44
45	for action in actions {
46		let n = match action {
47			| Action::Notify => true,
48			| Action::SetTweak(tweak) => {
49				tweaks.push(tweak.clone());
50				continue;
51			},
52			| _ => false,
53		};
54
55		if notify.is_some() {
56			return Err!(Request(BadJson(
57				r#"Malformed pushrule contains more than one of these actions: ["dont_notify", "notify", "coalesce"]"#
58			)));
59		}
60
61		notify = Some(n);
62	}
63
64	if notify == Some(true) || self.services.config.push_everything {
65		// MSC3771/MSC3773: badge count merges main and per-thread notifications.
66		let (main, threads) = join(
67			self.services
68				.pusher
69				.notification_count(user_id, event.room_id()),
70			self.services
71				.pusher
72				.thread_notification_counts(user_id, event.room_id()),
73		)
74		.await;
75
76		let thread_total: u64 = threads
77			.values()
78			.map(|(notifications, _)| *notifications)
79			.sum();
80
81		let unread: UInt = main
82			.saturating_add(thread_total)
83			.try_into()
84			.unwrap_or_else(|_| uint!(1));
85
86		self.send_notice(unread, pusher, tweaks, event)
87			.await?;
88	}
89
90	Ok(())
91}
92
93#[implement(super::Service)]
94#[tracing::instrument(level = "debug", skip_all)]
95async fn send_notice<Pdu: Event>(
96	&self,
97	unread: UInt,
98	pusher: &Pusher,
99	tweaks: Vec<Tweak>,
100	event: &Pdu,
101) -> Result {
102	// TODO: email
103	match &pusher.kind {
104		| PusherKind::Http(http) => {
105			let url = &http.url;
106			let url = url::Url::parse(&http.url).map_err(|e| {
107				err!(Request(InvalidParam(
108					warn!(%url, "HTTP pusher URL is not a valid URL: {e}")
109				)))
110			})?;
111
112			if ["http", "https"]
113				.iter()
114				.all(|&scheme| !scheme.eq_ignore_ascii_case(url.scheme()))
115			{
116				return Err!(Request(InvalidParam(
117					warn!(%url, "HTTP pusher URL is not a valid HTTP/HTTPS URL")
118				)));
119			}
120
121			if let Ok(ip) = IPAddress::parse(url.host_str().expect("URL previously validated"))
122				&& !self.services.client.valid_cidr_range(&ip)
123			{
124				return Err!(Request(InvalidParam(
125					warn!(%url, "HTTP pusher URL is a forbidden remote address")
126				)));
127			}
128
129			// TODO (timo): can pusher/devices have conflicting formats
130			let event_id_only = http.format == Some(PushFormat::EventIdOnly);
131
132			let mut device = Device::new(pusher.ids.app_id.clone(), pusher.ids.pushkey.clone());
133			device.data.data.clone_from(&http.data);
134			device.data.format.clone_from(&http.format);
135
136			// Tweaks are only added if the format is NOT event_id_only
137			if !event_id_only {
138				device.tweaks.clone_from(&tweaks);
139			}
140
141			let d = vec![device];
142			let mut notify = Notification::new(d);
143
144			notify.event_id = Some(event.event_id().to_owned());
145			notify.room_id = Some(event.room_id().to_owned());
146			if http
147				.data
148				.get("org.matrix.msc4076.disable_badge_count")
149				.is_none() && http.data.get("disable_badge_count").is_none()
150			{
151				notify.counts = NotificationCounts::new(unread, uint!(0));
152			} else {
153				// counts will not be serialised if it's the default (0, 0)
154				// skip_serializing_if = "NotificationCounts::is_default"
155				notify.counts = NotificationCounts::default();
156			}
157
158			if !event_id_only {
159				if *event.kind() == TimelineEventType::RoomEncrypted
160					|| tweaks.iter().any(|t| {
161						matches!(
162							t,
163							Tweak::Highlight(ruma::push::HighlightTweakValue::Yes)
164								| Tweak::Sound(_)
165						)
166					}) {
167					notify.prio = NotificationPriority::High;
168				} else {
169					notify.prio = NotificationPriority::Low;
170				}
171				notify.sender = Some(event.sender().to_owned());
172				notify.event_type = Some(event.kind().to_owned());
173				notify.content = serde_json::value::to_raw_value(event.content()).ok();
174
175				if *event.kind() == TimelineEventType::RoomMember {
176					notify.user_is_target = event.state_key() == Some(event.sender().as_str());
177				}
178
179				notify.sender_display_name = self
180					.services
181					.profile
182					.displayname(event.sender())
183					.await
184					.ok();
185
186				notify.room_name = self
187					.services
188					.state_accessor
189					.get_name(event.room_id())
190					.await
191					.ok();
192
193				notify.room_alias = self
194					.services
195					.state_accessor
196					.get_canonical_alias(event.room_id())
197					.await
198					.ok();
199			}
200
201			self.send_request(&http.url, send_event_notification::v1::Request::new(notify))
202				.await?;
203
204			Ok(())
205		},
206		// TODO: Handle email
207		//PusherKind::Email(_) => Ok(()),
208		| _ => Ok(()),
209	}
210}