Skip to content

Commit c5ca079

Browse files
committed
refactor: [torrust#1096] extract BanService
1 parent cde6e26 commit c5ca079

File tree

5 files changed

+170
-38
lines changed

5 files changed

+170
-38
lines changed

src/servers/udp/handlers.rs

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@ use aquatic_udp_protocol::{
1111
ResponsePeer, ScrapeRequest, ScrapeResponse, TorrentScrapeStatistics, TransactionId,
1212
};
1313
use bittorrent_primitives::info_hash::InfoHash;
14-
use bloom::{CountingBloomFilter, ASMS};
1514
use tokio::sync::RwLock;
1615
use torrust_tracker_clock::clock::Time as _;
1716
use tracing::{instrument, Level};
1817
use uuid::Uuid;
1918
use zerocopy::network_endian::I32;
2019

2120
use super::connection_cookie::{check, make};
21+
use super::server::banning::BanService;
2222
use super::RawRequest;
2323
use crate::core::{statistics, PeersWanted, Tracker};
2424
use crate::servers::udp::error::Error;
@@ -53,13 +53,13 @@ impl CookieTimeValues {
5353
/// - Delegating the request to the correct handler depending on the request type.
5454
///
5555
/// It will return an `Error` response if the request is invalid.
56-
#[instrument(fields(request_id), skip(udp_request, tracker, cookie_time_values, connection_id_errors_per_ip), ret(level = Level::TRACE))]
56+
#[instrument(fields(request_id), skip(udp_request, tracker, cookie_time_values, ban_service), ret(level = Level::TRACE))]
5757
pub(crate) async fn handle_packet(
5858
udp_request: RawRequest,
5959
tracker: &Tracker,
6060
local_addr: SocketAddr,
6161
cookie_time_values: CookieTimeValues,
62-
connection_id_errors_per_ip: Arc<RwLock<CountingBloomFilter>>,
62+
ban_service: Arc<RwLock<BanService>>,
6363
) -> Response {
6464
tracing::Span::current().record("request_id", Uuid::new_v4().to_string());
6565
tracing::debug!("Handling Packets: {udp_request:?}");
@@ -76,10 +76,8 @@ pub(crate) async fn handle_packet(
7676
| Error::CookieValueExpired { .. }
7777
| Error::CookieValueFromFuture { .. } => {
7878
// code-review: should we include `RequestParseError` and `BadRequest`?
79-
connection_id_errors_per_ip
80-
.write()
81-
.await
82-
.insert(&udp_request.from.ip().to_string());
79+
let mut ban_service = ban_service.write().await;
80+
ban_service.increase_counter(&udp_request.from.ip());
8381
}
8482
_ => {}
8583
}

src/servers/udp/server/banning.rs

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
use std::net::IpAddr;
2+
use std::time::Duration;
3+
4+
use bloom::{CountingBloomFilter, ASMS};
5+
use tokio::time::Instant;
6+
use url::Url;
7+
8+
use crate::servers::udp::UDP_TRACKER_LOG_TARGET;
9+
10+
/// The maximum number of connection id errors per ip. Clients will be banned if
11+
/// they exceed this limit.
12+
pub const MAX_CONNECTION_ID_ERRORS_PER_IP: u32 = 10;
13+
pub const RESET_CONNECTION_ID_ERRORS_COUNTER_FREQUENCY_IN_SECS: u64 = 3600;
14+
15+
pub struct BanService {
16+
max_connection_id_errors_per_ip: u32,
17+
ban_duration: Duration,
18+
connection_id_errors_per_ip: CountingBloomFilter,
19+
local_addr: Url,
20+
last_connection_id_errors_reset: Instant,
21+
}
22+
23+
impl BanService {
24+
#[must_use]
25+
pub fn new(max_connection_id_errors_per_ip: u32, duration_in_seconds: u64, local_addr: Url) -> Self {
26+
Self {
27+
max_connection_id_errors_per_ip,
28+
ban_duration: Duration::from_secs(duration_in_seconds),
29+
local_addr,
30+
connection_id_errors_per_ip: CountingBloomFilter::with_rate(4, 0.01, 100),
31+
last_connection_id_errors_reset: tokio::time::Instant::now(),
32+
}
33+
}
34+
35+
pub fn increase_counter(&mut self, ip: &IpAddr) {
36+
self.connection_id_errors_per_ip.insert(&ip.to_string());
37+
}
38+
39+
pub fn get_counter(&mut self, ip: &IpAddr) -> u32 {
40+
self.connection_id_errors_per_ip.estimate_count(&ip.to_string())
41+
}
42+
43+
/// Returns true if the given ip address is banned.
44+
#[must_use]
45+
pub fn is_banned(&self, ip: &IpAddr) -> bool {
46+
let connection_id_errors_from_ip = self.connection_id_errors_per_ip.estimate_count(&ip.to_string());
47+
48+
connection_id_errors_from_ip > self.max_connection_id_errors_per_ip
49+
}
50+
51+
pub fn run_bans_cleaner(&mut self) {
52+
if self.last_connection_id_errors_reset.elapsed() >= self.ban_duration {
53+
self.reset_filter();
54+
}
55+
}
56+
57+
/// Resets the filter and updates the reset timestamp.
58+
pub fn reset_filter(&mut self) {
59+
self.connection_id_errors_per_ip.clear();
60+
61+
self.last_connection_id_errors_reset = Instant::now();
62+
63+
let local_addr = self.local_addr.to_string();
64+
tracing::info!(target: UDP_TRACKER_LOG_TARGET, local_addr, "Udp::run_udp_server::loop (connection id errors filter cleared)");
65+
}
66+
}
67+
68+
#[cfg(test)]
69+
mod tests {
70+
use std::net::IpAddr;
71+
use std::time::Duration;
72+
73+
use tokio::time::sleep;
74+
75+
use super::BanService;
76+
77+
/// Sample service with one day ban duration.
78+
fn service_with_one_day_ban(counter_limit: u32) -> BanService {
79+
let one_day_in_seconds = 86400;
80+
let udp_tracker_url = "udp://127.0.0.1".parse().unwrap();
81+
82+
BanService::new(counter_limit, one_day_in_seconds, udp_tracker_url)
83+
}
84+
85+
#[test]
86+
fn it_should_increase_the_ip_counter() {
87+
let mut ban_service = service_with_one_day_ban(1);
88+
89+
let ip: IpAddr = "127.0.0.2".parse().unwrap();
90+
91+
ban_service.increase_counter(&ip);
92+
93+
assert_eq!(ban_service.get_counter(&ip), 1);
94+
}
95+
96+
#[test]
97+
fn it_should_ban_ips_with_counters_exceeding_a_predefined_limit() {
98+
let mut ban_service = service_with_one_day_ban(1);
99+
100+
let ip: IpAddr = "127.0.0.2".parse().unwrap();
101+
102+
ban_service.increase_counter(&ip); // Counter = 1
103+
ban_service.increase_counter(&ip); // Counter = 2
104+
105+
assert!(ban_service.is_banned(&ip));
106+
}
107+
108+
#[test]
109+
fn it_should_not_ban_ips_whose_counters_do_not_exceed_the_predefined_limit() {
110+
let mut ban_service = service_with_one_day_ban(1);
111+
112+
let ip: IpAddr = "127.0.0.2".parse().unwrap();
113+
114+
ban_service.increase_counter(&ip);
115+
116+
assert!(!ban_service.is_banned(&ip));
117+
}
118+
119+
#[test]
120+
fn it_should_allow_resetting_all_the_counters() {
121+
let mut ban_service = service_with_one_day_ban(1);
122+
123+
let ip: IpAddr = "127.0.0.2".parse().unwrap();
124+
125+
ban_service.increase_counter(&ip); // Counter = 1
126+
127+
ban_service.reset_filter();
128+
129+
assert_eq!(ban_service.get_counter(&ip), 0);
130+
}
131+
132+
#[tokio::test]
133+
async fn it_should_allow_run_a_bans_cleaner_to_reset_the_counters_periodically() {
134+
let udp_tracker_url = "udp://127.0.0.1".parse().unwrap();
135+
let ban_duration_in_secs = 1;
136+
137+
let mut ban_service = BanService::new(1, ban_duration_in_secs, udp_tracker_url);
138+
139+
let ip: IpAddr = "127.0.0.2".parse().unwrap();
140+
141+
ban_service.increase_counter(&ip); // Counter = 1
142+
143+
sleep(Duration::from_secs(2)).await;
144+
145+
ban_service.run_bans_cleaner();
146+
147+
assert_eq!(ban_service.get_counter(&ip), 0);
148+
}
149+
}

src/servers/udp/server/launcher.rs

Lines changed: 11 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@ use std::sync::Arc;
33
use std::time::Duration;
44

55
use bittorrent_tracker_client::udp::client::check;
6-
use bloom::{CountingBloomFilter, ASMS};
76
use derive_more::Constructor;
87
use futures_util::StreamExt;
98
use tokio::select;
109
use tokio::sync::{oneshot, RwLock};
1110
use tracing::instrument;
1211

12+
use super::banning::{BanService, MAX_CONNECTION_ID_ERRORS_PER_IP, RESET_CONNECTION_ID_ERRORS_COUNTER_FREQUENCY_IN_SECS};
1313
use super::request_buffer::ActiveRequests;
1414
use crate::bootstrap::jobs::Started;
1515
use crate::core::{statistics, Tracker};
@@ -21,11 +21,6 @@ use crate::servers::udp::server::processor::Processor;
2121
use crate::servers::udp::server::receiver::Receiver;
2222
use crate::servers::udp::UDP_TRACKER_LOG_TARGET;
2323

24-
/// The maximum number of connection id errors per ip. Clients will be banned if
25-
/// they exceed this limit.
26-
const MAX_CONNECTION_ID_ERRORS_PER_IP: u32 = 10;
27-
const RESET_CONNECTION_ID_ERRORS_COUNTER_FREQUENCY_IN_SECS: u64 = 3600;
28-
2924
/// A UDP server instance launcher.
3025
#[derive(Constructor)]
3126
pub struct Launcher;
@@ -120,27 +115,21 @@ impl Launcher {
120115
async fn run_udp_server_main(mut receiver: Receiver, tracker: Arc<Tracker>, cookie_lifetime: Duration) {
121116
let active_requests = &mut ActiveRequests::default();
122117

123-
// Create a counting bloom filter that uses 4 bits per element and has a
124-
// false positive rate of 0.01 when 100 items have been inserted
125-
let connection_id_errors_per_ip = Arc::new(RwLock::new(CountingBloomFilter::with_rate(4, 0.01, 100)));
126-
127-
// Timer to track when to clear the filter
128-
let mut last_connection_id_errors_reset = tokio::time::Instant::now();
129-
130118
let addr = receiver.bound_socket_address();
131119

132120
let local_addr = format!("udp://{addr}");
133121

134122
let cookie_lifetime = cookie_lifetime.as_secs_f64();
135123

124+
let ban_service = Arc::new(RwLock::new(BanService::new(
125+
MAX_CONNECTION_ID_ERRORS_PER_IP,
126+
RESET_CONNECTION_ID_ERRORS_COUNTER_FREQUENCY_IN_SECS,
127+
local_addr.parse().unwrap(),
128+
)));
129+
136130
loop {
137-
if last_connection_id_errors_reset.elapsed()
138-
>= Duration::from_secs(RESET_CONNECTION_ID_ERRORS_COUNTER_FREQUENCY_IN_SECS)
139-
{
140-
connection_id_errors_per_ip.write().await.clear();
141-
tracing::info!(target: UDP_TRACKER_LOG_TARGET, local_addr, "Udp::run_udp_server::loop (connection id errors filter cleared)");
142-
last_connection_id_errors_reset = tokio::time::Instant::now();
143-
}
131+
// code-review: the ban service could spawn a task to clear the bans.
132+
ban_service.write().await.run_bans_cleaner();
144133

145134
let processor = Processor::new(receiver.socket.clone(), tracker.clone(), cookie_lifetime);
146135

@@ -171,12 +160,7 @@ impl Launcher {
171160
}
172161
}
173162

174-
let connection_id_errors_from_ip = connection_id_errors_per_ip
175-
.read()
176-
.await
177-
.estimate_count(&req.from.ip().to_string());
178-
179-
if connection_id_errors_from_ip > MAX_CONNECTION_ID_ERRORS_PER_IP {
163+
if ban_service.read().await.is_banned(&req.from.ip()) {
180164
tracing::debug!(target: UDP_TRACKER_LOG_TARGET, local_addr, "Udp::run_udp_server::loop continue: (banned ip)");
181165
continue;
182166
}
@@ -193,7 +177,7 @@ impl Launcher {
193177
// chance to finish. However, the buffer is yielding before
194178
// aborting one tasks, giving it the chance to finish.
195179
let abort_handle: tokio::task::AbortHandle =
196-
tokio::task::spawn(processor.process_request(req, connection_id_errors_per_ip.clone())).abort_handle();
180+
tokio::task::spawn(processor.process_request(req, ban_service.clone())).abort_handle();
197181

198182
if abort_handle.is_finished() {
199183
continue;

src/servers/udp/server/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use thiserror::Error;
66

77
use super::RawRequest;
88

9+
pub mod banning;
910
pub mod bound_socket;
1011
pub mod launcher;
1112
pub mod processor;

src/servers/udp/server/processor.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@ use std::net::{IpAddr, SocketAddr};
33
use std::sync::Arc;
44

55
use aquatic_udp_protocol::Response;
6-
use bloom::CountingBloomFilter;
76
use tokio::sync::RwLock;
87
use tracing::{instrument, Level};
98

9+
use super::banning::BanService;
1010
use super::bound_socket::BoundSocket;
1111
use crate::core::{statistics, Tracker};
1212
use crate::servers::udp::handlers::CookieTimeValues;
@@ -27,15 +27,15 @@ impl Processor {
2727
}
2828
}
2929

30-
#[instrument(skip(self, request, connection_id_errors_per_ip))]
31-
pub async fn process_request(self, request: RawRequest, connection_id_errors_per_ip: Arc<RwLock<CountingBloomFilter>>) {
30+
#[instrument(skip(self, request, ban_service))]
31+
pub async fn process_request(self, request: RawRequest, ban_service: Arc<RwLock<BanService>>) {
3232
let from = request.from;
3333
let response = handlers::handle_packet(
3434
request,
3535
&self.tracker,
3636
self.socket.address(),
3737
CookieTimeValues::new(self.cookie_lifetime),
38-
connection_id_errors_per_ip,
38+
ban_service,
3939
)
4040
.await;
4141

0 commit comments

Comments
 (0)