tuwunel_service/rooms/state_cache/
update.rs1use std::collections::HashSet;
2
3use futures::StreamExt;
4use ruma::{
5 OwnedServerName, RoomId, UserId,
6 events::{
7 AnyStrippedStateEvent, AnySyncStateEvent, GlobalAccountDataEventType,
8 RoomAccountDataEventType, StateEventType,
9 direct::DirectEvent,
10 room::{
11 create::RoomCreateEventContent,
12 member::{MembershipState, RoomMemberEventContent},
13 },
14 },
15 serde::Raw,
16};
17use tuwunel_core::{Result, implement, is_not_empty, matrix::PduCount, utils::ReadyExt, warn};
18use tuwunel_database::{Json, serialize_key};
19
20#[implement(super::Service)]
22#[tracing::instrument(
23 level = "debug",
24 skip_all,
25 fields(
26 %room_id,
27 %user_id,
28 %sender,
29 %count,
30 ?membership_event,
31 ),
32 )]
33#[expect(clippy::too_many_arguments)]
34pub async fn update_membership(
35 &self,
36 room_id: &RoomId,
37 user_id: &UserId,
38 membership_event: RoomMemberEventContent,
39 sender: &UserId,
40 last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
41 invite_via: Option<Vec<OwnedServerName>>,
42 update_joined_count: bool,
43 count: PduCount,
44) -> Result {
45 let membership = membership_event.membership;
46
47 #[expect(clippy::collapsible_if)]
52 if !self.services.globals.user_is_local(user_id) {
53 if !self.services.users.exists(user_id).await {
54 self.services
55 .users
56 .create(user_id, None, None)
57 .await?;
58 }
59 }
60
61 match &membership {
62 | MembershipState::Join => {
63 if !self.once_joined(user_id, room_id).await {
65 self.mark_as_once_joined(user_id, room_id);
67
68 if let Ok(Some(predecessor)) = self
70 .services
71 .state_accessor
72 .room_state_get_content(room_id, &StateEventType::RoomCreate, "")
73 .await
74 .map(|content: RoomCreateEventContent| content.predecessor)
75 {
76 if let Ok(tag_event) = self
78 .services
79 .account_data
80 .get_room(&predecessor.room_id, user_id, RoomAccountDataEventType::Tag)
81 .await
82 {
83 self.services
84 .account_data
85 .update(
86 Some(room_id),
87 user_id,
88 RoomAccountDataEventType::Tag,
89 &tag_event,
90 )
91 .await
92 .ok();
93 }
94
95 if let Ok(mut direct_event) = self
97 .services
98 .account_data
99 .get_global::<DirectEvent>(user_id, GlobalAccountDataEventType::Direct)
100 .await
101 {
102 let mut room_ids_updated = false;
103 for room_ids in direct_event.content.0.values_mut() {
104 if room_ids.iter().any(|r| r == &predecessor.room_id) {
105 room_ids.push(room_id.to_owned());
106 room_ids_updated = true;
107 }
108 }
109
110 if room_ids_updated {
111 self.services
112 .account_data
113 .update(
114 None,
115 user_id,
116 GlobalAccountDataEventType::Direct
117 .to_string()
118 .into(),
119 &serde_json::to_value(&direct_event)
120 .expect("to json always works"),
121 )
122 .await?;
123 }
124 }
125 }
126 }
127
128 self.mark_as_joined(user_id, room_id, count);
129 },
130 | MembershipState::Invite => {
131 if self
133 .services
134 .users
135 .user_is_ignored(sender, user_id)
136 .await
137 {
138 return Ok(());
139 }
140
141 self.mark_as_invited(user_id, room_id, count, last_state, invite_via)
142 .await;
143 },
144 | MembershipState::Leave | MembershipState::Ban => {
145 self.mark_as_left(user_id, room_id, count);
146
147 if self.services.globals.user_is_local(user_id)
148 && (self.services.config.forget_forced_upon_leave
149 || self.services.metadata.is_banned(room_id).await
150 || self.services.metadata.is_disabled(room_id).await)
151 {
152 self.forget(room_id, user_id);
153 }
154 },
155 | _ => {},
156 }
157
158 if update_joined_count {
159 self.update_joined_count(room_id).await;
160 }
161
162 Ok(())
163}
164
165#[implement(super::Service)]
166#[tracing::instrument(level = "debug", skip(self))]
167pub async fn update_joined_count(&self, room_id: &RoomId) {
168 let mut joinedcount = 0_u64;
169 let mut invitedcount = 0_u64;
170 let mut knockedcount = 0_u64;
171 let mut joined_servers = HashSet::new();
172
173 self.room_members(room_id)
174 .ready_for_each(|joined| {
175 joined_servers.insert(joined.server_name().to_owned());
176 joinedcount = joinedcount.saturating_add(1);
177 })
178 .await;
179
180 invitedcount = invitedcount.saturating_add(
181 self.room_members_invited(room_id)
182 .count()
183 .await
184 .try_into()
185 .unwrap_or(0),
186 );
187
188 knockedcount = knockedcount.saturating_add(
189 self.room_members_knocked(room_id)
190 .count()
191 .await
192 .try_into()
193 .unwrap_or(0),
194 );
195
196 self.db
197 .roomid_joinedcount
198 .raw_put(room_id, joinedcount);
199 self.db
200 .roomid_invitedcount
201 .raw_put(room_id, invitedcount);
202 self.db
203 .roomid_knockedcount
204 .raw_put(room_id, knockedcount);
205
206 self.room_servers(room_id)
207 .ready_for_each(|old_joined_server| {
208 if joined_servers.remove(old_joined_server) {
209 return;
210 }
211
212 let roomserver_id = (room_id, old_joined_server);
214 let serverroom_id = (old_joined_server, room_id);
215
216 self.db.roomserverids.del(roomserver_id);
217 self.db.serverroomids.del(serverroom_id);
218 })
219 .await;
220
221 for server in &joined_servers {
223 let roomserver_id = (room_id, server);
224 let serverroom_id = (server, room_id);
225
226 self.db.roomserverids.put_raw(roomserver_id, []);
227 self.db.serverroomids.put_raw(serverroom_id, []);
228 }
229
230 self.appservice_in_room_cache
231 .write()
232 .expect("locked")
233 .remove(room_id);
234}
235
236#[implement(super::Service)]
240#[tracing::instrument(skip(self), level = "debug")]
241pub(crate) fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId, count: PduCount) {
242 let userroom_id = (user_id, room_id);
243 let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id");
244
245 let roomuser_id = (room_id, user_id);
246 let roomuser_id = serialize_key(roomuser_id).expect("failed to serialize roomuser_id");
247
248 self.db
249 .userroomid_joinedcount
250 .raw_aput::<8, _, _>(&userroom_id, count.into_unsigned());
251 self.db
252 .roomuserid_joinedcount
253 .raw_aput::<8, _, _>(&roomuser_id, count.into_unsigned());
254
255 self.db
256 .userroomid_invitestate
257 .remove(&userroom_id);
258 self.db
259 .roomuserid_invitecount
260 .remove(&roomuser_id);
261
262 self.db.userroomid_leftstate.remove(&userroom_id);
263 self.db.roomuserid_leftcount.remove(&roomuser_id);
264
265 self.db
266 .userroomid_knockedstate
267 .remove(&userroom_id);
268 self.db
269 .roomuserid_knockedcount
270 .remove(&roomuser_id);
271}
272
273#[implement(super::Service)]
277#[tracing::instrument(skip(self), level = "debug")]
278pub(crate) fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId, count: PduCount) {
279 let userroom_id = (user_id, room_id);
280 let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id");
281
282 let roomuser_id = (room_id, user_id);
283 let roomuser_id = serialize_key(roomuser_id).expect("failed to serialize roomuser_id");
284
285 let leftstate = Vec::<Raw<AnySyncStateEvent>>::new();
287
288 self.db
289 .userroomid_leftstate
290 .raw_put(&userroom_id, Json(leftstate));
291 self.db
292 .roomuserid_leftcount
293 .raw_aput::<8, _, _>(&roomuser_id, count.into_unsigned());
294
295 self.db
296 .userroomid_joinedcount
297 .remove(&userroom_id);
298 self.db
299 .roomuserid_joinedcount
300 .remove(&roomuser_id);
301
302 self.db
303 .userroomid_invitestate
304 .remove(&userroom_id);
305 self.db
306 .roomuserid_invitecount
307 .remove(&roomuser_id);
308
309 self.db
310 .userroomid_knockedstate
311 .remove(&userroom_id);
312 self.db
313 .roomuserid_knockedcount
314 .remove(&roomuser_id);
315}
316
317#[implement(super::Service)]
321#[tracing::instrument(skip(self), level = "debug")]
322pub(crate) fn _mark_as_knocked(
323 &self,
324 user_id: &UserId,
325 room_id: &RoomId,
326 count: PduCount,
327 knocked_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
328) {
329 let userroom_id = (user_id, room_id);
330 let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id");
331
332 let roomuser_id = (room_id, user_id);
333 let roomuser_id = serialize_key(roomuser_id).expect("failed to serialize roomuser_id");
334
335 self.db
336 .userroomid_knockedstate
337 .raw_put(&userroom_id, Json(knocked_state.unwrap_or_default()));
338 self.db
339 .roomuserid_knockedcount
340 .raw_aput::<8, _, _>(&roomuser_id, count.into_unsigned());
341
342 self.db
343 .userroomid_joinedcount
344 .remove(&userroom_id);
345 self.db
346 .roomuserid_joinedcount
347 .remove(&roomuser_id);
348
349 self.db
350 .userroomid_invitestate
351 .remove(&userroom_id);
352 self.db
353 .roomuserid_invitecount
354 .remove(&roomuser_id);
355
356 self.db.userroomid_leftstate.remove(&userroom_id);
357 self.db.roomuserid_leftcount.remove(&roomuser_id);
358}
359
360#[implement(super::Service)]
362#[tracing::instrument(skip(self), level = "debug")]
363pub fn forget(&self, room_id: &RoomId, user_id: &UserId) {
364 let userroom_id = (user_id, room_id);
365 let roomuser_id = (room_id, user_id);
366
367 self.db.userroomid_leftstate.del(userroom_id);
368 self.db.roomuserid_leftcount.del(roomuser_id);
369}
370
371#[implement(super::Service)]
372#[tracing::instrument(level = "debug", skip(self))]
373fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) {
374 let key = (user_id, room_id);
375 self.db.roomuseroncejoinedids.put_raw(key, []);
376}
377
378#[implement(super::Service)]
379#[tracing::instrument(level = "debug", skip(self, last_state, invite_via))]
380pub(crate) async fn mark_as_invited(
381 &self,
382 user_id: &UserId,
383 room_id: &RoomId,
384 count: PduCount,
385 last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
386 invite_via: Option<Vec<OwnedServerName>>,
387) {
388 let roomuser_id = (room_id, user_id);
389 let roomuser_id = serialize_key(roomuser_id).expect("failed to serialize roomuser_id");
390
391 let userroom_id = (user_id, room_id);
392 let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id");
393
394 self.db
395 .userroomid_invitestate
396 .raw_put(&userroom_id, Json(last_state.unwrap_or_default()));
397 self.db
398 .roomuserid_invitecount
399 .raw_aput::<8, _, _>(&roomuser_id, count.into_unsigned());
400
401 self.db
402 .userroomid_joinedcount
403 .remove(&userroom_id);
404 self.db
405 .roomuserid_joinedcount
406 .remove(&roomuser_id);
407
408 self.db.userroomid_leftstate.remove(&userroom_id);
409 self.db.roomuserid_leftcount.remove(&roomuser_id);
410
411 self.db
412 .userroomid_knockedstate
413 .remove(&userroom_id);
414 self.db
415 .roomuserid_knockedcount
416 .remove(&roomuser_id);
417
418 if let Some(servers) = invite_via.filter(is_not_empty!()) {
419 self.add_servers_invite_via(room_id, servers)
420 .await;
421 }
422}