Skip to main content

tuwunel_service/rooms/state_cache/
update.rs

1use std::collections::HashSet;
2
3use futures::StreamExt;
4use ruma::{
5	OwnedServerName, RoomId, UserId,
6	events::{
7		AnyStrippedStateEvent, AnySyncStateEvent, GlobalAccountDataEventType,
8		RoomAccountDataEventType, StateEventType,
9		direct::DirectEvent,
10		room::{
11			create::RoomCreateEventContent,
12			member::{MembershipState, RoomMemberEventContent},
13		},
14	},
15	serde::Raw,
16};
17use tuwunel_core::{Result, implement, is_not_empty, matrix::PduCount, utils::ReadyExt, warn};
18use tuwunel_database::{Json, serialize_key};
19
20/// Update current membership data.
21#[implement(super::Service)]
22#[tracing::instrument(
23		level = "debug",
24		skip_all,
25		fields(
26			%room_id,
27			%user_id,
28			%sender,
29			%count,
30			?membership_event,
31		),
32	)]
33#[expect(clippy::too_many_arguments)]
34pub async fn update_membership(
35	&self,
36	room_id: &RoomId,
37	user_id: &UserId,
38	membership_event: RoomMemberEventContent,
39	sender: &UserId,
40	last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
41	invite_via: Option<Vec<OwnedServerName>>,
42	update_joined_count: bool,
43	count: PduCount,
44) -> Result {
45	let membership = membership_event.membership;
46
47	// Keep track what remote users exist by adding them as "deactivated" users
48	//
49	// TODO: use futures to update remote profiles without blocking the membership
50	// update
51	#[expect(clippy::collapsible_if)]
52	if !self.services.globals.user_is_local(user_id) {
53		if !self.services.users.exists(user_id).await {
54			self.services
55				.users
56				.create(user_id, None, None)
57				.await?;
58		}
59	}
60
61	match &membership {
62		| MembershipState::Join => {
63			// Check if the user never joined this room
64			if !self.once_joined(user_id, room_id).await {
65				// Add the user ID to the join list then
66				self.mark_as_once_joined(user_id, room_id);
67
68				// Check if the room has a predecessor
69				if let Ok(Some(predecessor)) = self
70					.services
71					.state_accessor
72					.room_state_get_content(room_id, &StateEventType::RoomCreate, "")
73					.await
74					.map(|content: RoomCreateEventContent| content.predecessor)
75				{
76					// Copy old tags to new room
77					if let Ok(tag_event) = self
78						.services
79						.account_data
80						.get_room(&predecessor.room_id, user_id, RoomAccountDataEventType::Tag)
81						.await
82					{
83						self.services
84							.account_data
85							.update(
86								Some(room_id),
87								user_id,
88								RoomAccountDataEventType::Tag,
89								&tag_event,
90							)
91							.await
92							.ok();
93					}
94
95					// Copy direct chat flag
96					if let Ok(mut direct_event) = self
97						.services
98						.account_data
99						.get_global::<DirectEvent>(user_id, GlobalAccountDataEventType::Direct)
100						.await
101					{
102						let mut room_ids_updated = false;
103						for room_ids in direct_event.content.0.values_mut() {
104							if room_ids.iter().any(|r| r == &predecessor.room_id) {
105								room_ids.push(room_id.to_owned());
106								room_ids_updated = true;
107							}
108						}
109
110						if room_ids_updated {
111							self.services
112								.account_data
113								.update(
114									None,
115									user_id,
116									GlobalAccountDataEventType::Direct
117										.to_string()
118										.into(),
119									&serde_json::to_value(&direct_event)
120										.expect("to json always works"),
121								)
122								.await?;
123						}
124					}
125				}
126			}
127
128			self.mark_as_joined(user_id, room_id, count);
129		},
130		| MembershipState::Invite => {
131			// We want to know if the sender is ignored by the receiver
132			if self
133				.services
134				.users
135				.user_is_ignored(sender, user_id)
136				.await
137			{
138				return Ok(());
139			}
140
141			self.mark_as_invited(user_id, room_id, count, last_state, invite_via)
142				.await;
143		},
144		| MembershipState::Leave | MembershipState::Ban => {
145			self.mark_as_left(user_id, room_id, count);
146
147			if self.services.globals.user_is_local(user_id)
148				&& (self.services.config.forget_forced_upon_leave
149					|| self.services.metadata.is_banned(room_id).await
150					|| self.services.metadata.is_disabled(room_id).await)
151			{
152				self.forget(room_id, user_id);
153			}
154		},
155		| MembershipState::Knock => {
156			self.mark_as_knocked(user_id, room_id, count, last_state);
157		},
158		| _ => {},
159	}
160
161	if update_joined_count {
162		self.update_joined_count(room_id).await;
163	}
164
165	Ok(())
166}
167
168#[implement(super::Service)]
169#[tracing::instrument(level = "debug", skip(self))]
170pub async fn update_joined_count(&self, room_id: &RoomId) {
171	let mut joinedcount = 0_u64;
172	let mut invitedcount = 0_u64;
173	let mut knockedcount = 0_u64;
174	let mut joined_servers = HashSet::new();
175
176	self.room_members(room_id)
177		.ready_for_each(|joined| {
178			joined_servers.insert(joined.server_name().to_owned());
179			joinedcount = joinedcount.saturating_add(1);
180		})
181		.await;
182
183	invitedcount = invitedcount.saturating_add(
184		self.room_members_invited(room_id)
185			.count()
186			.await
187			.try_into()
188			.unwrap_or(0),
189	);
190
191	knockedcount = knockedcount.saturating_add(
192		self.room_members_knocked(room_id)
193			.count()
194			.await
195			.try_into()
196			.unwrap_or(0),
197	);
198
199	self.db
200		.roomid_joinedcount
201		.raw_put(room_id, joinedcount);
202	self.db
203		.roomid_invitedcount
204		.raw_put(room_id, invitedcount);
205	self.db
206		.roomid_knockedcount
207		.raw_put(room_id, knockedcount);
208
209	self.room_servers(room_id)
210		.ready_for_each(|old_joined_server| {
211			if joined_servers.remove(old_joined_server) {
212				return;
213			}
214
215			// Server not in room anymore
216			let roomserver_id = (room_id, old_joined_server);
217			let serverroom_id = (old_joined_server, room_id);
218
219			self.db.roomserverids.del(roomserver_id);
220			self.db.serverroomids.del(serverroom_id);
221		})
222		.await;
223
224	// Now only new servers are in joined_servers anymore
225	for server in &joined_servers {
226		let roomserver_id = (room_id, server);
227		let serverroom_id = (server, room_id);
228
229		self.db.roomserverids.put_raw(roomserver_id, []);
230		self.db.serverroomids.put_raw(serverroom_id, []);
231	}
232
233	self.appservice_in_room_cache
234		.write()
235		.expect("locked")
236		.remove(room_id);
237}
238
239/// Direct DB function to directly mark a user as joined. It is not
240/// recommended to use this directly. You most likely should use
241/// `update_membership` instead
242#[implement(super::Service)]
243#[tracing::instrument(skip(self), level = "debug")]
244pub(crate) fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId, count: PduCount) {
245	let userroom_id = (user_id, room_id);
246	let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id");
247
248	let roomuser_id = (room_id, user_id);
249	let roomuser_id = serialize_key(roomuser_id).expect("failed to serialize roomuser_id");
250
251	self.db
252		.userroomid_joinedcount
253		.raw_aput::<8, _, _>(&userroom_id, count.into_unsigned());
254	self.db
255		.roomuserid_joinedcount
256		.raw_aput::<8, _, _>(&roomuser_id, count.into_unsigned());
257
258	self.db
259		.userroomid_invitestate
260		.remove(&userroom_id);
261	self.db
262		.roomuserid_invitecount
263		.remove(&roomuser_id);
264
265	self.db.userroomid_leftstate.remove(&userroom_id);
266	self.db.roomuserid_leftcount.remove(&roomuser_id);
267
268	self.db
269		.userroomid_knockedstate
270		.remove(&userroom_id);
271	self.db
272		.roomuserid_knockedcount
273		.remove(&roomuser_id);
274}
275
276/// Direct DB function to directly mark a user as left. It is not
277/// recommended to use this directly. You most likely should use
278/// `update_membership` instead
279#[implement(super::Service)]
280#[tracing::instrument(skip(self), level = "debug")]
281pub(crate) fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId, count: PduCount) {
282	let userroom_id = (user_id, room_id);
283	let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id");
284
285	let roomuser_id = (room_id, user_id);
286	let roomuser_id = serialize_key(roomuser_id).expect("failed to serialize roomuser_id");
287
288	// (timo) TODO
289	let leftstate = Vec::<Raw<AnySyncStateEvent>>::new();
290
291	self.db
292		.userroomid_leftstate
293		.raw_put(&userroom_id, Json(leftstate));
294	self.db
295		.roomuserid_leftcount
296		.raw_aput::<8, _, _>(&roomuser_id, count.into_unsigned());
297
298	self.db
299		.userroomid_joinedcount
300		.remove(&userroom_id);
301	self.db
302		.roomuserid_joinedcount
303		.remove(&roomuser_id);
304
305	self.db
306		.userroomid_invitestate
307		.remove(&userroom_id);
308	self.db
309		.roomuserid_invitecount
310		.remove(&roomuser_id);
311
312	self.db
313		.userroomid_knockedstate
314		.remove(&userroom_id);
315	self.db
316		.roomuserid_knockedcount
317		.remove(&roomuser_id);
318}
319
320/// Direct DB function to directly mark a user as knocked. It is not
321/// recommended to use this directly. You most likely should use
322/// `update_membership` instead
323#[implement(super::Service)]
324#[tracing::instrument(skip(self), level = "debug")]
325pub(crate) fn mark_as_knocked(
326	&self,
327	user_id: &UserId,
328	room_id: &RoomId,
329	count: PduCount,
330	knocked_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
331) {
332	let userroom_id = (user_id, room_id);
333	let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id");
334
335	let roomuser_id = (room_id, user_id);
336	let roomuser_id = serialize_key(roomuser_id).expect("failed to serialize roomuser_id");
337
338	self.db
339		.userroomid_knockedstate
340		.raw_put(&userroom_id, Json(knocked_state.unwrap_or_default()));
341	self.db
342		.roomuserid_knockedcount
343		.raw_aput::<8, _, _>(&roomuser_id, count.into_unsigned());
344
345	self.db
346		.userroomid_joinedcount
347		.remove(&userroom_id);
348	self.db
349		.roomuserid_joinedcount
350		.remove(&roomuser_id);
351
352	self.db
353		.userroomid_invitestate
354		.remove(&userroom_id);
355	self.db
356		.roomuserid_invitecount
357		.remove(&roomuser_id);
358
359	self.db.userroomid_leftstate.remove(&userroom_id);
360	self.db.roomuserid_leftcount.remove(&roomuser_id);
361}
362
363/// Makes a user forget a room.
364#[implement(super::Service)]
365#[tracing::instrument(skip(self), level = "debug")]
366pub fn forget(&self, room_id: &RoomId, user_id: &UserId) {
367	let userroom_id = (user_id, room_id);
368	let roomuser_id = (room_id, user_id);
369
370	self.db.userroomid_leftstate.del(userroom_id);
371	self.db.roomuserid_leftcount.del(roomuser_id);
372}
373
374#[implement(super::Service)]
375#[tracing::instrument(level = "debug", skip(self))]
376fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) {
377	let key = (user_id, room_id);
378	self.db.roomuseroncejoinedids.put_raw(key, []);
379}
380
381#[implement(super::Service)]
382#[tracing::instrument(level = "debug", skip(self, last_state, invite_via))]
383pub(crate) async fn mark_as_invited(
384	&self,
385	user_id: &UserId,
386	room_id: &RoomId,
387	count: PduCount,
388	last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
389	invite_via: Option<Vec<OwnedServerName>>,
390) {
391	let roomuser_id = (room_id, user_id);
392	let roomuser_id = serialize_key(roomuser_id).expect("failed to serialize roomuser_id");
393
394	let userroom_id = (user_id, room_id);
395	let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id");
396
397	self.db
398		.userroomid_invitestate
399		.raw_put(&userroom_id, Json(last_state.unwrap_or_default()));
400	self.db
401		.roomuserid_invitecount
402		.raw_aput::<8, _, _>(&roomuser_id, count.into_unsigned());
403
404	self.db
405		.userroomid_joinedcount
406		.remove(&userroom_id);
407	self.db
408		.roomuserid_joinedcount
409		.remove(&roomuser_id);
410
411	self.db.userroomid_leftstate.remove(&userroom_id);
412	self.db.roomuserid_leftcount.remove(&roomuser_id);
413
414	self.db
415		.userroomid_knockedstate
416		.remove(&userroom_id);
417	self.db
418		.roomuserid_knockedcount
419		.remove(&roomuser_id);
420
421	if let Some(servers) = invite_via.filter(is_not_empty!()) {
422		self.add_servers_invite_via(room_id, servers)
423			.await;
424	}
425}