tuwunel_service/rooms/state_cache/
update.rs1use std::collections::HashSet;
2
3use futures::StreamExt;
4use ruma::{
5 OwnedServerName, RoomId, UserId,
6 events::{
7 AnyStrippedStateEvent, AnySyncStateEvent, GlobalAccountDataEventType,
8 RoomAccountDataEventType, StateEventType,
9 direct::DirectEvent,
10 room::{
11 create::RoomCreateEventContent,
12 member::{MembershipState, RoomMemberEventContent},
13 },
14 },
15 serde::Raw,
16};
17use tuwunel_core::{Result, implement, is_not_empty, matrix::PduCount, utils::ReadyExt, warn};
18use tuwunel_database::{Json, serialize_key};
19
20#[implement(super::Service)]
22#[tracing::instrument(
23 level = "debug",
24 skip_all,
25 fields(
26 %room_id,
27 %user_id,
28 %sender,
29 %count,
30 ?membership_event,
31 ),
32 )]
33#[expect(clippy::too_many_arguments)]
34pub async fn update_membership(
35 &self,
36 room_id: &RoomId,
37 user_id: &UserId,
38 membership_event: RoomMemberEventContent,
39 sender: &UserId,
40 last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
41 invite_via: Option<Vec<OwnedServerName>>,
42 update_joined_count: bool,
43 count: PduCount,
44) -> Result {
45 let membership = membership_event.membership;
46
47 #[expect(clippy::collapsible_if)]
52 if !self.services.globals.user_is_local(user_id) {
53 if !self.services.users.exists(user_id).await {
54 self.services
55 .users
56 .create(user_id, None, None)
57 .await?;
58 }
59 }
60
61 match &membership {
62 | MembershipState::Join => {
63 if !self.once_joined(user_id, room_id).await {
65 self.mark_as_once_joined(user_id, room_id);
67
68 if let Ok(Some(predecessor)) = self
70 .services
71 .state_accessor
72 .room_state_get_content(room_id, &StateEventType::RoomCreate, "")
73 .await
74 .map(|content: RoomCreateEventContent| content.predecessor)
75 {
76 if let Ok(tag_event) = self
78 .services
79 .account_data
80 .get_room(&predecessor.room_id, user_id, RoomAccountDataEventType::Tag)
81 .await
82 {
83 self.services
84 .account_data
85 .update(
86 Some(room_id),
87 user_id,
88 RoomAccountDataEventType::Tag,
89 &tag_event,
90 )
91 .await
92 .ok();
93 }
94
95 if let Ok(mut direct_event) = self
97 .services
98 .account_data
99 .get_global::<DirectEvent>(user_id, GlobalAccountDataEventType::Direct)
100 .await
101 {
102 let mut room_ids_updated = false;
103 for room_ids in direct_event.content.0.values_mut() {
104 if room_ids.iter().any(|r| r == &predecessor.room_id) {
105 room_ids.push(room_id.to_owned());
106 room_ids_updated = true;
107 }
108 }
109
110 if room_ids_updated {
111 self.services
112 .account_data
113 .update(
114 None,
115 user_id,
116 GlobalAccountDataEventType::Direct
117 .to_string()
118 .into(),
119 &serde_json::to_value(&direct_event)
120 .expect("to json always works"),
121 )
122 .await?;
123 }
124 }
125 }
126 }
127
128 self.mark_as_joined(user_id, room_id, count);
129 },
130 | MembershipState::Invite => {
131 if self
133 .services
134 .users
135 .user_is_ignored(sender, user_id)
136 .await
137 {
138 return Ok(());
139 }
140
141 self.mark_as_invited(user_id, room_id, count, last_state, invite_via)
142 .await;
143 },
144 | MembershipState::Leave | MembershipState::Ban => {
145 self.mark_as_left(user_id, room_id, count);
146
147 if self.services.globals.user_is_local(user_id)
148 && (self.services.config.forget_forced_upon_leave
149 || self.services.metadata.is_banned(room_id).await
150 || self.services.metadata.is_disabled(room_id).await)
151 {
152 self.forget(room_id, user_id);
153 }
154 },
155 | MembershipState::Knock => {
156 self.mark_as_knocked(user_id, room_id, count, last_state);
157 },
158 | _ => {},
159 }
160
161 if update_joined_count {
162 self.update_joined_count(room_id).await;
163 }
164
165 Ok(())
166}
167
168#[implement(super::Service)]
169#[tracing::instrument(level = "debug", skip(self))]
170pub async fn update_joined_count(&self, room_id: &RoomId) {
171 let mut joinedcount = 0_u64;
172 let mut invitedcount = 0_u64;
173 let mut knockedcount = 0_u64;
174 let mut joined_servers = HashSet::new();
175
176 self.room_members(room_id)
177 .ready_for_each(|joined| {
178 joined_servers.insert(joined.server_name().to_owned());
179 joinedcount = joinedcount.saturating_add(1);
180 })
181 .await;
182
183 invitedcount = invitedcount.saturating_add(
184 self.room_members_invited(room_id)
185 .count()
186 .await
187 .try_into()
188 .unwrap_or(0),
189 );
190
191 knockedcount = knockedcount.saturating_add(
192 self.room_members_knocked(room_id)
193 .count()
194 .await
195 .try_into()
196 .unwrap_or(0),
197 );
198
199 self.db
200 .roomid_joinedcount
201 .raw_put(room_id, joinedcount);
202 self.db
203 .roomid_invitedcount
204 .raw_put(room_id, invitedcount);
205 self.db
206 .roomid_knockedcount
207 .raw_put(room_id, knockedcount);
208
209 self.room_servers(room_id)
210 .ready_for_each(|old_joined_server| {
211 if joined_servers.remove(old_joined_server) {
212 return;
213 }
214
215 let roomserver_id = (room_id, old_joined_server);
217 let serverroom_id = (old_joined_server, room_id);
218
219 self.db.roomserverids.del(roomserver_id);
220 self.db.serverroomids.del(serverroom_id);
221 })
222 .await;
223
224 for server in &joined_servers {
226 let roomserver_id = (room_id, server);
227 let serverroom_id = (server, room_id);
228
229 self.db.roomserverids.put_raw(roomserver_id, []);
230 self.db.serverroomids.put_raw(serverroom_id, []);
231 }
232
233 self.appservice_in_room_cache
234 .write()
235 .expect("locked")
236 .remove(room_id);
237}
238
239#[implement(super::Service)]
243#[tracing::instrument(skip(self), level = "debug")]
244pub(crate) fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId, count: PduCount) {
245 let userroom_id = (user_id, room_id);
246 let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id");
247
248 let roomuser_id = (room_id, user_id);
249 let roomuser_id = serialize_key(roomuser_id).expect("failed to serialize roomuser_id");
250
251 self.db
252 .userroomid_joinedcount
253 .raw_aput::<8, _, _>(&userroom_id, count.into_unsigned());
254 self.db
255 .roomuserid_joinedcount
256 .raw_aput::<8, _, _>(&roomuser_id, count.into_unsigned());
257
258 self.db
259 .userroomid_invitestate
260 .remove(&userroom_id);
261 self.db
262 .roomuserid_invitecount
263 .remove(&roomuser_id);
264
265 self.db.userroomid_leftstate.remove(&userroom_id);
266 self.db.roomuserid_leftcount.remove(&roomuser_id);
267
268 self.db
269 .userroomid_knockedstate
270 .remove(&userroom_id);
271 self.db
272 .roomuserid_knockedcount
273 .remove(&roomuser_id);
274}
275
276#[implement(super::Service)]
280#[tracing::instrument(skip(self), level = "debug")]
281pub(crate) fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId, count: PduCount) {
282 let userroom_id = (user_id, room_id);
283 let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id");
284
285 let roomuser_id = (room_id, user_id);
286 let roomuser_id = serialize_key(roomuser_id).expect("failed to serialize roomuser_id");
287
288 let leftstate = Vec::<Raw<AnySyncStateEvent>>::new();
290
291 self.db
292 .userroomid_leftstate
293 .raw_put(&userroom_id, Json(leftstate));
294 self.db
295 .roomuserid_leftcount
296 .raw_aput::<8, _, _>(&roomuser_id, count.into_unsigned());
297
298 self.db
299 .userroomid_joinedcount
300 .remove(&userroom_id);
301 self.db
302 .roomuserid_joinedcount
303 .remove(&roomuser_id);
304
305 self.db
306 .userroomid_invitestate
307 .remove(&userroom_id);
308 self.db
309 .roomuserid_invitecount
310 .remove(&roomuser_id);
311
312 self.db
313 .userroomid_knockedstate
314 .remove(&userroom_id);
315 self.db
316 .roomuserid_knockedcount
317 .remove(&roomuser_id);
318}
319
320#[implement(super::Service)]
324#[tracing::instrument(skip(self), level = "debug")]
325pub(crate) fn mark_as_knocked(
326 &self,
327 user_id: &UserId,
328 room_id: &RoomId,
329 count: PduCount,
330 knocked_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
331) {
332 let userroom_id = (user_id, room_id);
333 let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id");
334
335 let roomuser_id = (room_id, user_id);
336 let roomuser_id = serialize_key(roomuser_id).expect("failed to serialize roomuser_id");
337
338 self.db
339 .userroomid_knockedstate
340 .raw_put(&userroom_id, Json(knocked_state.unwrap_or_default()));
341 self.db
342 .roomuserid_knockedcount
343 .raw_aput::<8, _, _>(&roomuser_id, count.into_unsigned());
344
345 self.db
346 .userroomid_joinedcount
347 .remove(&userroom_id);
348 self.db
349 .roomuserid_joinedcount
350 .remove(&roomuser_id);
351
352 self.db
353 .userroomid_invitestate
354 .remove(&userroom_id);
355 self.db
356 .roomuserid_invitecount
357 .remove(&roomuser_id);
358
359 self.db.userroomid_leftstate.remove(&userroom_id);
360 self.db.roomuserid_leftcount.remove(&roomuser_id);
361}
362
363#[implement(super::Service)]
365#[tracing::instrument(skip(self), level = "debug")]
366pub fn forget(&self, room_id: &RoomId, user_id: &UserId) {
367 let userroom_id = (user_id, room_id);
368 let roomuser_id = (room_id, user_id);
369
370 self.db.userroomid_leftstate.del(userroom_id);
371 self.db.roomuserid_leftcount.del(roomuser_id);
372}
373
374#[implement(super::Service)]
375#[tracing::instrument(level = "debug", skip(self))]
376fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) {
377 let key = (user_id, room_id);
378 self.db.roomuseroncejoinedids.put_raw(key, []);
379}
380
381#[implement(super::Service)]
382#[tracing::instrument(level = "debug", skip(self, last_state, invite_via))]
383pub(crate) async fn mark_as_invited(
384 &self,
385 user_id: &UserId,
386 room_id: &RoomId,
387 count: PduCount,
388 last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
389 invite_via: Option<Vec<OwnedServerName>>,
390) {
391 let roomuser_id = (room_id, user_id);
392 let roomuser_id = serialize_key(roomuser_id).expect("failed to serialize roomuser_id");
393
394 let userroom_id = (user_id, room_id);
395 let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id");
396
397 self.db
398 .userroomid_invitestate
399 .raw_put(&userroom_id, Json(last_state.unwrap_or_default()));
400 self.db
401 .roomuserid_invitecount
402 .raw_aput::<8, _, _>(&roomuser_id, count.into_unsigned());
403
404 self.db
405 .userroomid_joinedcount
406 .remove(&userroom_id);
407 self.db
408 .roomuserid_joinedcount
409 .remove(&roomuser_id);
410
411 self.db.userroomid_leftstate.remove(&userroom_id);
412 self.db.roomuserid_leftcount.remove(&roomuser_id);
413
414 self.db
415 .userroomid_knockedstate
416 .remove(&userroom_id);
417 self.db
418 .roomuserid_knockedcount
419 .remove(&roomuser_id);
420
421 if let Some(servers) = invite_via.filter(is_not_empty!()) {
422 self.add_servers_invite_via(room_id, servers)
423 .await;
424 }
425}