1use std::{borrow::Borrow, collections::HashMap, iter::once, sync::Arc};
2
3use futures::{FutureExt, StreamExt};
4use ruma::{
5 CanonicalJsonObject, CanonicalJsonValue, OwnedEventId, OwnedServerName, RoomId,
6 RoomOrAliasId, RoomVersionId, UserId,
7 api::federation::{self, membership::RawStrippedState},
8 canonical_json::to_canonical_value,
9 events::{
10 StateEventType,
11 room::member::{MembershipState, RoomMemberEventContent},
12 },
13};
14use tuwunel_core::{
15 Err, Event, PduCount, Result, at, debug, debug_info, debug_warn, err, implement, info,
16 matrix::event::gen_event_id,
17 pdu::{PduBuilder, PduEvent},
18 trace, utils, warn,
19};
20
21use super::{
22 Service, StrippedCreateVerdict, enforce_stripped_create, into_client_stripped, v12_room_ids,
23};
24use crate::{
25 membership::join::get_servers_for_room,
26 rooms::{
27 state::RoomMutexGuard,
28 state_compressor::{CompressedState, HashSetCompressStateEvent},
29 },
30};
31
32#[implement(Service)]
33#[tracing::instrument(
34 level = "debug",
35 skip_all,
36 fields(%sender_user, %room_id)
37)]
38pub async fn knock(
39 &self,
40 sender_user: &UserId,
41 room_id: &RoomId,
42 orig_server_name: Option<&RoomOrAliasId>,
43 reason: Option<String>,
44 servers: &[OwnedServerName],
45 state_lock: &RoomMutexGuard,
46) -> Result {
47 let servers =
48 get_servers_for_room(&self.services, sender_user, room_id, orig_server_name, servers)
49 .await?;
50
51 if self
52 .services
53 .state_cache
54 .is_invited(sender_user, room_id)
55 .await
56 {
57 debug_warn!("{sender_user} is already invited in {room_id} but attempted to knock");
58 return Err!(Request(Forbidden(
59 "You cannot knock on a room you are already invited/accepted to."
60 )));
61 }
62
63 if self
64 .services
65 .state_cache
66 .is_joined(sender_user, room_id)
67 .await
68 {
69 debug_warn!("{sender_user} is already joined in {room_id} but attempted to knock");
70 return Err!(Request(Forbidden("You cannot knock on a room you are already joined in.")));
71 }
72
73 let server_in_room = self
74 .services
75 .state_cache
76 .server_in_room(self.services.globals.server_name(), room_id)
77 .await;
78
79 if server_in_room
81 && self
82 .services
83 .state_cache
84 .is_knocked(sender_user, room_id)
85 .await
86 {
87 debug_warn!("{sender_user} is already knocked in {room_id}");
88 return Ok(());
89 }
90
91 if let Ok(membership) = self
92 .services
93 .state_accessor
94 .get_member(room_id, sender_user)
95 .await && membership.membership == MembershipState::Ban
96 {
97 debug_warn!("{sender_user} is banned from {room_id} but attempted to knock");
98 return Err!(Request(Forbidden("You cannot knock on a room you are banned from.")));
99 }
100
101 let local_knock = server_in_room
102 || servers.is_empty()
103 || (servers.len() == 1 && self.services.globals.server_is_ours(&servers[0]));
104
105 if local_knock {
106 self.knock_room_helper_local(sender_user, room_id, reason, &servers, state_lock)
107 .boxed()
108 .await
109 } else {
110 self.knock_room_helper_remote(sender_user, room_id, reason, &servers, state_lock)
111 .boxed()
112 .await
113 }
114}
115
116#[implement(Service)]
117async fn knock_room_helper_local(
118 &self,
119 sender_user: &UserId,
120 room_id: &RoomId,
121 reason: Option<String>,
122 servers: &[OwnedServerName],
123 state_lock: &RoomMutexGuard,
124) -> Result {
125 debug_info!("We can knock locally");
126
127 let room_version_id = self
128 .services
129 .state
130 .get_room_version(room_id)
131 .await?;
132
133 ensure_room_version_supports_knock(&room_version_id)?;
134
135 let mut content = RoomMemberEventContent {
136 reason: reason.clone(),
137 ..RoomMemberEventContent::new(MembershipState::Knock)
138 };
139
140 self.services
141 .profile
142 .fill_profile_data(sender_user, &mut content)
143 .await;
144
145 let Err(error) = self
146 .services
147 .timeline
148 .build_and_append_pdu(
149 PduBuilder::state(sender_user.to_string(), &content),
150 sender_user,
151 room_id,
152 state_lock,
153 )
154 .await
155 else {
156 return Ok(());
157 };
158
159 if servers.is_empty()
160 || (servers.len() == 1 && self.services.globals.server_is_ours(&servers[0]))
161 {
162 return Err(error);
163 }
164
165 warn!("We couldn't do the knock locally, maybe federation can help to satisfy the knock");
166
167 self.knock_room_local_federation_fallback(sender_user, room_id, reason, servers, state_lock)
168 .boxed()
169 .await
170}
171
172fn ensure_room_version_supports_knock(room_version_id: &RoomVersionId) -> Result {
173 if matches!(
174 room_version_id,
175 RoomVersionId::V1
176 | RoomVersionId::V2
177 | RoomVersionId::V3
178 | RoomVersionId::V4
179 | RoomVersionId::V5
180 | RoomVersionId::V6
181 ) {
182 return Err!(Request(Forbidden("This room does not support knocking.")));
183 }
184
185 Ok(())
186}
187
188#[implement(Service)]
189async fn knock_room_local_federation_fallback(
190 &self,
191 sender_user: &UserId,
192 room_id: &RoomId,
193 reason: Option<String>,
194 servers: &[OwnedServerName],
195 state_lock: &RoomMutexGuard,
196) -> Result {
197 let (make_knock_response, remote_server) = self
198 .make_knock_request(sender_user, room_id, servers)
199 .await?;
200
201 info!("make_knock finished");
202
203 let room_version_id = make_knock_response.room_version.clone();
204
205 if !self
206 .services
207 .config
208 .supported_room_version(&room_version_id)
209 {
210 return Err!(BadServerResponse(
211 "Remote room version {room_version_id} is not supported by tuwunel"
212 ));
213 }
214
215 let (knock_event, event_id) = self
216 .build_knock_event(sender_user, room_id, reason, &make_knock_response, &room_version_id)
217 .await?;
218
219 let send_knock_response = self
220 .execute_send_knock(&remote_server, room_id, &event_id, &knock_event, &room_version_id)
221 .await?;
222
223 self.services
224 .short
225 .get_or_create_shortroomid(room_id)
226 .await;
227
228 self.finalize_knock_membership(
229 room_id,
230 sender_user,
231 &event_id,
232 knock_event,
233 send_knock_response,
234 state_lock,
235 )
236 .await
237}
238
239#[implement(Service)]
240async fn finalize_knock_membership(
241 &self,
242 room_id: &RoomId,
243 sender_user: &UserId,
244 event_id: &OwnedEventId,
245 knock_event: CanonicalJsonObject,
246 send_knock_response: federation::membership::create_knock_event::v1::Response,
247 state_lock: &RoomMutexGuard,
248) -> Result {
249 info!("Parsing knock event");
250 let parsed_knock_pdu = PduEvent::from_object_and_eventid(event_id, knock_event.clone())
251 .map_err(|e| err!(BadServerResponse("Invalid knock event PDU: {e:?}")))?;
252
253 info!("Updating membership locally to knock state with provided stripped state events");
254 let count = self.services.globals.next_count();
255 self.services
256 .state_cache
257 .update_membership(
258 room_id,
259 sender_user,
260 parsed_knock_pdu
261 .get_content::<RoomMemberEventContent>()
262 .expect("we just created this"),
263 sender_user,
264 Some(
265 send_knock_response
266 .knock_room_state
267 .into_iter()
268 .filter_map(|state| into_client_stripped(room_id, state))
269 .collect(),
270 ),
271 None,
272 false,
273 PduCount::Normal(*count),
274 )
275 .await?;
276
277 info!("Appending room knock event locally");
278 self.services
279 .timeline
280 .append_pdu(
281 &parsed_knock_pdu,
282 knock_event,
283 once(parsed_knock_pdu.event_id.borrow()),
284 state_lock,
285 )
286 .await?;
287
288 Ok(())
289}
290
291#[implement(Service)]
292async fn knock_room_helper_remote(
293 &self,
294 sender_user: &UserId,
295 room_id: &RoomId,
296 reason: Option<String>,
297 servers: &[OwnedServerName],
298 state_lock: &RoomMutexGuard,
299) -> Result {
300 info!("Knocking {room_id} over federation.");
301
302 let (make_knock_response, remote_server) = self
303 .make_knock_request(sender_user, room_id, servers)
304 .await?;
305
306 info!("make_knock finished");
307
308 let room_version_id = make_knock_response.room_version.clone();
309
310 if !self
311 .services
312 .config
313 .supported_room_version(&room_version_id)
314 {
315 return Err!(BadServerResponse(
316 "Remote room version {room_version_id} is not supported by tuwunel"
317 ));
318 }
319
320 let (knock_event, event_id) = self
321 .build_knock_event(sender_user, room_id, reason, &make_knock_response, &room_version_id)
322 .await?;
323
324 let send_knock_response = self
325 .execute_send_knock(&remote_server, room_id, &event_id, &knock_event, &room_version_id)
326 .await?;
327
328 self.services
329 .short
330 .get_or_create_shortroomid(room_id)
331 .await;
332
333 info!("Parsing knock event");
334 let parsed_knock_pdu = PduEvent::from_object_and_eventid(&event_id, knock_event.clone())
335 .map_err(|e| err!(BadServerResponse("Invalid knock event PDU: {e:?}")))?;
336
337 let state_map = self
338 .ingest_send_knock_state(room_id, &send_knock_response, &room_version_id)
339 .await?;
340
341 self.apply_send_knock_state(room_id, &state_map, state_lock)
342 .await?;
343
344 let statehash_after_knock = self
345 .services
346 .state
347 .append_to_state(&parsed_knock_pdu)
348 .await?;
349
350 info!("Updating membership locally to knock state with provided stripped state events");
351 let count = self.services.globals.next_count();
352 self.services
353 .state_cache
354 .update_membership(
355 room_id,
356 sender_user,
357 parsed_knock_pdu
358 .get_content::<RoomMemberEventContent>()
359 .expect("we just created this"),
360 sender_user,
361 Some(
362 send_knock_response
363 .knock_room_state
364 .into_iter()
365 .filter_map(|state| into_client_stripped(room_id, state))
366 .collect(),
367 ),
368 None,
369 false,
370 PduCount::Normal(*count),
371 )
372 .await?;
373
374 info!("Appending room knock event locally");
375 self.services
376 .timeline
377 .append_pdu(
378 &parsed_knock_pdu,
379 knock_event,
380 once(parsed_knock_pdu.event_id.borrow()),
381 state_lock,
382 )
383 .await?;
384
385 info!("Setting final room state for new room");
386 self.services
389 .state
390 .set_room_state(room_id, statehash_after_knock, state_lock);
391
392 Ok(())
393}
394
395#[implement(Service)]
396async fn build_knock_event(
397 &self,
398 sender_user: &UserId,
399 room_id: &RoomId,
400 reason: Option<String>,
401 make_knock_response: &federation::membership::prepare_knock_event::v1::Response,
402 room_version_id: &RoomVersionId,
403) -> Result<(CanonicalJsonObject, OwnedEventId)> {
404 let mut knock_event_stub: CanonicalJsonObject =
405 serde_json::from_str(make_knock_response.event.get()).map_err(|e| {
406 err!(BadServerResponse("Invalid make_knock event json received from server: {e:?}"))
407 })?;
408
409 let mut content = RoomMemberEventContent {
410 reason,
411 ..RoomMemberEventContent::new(MembershipState::Knock)
412 };
413
414 self.services
415 .profile
416 .fill_profile_data(sender_user, &mut content)
417 .await;
418
419 knock_event_stub.insert(
420 "origin".into(),
421 CanonicalJsonValue::String(
422 self.services
423 .globals
424 .server_name()
425 .as_str()
426 .to_owned(),
427 ),
428 );
429 knock_event_stub.insert(
430 "origin_server_ts".into(),
431 CanonicalJsonValue::Integer(
432 utils::millis_since_unix_epoch()
433 .try_into()
434 .expect("Timestamp is valid js_int value"),
435 ),
436 );
437 knock_event_stub.insert(
438 "content".into(),
439 to_canonical_value(content).expect("event is valid, we just created it"),
440 );
441
442 knock_event_stub
443 .insert("room_id".into(), CanonicalJsonValue::String(room_id.as_str().into()));
444
445 knock_event_stub
446 .insert("state_key".into(), CanonicalJsonValue::String(sender_user.as_str().into()));
447
448 knock_event_stub
449 .insert("sender".into(), CanonicalJsonValue::String(sender_user.as_str().into()));
450
451 knock_event_stub.insert("type".into(), CanonicalJsonValue::String("m.room.member".into()));
452
453 self.services
456 .server_keys
457 .hash_and_sign_event(&mut knock_event_stub, room_version_id)?;
458
459 let event_id = gen_event_id(&knock_event_stub, room_version_id)?;
460
461 knock_event_stub
462 .insert("event_id".into(), CanonicalJsonValue::String(event_id.clone().into()));
463
464 Ok((knock_event_stub, event_id))
465}
466
467#[implement(Service)]
468async fn execute_send_knock(
469 &self,
470 remote_server: &OwnedServerName,
471 room_id: &RoomId,
472 event_id: &OwnedEventId,
473 knock_event: &CanonicalJsonObject,
474 room_version_id: &RoomVersionId,
475) -> Result<federation::membership::create_knock_event::v1::Response> {
476 info!("Asking {remote_server} for send_knock in room {room_id}");
477 let send_knock_request = federation::membership::create_knock_event::v1::Request {
478 room_id: room_id.to_owned(),
479 event_id: event_id.clone(),
480 pdu: self
481 .services
482 .federation
483 .format_pdu_into(knock_event.clone(), Some(room_version_id))
484 .await,
485 };
486
487 let response = self
488 .services
489 .federation
490 .execute(remote_server, send_knock_request)
491 .await?;
492
493 info!("send_knock finished");
494 Ok(response)
495}
496
497#[implement(Service)]
498#[expect(
499 deprecated,
500 reason = "Matrix 1.16 still permits receiving the legacy stripped variant for backwards \
501 compatibility."
502)]
503async fn ingest_send_knock_state(
504 &self,
505 room_id: &RoomId,
506 send_knock_response: &federation::membership::create_knock_event::v1::Response,
507 room_version_id: &RoomVersionId,
508) -> Result<HashMap<u64, OwnedEventId>> {
509 info!("Going through send_knock response knock state events");
510
511 let verdict = self
512 .validate_stripped_create(&send_knock_response.knock_room_state, room_id, room_version_id)
513 .await?;
514
515 let enforce = self
516 .services
517 .config
518 .enforce_stripped_state_pdu_validation;
519
520 let drop_create = enforce_stripped_create(verdict, v12_room_ids(room_version_id), enforce);
521
522 if verdict != StrippedCreateVerdict::Valid {
523 debug_warn!(?verdict, %room_id, drop_create, "MSC4311 knock create-event validation failed");
524 }
525
526 let state = send_knock_response
527 .knock_room_state
528 .iter()
529 .filter_map(|event| match event {
530 | RawStrippedState::Pdu(raw) =>
531 serde_json::from_str::<CanonicalJsonObject>(raw.get()).ok(),
532 | RawStrippedState::Stripped(raw) =>
533 serde_json::from_str::<CanonicalJsonObject>(raw.json().get()).ok(),
534 });
535
536 let mut state_map: HashMap<u64, OwnedEventId> = HashMap::new();
537
538 for event in state {
539 let Some(state_key) = event.get("state_key") else {
540 debug_warn!("send_knock stripped state event missing state_key: {event:?}");
541 continue;
542 };
543 let Some(event_type) = event.get("type") else {
544 debug_warn!("send_knock stripped state event missing event type: {event:?}");
545 continue;
546 };
547
548 let Ok(state_key) = serde_json::from_value::<String>(state_key.clone().into()) else {
549 debug_warn!("send_knock stripped state event has invalid state_key: {event:?}");
550 continue;
551 };
552 let Ok(event_type) = serde_json::from_value::<StateEventType>(event_type.clone().into())
553 else {
554 debug_warn!("send_knock stripped state event has invalid event type: {event:?}");
555 continue;
556 };
557
558 if drop_create && event_type == StateEventType::RoomCreate && state_key.is_empty() {
560 debug_warn!(%room_id, "dropping unvalidated create event from knock state");
561 continue;
562 }
563
564 let event_id = gen_event_id(&event, room_version_id)?;
565 let shortstatekey = self
566 .services
567 .short
568 .get_or_create_shortstatekey(&event_type, &state_key)
569 .await;
570
571 self.services
572 .timeline
573 .add_pdu_outlier(&event_id, &event);
574
575 state_map.insert(shortstatekey, event_id.clone());
576 }
577
578 Ok(state_map)
579}
580
581#[implement(Service)]
582async fn apply_send_knock_state(
583 &self,
584 room_id: &RoomId,
585 state_map: &HashMap<u64, OwnedEventId>,
586 state_lock: &RoomMutexGuard,
587) -> Result {
588 info!("Compressing state from send_knock");
589 let compressed: CompressedState = self
590 .services
591 .state_compressor
592 .compress_state_events(
593 state_map
594 .iter()
595 .map(|(ssk, eid)| (ssk, eid.borrow())),
596 )
597 .collect()
598 .await;
599
600 debug!("Saving compressed state");
601 let HashSetCompressStateEvent {
602 shortstatehash: statehash_before_knock,
603 added,
604 removed,
605 } = self
606 .services
607 .state_compressor
608 .save_state(room_id, Arc::new(compressed))
609 .await?;
610
611 debug!("Forcing state for new room");
612 self.services
613 .state
614 .force_state(room_id, statehash_before_knock, added, removed, state_lock)
615 .await?;
616
617 Ok(())
618}
619
620#[implement(Service)]
621async fn make_knock_request(
622 &self,
623 sender_user: &UserId,
624 room_id: &RoomId,
625 servers: &[OwnedServerName],
626) -> Result<(federation::membership::prepare_knock_event::v1::Response, OwnedServerName)> {
627 let mut make_knock_response_and_server =
628 Err!(BadServerResponse("No server available to assist in knocking."));
629
630 let mut make_knock_counter: usize = 0;
631
632 for remote_server in servers {
633 if self
634 .services
635 .globals
636 .server_is_ours(remote_server)
637 {
638 continue;
639 }
640
641 info!("Asking {remote_server} for make_knock ({make_knock_counter})");
642
643 let make_knock_response = self
644 .services
645 .federation
646 .execute(remote_server, federation::membership::prepare_knock_event::v1::Request {
647 room_id: room_id.to_owned(),
648 user_id: sender_user.to_owned(),
649 ver: self
650 .services
651 .config
652 .supported_room_versions()
653 .map(at!(0))
654 .collect(),
655 })
656 .await;
657
658 trace!("make_knock response: {make_knock_response:?}");
659 make_knock_counter = make_knock_counter.saturating_add(1);
660
661 make_knock_response_and_server = make_knock_response.map(|r| (r, remote_server.clone()));
662
663 if make_knock_response_and_server.is_ok() {
664 break;
665 }
666
667 if make_knock_counter > 40 {
668 warn!(
669 "50 servers failed to provide valid make_knock response, assuming no server can \
670 assist in knocking."
671 );
672 make_knock_response_and_server =
673 Err!(BadServerResponse("No server available to assist in knocking."));
674
675 return make_knock_response_and_server;
676 }
677 }
678
679 make_knock_response_and_server
680}