Skip to content

Commit 42effd1

Browse files
author
pyama
committed
Apply rate limiting to countries other than the target country
1 parent 5f3a38b commit 42effd1

File tree

2 files changed

+61
-16
lines changed

2 files changed

+61
-16
lines changed

country_limiter.go

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@ import (
1313
)
1414

1515
type CountryLimiter struct {
16-
db *maxminddb.Reader
17-
countries []string
16+
db *maxminddb.Reader
17+
limitRateForOtherCountries bool
18+
countries []string
1819
BaseLimiter
1920
}
2021

@@ -26,15 +27,17 @@ func NewCountryLimiter(
2627
reqLimit int,
2728
windowLen time.Duration,
2829
targetExtensions []string,
30+
limitRateForOtherCountries bool, // 他の国々にレートリミットを適用する
2931
onRequestLimit func(*rl.Context, string) http.HandlerFunc,
3032
) (*CountryLimiter, error) {
3133
db, err := maxminddb.Open(dbPath)
3234
if err != nil {
3335
return nil, err
3436
}
3537
return &CountryLimiter{
36-
db: db,
37-
countries: countries,
38+
db: db,
39+
countries: countries,
40+
limitRateForOtherCountries: limitRateForOtherCountries,
3841
BaseLimiter: NewBaseLimiter(
3942
reqLimit,
4043
windowLen,
@@ -58,15 +61,31 @@ func (l *CountryLimiter) Rule(r *http.Request) (*rl.Rule, error) {
5861
if err != nil {
5962
return nil, err
6063
}
64+
65+
if country == "" {
66+
return &rl.Rule{ReqLimit: -1}, nil
67+
}
68+
6169
for _, c := range l.countries {
6270
if country == c {
71+
if l.limitRateForOtherCountries {
72+
return &rl.Rule{ReqLimit: -1}, nil
73+
}
6374
return &rl.Rule{
6475
Key: remoteAddr,
6576
ReqLimit: l.reqLimit,
6677
WindowLen: l.windowLen,
6778
}, nil
6879
}
6980
}
81+
82+
if l.limitRateForOtherCountries {
83+
return &rl.Rule{
84+
Key: remoteAddr,
85+
ReqLimit: l.reqLimit,
86+
WindowLen: l.windowLen,
87+
}, nil
88+
}
7089
return &rl.Rule{ReqLimit: -1}, nil
7190
}
7291

country_limiter_test.go

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,16 @@ func testHTTPRequest(remoteAddr string) *http.Request {
1818
func TestCountryLimiter(t *testing.T) {
1919
abspath, _ := filepath.Abs("./testdata/GeoIP2-Country-Test.mmdb")
2020
reqLimit := 10
21-
cl, err := NewCountryLimiter(abspath, []string{"US"}, reqLimit, 1*time.Hour, nil, nil)
22-
if err != nil {
23-
t.Fatal(err)
24-
}
2521

2622
// Define your test cases
2723
testCases := []struct {
28-
name string
29-
request *http.Request
30-
countries []string
31-
expectedCountry string
32-
expectedError bool
33-
allowed bool
24+
name string
25+
request *http.Request
26+
countries []string
27+
limitRateForOtherCountries bool
28+
expectedCountry string
29+
expectedError bool
30+
allowed bool
3431
}{
3532
{
3633
name: "Valid IP from United States With Port",
@@ -47,20 +44,49 @@ func TestCountryLimiter(t *testing.T) {
4744
allowed: true,
4845
expectedError: false,
4946
},
50-
5147
{
5248
name: "Invalid IP format",
5349
request: testHTTPRequest("invalid-ip"),
5450
expectedCountry: "",
5551
allowed: false,
5652
expectedError: true,
5753
},
54+
{
55+
name: "Valid IP from United States With Port and limitRateForOtherCountries,empty country",
56+
request: testHTTPRequest("1.1.1.1"),
57+
expectedCountry: "",
58+
countries: []string{"US"},
59+
limitRateForOtherCountries: true,
60+
allowed: false,
61+
expectedError: false,
62+
},
63+
64+
{
65+
name: "Valid IP from United States With Port and limitRateForOtherCountries,Franch",
66+
request: testHTTPRequest("67.43.156.0"),
67+
expectedCountry: "BT",
68+
countries: []string{"US"},
69+
limitRateForOtherCountries: true,
70+
allowed: true,
71+
expectedError: false,
72+
},
5873
}
5974

6075
for _, tc := range testCases {
6176
t.Run(tc.name, func(t *testing.T) {
6277
// Run the country function to get the ISO country code
63-
78+
cl, err := NewCountryLimiter(
79+
abspath,
80+
[]string{"US"},
81+
reqLimit,
82+
1*time.Hour,
83+
nil,
84+
tc.limitRateForOtherCountries,
85+
nil,
86+
)
87+
if err != nil {
88+
t.Fatal(err)
89+
}
6490
remoteAddr := strings.Split(tc.request.RemoteAddr, ":")[0]
6591
country, err := cl.country(remoteAddr)
6692
if tc.expectedError {

0 commit comments

Comments
 (0)