Skip to content

Commit

Permalink
Apply rate limiting to countries other than the target country
Browse files Browse the repository at this point in the history
  • Loading branch information
pyama86 committed Dec 19, 2023
1 parent 5f3a38b commit 42effd1
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 16 deletions.
27 changes: 23 additions & 4 deletions country_limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ import (
)

type CountryLimiter struct {
db *maxminddb.Reader
countries []string
db *maxminddb.Reader
limitRateForOtherCountries bool
countries []string
BaseLimiter
}

Expand All @@ -26,15 +27,17 @@ func NewCountryLimiter(
reqLimit int,
windowLen time.Duration,
targetExtensions []string,
limitRateForOtherCountries bool, // 他の国々にレートリミットを適用する
onRequestLimit func(*rl.Context, string) http.HandlerFunc,
) (*CountryLimiter, error) {
db, err := maxminddb.Open(dbPath)
if err != nil {
return nil, err
}
return &CountryLimiter{
db: db,
countries: countries,
db: db,
countries: countries,
limitRateForOtherCountries: limitRateForOtherCountries,
BaseLimiter: NewBaseLimiter(
reqLimit,
windowLen,
Expand All @@ -58,15 +61,31 @@ func (l *CountryLimiter) Rule(r *http.Request) (*rl.Rule, error) {
if err != nil {
return nil, err
}

if country == "" {
return &rl.Rule{ReqLimit: -1}, nil
}

for _, c := range l.countries {
if country == c {
if l.limitRateForOtherCountries {
return &rl.Rule{ReqLimit: -1}, nil
}
return &rl.Rule{
Key: remoteAddr,
ReqLimit: l.reqLimit,
WindowLen: l.windowLen,
}, nil
}
}

if l.limitRateForOtherCountries {
return &rl.Rule{
Key: remoteAddr,
ReqLimit: l.reqLimit,
WindowLen: l.windowLen,
}, nil
}
return &rl.Rule{ReqLimit: -1}, nil
}

Expand Down
50 changes: 38 additions & 12 deletions country_limiter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,16 @@ func testHTTPRequest(remoteAddr string) *http.Request {
func TestCountryLimiter(t *testing.T) {
abspath, _ := filepath.Abs("./testdata/GeoIP2-Country-Test.mmdb")
reqLimit := 10
cl, err := NewCountryLimiter(abspath, []string{"US"}, reqLimit, 1*time.Hour, nil, nil)
if err != nil {
t.Fatal(err)
}

// Define your test cases
testCases := []struct {
name string
request *http.Request
countries []string
expectedCountry string
expectedError bool
allowed bool
name string
request *http.Request
countries []string
limitRateForOtherCountries bool
expectedCountry string
expectedError bool
allowed bool
}{
{
name: "Valid IP from United States With Port",
Expand All @@ -47,20 +44,49 @@ func TestCountryLimiter(t *testing.T) {
allowed: true,
expectedError: false,
},

{
name: "Invalid IP format",
request: testHTTPRequest("invalid-ip"),
expectedCountry: "",
allowed: false,
expectedError: true,
},
{
name: "Valid IP from United States With Port and limitRateForOtherCountries,empty country",
request: testHTTPRequest("1.1.1.1"),
expectedCountry: "",
countries: []string{"US"},
limitRateForOtherCountries: true,
allowed: false,
expectedError: false,
},

{
name: "Valid IP from United States With Port and limitRateForOtherCountries,Franch",
request: testHTTPRequest("67.43.156.0"),
expectedCountry: "BT",
countries: []string{"US"},
limitRateForOtherCountries: true,
allowed: true,
expectedError: false,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Run the country function to get the ISO country code

cl, err := NewCountryLimiter(
abspath,
[]string{"US"},
reqLimit,
1*time.Hour,
nil,
tc.limitRateForOtherCountries,
nil,
)
if err != nil {
t.Fatal(err)
}
remoteAddr := strings.Split(tc.request.RemoteAddr, ":")[0]
country, err := cl.country(remoteAddr)
if tc.expectedError {
Expand Down

0 comments on commit 42effd1

Please sign in to comment.