tuwunel_service/membership/
invite.rs1use futures::FutureExt;
2use ruma::{
3 OwnedServerName, RoomId, UserId,
4 api::federation::membership::create_invite,
5 events::room::member::{MembershipState, RoomMemberEventContent},
6};
7use tuwunel_core::{
8 Err, Result, at, err, implement, matrix::event::gen_event_id_canonical_json, pdu::PduBuilder,
9};
10
11use super::Service;
12
13#[implement(Service)]
14#[tracing::instrument(
15 level = "debug",
16 skip_all,
17 fields(%sender_user, %room_id, %user_id)
18)]
19pub async fn invite(
20 &self,
21 sender_user: &UserId,
22 user_id: &UserId,
23 room_id: &RoomId,
24 reason: Option<&String>,
25 is_direct: bool,
26) -> Result {
27 if self.services.globals.user_is_local(user_id) {
28 self.local_invite(sender_user, user_id, room_id, reason, is_direct)
29 .boxed()
30 .await?;
31 } else {
32 self.remote_invite(sender_user, user_id, room_id, reason, is_direct)
33 .boxed()
34 .await?;
35 }
36
37 Ok(())
38}
39
40#[implement(Service)]
41#[tracing::instrument(name = "remote", level = "debug", skip_all)]
42async fn remote_invite(
43 &self,
44 sender_user: &UserId,
45 user_id: &UserId,
46 room_id: &RoomId,
47 reason: Option<&String>,
48 is_direct: bool,
49) -> Result {
50 let (pdu, pdu_json, invite_room_state) = {
51 let state_lock = self.services.state.mutex.lock(room_id).await;
52
53 let content = RoomMemberEventContent {
54 displayname: self
55 .services
56 .users
57 .displayname(user_id)
58 .await
59 .ok(),
60 avatar_url: self.services.users.avatar_url(user_id).await.ok(),
61 is_direct: Some(is_direct),
62 reason: reason.cloned(),
63 ..RoomMemberEventContent::new(MembershipState::Invite)
64 };
65
66 let (pdu, pdu_json) = self
67 .services
68 .timeline
69 .create_hash_and_sign_event(
70 PduBuilder::state(user_id.to_string(), &content),
71 sender_user,
72 room_id,
73 &state_lock,
74 )
75 .await?;
76
77 let invite_room_state = self.services.state.summary_stripped(&pdu).await;
78
79 drop(state_lock);
80
81 (pdu, pdu_json, invite_room_state)
82 };
83
84 let room_version_id = self
85 .services
86 .state
87 .get_room_version(room_id)
88 .await?;
89
90 let response = self
91 .services
92 .federation
93 .execute(user_id.server_name(), create_invite::v2::Request {
94 room_id: room_id.to_owned(),
95 event_id: (*pdu.event_id).to_owned(),
96 room_version: room_version_id.clone(),
97 event: self
98 .services
99 .federation
100 .format_pdu_into(pdu_json.clone(), Some(&room_version_id))
101 .await,
102 invite_room_state: invite_room_state
103 .into_iter()
104 .map(Into::into)
105 .collect(),
106 via: self
107 .services
108 .state_cache
109 .servers_route_via(room_id)
110 .await
111 .ok(),
112 })
113 .await?;
114
115 let (event_id, value) = gen_event_id_canonical_json(&response.event, &room_version_id)
118 .map_err(|e| {
119 err!(Request(BadJson(warn!("Could not convert event to canonical JSON: {e}"))))
120 })?;
121
122 if pdu.event_id != event_id {
123 return Err!(Request(BadJson(warn!(
124 %pdu.event_id, %event_id,
125 "Server {} sent event with wrong event ID",
126 user_id.server_name()
127 ))));
128 }
129
130 let origin: OwnedServerName = serde_json::from_value(serde_json::to_value(
131 value
132 .get("origin")
133 .ok_or_else(|| err!(Request(BadJson("Event missing origin field."))))?,
134 )?)
135 .map_err(|e| {
136 err!(Request(BadJson(warn!("Origin field in event is not a valid server name: {e}"))))
137 })?;
138
139 let pdu_id = self
140 .services
141 .event_handler
142 .handle_incoming_pdu(&origin, room_id, &event_id, value, true)
143 .await?
144 .map(at!(0))
145 .ok_or_else(|| {
146 err!(Request(InvalidParam("Could not accept incoming PDU as timeline event.")))
147 })?;
148
149 self.services
150 .sending
151 .send_pdu_room(room_id, &pdu_id)
152 .await?;
153
154 Ok(())
155}
156
157#[implement(Service)]
158#[tracing::instrument(name = "local", level = "debug", skip_all)]
159async fn local_invite(
160 &self,
161 sender_user: &UserId,
162 user_id: &UserId,
163 room_id: &RoomId,
164 reason: Option<&String>,
165 is_direct: bool,
166) -> Result {
167 if self.services.users.invites_blocked(user_id).await {
168 return Err!(Request(InviteBlocked("{user_id} has blocked invites.")));
169 }
170
171 if !self
172 .services
173 .state_cache
174 .is_joined(sender_user, room_id)
175 .await
176 {
177 return Err!(Request(Forbidden(
178 "You must be joined in the room you are trying to invite from."
179 )));
180 }
181
182 let state_lock = self.services.state.mutex.lock(room_id).await;
183
184 let content = RoomMemberEventContent {
185 displayname: self
186 .services
187 .users
188 .displayname(user_id)
189 .await
190 .ok(),
191 avatar_url: self.services.users.avatar_url(user_id).await.ok(),
192 blurhash: self.services.users.blurhash(user_id).await.ok(),
193 is_direct: Some(is_direct),
194 reason: reason.cloned(),
195 ..RoomMemberEventContent::new(MembershipState::Invite)
196 };
197
198 self.services
199 .timeline
200 .build_and_append_pdu(
201 PduBuilder::state(user_id.to_string(), &content),
202 sender_user,
203 room_id,
204 &state_lock,
205 )
206 .await?;
207
208 drop(state_lock);
209
210 Ok(())
211}