Skip to main content

tuwunel_service/rooms/state_cache/
update.rs

1use std::collections::HashSet;
2
3use futures::StreamExt;
4use ruma::{
5	OwnedServerName, RoomId, UserId,
6	events::{
7		AnyStrippedStateEvent, AnySyncStateEvent, GlobalAccountDataEventType,
8		RoomAccountDataEventType, StateEventType,
9		direct::DirectEvent,
10		room::{
11			create::RoomCreateEventContent,
12			member::{MembershipState, RoomMemberEventContent},
13		},
14	},
15	serde::Raw,
16};
17use tuwunel_core::{Result, implement, is_not_empty, matrix::PduCount, utils::ReadyExt, warn};
18use tuwunel_database::{Json, serialize_key};
19
20/// Update current membership data.
21#[implement(super::Service)]
22#[tracing::instrument(
23		level = "debug",
24		skip_all,
25		fields(
26			%room_id,
27			%user_id,
28			%sender,
29			%count,
30			?membership_event,
31		),
32	)]
33#[expect(clippy::too_many_arguments)]
34pub async fn update_membership(
35	&self,
36	room_id: &RoomId,
37	user_id: &UserId,
38	membership_event: RoomMemberEventContent,
39	sender: &UserId,
40	last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
41	invite_via: Option<Vec<OwnedServerName>>,
42	update_joined_count: bool,
43	count: PduCount,
44) -> Result {
45	let membership = membership_event.membership;
46
47	// Keep track what remote users exist by adding them as "deactivated" users
48	//
49	// TODO: use futures to update remote profiles without blocking the membership
50	// update
51	#[expect(clippy::collapsible_if)]
52	if !self.services.globals.user_is_local(user_id) {
53		if !self.services.users.exists(user_id).await {
54			self.services
55				.users
56				.create(user_id, None, None)
57				.await?;
58		}
59	}
60
61	match &membership {
62		| MembershipState::Join => {
63			// Check if the user never joined this room
64			if !self.once_joined(user_id, room_id).await {
65				// Add the user ID to the join list then
66				self.mark_as_once_joined(user_id, room_id);
67
68				// Check if the room has a predecessor
69				if let Ok(Some(predecessor)) = self
70					.services
71					.state_accessor
72					.room_state_get_content(room_id, &StateEventType::RoomCreate, "")
73					.await
74					.map(|content: RoomCreateEventContent| content.predecessor)
75				{
76					// Copy old tags to new room
77					if let Ok(tag_event) = self
78						.services
79						.account_data
80						.get_room(&predecessor.room_id, user_id, RoomAccountDataEventType::Tag)
81						.await
82					{
83						self.services
84							.account_data
85							.update(
86								Some(room_id),
87								user_id,
88								RoomAccountDataEventType::Tag,
89								&tag_event,
90							)
91							.await
92							.ok();
93					}
94
95					// Copy direct chat flag
96					if let Ok(mut direct_event) = self
97						.services
98						.account_data
99						.get_global::<DirectEvent>(user_id, GlobalAccountDataEventType::Direct)
100						.await
101					{
102						let mut room_ids_updated = false;
103						for room_ids in direct_event.content.0.values_mut() {
104							if room_ids.iter().any(|r| r == &predecessor.room_id) {
105								room_ids.push(room_id.to_owned());
106								room_ids_updated = true;
107							}
108						}
109
110						if room_ids_updated {
111							self.services
112								.account_data
113								.update(
114									None,
115									user_id,
116									GlobalAccountDataEventType::Direct
117										.to_string()
118										.into(),
119									&serde_json::to_value(&direct_event)
120										.expect("to json always works"),
121								)
122								.await?;
123						}
124					}
125				}
126			}
127
128			self.mark_as_joined(user_id, room_id, count);
129		},
130		| MembershipState::Invite => {
131			// We want to know if the sender is ignored by the receiver
132			if self
133				.services
134				.users
135				.user_is_ignored(sender, user_id)
136				.await
137			{
138				return Ok(());
139			}
140
141			self.mark_as_invited(user_id, room_id, count, last_state, invite_via)
142				.await;
143		},
144		| MembershipState::Leave | MembershipState::Ban => {
145			self.mark_as_left(user_id, room_id, count);
146
147			if self.services.globals.user_is_local(user_id)
148				&& (self.services.config.forget_forced_upon_leave
149					|| self.services.metadata.is_banned(room_id).await
150					|| self.services.metadata.is_disabled(room_id).await)
151			{
152				self.forget(room_id, user_id);
153			}
154		},
155		| _ => {},
156	}
157
158	if update_joined_count {
159		self.update_joined_count(room_id).await;
160	}
161
162	Ok(())
163}
164
165#[implement(super::Service)]
166#[tracing::instrument(level = "debug", skip(self))]
167pub async fn update_joined_count(&self, room_id: &RoomId) {
168	let mut joinedcount = 0_u64;
169	let mut invitedcount = 0_u64;
170	let mut knockedcount = 0_u64;
171	let mut joined_servers = HashSet::new();
172
173	self.room_members(room_id)
174		.ready_for_each(|joined| {
175			joined_servers.insert(joined.server_name().to_owned());
176			joinedcount = joinedcount.saturating_add(1);
177		})
178		.await;
179
180	invitedcount = invitedcount.saturating_add(
181		self.room_members_invited(room_id)
182			.count()
183			.await
184			.try_into()
185			.unwrap_or(0),
186	);
187
188	knockedcount = knockedcount.saturating_add(
189		self.room_members_knocked(room_id)
190			.count()
191			.await
192			.try_into()
193			.unwrap_or(0),
194	);
195
196	self.db
197		.roomid_joinedcount
198		.raw_put(room_id, joinedcount);
199	self.db
200		.roomid_invitedcount
201		.raw_put(room_id, invitedcount);
202	self.db
203		.roomid_knockedcount
204		.raw_put(room_id, knockedcount);
205
206	self.room_servers(room_id)
207		.ready_for_each(|old_joined_server| {
208			if joined_servers.remove(old_joined_server) {
209				return;
210			}
211
212			// Server not in room anymore
213			let roomserver_id = (room_id, old_joined_server);
214			let serverroom_id = (old_joined_server, room_id);
215
216			self.db.roomserverids.del(roomserver_id);
217			self.db.serverroomids.del(serverroom_id);
218		})
219		.await;
220
221	// Now only new servers are in joined_servers anymore
222	for server in &joined_servers {
223		let roomserver_id = (room_id, server);
224		let serverroom_id = (server, room_id);
225
226		self.db.roomserverids.put_raw(roomserver_id, []);
227		self.db.serverroomids.put_raw(serverroom_id, []);
228	}
229
230	self.appservice_in_room_cache
231		.write()
232		.expect("locked")
233		.remove(room_id);
234}
235
236/// Direct DB function to directly mark a user as joined. It is not
237/// recommended to use this directly. You most likely should use
238/// `update_membership` instead
239#[implement(super::Service)]
240#[tracing::instrument(skip(self), level = "debug")]
241pub(crate) fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId, count: PduCount) {
242	let userroom_id = (user_id, room_id);
243	let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id");
244
245	let roomuser_id = (room_id, user_id);
246	let roomuser_id = serialize_key(roomuser_id).expect("failed to serialize roomuser_id");
247
248	self.db
249		.userroomid_joinedcount
250		.raw_aput::<8, _, _>(&userroom_id, count.into_unsigned());
251	self.db
252		.roomuserid_joinedcount
253		.raw_aput::<8, _, _>(&roomuser_id, count.into_unsigned());
254
255	self.db
256		.userroomid_invitestate
257		.remove(&userroom_id);
258	self.db
259		.roomuserid_invitecount
260		.remove(&roomuser_id);
261
262	self.db.userroomid_leftstate.remove(&userroom_id);
263	self.db.roomuserid_leftcount.remove(&roomuser_id);
264
265	self.db
266		.userroomid_knockedstate
267		.remove(&userroom_id);
268	self.db
269		.roomuserid_knockedcount
270		.remove(&roomuser_id);
271}
272
273/// Direct DB function to directly mark a user as left. It is not
274/// recommended to use this directly. You most likely should use
275/// `update_membership` instead
276#[implement(super::Service)]
277#[tracing::instrument(skip(self), level = "debug")]
278pub(crate) fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId, count: PduCount) {
279	let userroom_id = (user_id, room_id);
280	let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id");
281
282	let roomuser_id = (room_id, user_id);
283	let roomuser_id = serialize_key(roomuser_id).expect("failed to serialize roomuser_id");
284
285	// (timo) TODO
286	let leftstate = Vec::<Raw<AnySyncStateEvent>>::new();
287
288	self.db
289		.userroomid_leftstate
290		.raw_put(&userroom_id, Json(leftstate));
291	self.db
292		.roomuserid_leftcount
293		.raw_aput::<8, _, _>(&roomuser_id, count.into_unsigned());
294
295	self.db
296		.userroomid_joinedcount
297		.remove(&userroom_id);
298	self.db
299		.roomuserid_joinedcount
300		.remove(&roomuser_id);
301
302	self.db
303		.userroomid_invitestate
304		.remove(&userroom_id);
305	self.db
306		.roomuserid_invitecount
307		.remove(&roomuser_id);
308
309	self.db
310		.userroomid_knockedstate
311		.remove(&userroom_id);
312	self.db
313		.roomuserid_knockedcount
314		.remove(&roomuser_id);
315}
316
317/// Direct DB function to directly mark a user as knocked. It is not
318/// recommended to use this directly. You most likely should use
319/// `update_membership` instead
320#[implement(super::Service)]
321#[tracing::instrument(skip(self), level = "debug")]
322pub(crate) fn _mark_as_knocked(
323	&self,
324	user_id: &UserId,
325	room_id: &RoomId,
326	count: PduCount,
327	knocked_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
328) {
329	let userroom_id = (user_id, room_id);
330	let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id");
331
332	let roomuser_id = (room_id, user_id);
333	let roomuser_id = serialize_key(roomuser_id).expect("failed to serialize roomuser_id");
334
335	self.db
336		.userroomid_knockedstate
337		.raw_put(&userroom_id, Json(knocked_state.unwrap_or_default()));
338	self.db
339		.roomuserid_knockedcount
340		.raw_aput::<8, _, _>(&roomuser_id, count.into_unsigned());
341
342	self.db
343		.userroomid_joinedcount
344		.remove(&userroom_id);
345	self.db
346		.roomuserid_joinedcount
347		.remove(&roomuser_id);
348
349	self.db
350		.userroomid_invitestate
351		.remove(&userroom_id);
352	self.db
353		.roomuserid_invitecount
354		.remove(&roomuser_id);
355
356	self.db.userroomid_leftstate.remove(&userroom_id);
357	self.db.roomuserid_leftcount.remove(&roomuser_id);
358}
359
360/// Makes a user forget a room.
361#[implement(super::Service)]
362#[tracing::instrument(skip(self), level = "debug")]
363pub fn forget(&self, room_id: &RoomId, user_id: &UserId) {
364	let userroom_id = (user_id, room_id);
365	let roomuser_id = (room_id, user_id);
366
367	self.db.userroomid_leftstate.del(userroom_id);
368	self.db.roomuserid_leftcount.del(roomuser_id);
369}
370
371#[implement(super::Service)]
372#[tracing::instrument(level = "debug", skip(self))]
373fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) {
374	let key = (user_id, room_id);
375	self.db.roomuseroncejoinedids.put_raw(key, []);
376}
377
378#[implement(super::Service)]
379#[tracing::instrument(level = "debug", skip(self, last_state, invite_via))]
380pub(crate) async fn mark_as_invited(
381	&self,
382	user_id: &UserId,
383	room_id: &RoomId,
384	count: PduCount,
385	last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
386	invite_via: Option<Vec<OwnedServerName>>,
387) {
388	let roomuser_id = (room_id, user_id);
389	let roomuser_id = serialize_key(roomuser_id).expect("failed to serialize roomuser_id");
390
391	let userroom_id = (user_id, room_id);
392	let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id");
393
394	self.db
395		.userroomid_invitestate
396		.raw_put(&userroom_id, Json(last_state.unwrap_or_default()));
397	self.db
398		.roomuserid_invitecount
399		.raw_aput::<8, _, _>(&roomuser_id, count.into_unsigned());
400
401	self.db
402		.userroomid_joinedcount
403		.remove(&userroom_id);
404	self.db
405		.roomuserid_joinedcount
406		.remove(&roomuser_id);
407
408	self.db.userroomid_leftstate.remove(&userroom_id);
409	self.db.roomuserid_leftcount.remove(&roomuser_id);
410
411	self.db
412		.userroomid_knockedstate
413		.remove(&userroom_id);
414	self.db
415		.roomuserid_knockedcount
416		.remove(&roomuser_id);
417
418	if let Some(servers) = invite_via.filter(is_not_empty!()) {
419		self.add_servers_invite_via(room_id, servers)
420			.await;
421	}
422}