Skip to main content

tuwunel_service/membership/
invite.rs

1use futures::FutureExt;
2use ruma::{
3	OwnedServerName, RoomId, UserId,
4	api::federation::membership::create_invite,
5	events::room::member::{MembershipState, RoomMemberEventContent},
6};
7use tuwunel_core::{
8	Err, Result, at, err, implement, matrix::event::gen_event_id_canonical_json, pdu::PduBuilder,
9};
10
11use super::Service;
12
13#[implement(Service)]
14#[tracing::instrument(
15    level = "debug",
16    skip_all,
17    fields(%sender_user, %room_id, %user_id)
18)]
19pub async fn invite(
20	&self,
21	sender_user: &UserId,
22	user_id: &UserId,
23	room_id: &RoomId,
24	reason: Option<&String>,
25	is_direct: bool,
26) -> Result {
27	if self.services.globals.user_is_local(user_id) {
28		self.local_invite(sender_user, user_id, room_id, reason, is_direct)
29			.boxed()
30			.await?;
31	} else {
32		self.remote_invite(sender_user, user_id, room_id, reason, is_direct)
33			.boxed()
34			.await?;
35	}
36
37	Ok(())
38}
39
40#[implement(Service)]
41#[tracing::instrument(name = "remote", level = "debug", skip_all)]
42async fn remote_invite(
43	&self,
44	sender_user: &UserId,
45	user_id: &UserId,
46	room_id: &RoomId,
47	reason: Option<&String>,
48	is_direct: bool,
49) -> Result {
50	let (pdu, pdu_json, invite_room_state) = {
51		let state_lock = self.services.state.mutex.lock(room_id).await;
52
53		let content = RoomMemberEventContent {
54			displayname: self
55				.services
56				.users
57				.displayname(user_id)
58				.await
59				.ok(),
60			avatar_url: self.services.users.avatar_url(user_id).await.ok(),
61			is_direct: Some(is_direct),
62			reason: reason.cloned(),
63			..RoomMemberEventContent::new(MembershipState::Invite)
64		};
65
66		let (pdu, pdu_json) = self
67			.services
68			.timeline
69			.create_hash_and_sign_event(
70				PduBuilder::state(user_id.to_string(), &content),
71				sender_user,
72				room_id,
73				&state_lock,
74			)
75			.await?;
76
77		let invite_room_state = self.services.state.summary_stripped(&pdu).await;
78
79		drop(state_lock);
80
81		(pdu, pdu_json, invite_room_state)
82	};
83
84	let room_version_id = self
85		.services
86		.state
87		.get_room_version(room_id)
88		.await?;
89
90	let response = self
91		.services
92		.federation
93		.execute(user_id.server_name(), create_invite::v2::Request {
94			room_id: room_id.to_owned(),
95			event_id: (*pdu.event_id).to_owned(),
96			room_version: room_version_id.clone(),
97			event: self
98				.services
99				.federation
100				.format_pdu_into(pdu_json.clone(), Some(&room_version_id))
101				.await,
102			invite_room_state: invite_room_state
103				.into_iter()
104				.map(Into::into)
105				.collect(),
106			via: self
107				.services
108				.state_cache
109				.servers_route_via(room_id)
110				.await
111				.ok(),
112		})
113		.await?;
114
115	// We do not add the event_id field to the pdu here because of signature and
116	// hashes checks
117	let (event_id, value) = gen_event_id_canonical_json(&response.event, &room_version_id)
118		.map_err(|e| {
119			err!(Request(BadJson(warn!("Could not convert event to canonical JSON: {e}"))))
120		})?;
121
122	if pdu.event_id != event_id {
123		return Err!(Request(BadJson(warn!(
124			%pdu.event_id, %event_id,
125			"Server {} sent event with wrong event ID",
126			user_id.server_name()
127		))));
128	}
129
130	let origin: OwnedServerName = serde_json::from_value(serde_json::to_value(
131		value
132			.get("origin")
133			.ok_or_else(|| err!(Request(BadJson("Event missing origin field."))))?,
134	)?)
135	.map_err(|e| {
136		err!(Request(BadJson(warn!("Origin field in event is not a valid server name: {e}"))))
137	})?;
138
139	let pdu_id = self
140		.services
141		.event_handler
142		.handle_incoming_pdu(&origin, room_id, &event_id, value, true)
143		.await?
144		.map(at!(0))
145		.ok_or_else(|| {
146			err!(Request(InvalidParam("Could not accept incoming PDU as timeline event.")))
147		})?;
148
149	self.services
150		.sending
151		.send_pdu_room(room_id, &pdu_id)
152		.await?;
153
154	Ok(())
155}
156
157#[implement(Service)]
158#[tracing::instrument(name = "local", level = "debug", skip_all)]
159async fn local_invite(
160	&self,
161	sender_user: &UserId,
162	user_id: &UserId,
163	room_id: &RoomId,
164	reason: Option<&String>,
165	is_direct: bool,
166) -> Result {
167	if self.services.users.invites_blocked(user_id).await {
168		return Err!(Request(InviteBlocked("{user_id} has blocked invites.")));
169	}
170
171	if !self
172		.services
173		.state_cache
174		.is_joined(sender_user, room_id)
175		.await
176	{
177		return Err!(Request(Forbidden(
178			"You must be joined in the room you are trying to invite from."
179		)));
180	}
181
182	let state_lock = self.services.state.mutex.lock(room_id).await;
183
184	let content = RoomMemberEventContent {
185		displayname: self
186			.services
187			.users
188			.displayname(user_id)
189			.await
190			.ok(),
191		avatar_url: self.services.users.avatar_url(user_id).await.ok(),
192		blurhash: self.services.users.blurhash(user_id).await.ok(),
193		is_direct: Some(is_direct),
194		reason: reason.cloned(),
195		..RoomMemberEventContent::new(MembershipState::Invite)
196	};
197
198	self.services
199		.timeline
200		.build_and_append_pdu(
201			PduBuilder::state(user_id.to_string(), &content),
202			sender_user,
203			room_id,
204			&state_lock,
205		)
206		.await?;
207
208	drop(state_lock);
209
210	Ok(())
211}