Skip to main content

tuwunel_service/threepid/
ratelimit.rs

1use std::{hash::Hash, net::IpAddr, time::Instant};
2
3use http::StatusCode;
4use ruma::api::error::{ErrorKind, LimitExceededErrorData};
5use tuwunel_core::{Error, Result, implement};
6
7use super::Ratelimiter;
8
9/// Refills per second on each requestToken bucket; a generous burst absorbs a
10/// real client's retries while bounding sustained spray.
11const RC_PER_SECOND: f64 = 0.2;
12const RC_BURST: f64 = 5.0;
13
14/// Cap on each bucket table; fully refilled buckets are pruned past it so a
15/// spray cannot grow the table without bound.
16const RATELIMIT_MAP_CAP: usize = 1 << 16;
17
18/// Per-caller-IP requestToken throttle, the axis bounding one source spraying
19/// many addresses.
20#[implement(super::Service)]
21pub fn check_ip_rate_limit(&self, client: IpAddr) -> Result {
22	check_bucket(&self.ip_ratelimiter, client, RC_PER_SECOND, RC_BURST)
23}
24
25/// Per-target-address requestToken throttle, the axis bounding many sources
26/// spraying one address.
27#[implement(super::Service)]
28pub fn check_address_rate_limit(&self, address: &str) -> Result {
29	check_bucket(&self.address_ratelimiter, address.into(), RC_PER_SECOND, RC_BURST)
30}
31
32fn check_bucket<K>(table: &Ratelimiter<K>, key: K, rate: f64, burst: f64) -> Result
33where
34	K: Eq + Hash,
35{
36	let now = Instant::now();
37	let mut buckets = table.lock()?;
38
39	if buckets.len() >= RATELIMIT_MAP_CAP {
40		buckets.retain(|_, bucket| {
41			let (last, toks) = *bucket;
42			now.duration_since(last)
43				.as_secs_f64()
44				.mul_add(rate, toks)
45				< burst
46		});
47	}
48
49	let (last_time, tokens) = buckets.entry(key).or_insert_with(|| (now, burst));
50
51	let new_tokens = now
52		.duration_since(*last_time)
53		.as_secs_f64()
54		.mul_add(rate, *tokens)
55		.min(burst);
56
57	if new_tokens < 1.0 {
58		return Err(Error::Request(
59			ErrorKind::LimitExceeded(LimitExceededErrorData { retry_after: None }),
60			"Too many verification requests.".into(),
61			StatusCode::TOO_MANY_REQUESTS,
62		));
63	}
64
65	*last_time = now;
66	*tokens = new_tokens - 1.0;
67
68	Ok(())
69}