Skip to main content

tuwunel_service/membership/
knock.rs

1use std::{borrow::Borrow, collections::HashMap, iter::once, sync::Arc};
2
3use futures::{FutureExt, StreamExt};
4use ruma::{
5	CanonicalJsonObject, CanonicalJsonValue, OwnedEventId, OwnedServerName, RoomId,
6	RoomOrAliasId, RoomVersionId, UserId,
7	api::federation::{self, membership::RawStrippedState},
8	canonical_json::to_canonical_value,
9	events::{
10		StateEventType,
11		room::member::{MembershipState, RoomMemberEventContent},
12	},
13};
14use tuwunel_core::{
15	Err, Event, PduCount, Result, at, debug, debug_info, debug_warn, err, implement, info,
16	matrix::event::gen_event_id,
17	pdu::{PduBuilder, PduEvent},
18	trace, utils, warn,
19};
20
21use super::{
22	Service, StrippedCreateVerdict, enforce_stripped_create, into_client_stripped, v12_room_ids,
23};
24use crate::{
25	membership::join::get_servers_for_room,
26	rooms::{
27		state::RoomMutexGuard,
28		state_compressor::{CompressedState, HashSetCompressStateEvent},
29	},
30};
31
32#[implement(Service)]
33#[tracing::instrument(
34	level = "debug",
35	skip_all,
36	fields(%sender_user, %room_id)
37)]
38pub async fn knock(
39	&self,
40	sender_user: &UserId,
41	room_id: &RoomId,
42	orig_server_name: Option<&RoomOrAliasId>,
43	reason: Option<String>,
44	servers: &[OwnedServerName],
45	state_lock: &RoomMutexGuard,
46) -> Result {
47	let servers =
48		get_servers_for_room(&self.services, sender_user, room_id, orig_server_name, servers)
49			.await?;
50
51	if self
52		.services
53		.state_cache
54		.is_invited(sender_user, room_id)
55		.await
56	{
57		debug_warn!("{sender_user} is already invited in {room_id} but attempted to knock");
58		return Err!(Request(Forbidden(
59			"You cannot knock on a room you are already invited/accepted to."
60		)));
61	}
62
63	if self
64		.services
65		.state_cache
66		.is_joined(sender_user, room_id)
67		.await
68	{
69		debug_warn!("{sender_user} is already joined in {room_id} but attempted to knock");
70		return Err!(Request(Forbidden("You cannot knock on a room you are already joined in.")));
71	}
72
73	let server_in_room = self
74		.services
75		.state_cache
76		.server_in_room(self.services.globals.server_name(), room_id)
77		.await;
78
79	// Trust a local knock; re-drive a remote one in case we missed a kick.
80	if server_in_room
81		&& self
82			.services
83			.state_cache
84			.is_knocked(sender_user, room_id)
85			.await
86	{
87		debug_warn!("{sender_user} is already knocked in {room_id}");
88		return Ok(());
89	}
90
91	if let Ok(membership) = self
92		.services
93		.state_accessor
94		.get_member(room_id, sender_user)
95		.await && membership.membership == MembershipState::Ban
96	{
97		debug_warn!("{sender_user} is banned from {room_id} but attempted to knock");
98		return Err!(Request(Forbidden("You cannot knock on a room you are banned from.")));
99	}
100
101	let local_knock = server_in_room
102		|| servers.is_empty()
103		|| (servers.len() == 1 && self.services.globals.server_is_ours(&servers[0]));
104
105	if local_knock {
106		self.knock_room_helper_local(sender_user, room_id, reason, &servers, state_lock)
107			.boxed()
108			.await
109	} else {
110		self.knock_room_helper_remote(sender_user, room_id, reason, &servers, state_lock)
111			.boxed()
112			.await
113	}
114}
115
116#[implement(Service)]
117async fn knock_room_helper_local(
118	&self,
119	sender_user: &UserId,
120	room_id: &RoomId,
121	reason: Option<String>,
122	servers: &[OwnedServerName],
123	state_lock: &RoomMutexGuard,
124) -> Result {
125	debug_info!("We can knock locally");
126
127	let room_version_id = self
128		.services
129		.state
130		.get_room_version(room_id)
131		.await?;
132
133	ensure_room_version_supports_knock(&room_version_id)?;
134
135	let mut content = RoomMemberEventContent {
136		reason: reason.clone(),
137		..RoomMemberEventContent::new(MembershipState::Knock)
138	};
139
140	self.services
141		.profile
142		.fill_profile_data(sender_user, &mut content)
143		.await;
144
145	let Err(error) = self
146		.services
147		.timeline
148		.build_and_append_pdu(
149			PduBuilder::state(sender_user.to_string(), &content),
150			sender_user,
151			room_id,
152			state_lock,
153		)
154		.await
155	else {
156		return Ok(());
157	};
158
159	if servers.is_empty()
160		|| (servers.len() == 1 && self.services.globals.server_is_ours(&servers[0]))
161	{
162		return Err(error);
163	}
164
165	warn!("We couldn't do the knock locally, maybe federation can help to satisfy the knock");
166
167	self.knock_room_local_federation_fallback(sender_user, room_id, reason, servers, state_lock)
168		.boxed()
169		.await
170}
171
172fn ensure_room_version_supports_knock(room_version_id: &RoomVersionId) -> Result {
173	if matches!(
174		room_version_id,
175		RoomVersionId::V1
176			| RoomVersionId::V2
177			| RoomVersionId::V3
178			| RoomVersionId::V4
179			| RoomVersionId::V5
180			| RoomVersionId::V6
181	) {
182		return Err!(Request(Forbidden("This room does not support knocking.")));
183	}
184
185	Ok(())
186}
187
188#[implement(Service)]
189async fn knock_room_local_federation_fallback(
190	&self,
191	sender_user: &UserId,
192	room_id: &RoomId,
193	reason: Option<String>,
194	servers: &[OwnedServerName],
195	state_lock: &RoomMutexGuard,
196) -> Result {
197	let (make_knock_response, remote_server) = self
198		.make_knock_request(sender_user, room_id, servers)
199		.await?;
200
201	info!("make_knock finished");
202
203	let room_version_id = make_knock_response.room_version.clone();
204
205	if !self
206		.services
207		.config
208		.supported_room_version(&room_version_id)
209	{
210		return Err!(BadServerResponse(
211			"Remote room version {room_version_id} is not supported by tuwunel"
212		));
213	}
214
215	let (knock_event, event_id) = self
216		.build_knock_event(sender_user, room_id, reason, &make_knock_response, &room_version_id)
217		.await?;
218
219	let send_knock_response = self
220		.execute_send_knock(&remote_server, room_id, &event_id, &knock_event, &room_version_id)
221		.await?;
222
223	self.services
224		.short
225		.get_or_create_shortroomid(room_id)
226		.await;
227
228	self.finalize_knock_membership(
229		room_id,
230		sender_user,
231		&event_id,
232		knock_event,
233		send_knock_response,
234		state_lock,
235	)
236	.await
237}
238
239#[implement(Service)]
240async fn finalize_knock_membership(
241	&self,
242	room_id: &RoomId,
243	sender_user: &UserId,
244	event_id: &OwnedEventId,
245	knock_event: CanonicalJsonObject,
246	send_knock_response: federation::membership::create_knock_event::v1::Response,
247	state_lock: &RoomMutexGuard,
248) -> Result {
249	info!("Parsing knock event");
250	let parsed_knock_pdu = PduEvent::from_object_and_eventid(event_id, knock_event.clone())
251		.map_err(|e| err!(BadServerResponse("Invalid knock event PDU: {e:?}")))?;
252
253	info!("Updating membership locally to knock state with provided stripped state events");
254	let count = self.services.globals.next_count();
255	self.services
256		.state_cache
257		.update_membership(
258			room_id,
259			sender_user,
260			parsed_knock_pdu
261				.get_content::<RoomMemberEventContent>()
262				.expect("we just created this"),
263			sender_user,
264			Some(
265				send_knock_response
266					.knock_room_state
267					.into_iter()
268					.filter_map(|state| into_client_stripped(room_id, state))
269					.collect(),
270			),
271			None,
272			false,
273			PduCount::Normal(*count),
274		)
275		.await?;
276
277	info!("Appending room knock event locally");
278	self.services
279		.timeline
280		.append_pdu(
281			&parsed_knock_pdu,
282			knock_event,
283			once(parsed_knock_pdu.event_id.borrow()),
284			state_lock,
285		)
286		.await?;
287
288	Ok(())
289}
290
291#[implement(Service)]
292async fn knock_room_helper_remote(
293	&self,
294	sender_user: &UserId,
295	room_id: &RoomId,
296	reason: Option<String>,
297	servers: &[OwnedServerName],
298	state_lock: &RoomMutexGuard,
299) -> Result {
300	info!("Knocking {room_id} over federation.");
301
302	let (make_knock_response, remote_server) = self
303		.make_knock_request(sender_user, room_id, servers)
304		.await?;
305
306	info!("make_knock finished");
307
308	let room_version_id = make_knock_response.room_version.clone();
309
310	if !self
311		.services
312		.config
313		.supported_room_version(&room_version_id)
314	{
315		return Err!(BadServerResponse(
316			"Remote room version {room_version_id} is not supported by tuwunel"
317		));
318	}
319
320	let (knock_event, event_id) = self
321		.build_knock_event(sender_user, room_id, reason, &make_knock_response, &room_version_id)
322		.await?;
323
324	let send_knock_response = self
325		.execute_send_knock(&remote_server, room_id, &event_id, &knock_event, &room_version_id)
326		.await?;
327
328	self.services
329		.short
330		.get_or_create_shortroomid(room_id)
331		.await;
332
333	info!("Parsing knock event");
334	let parsed_knock_pdu = PduEvent::from_object_and_eventid(&event_id, knock_event.clone())
335		.map_err(|e| err!(BadServerResponse("Invalid knock event PDU: {e:?}")))?;
336
337	let state_map = self
338		.ingest_send_knock_state(room_id, &send_knock_response, &room_version_id)
339		.await?;
340
341	self.apply_send_knock_state(room_id, &state_map, state_lock)
342		.await?;
343
344	let statehash_after_knock = self
345		.services
346		.state
347		.append_to_state(&parsed_knock_pdu)
348		.await?;
349
350	info!("Updating membership locally to knock state with provided stripped state events");
351	let count = self.services.globals.next_count();
352	self.services
353		.state_cache
354		.update_membership(
355			room_id,
356			sender_user,
357			parsed_knock_pdu
358				.get_content::<RoomMemberEventContent>()
359				.expect("we just created this"),
360			sender_user,
361			Some(
362				send_knock_response
363					.knock_room_state
364					.into_iter()
365					.filter_map(|state| into_client_stripped(room_id, state))
366					.collect(),
367			),
368			None,
369			false,
370			PduCount::Normal(*count),
371		)
372		.await?;
373
374	info!("Appending room knock event locally");
375	self.services
376		.timeline
377		.append_pdu(
378			&parsed_knock_pdu,
379			knock_event,
380			once(parsed_knock_pdu.event_id.borrow()),
381			state_lock,
382		)
383		.await?;
384
385	info!("Setting final room state for new room");
386	// We set the room state after inserting the pdu, so that we never have a moment
387	// in time where events in the current room state do not exist
388	self.services
389		.state
390		.set_room_state(room_id, statehash_after_knock, state_lock);
391
392	Ok(())
393}
394
395#[implement(Service)]
396async fn build_knock_event(
397	&self,
398	sender_user: &UserId,
399	room_id: &RoomId,
400	reason: Option<String>,
401	make_knock_response: &federation::membership::prepare_knock_event::v1::Response,
402	room_version_id: &RoomVersionId,
403) -> Result<(CanonicalJsonObject, OwnedEventId)> {
404	let mut knock_event_stub: CanonicalJsonObject =
405		serde_json::from_str(make_knock_response.event.get()).map_err(|e| {
406			err!(BadServerResponse("Invalid make_knock event json received from server: {e:?}"))
407		})?;
408
409	let mut content = RoomMemberEventContent {
410		reason,
411		..RoomMemberEventContent::new(MembershipState::Knock)
412	};
413
414	self.services
415		.profile
416		.fill_profile_data(sender_user, &mut content)
417		.await;
418
419	knock_event_stub.insert(
420		"origin".into(),
421		CanonicalJsonValue::String(
422			self.services
423				.globals
424				.server_name()
425				.as_str()
426				.to_owned(),
427		),
428	);
429	knock_event_stub.insert(
430		"origin_server_ts".into(),
431		CanonicalJsonValue::Integer(
432			utils::millis_since_unix_epoch()
433				.try_into()
434				.expect("Timestamp is valid js_int value"),
435		),
436	);
437	knock_event_stub.insert(
438		"content".into(),
439		to_canonical_value(content).expect("event is valid, we just created it"),
440	);
441
442	knock_event_stub
443		.insert("room_id".into(), CanonicalJsonValue::String(room_id.as_str().into()));
444
445	knock_event_stub
446		.insert("state_key".into(), CanonicalJsonValue::String(sender_user.as_str().into()));
447
448	knock_event_stub
449		.insert("sender".into(), CanonicalJsonValue::String(sender_user.as_str().into()));
450
451	knock_event_stub.insert("type".into(), CanonicalJsonValue::String("m.room.member".into()));
452
453	// In order to create a compatible ref hash (EventID) the `hashes` field needs
454	// to be present
455	self.services
456		.server_keys
457		.hash_and_sign_event(&mut knock_event_stub, room_version_id)?;
458
459	let event_id = gen_event_id(&knock_event_stub, room_version_id)?;
460
461	knock_event_stub
462		.insert("event_id".into(), CanonicalJsonValue::String(event_id.clone().into()));
463
464	Ok((knock_event_stub, event_id))
465}
466
467#[implement(Service)]
468async fn execute_send_knock(
469	&self,
470	remote_server: &OwnedServerName,
471	room_id: &RoomId,
472	event_id: &OwnedEventId,
473	knock_event: &CanonicalJsonObject,
474	room_version_id: &RoomVersionId,
475) -> Result<federation::membership::create_knock_event::v1::Response> {
476	info!("Asking {remote_server} for send_knock in room {room_id}");
477	let send_knock_request = federation::membership::create_knock_event::v1::Request {
478		room_id: room_id.to_owned(),
479		event_id: event_id.clone(),
480		pdu: self
481			.services
482			.federation
483			.format_pdu_into(knock_event.clone(), Some(room_version_id))
484			.await,
485	};
486
487	let response = self
488		.services
489		.federation
490		.execute(remote_server, send_knock_request)
491		.await?;
492
493	info!("send_knock finished");
494	Ok(response)
495}
496
497#[implement(Service)]
498#[expect(
499	deprecated,
500	reason = "Matrix 1.16 still permits receiving the legacy stripped variant for backwards \
501	          compatibility."
502)]
503async fn ingest_send_knock_state(
504	&self,
505	room_id: &RoomId,
506	send_knock_response: &federation::membership::create_knock_event::v1::Response,
507	room_version_id: &RoomVersionId,
508) -> Result<HashMap<u64, OwnedEventId>> {
509	info!("Going through send_knock response knock state events");
510
511	let verdict = self
512		.validate_stripped_create(&send_knock_response.knock_room_state, room_id, room_version_id)
513		.await?;
514
515	let enforce = self
516		.services
517		.config
518		.enforce_stripped_state_pdu_validation;
519
520	let drop_create = enforce_stripped_create(verdict, v12_room_ids(room_version_id), enforce);
521
522	if verdict != StrippedCreateVerdict::Valid {
523		debug_warn!(?verdict, %room_id, drop_create, "MSC4311 knock create-event validation failed");
524	}
525
526	let state = send_knock_response
527		.knock_room_state
528		.iter()
529		.filter_map(|event| match event {
530			| RawStrippedState::Pdu(raw) =>
531				serde_json::from_str::<CanonicalJsonObject>(raw.get()).ok(),
532			| RawStrippedState::Stripped(raw) =>
533				serde_json::from_str::<CanonicalJsonObject>(raw.json().get()).ok(),
534		});
535
536	let mut state_map: HashMap<u64, OwnedEventId> = HashMap::new();
537
538	for event in state {
539		let Some(state_key) = event.get("state_key") else {
540			debug_warn!("send_knock stripped state event missing state_key: {event:?}");
541			continue;
542		};
543		let Some(event_type) = event.get("type") else {
544			debug_warn!("send_knock stripped state event missing event type: {event:?}");
545			continue;
546		};
547
548		let Ok(state_key) = serde_json::from_value::<String>(state_key.clone().into()) else {
549			debug_warn!("send_knock stripped state event has invalid state_key: {event:?}");
550			continue;
551		};
552		let Ok(event_type) = serde_json::from_value::<StateEventType>(event_type.clone().into())
553		else {
554			debug_warn!("send_knock stripped state event has invalid event type: {event:?}");
555			continue;
556		};
557
558		// MSC4311: drop a create event that failed validation when policy enforces.
559		if drop_create && event_type == StateEventType::RoomCreate && state_key.is_empty() {
560			debug_warn!(%room_id, "dropping unvalidated create event from knock state");
561			continue;
562		}
563
564		let event_id = gen_event_id(&event, room_version_id)?;
565		let shortstatekey = self
566			.services
567			.short
568			.get_or_create_shortstatekey(&event_type, &state_key)
569			.await;
570
571		self.services
572			.timeline
573			.add_pdu_outlier(&event_id, &event);
574
575		state_map.insert(shortstatekey, event_id.clone());
576	}
577
578	Ok(state_map)
579}
580
581#[implement(Service)]
582async fn apply_send_knock_state(
583	&self,
584	room_id: &RoomId,
585	state_map: &HashMap<u64, OwnedEventId>,
586	state_lock: &RoomMutexGuard,
587) -> Result {
588	info!("Compressing state from send_knock");
589	let compressed: CompressedState = self
590		.services
591		.state_compressor
592		.compress_state_events(
593			state_map
594				.iter()
595				.map(|(ssk, eid)| (ssk, eid.borrow())),
596		)
597		.collect()
598		.await;
599
600	debug!("Saving compressed state");
601	let HashSetCompressStateEvent {
602		shortstatehash: statehash_before_knock,
603		added,
604		removed,
605	} = self
606		.services
607		.state_compressor
608		.save_state(room_id, Arc::new(compressed))
609		.await?;
610
611	debug!("Forcing state for new room");
612	self.services
613		.state
614		.force_state(room_id, statehash_before_knock, added, removed, state_lock)
615		.await?;
616
617	Ok(())
618}
619
620#[implement(Service)]
621async fn make_knock_request(
622	&self,
623	sender_user: &UserId,
624	room_id: &RoomId,
625	servers: &[OwnedServerName],
626) -> Result<(federation::membership::prepare_knock_event::v1::Response, OwnedServerName)> {
627	let mut make_knock_response_and_server =
628		Err!(BadServerResponse("No server available to assist in knocking."));
629
630	let mut make_knock_counter: usize = 0;
631
632	for remote_server in servers {
633		if self
634			.services
635			.globals
636			.server_is_ours(remote_server)
637		{
638			continue;
639		}
640
641		info!("Asking {remote_server} for make_knock ({make_knock_counter})");
642
643		let make_knock_response = self
644			.services
645			.federation
646			.execute(remote_server, federation::membership::prepare_knock_event::v1::Request {
647				room_id: room_id.to_owned(),
648				user_id: sender_user.to_owned(),
649				ver: self
650					.services
651					.config
652					.supported_room_versions()
653					.map(at!(0))
654					.collect(),
655			})
656			.await;
657
658		trace!("make_knock response: {make_knock_response:?}");
659		make_knock_counter = make_knock_counter.saturating_add(1);
660
661		make_knock_response_and_server = make_knock_response.map(|r| (r, remote_server.clone()));
662
663		if make_knock_response_and_server.is_ok() {
664			break;
665		}
666
667		if make_knock_counter > 40 {
668			warn!(
669				"50 servers failed to provide valid make_knock response, assuming no server can \
670				 assist in knocking."
671			);
672			make_knock_response_and_server =
673				Err!(BadServerResponse("No server available to assist in knocking."));
674
675			return make_knock_response_and_server;
676		}
677	}
678
679	make_knock_response_and_server
680}