1use std::{
2 collections::BTreeMap,
3 time::{Duration, SystemTime, UNIX_EPOCH},
4};
5
6use ruma::{
7 CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedServerName, RoomId, RoomVersionId,
8 ServerName, SigningKeyAlgorithm,
9 api::{
10 error::{ErrorKind, RetryAfter},
11 federation::policy::sign_event::v1 as sign_event,
12 },
13 events::{StateEventType, room::policy::RoomPolicyEventContent},
14 serde::Base64,
15 signatures::{to_canonical_json_string_for_signing, verify_canonical_json_bytes},
16};
17use serde::{Deserialize, Serialize};
18use serde_json::value::to_raw_value;
19use tuwunel_core::{
20 Err, Result, at, debug, implement,
21 matrix::{Event, pdu::into_outgoing_federation, room_version},
22 trace,
23 utils::time::now_secs,
24 warn,
25};
26use tuwunel_database::{Cbor, Deserialized};
27
28const UNSTABLE_POLICY_TYPE: &str = "org.matrix.msc4284.policy";
33
34#[derive(Debug)]
36pub enum PolicyCheck {
37 NotApplicable,
41
42 Pass,
44
45 Missing,
48
49 Invalid,
52}
53
54#[derive(Debug)]
56enum FetchOutcome {
57 Signed(String),
59
60 FailOpen,
62
63 Refused,
66
67 RateLimited {
71 until_secs: u64,
72 },
73}
74
75#[derive(Debug, Serialize, Deserialize)]
78enum PolicySigState {
79 Refused,
81
82 BackoffUntil {
85 until_secs: u64,
86 },
87}
88
89#[derive(Deserialize)]
93struct UnstablePolicyContent {
94 via: OwnedServerName,
95
96 #[serde(default)]
97 public_keys: BTreeMap<SigningKeyAlgorithm, Base64>,
98
99 #[serde(default)]
100 public_key: Option<Base64>,
101}
102
103#[implement(UnstablePolicyContent)]
104fn into_stable(
105 Self { via, mut public_keys, public_key }: Self,
106) -> Option<RoomPolicyEventContent> {
107 if let Some(key) = public_key {
108 public_keys
109 .entry(SigningKeyAlgorithm::Ed25519)
110 .or_insert(key);
111 }
112
113 let ed25519 = public_keys.remove(&SigningKeyAlgorithm::Ed25519)?;
114
115 Some(RoomPolicyEventContent::new(via, ed25519))
116}
117
118#[implement(super::Service)]
119fn cache_policy_refused(&self, event_id: &EventId) {
120 self.db
121 .eventid_policysigstate
122 .raw_put(event_id.as_str(), Cbor(&PolicySigState::Refused));
123}
124
125#[implement(super::Service)]
126fn cache_policy_backoff(&self, event_id: &EventId, until_secs: u64) {
127 self.db
128 .eventid_policysigstate
129 .raw_put(event_id.as_str(), Cbor(&PolicySigState::BackoffUntil { until_secs }));
130}
131
132#[implement(super::Service)]
133async fn cached_policy_state(&self, event_id: &EventId) -> Option<PolicySigState> {
134 self.db
135 .eventid_policysigstate
136 .get(event_id.as_str())
137 .await
138 .deserialized::<Cbor<_>>()
139 .map(at!(0))
140 .ok()
141}
142
143#[implement(super::Service)]
151pub async fn lookup_policy_server(&self, room_id: &RoomId) -> Option<RoomPolicyEventContent> {
152 let read = async |event_type: &StateEventType| {
153 self.services
154 .state_accessor
155 .room_state_get_content::<UnstablePolicyContent>(room_id, event_type, "")
156 .await
157 .ok()
158 .and_then(UnstablePolicyContent::into_stable)
159 };
160
161 let content = match read(&StateEventType::RoomPolicy).await {
162 | Some(content) => content,
163 | None => read(&StateEventType::from(UNSTABLE_POLICY_TYPE.to_owned())).await?,
164 };
165
166 self.services
167 .state_cache
168 .server_in_room(&content.via, room_id)
169 .await
170 .then_some(content)
171}
172
173#[implement(super::Service)]
179#[tracing::instrument(name = "policy_sign", level = "debug", skip_all)]
180pub async fn sign_outgoing_pdu<E>(&self, pdu_json: &mut CanonicalJsonObject, pdu: &E) -> Result
181where
182 E: Event,
183{
184 if !self.services.server.config.enable_policy_servers {
185 return Ok(());
186 }
187
188 if is_policy_state_event(pdu) {
189 return Ok(());
190 }
191
192 let Ok(room_version) = self
193 .services
194 .state
195 .get_room_version(pdu.room_id())
196 .await
197 else {
198 return Ok(());
199 };
200
201 let Some(policy) = self.lookup_policy_server(pdu.room_id()).await else {
202 trace!(room_id = %pdu.room_id(), "no policy server configured");
203 return Ok(());
204 };
205
206 let event_id = pdu.event_id();
207 match self.cached_policy_state(event_id).await {
208 | Some(PolicySigState::Refused) =>
209 return Err!(Request(Forbidden("Event was rejected by the room's policy server."))),
210
211 | Some(PolicySigState::BackoffUntil { until_secs }) if until_secs > now_secs() => {
212 debug!(via = %policy.via, until_secs, "skipping outbound /sign during policy backoff");
213 return Ok(());
214 },
215 | _ => {},
216 }
217
218 match self
219 .fetch_policy_signature(&policy, pdu_json, &room_version)
220 .await
221 {
222 | FetchOutcome::Signed(signature) => {
223 insert_policy_signature(pdu_json, &policy.via, &signature);
224 debug!(via = %policy.via, event_id = %event_id, "folded policy server signature");
225 },
226 | FetchOutcome::Refused => {
227 self.cache_policy_refused(event_id);
228 return Err!(Request(Forbidden("Event was rejected by the room's policy server.")));
229 },
230 | FetchOutcome::RateLimited { until_secs } => {
231 self.cache_policy_backoff(event_id, until_secs);
232 },
233 | FetchOutcome::FailOpen => {},
234 }
235
236 Ok(())
237}
238
239#[implement(super::Service)]
243#[tracing::instrument(
244 name = "policy_fetch",
245 level = "debug",
246 skip_all,
247 fields(via = %policy.via)
248)]
249async fn fetch_policy_signature(
250 &self,
251 policy: &RoomPolicyEventContent,
252 pdu_json: &CanonicalJsonObject,
253 room_version: &RoomVersionId,
254) -> FetchOutcome {
255 let outgoing = into_outgoing_federation(pdu_json.clone(), room_version);
256 let Ok(raw) = to_raw_value(&outgoing) else {
257 warn!(via = %policy.via, "failed to serialize PDU for policy /sign; failing open");
258 return FetchOutcome::FailOpen;
259 };
260
261 let timeout = Duration::from_secs(
262 self.services
263 .server
264 .config
265 .policy_server_request_timeout,
266 );
267
268 let response = match tokio::time::timeout(
269 timeout,
270 self.services
271 .federation
272 .execute(&policy.via, sign_event::Request::new(raw)),
273 )
274 .await
275 {
276 | Ok(Ok(response)) => response,
277 | Ok(Err(error)) if error.kind() == ErrorKind::Forbidden => return FetchOutcome::Refused,
278 | Ok(Err(error)) => {
279 if let Some(until_secs) = parse_rate_limit(&error) {
280 warn!(via = %policy.via, until_secs, "policy server /sign rate-limited");
281 return FetchOutcome::RateLimited { until_secs };
282 }
283 warn!(via = %policy.via, %error, "policy server /sign failed; failing open");
284 return FetchOutcome::FailOpen;
285 },
286 | Err(elapsed) => {
287 warn!(via = %policy.via, %elapsed, "policy server /sign timed out; failing open");
288 return FetchOutcome::FailOpen;
289 },
290 };
291
292 response
294 .ed25519_signature(&policy.via)
295 .map(ToOwned::to_owned)
296 .map_or(FetchOutcome::Refused, FetchOutcome::Signed)
297}
298
299fn parse_rate_limit(error: &tuwunel_core::Error) -> Option<u64> {
300 let ErrorKind::LimitExceeded(data) = error.kind() else {
301 return None;
302 };
303
304 let until = match data.retry_after.as_ref()? {
305 | RetryAfter::Delay(d) => SystemTime::now().checked_add(*d)?,
306 | RetryAfter::DateTime(t) => *t,
307 };
308
309 until
310 .duration_since(UNIX_EPOCH)
311 .ok()
312 .map(|d| d.as_secs())
313}
314
315#[implement(super::Service)]
321#[tracing::instrument(name = "policy_verify_or_fetch", level = "debug", skip_all)]
322pub async fn verify_or_fetch_inbound_policy_signature<E>(
323 &self,
324 pdu_json: &mut CanonicalJsonObject,
325 pdu: &E,
326) -> PolicyCheck
327where
328 E: Event,
329{
330 match self
331 .check_inbound_policy_signature(pdu_json, pdu)
332 .await
333 {
334 | PolicyCheck::Missing =>
335 self.fetch_inbound_policy_signature(pdu_json, pdu)
336 .await,
337 | other => other,
338 }
339}
340
341#[implement(super::Service)]
350#[tracing::instrument(name = "policy_fetch_inbound", level = "debug", skip_all)]
351async fn fetch_inbound_policy_signature<E>(
352 &self,
353 pdu_json: &mut CanonicalJsonObject,
354 pdu: &E,
355) -> PolicyCheck
356where
357 E: Event,
358{
359 let Some(policy) = self.lookup_policy_server(pdu.room_id()).await else {
360 return PolicyCheck::NotApplicable;
361 };
362
363 let Ok(room_version) = self
364 .services
365 .state
366 .get_room_version(pdu.room_id())
367 .await
368 else {
369 return PolicyCheck::NotApplicable;
370 };
371
372 let event_id = pdu.event_id();
373 match self.cached_policy_state(event_id).await {
374 | Some(PolicySigState::Refused) => return PolicyCheck::Invalid,
375 | Some(PolicySigState::BackoffUntil { until_secs }) if until_secs > now_secs() => {
376 debug!(
377 until_secs,
378 via = %policy.via,
379 "policy server in backoff; failing open"
380 );
381
382 return PolicyCheck::Pass;
383 },
384 | _ => {},
385 }
386
387 match self
388 .fetch_policy_signature(&policy, pdu_json, &room_version)
389 .await
390 {
391 | FetchOutcome::Signed(signature) => {
392 debug!(
393 via = %policy.via,
394 event_id = %event_id,
395 "folded inbound policy server signature"
396 );
397
398 insert_policy_signature(pdu_json, &policy.via, &signature);
399 PolicyCheck::Pass
400 },
401 | FetchOutcome::Refused => {
402 debug!(
403 via = %policy.via,
404 event_id = %event_id,
405 "policy server refused to sign inbound PDU; soft-failing"
406 );
407
408 self.cache_policy_refused(event_id);
409 PolicyCheck::Invalid
410 },
411 | FetchOutcome::RateLimited { until_secs } => {
412 self.cache_policy_backoff(event_id, until_secs);
413 PolicyCheck::Pass
414 },
415 | FetchOutcome::FailOpen => PolicyCheck::Pass,
416 }
417}
418
419#[implement(super::Service)]
425#[tracing::instrument(name = "policy_verify", level = "debug", skip_all)]
426pub async fn check_inbound_policy_signature<E>(
427 &self,
428 pdu_json: &CanonicalJsonObject,
429 pdu: &E,
430) -> PolicyCheck
431where
432 E: Event,
433{
434 if !self.services.server.config.enable_policy_servers {
435 return PolicyCheck::NotApplicable;
436 }
437
438 if is_policy_state_event(pdu) {
439 return PolicyCheck::NotApplicable;
440 }
441
442 let Some(policy) = self.lookup_policy_server(pdu.room_id()).await else {
443 return PolicyCheck::NotApplicable;
444 };
445
446 let Ok(room_version) = self
447 .services
448 .state
449 .get_room_version(pdu.room_id())
450 .await
451 else {
452 return PolicyCheck::NotApplicable;
453 };
454
455 let Ok(rules) = room_version::rules(&room_version) else {
456 return PolicyCheck::NotApplicable;
457 };
458
459 let Some(public_key) = policy
461 .public_keys
462 .get(&SigningKeyAlgorithm::Ed25519)
463 else {
464 return PolicyCheck::NotApplicable;
465 };
466
467 let Some(signature_b64) = extract_policy_signature(pdu_json, &policy.via) else {
468 return PolicyCheck::Missing;
469 };
470
471 let Ok(signature) = Base64::<ruma::serde::base64::Standard>::parse(signature_b64) else {
472 return PolicyCheck::Invalid;
473 };
474
475 let Ok(redacted) = ruma::canonical_json::redact(pdu_json.clone(), &rules.redaction, None)
476 else {
477 return PolicyCheck::Invalid;
478 };
479
480 let Ok(canonical) = to_canonical_json_string_for_signing(&redacted) else {
481 return PolicyCheck::Invalid;
482 };
483
484 verify_canonical_json_bytes(
485 &SigningKeyAlgorithm::Ed25519,
486 public_key.as_bytes(),
487 signature.as_bytes(),
488 canonical.as_bytes(),
489 )
490 .map(|()| PolicyCheck::Pass)
491 .unwrap_or_else(|error| {
492 debug!(via = %policy.via, %error, "policy server signature failed verification");
493 PolicyCheck::Invalid
494 })
495}
496
497fn is_policy_state_event<E: Event>(pdu: &E) -> bool {
498 if pdu.state_key() != Some("") {
499 return false;
500 }
501
502 let kind = pdu.kind().to_cow_str();
503
504 kind == "m.room.policy" || kind == UNSTABLE_POLICY_TYPE
505}
506
507fn extract_policy_signature<'a>(
508 pdu_json: &'a CanonicalJsonObject,
509 via: &ServerName,
510) -> Option<&'a str> {
511 let CanonicalJsonValue::Object(server_map) = pdu_json.get("signatures")? else {
512 return None;
513 };
514
515 let CanonicalJsonValue::Object(key_map) = server_map.get(via.as_str())? else {
516 return None;
517 };
518
519 let CanonicalJsonValue::String(signature) =
520 key_map.get(RoomPolicyEventContent::POLICY_SERVER_ED25519_SIGNING_KEY_ID)?
521 else {
522 return None;
523 };
524
525 Some(signature.as_str())
526}
527
528fn insert_policy_signature(
529 pdu_json: &mut CanonicalJsonObject,
530 via: &ServerName,
531 signature: &str,
532) {
533 let signatures = pdu_json
534 .entry("signatures".into())
535 .or_insert_with(|| CanonicalJsonValue::Object(BTreeMap::new()));
536
537 let CanonicalJsonValue::Object(server_map) = signatures else {
538 return;
539 };
540
541 let entry = server_map
542 .entry(via.as_str().into())
543 .or_insert_with(|| CanonicalJsonValue::Object(BTreeMap::new()));
544
545 if let CanonicalJsonValue::Object(key_map) = entry {
546 key_map.insert(
547 RoomPolicyEventContent::POLICY_SERVER_ED25519_SIGNING_KEY_ID.into(),
548 CanonicalJsonValue::String(signature.to_owned()),
549 );
550 }
551}