Skip to main content

tuwunel_service/membership/
knock.rs

1use std::{borrow::Borrow, collections::HashMap, iter::once, sync::Arc};
2
3use futures::{FutureExt, StreamExt};
4use ruma::{
5	CanonicalJsonObject, CanonicalJsonValue, OwnedEventId, OwnedServerName, RoomId,
6	RoomOrAliasId, RoomVersionId, UserId,
7	api::federation::{self, membership::RawStrippedState},
8	canonical_json::to_canonical_value,
9	events::{
10		StateEventType,
11		room::member::{MembershipState, RoomMemberEventContent},
12	},
13};
14use tuwunel_core::{
15	Err, Event, PduCount, Result, at, debug, debug_info, debug_warn, err, extract_variant,
16	implement, info,
17	matrix::event::gen_event_id,
18	pdu::{PduBuilder, PduEvent},
19	trace, utils, warn,
20};
21
22use super::Service;
23use crate::{
24	membership::join::get_servers_for_room,
25	rooms::{
26		state::RoomMutexGuard,
27		state_compressor::{CompressedState, HashSetCompressStateEvent},
28	},
29};
30
31#[implement(Service)]
32#[tracing::instrument(
33	level = "debug",
34	skip_all,
35	fields(%sender_user, %room_id)
36)]
37pub async fn knock(
38	&self,
39	sender_user: &UserId,
40	room_id: &RoomId,
41	orig_server_name: Option<&RoomOrAliasId>,
42	reason: Option<String>,
43	servers: &[OwnedServerName],
44	state_lock: &RoomMutexGuard,
45) -> Result {
46	let servers =
47		get_servers_for_room(&self.services, sender_user, room_id, orig_server_name, servers)
48			.await?;
49
50	if self
51		.services
52		.state_cache
53		.is_invited(sender_user, room_id)
54		.await
55	{
56		debug_warn!("{sender_user} is already invited in {room_id} but attempted to knock");
57		return Err!(Request(Forbidden(
58			"You cannot knock on a room you are already invited/accepted to."
59		)));
60	}
61
62	if self
63		.services
64		.state_cache
65		.is_joined(sender_user, room_id)
66		.await
67	{
68		debug_warn!("{sender_user} is already joined in {room_id} but attempted to knock");
69		return Err!(Request(Forbidden("You cannot knock on a room you are already joined in.")));
70	}
71
72	if self
73		.services
74		.state_cache
75		.is_knocked(sender_user, room_id)
76		.await
77	{
78		debug_warn!("{sender_user} is already knocked in {room_id}");
79		return Ok(());
80	}
81
82	if let Ok(membership) = self
83		.services
84		.state_accessor
85		.get_member(room_id, sender_user)
86		.await && membership.membership == MembershipState::Ban
87	{
88		debug_warn!("{sender_user} is banned from {room_id} but attempted to knock");
89		return Err!(Request(Forbidden("You cannot knock on a room you are banned from.")));
90	}
91
92	let server_in_room = self
93		.services
94		.state_cache
95		.server_in_room(self.services.globals.server_name(), room_id)
96		.await;
97
98	let local_knock = server_in_room
99		|| servers.is_empty()
100		|| (servers.len() == 1 && self.services.globals.server_is_ours(&servers[0]));
101
102	if local_knock {
103		self.knock_room_helper_local(sender_user, room_id, reason, &servers, state_lock)
104			.boxed()
105			.await
106	} else {
107		self.knock_room_helper_remote(sender_user, room_id, reason, &servers, state_lock)
108			.boxed()
109			.await
110	}
111}
112
113#[implement(Service)]
114#[expect(
115	deprecated,
116	reason = "Matrix 1.16 still permits receiving the legacy stripped variant for backwards \
117	          compatibility."
118)]
119async fn knock_room_helper_local(
120	&self,
121	sender_user: &UserId,
122	room_id: &RoomId,
123	reason: Option<String>,
124	servers: &[OwnedServerName],
125	state_lock: &RoomMutexGuard,
126) -> Result {
127	debug_info!("We can knock locally");
128
129	let room_version_id = self
130		.services
131		.state
132		.get_room_version(room_id)
133		.await?;
134
135	if matches!(
136		room_version_id,
137		RoomVersionId::V1
138			| RoomVersionId::V2
139			| RoomVersionId::V3
140			| RoomVersionId::V4
141			| RoomVersionId::V5
142			| RoomVersionId::V6
143	) {
144		return Err!(Request(Forbidden("This room does not support knocking.")));
145	}
146
147	let content = RoomMemberEventContent {
148		displayname: self
149			.services
150			.users
151			.displayname(sender_user)
152			.await
153			.ok(),
154		avatar_url: self
155			.services
156			.users
157			.avatar_url(sender_user)
158			.await
159			.ok(),
160		blurhash: self
161			.services
162			.users
163			.blurhash(sender_user)
164			.await
165			.ok(),
166		reason: reason.clone(),
167		..RoomMemberEventContent::new(MembershipState::Knock)
168	};
169
170	// Try normal knock first
171	let Err(error) = self
172		.services
173		.timeline
174		.build_and_append_pdu(
175			PduBuilder::state(sender_user.to_string(), &content),
176			sender_user,
177			room_id,
178			state_lock,
179		)
180		.await
181	else {
182		return Ok(());
183	};
184
185	if servers.is_empty()
186		|| (servers.len() == 1 && self.services.globals.server_is_ours(&servers[0]))
187	{
188		return Err(error);
189	}
190
191	warn!("We couldn't do the knock locally, maybe federation can help to satisfy the knock");
192
193	let (make_knock_response, remote_server) = self
194		.make_knock_request(sender_user, room_id, servers)
195		.await?;
196
197	info!("make_knock finished");
198
199	let room_version_id = make_knock_response.room_version;
200
201	if !self
202		.services
203		.config
204		.supported_room_version(&room_version_id)
205	{
206		return Err!(BadServerResponse(
207			"Remote room version {room_version_id} is not supported by tuwunel"
208		));
209	}
210
211	let mut knock_event_stub = serde_json::from_str::<CanonicalJsonObject>(
212		make_knock_response.event.get(),
213	)
214	.map_err(|e| {
215		err!(BadServerResponse("Invalid make_knock event json received from server: {e:?}"))
216	})?;
217
218	knock_event_stub.insert(
219		"origin".into(),
220		CanonicalJsonValue::String(
221			self.services
222				.globals
223				.server_name()
224				.as_str()
225				.to_owned(),
226		),
227	);
228	knock_event_stub.insert(
229		"origin_server_ts".into(),
230		CanonicalJsonValue::Integer(
231			utils::millis_since_unix_epoch()
232				.try_into()
233				.expect("Timestamp is valid js_int value"),
234		),
235	);
236	knock_event_stub.insert(
237		"content".into(),
238		to_canonical_value(RoomMemberEventContent {
239			displayname: self
240				.services
241				.users
242				.displayname(sender_user)
243				.await
244				.ok(),
245			avatar_url: self
246				.services
247				.users
248				.avatar_url(sender_user)
249				.await
250				.ok(),
251			blurhash: self
252				.services
253				.users
254				.blurhash(sender_user)
255				.await
256				.ok(),
257			reason,
258			..RoomMemberEventContent::new(MembershipState::Knock)
259		})
260		.expect("event is valid, we just created it"),
261	);
262
263	knock_event_stub
264		.insert("room_id".into(), CanonicalJsonValue::String(room_id.as_str().into()));
265
266	knock_event_stub
267		.insert("state_key".into(), CanonicalJsonValue::String(sender_user.as_str().into()));
268
269	knock_event_stub
270		.insert("sender".into(), CanonicalJsonValue::String(sender_user.as_str().into()));
271
272	knock_event_stub.insert("type".into(), CanonicalJsonValue::String("m.room.member".into()));
273
274	// In order to create a compatible ref hash (EventID) the `hashes` field needs
275	// to be present
276	self.services
277		.server_keys
278		.hash_and_sign_event(&mut knock_event_stub, &room_version_id)?;
279
280	// Generate event id
281	let event_id = gen_event_id(&knock_event_stub, &room_version_id)?;
282
283	// Add event_id
284	knock_event_stub
285		.insert("event_id".into(), CanonicalJsonValue::String(event_id.clone().into()));
286
287	// It has enough fields to be called a proper event now
288	let knock_event = knock_event_stub;
289
290	info!("Asking {remote_server} for send_knock in room {room_id}");
291	let send_knock_request = federation::membership::create_knock_event::v1::Request {
292		room_id: room_id.to_owned(),
293		event_id: event_id.clone(),
294		pdu: self
295			.services
296			.federation
297			.format_pdu_into(knock_event.clone(), Some(&room_version_id))
298			.await,
299	};
300
301	let send_knock_response = self
302		.services
303		.federation
304		.execute(&remote_server, send_knock_request)
305		.await?;
306
307	info!("send_knock finished");
308
309	self.services
310		.short
311		.get_or_create_shortroomid(room_id)
312		.await;
313
314	info!("Parsing knock event");
315
316	let parsed_knock_pdu = PduEvent::from_object_and_eventid(&event_id, knock_event.clone())
317		.map_err(|e| err!(BadServerResponse("Invalid knock event PDU: {e:?}")))?;
318
319	info!("Updating membership locally to knock state with provided stripped state events");
320	let count = self.services.globals.next_count();
321	self.services
322		.state_cache
323		.update_membership(
324			room_id,
325			sender_user,
326			parsed_knock_pdu
327				.get_content::<RoomMemberEventContent>()
328				.expect("we just created this"),
329			sender_user,
330			Some(
331				send_knock_response
332					.knock_room_state
333					.into_iter()
334					.filter_map(|s| extract_variant!(s, RawStrippedState::Stripped))
335					.collect(),
336			),
337			None,
338			false,
339			PduCount::Normal(*count),
340		)
341		.await?;
342
343	info!("Appending room knock event locally");
344	self.services
345		.timeline
346		.append_pdu(
347			&parsed_knock_pdu,
348			knock_event,
349			once(parsed_knock_pdu.event_id.borrow()),
350			state_lock,
351		)
352		.await?;
353
354	Ok(())
355}
356
357#[implement(Service)]
358#[expect(
359	deprecated,
360	reason = "Matrix 1.16 still permits receiving the legacy stripped variant for backwards \
361	          compatibility."
362)]
363async fn knock_room_helper_remote(
364	&self,
365	sender_user: &UserId,
366	room_id: &RoomId,
367	reason: Option<String>,
368	servers: &[OwnedServerName],
369	state_lock: &RoomMutexGuard,
370) -> Result {
371	info!("Knocking {room_id} over federation.");
372
373	let (make_knock_response, remote_server) = self
374		.make_knock_request(sender_user, room_id, servers)
375		.await?;
376
377	info!("make_knock finished");
378
379	let room_version_id = make_knock_response.room_version;
380
381	if !self
382		.services
383		.config
384		.supported_room_version(&room_version_id)
385	{
386		return Err!(BadServerResponse(
387			"Remote room version {room_version_id} is not supported by tuwunel"
388		));
389	}
390
391	let mut knock_event_stub: CanonicalJsonObject =
392		serde_json::from_str(make_knock_response.event.get()).map_err(|e| {
393			err!(BadServerResponse("Invalid make_knock event json received from server: {e:?}"))
394		})?;
395
396	knock_event_stub.insert(
397		"origin".into(),
398		CanonicalJsonValue::String(
399			self.services
400				.globals
401				.server_name()
402				.as_str()
403				.to_owned(),
404		),
405	);
406	knock_event_stub.insert(
407		"origin_server_ts".into(),
408		CanonicalJsonValue::Integer(
409			utils::millis_since_unix_epoch()
410				.try_into()
411				.expect("Timestamp is valid js_int value"),
412		),
413	);
414	knock_event_stub.insert(
415		"content".into(),
416		to_canonical_value(RoomMemberEventContent {
417			displayname: self
418				.services
419				.users
420				.displayname(sender_user)
421				.await
422				.ok(),
423			avatar_url: self
424				.services
425				.users
426				.avatar_url(sender_user)
427				.await
428				.ok(),
429			blurhash: self
430				.services
431				.users
432				.blurhash(sender_user)
433				.await
434				.ok(),
435			reason,
436			..RoomMemberEventContent::new(MembershipState::Knock)
437		})
438		.expect("event is valid, we just created it"),
439	);
440
441	knock_event_stub
442		.insert("room_id".into(), CanonicalJsonValue::String(room_id.as_str().into()));
443
444	knock_event_stub
445		.insert("state_key".into(), CanonicalJsonValue::String(sender_user.as_str().into()));
446
447	knock_event_stub
448		.insert("sender".into(), CanonicalJsonValue::String(sender_user.as_str().into()));
449
450	knock_event_stub.insert("type".into(), CanonicalJsonValue::String("m.room.member".into()));
451
452	// In order to create a compatible ref hash (EventID) the `hashes` field needs
453	// to be present
454	self.services
455		.server_keys
456		.hash_and_sign_event(&mut knock_event_stub, &room_version_id)?;
457
458	// Generate event id
459	let event_id = gen_event_id(&knock_event_stub, &room_version_id)?;
460
461	// Add event_id
462	knock_event_stub
463		.insert("event_id".into(), CanonicalJsonValue::String(event_id.clone().into()));
464
465	// It has enough fields to be called a proper event now
466	let knock_event = knock_event_stub;
467
468	info!("Asking {remote_server} for send_knock in room {room_id}");
469	let send_knock_request = federation::membership::create_knock_event::v1::Request {
470		room_id: room_id.to_owned(),
471		event_id: event_id.clone(),
472		pdu: self
473			.services
474			.federation
475			.format_pdu_into(knock_event.clone(), Some(&room_version_id))
476			.await,
477	};
478
479	let send_knock_response = self
480		.services
481		.federation
482		.execute(&remote_server, send_knock_request)
483		.await?;
484
485	info!("send_knock finished");
486
487	self.services
488		.short
489		.get_or_create_shortroomid(room_id)
490		.await;
491
492	info!("Parsing knock event");
493	let parsed_knock_pdu = PduEvent::from_object_and_eventid(&event_id, knock_event.clone())
494		.map_err(|e| err!(BadServerResponse("Invalid knock event PDU: {e:?}")))?;
495
496	info!("Going through send_knock response knock state events");
497	let state = send_knock_response
498		.knock_room_state
499		.iter()
500		.map(|event| {
501			serde_json::from_str::<CanonicalJsonObject>(
502				extract_variant!(event.clone(), RawStrippedState::Stripped)
503					.expect("Raw<AnyStrippedStateEvent>")
504					.json()
505					.get(),
506			)
507		})
508		.filter_map(Result::ok);
509
510	let mut state_map: HashMap<u64, OwnedEventId> = HashMap::new();
511
512	for event in state {
513		let Some(state_key) = event.get("state_key") else {
514			debug_warn!("send_knock stripped state event missing state_key: {event:?}");
515			continue;
516		};
517		let Some(event_type) = event.get("type") else {
518			debug_warn!("send_knock stripped state event missing event type: {event:?}");
519			continue;
520		};
521
522		let Ok(state_key) = serde_json::from_value::<String>(state_key.clone().into()) else {
523			debug_warn!("send_knock stripped state event has invalid state_key: {event:?}");
524			continue;
525		};
526		let Ok(event_type) = serde_json::from_value::<StateEventType>(event_type.clone().into())
527		else {
528			debug_warn!("send_knock stripped state event has invalid event type: {event:?}");
529			continue;
530		};
531
532		let event_id = gen_event_id(&event, &room_version_id)?;
533		let shortstatekey = self
534			.services
535			.short
536			.get_or_create_shortstatekey(&event_type, &state_key)
537			.await;
538
539		self.services
540			.timeline
541			.add_pdu_outlier(&event_id, &event);
542
543		state_map.insert(shortstatekey, event_id.clone());
544	}
545
546	info!("Compressing state from send_knock");
547	let compressed: CompressedState = self
548		.services
549		.state_compressor
550		.compress_state_events(
551			state_map
552				.iter()
553				.map(|(ssk, eid)| (ssk, eid.borrow())),
554		)
555		.collect()
556		.await;
557
558	debug!("Saving compressed state");
559	let HashSetCompressStateEvent {
560		shortstatehash: statehash_before_knock,
561		added,
562		removed,
563	} = self
564		.services
565		.state_compressor
566		.save_state(room_id, Arc::new(compressed))
567		.await?;
568
569	debug!("Forcing state for new room");
570	self.services
571		.state
572		.force_state(room_id, statehash_before_knock, added, removed, state_lock)
573		.await?;
574
575	let statehash_after_knock = self
576		.services
577		.state
578		.append_to_state(&parsed_knock_pdu)
579		.await?;
580
581	info!("Updating membership locally to knock state with provided stripped state events");
582	let count = self.services.globals.next_count();
583	self.services
584		.state_cache
585		.update_membership(
586			room_id,
587			sender_user,
588			parsed_knock_pdu
589				.get_content::<RoomMemberEventContent>()
590				.expect("we just created this"),
591			sender_user,
592			Some(
593				send_knock_response
594					.knock_room_state
595					.into_iter()
596					.filter_map(|s| extract_variant!(s, RawStrippedState::Stripped))
597					.collect(),
598			),
599			None,
600			false,
601			PduCount::Normal(*count),
602		)
603		.await?;
604
605	info!("Appending room knock event locally");
606	self.services
607		.timeline
608		.append_pdu(
609			&parsed_knock_pdu,
610			knock_event,
611			once(parsed_knock_pdu.event_id.borrow()),
612			state_lock,
613		)
614		.await?;
615
616	info!("Setting final room state for new room");
617	// We set the room state after inserting the pdu, so that we never have a moment
618	// in time where events in the current room state do not exist
619	self.services
620		.state
621		.set_room_state(room_id, statehash_after_knock, state_lock);
622
623	Ok(())
624}
625
626#[implement(Service)]
627async fn make_knock_request(
628	&self,
629	sender_user: &UserId,
630	room_id: &RoomId,
631	servers: &[OwnedServerName],
632) -> Result<(federation::membership::prepare_knock_event::v1::Response, OwnedServerName)> {
633	let mut make_knock_response_and_server =
634		Err!(BadServerResponse("No server available to assist in knocking."));
635
636	let mut make_knock_counter: usize = 0;
637
638	for remote_server in servers {
639		if self
640			.services
641			.globals
642			.server_is_ours(remote_server)
643		{
644			continue;
645		}
646
647		info!("Asking {remote_server} for make_knock ({make_knock_counter})");
648
649		let make_knock_response = self
650			.services
651			.federation
652			.execute(remote_server, federation::membership::prepare_knock_event::v1::Request {
653				room_id: room_id.to_owned(),
654				user_id: sender_user.to_owned(),
655				ver: self
656					.services
657					.config
658					.supported_room_versions()
659					.map(at!(0))
660					.collect(),
661			})
662			.await;
663
664		trace!("make_knock response: {make_knock_response:?}");
665		make_knock_counter = make_knock_counter.saturating_add(1);
666
667		make_knock_response_and_server = make_knock_response.map(|r| (r, remote_server.clone()));
668
669		if make_knock_response_and_server.is_ok() {
670			break;
671		}
672
673		if make_knock_counter > 40 {
674			warn!(
675				"50 servers failed to provide valid make_knock response, assuming no server can \
676				 assist in knocking."
677			);
678			make_knock_response_and_server =
679				Err!(BadServerResponse("No server available to assist in knocking."));
680
681			return make_knock_response_and_server;
682		}
683	}
684
685	make_knock_response_and_server
686}