Skip to content
This repository was archived by the owner on May 27, 2025. It is now read-only.

Commit 1a8b4a1

Browse files
committed
Refactor rate limiting
1 parent 0ee30a2 commit 1a8b4a1

File tree

9 files changed

+437
-133
lines changed

9 files changed

+437
-133
lines changed

.github/workflows/ci.yaml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,6 @@ jobs:
4040
- name: Check for common mistakes
4141
run: cargo check
4242

43-
- name: Check documentation
44-
run: cargo doc --no-deps --document-private-items
45-
4643
security-audit:
4744
name: Security Audit
4845
runs-on: ubuntu-latest

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ tokio-util = "0.7.12"
1919
reqwest = "0.12.12"
2020
metrics = "0.24.1"
2121
metrics-derive = "0.1"
22+
thiserror = "2.0.11"
23+
serde_json = "1.0.138"
2224

2325
[features]
2426
integration = []

src/client.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
use crate::rate_limit::Ticket;
2+
use axum::extract::ws::WebSocket;
3+
use axum::Error;
4+
use std::net::IpAddr;
5+
6+
pub struct ClientConnection {
7+
client_addr: IpAddr,
8+
_ticket: Ticket,
9+
websocket: WebSocket,
10+
}
11+
12+
impl ClientConnection {
13+
pub fn new(client_addr: IpAddr, ticket: Ticket, websocket: WebSocket) -> Self {
14+
Self {
15+
client_addr,
16+
_ticket: ticket,
17+
websocket,
18+
}
19+
}
20+
21+
pub async fn send(&mut self, data: String) -> Result<(), Error> {
22+
self.websocket.send(data.into_bytes().into()).await
23+
}
24+
25+
pub fn id(&self) -> String {
26+
self.client_addr.to_string()
27+
}
28+
}

src/integration.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
mod test {
22
use crate::metrics::Metrics;
3+
use crate::rate_limit::InMemoryRateLimit;
34
use crate::registry::Registry;
45
use crate::server::Server;
56
use futures::StreamExt;
@@ -36,14 +37,22 @@ mod test {
3637
}
3738
fn new(addr: SocketAddr) -> TestHarness {
3839
let (sender, _) = broadcast::channel(5);
39-
let registry = Registry::new(sender.clone(), 3, Arc::new(Metrics::default()));
40+
let metrics = Arc::new(Metrics::default());
41+
let registry = Registry::new(sender.clone(), metrics.clone());
42+
let rate_limited = Arc::new(InMemoryRateLimit::new(3, 10));
4043

4144
Self {
4245
received_messages: Arc::new(Mutex::new(HashMap::new())),
4346
clients_failed_to_connect: Arc::new(Mutex::new(HashMap::new())),
4447
current_client_id: 0,
4548
cancel_token: CancellationToken::new(),
46-
server: Server::new(addr.into(), registry),
49+
server: Server::new(
50+
addr.into(),
51+
registry,
52+
metrics,
53+
rate_limited,
54+
"header".to_string(),
55+
),
4756
server_addr: addr,
4857
client_id_to_handle: HashMap::new(),
4958
sender,

src/main.rs

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1+
mod client;
12
#[cfg(all(feature = "integration", test))]
23
mod integration;
34
mod metrics;
5+
mod rate_limit;
46
mod registry;
57
mod server;
68
mod subscriber;
79

810
use crate::metrics::Metrics;
11+
use crate::rate_limit::InMemoryRateLimit;
912
use crate::registry::Registry;
1013
use crate::server::Server;
1114
use crate::subscriber::WebsocketSubscriber;
@@ -49,7 +52,23 @@ struct Args {
4952
default_value = "100",
5053
help = "Maximum number of concurrently connected clients"
5154
)]
52-
maximum_concurrent_connections: usize,
55+
global_connections_limit: usize,
56+
57+
#[arg(
58+
long,
59+
env,
60+
default_value = "10",
61+
help = "Maximum number of concurrently connected clients"
62+
)]
63+
per_ip_connections_limit: usize,
64+
65+
#[arg(
66+
long,
67+
env,
68+
default_value = "X-Forwarded-For",
69+
help = "Header to use to determine the clients origin IP"
70+
)]
71+
ip_addr_http_header: String,
5372

5473
#[arg(long, env, default_value = "info")]
5574
log_level: Level,
@@ -134,9 +153,20 @@ async fn main() {
134153
);
135154
let subscriber_task = subscriber.run(token.clone());
136155

137-
let registry = Registry::new(sender, args.maximum_concurrent_connections, metrics.clone());
156+
let registry = Registry::new(sender, metrics.clone());
157+
158+
let rate_limiter = Arc::new(InMemoryRateLimit::new(
159+
args.global_connections_limit,
160+
args.per_ip_connections_limit,
161+
));
138162

139-
let server = Server::new(args.listen_addr, registry.clone());
163+
let server = Server::new(
164+
args.listen_addr,
165+
registry.clone(),
166+
metrics,
167+
rate_limiter,
168+
args.ip_addr_http_header,
169+
);
140170
let server_task = server.listen(token.clone());
141171

142172
let mut interrupt = signal(SignalKind::interrupt()).unwrap();

0 commit comments

Comments
 (0)