Skip to content

Commit 5ea34d2

Browse files
committed
Close #3: Add support for rate limiting
1 parent d37ad63 commit 5ea34d2

File tree

5 files changed

+178
-4
lines changed

5 files changed

+178
-4
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,9 @@ have the `backup` permission:
166166
```go
167167
router.Handle("/backup", gate.ProtectWithPermissions(&testHandler{}, []string{"read", "backup"}))
168168
```
169+
170+
## Rate limiting
171+
To add a rate limit of 100 requests per second:
172+
```
173+
gate := g8.NewGate(nil).WithRateLimit(100)
174+
```

gate.go

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,28 @@ const (
99
// AuthorizationHeader is the header in which g8 looks for the authorization bearer token
1010
AuthorizationHeader = "Authorization"
1111

12-
// DefaultUnauthorizedResponseBody is the default response body returned if a request was sent with a missing or
13-
// invalid token
12+
// DefaultUnauthorizedResponseBody is the default response body returned if a request was sent with a missing or invalid token
1413
DefaultUnauthorizedResponseBody = "Authorization Bearer token is missing or invalid"
14+
15+
// DefaultTooManyRequestsResponseBody is the default response body returned if a request exceeded the allowed rate limit
16+
DefaultTooManyRequestsResponseBody = "Too Many Requests"
1517
)
1618

1719
// Gate is lock to the front door of your API, letting only those you allow through.
1820
type Gate struct {
1921
authorizationService *AuthorizationService
2022
unauthorizedResponseBody []byte
23+
24+
rateLimiter *RateLimiter
25+
tooManyRequestsResponseBody []byte
2126
}
2227

2328
// NewGate creates a new Gate.
2429
func NewGate(authorizationService *AuthorizationService) *Gate {
2530
return &Gate{
26-
unauthorizedResponseBody: []byte(DefaultUnauthorizedResponseBody),
27-
authorizationService: authorizationService,
31+
authorizationService: authorizationService,
32+
unauthorizedResponseBody: []byte(DefaultUnauthorizedResponseBody),
33+
tooManyRequestsResponseBody: []byte(DefaultTooManyRequestsResponseBody),
2834
}
2935
}
3036

@@ -34,6 +40,12 @@ func (gate *Gate) WithCustomUnauthorizedResponseBody(unauthorizedResponseBody []
3440
return gate
3541
}
3642

43+
// WithRateLimit adds rate limiting to the Gate
44+
func (gate *Gate) WithRateLimit(maximumRequestsPerSecond int) *Gate {
45+
gate.rateLimiter = NewRateLimiter(maximumRequestsPerSecond)
46+
return gate
47+
}
48+
3749
// Protect secures a handler, requiring requests going through to have a valid Authorization Bearer token.
3850
// Unlike ProtectWithPermissions, Protect will allow access to any registered tokens, regardless of their permissions
3951
// or lack thereof.
@@ -106,6 +118,13 @@ func (gate *Gate) ProtectFunc(handlerFunc http.HandlerFunc) http.HandlerFunc {
106118
//
107119
func (gate *Gate) ProtectFuncWithPermissions(handlerFunc http.HandlerFunc, permissions []string) http.HandlerFunc {
108120
return func(writer http.ResponseWriter, request *http.Request) {
121+
if gate.rateLimiter != nil {
122+
if !gate.rateLimiter.Try() {
123+
writer.WriteHeader(http.StatusTooManyRequests)
124+
_, _ = writer.Write(gate.tooManyRequestsResponseBody)
125+
return
126+
}
127+
}
109128
if gate.authorizationService != nil {
110129
token := extractTokenFromRequest(request)
111130
if !gate.authorizationService.IsAuthorized(token, permissions) {

gate_test.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,3 +399,37 @@ func TestGate_ProtectWithNilAuthorizationService(t *testing.T) {
399399
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
400400
}
401401
}
402+
403+
func TestGate_ProtectWithNilAuthorizationServiceAndRateLimit(t *testing.T) {
404+
gate := NewGate(nil).WithRateLimit(2)
405+
request, _ := http.NewRequest("GET", "/handle", nil)
406+
router := http.NewServeMux()
407+
router.Handle("/handle", gate.Protect(&testHandler{}))
408+
409+
responseRecorder := httptest.NewRecorder()
410+
router.ServeHTTP(responseRecorder, request)
411+
if responseRecorder.Code != http.StatusOK {
412+
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
413+
}
414+
415+
responseRecorder = httptest.NewRecorder()
416+
router.ServeHTTP(responseRecorder, request)
417+
if responseRecorder.Code != http.StatusOK {
418+
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
419+
}
420+
421+
responseRecorder = httptest.NewRecorder()
422+
router.ServeHTTP(responseRecorder, request)
423+
if responseRecorder.Code != http.StatusTooManyRequests {
424+
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusTooManyRequests, responseRecorder.Code)
425+
}
426+
427+
// Wait for rate limit time window to pass
428+
time.Sleep(time.Second)
429+
430+
responseRecorder = httptest.NewRecorder()
431+
router.ServeHTTP(responseRecorder, request)
432+
if responseRecorder.Code != http.StatusOK {
433+
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
434+
}
435+
}

ratelimiter.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package g8
2+
3+
import (
4+
"sync"
5+
"time"
6+
)
7+
8+
// RateLimiter is a fixed rate limiter
9+
type RateLimiter struct {
10+
maximumExecutionsPerSecond int
11+
executionsLeftInWindow int
12+
windowStartTime time.Time
13+
mutex sync.Mutex
14+
}
15+
16+
// NewRateLimiter creates a RateLimiter
17+
func NewRateLimiter(maximumExecutionsPerSecond int) *RateLimiter {
18+
return &RateLimiter{
19+
windowStartTime: time.Now(),
20+
executionsLeftInWindow: maximumExecutionsPerSecond,
21+
maximumExecutionsPerSecond: maximumExecutionsPerSecond,
22+
}
23+
}
24+
25+
// Try updates the number of executions if the rate limit quota hasn't been reached and returns whether the
26+
// attempt was successful or not.
27+
//
28+
// Returns false if the execution was not successful (rate limit quota has been reached)
29+
// Returns true if the execution was successful (rate limit quota has not been reached)
30+
func (r *RateLimiter) Try() bool {
31+
r.mutex.Lock()
32+
defer r.mutex.Unlock()
33+
if time.Now().Add(-time.Second).After(r.windowStartTime) {
34+
r.windowStartTime = time.Now()
35+
r.executionsLeftInWindow = r.maximumExecutionsPerSecond
36+
}
37+
if r.executionsLeftInWindow == 0 {
38+
return false
39+
}
40+
r.executionsLeftInWindow--
41+
return true
42+
}

ratelimiter_test.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
package g8
2+
3+
import (
4+
"testing"
5+
"time"
6+
)
7+
8+
func TestNewRateLimiter(t *testing.T) {
9+
rl := NewRateLimiter(2)
10+
if rl.maximumExecutionsPerSecond != 2 {
11+
t.Errorf("expected maximumExecutionsPerSecond to be %d, got %d", 2, rl.maximumExecutionsPerSecond)
12+
}
13+
if rl.executionsLeftInWindow != 2 {
14+
t.Errorf("expected executionsLeftInWindow to be %d, got %d", 2, rl.executionsLeftInWindow)
15+
}
16+
// First execution: should not be rate limited
17+
if notRateLimited := rl.Try(); !notRateLimited {
18+
t.Error("expected Try to return true")
19+
}
20+
if rl.maximumExecutionsPerSecond != 2 {
21+
t.Errorf("expected maximumExecutionsPerSecond to be %d, got %d", 2, rl.maximumExecutionsPerSecond)
22+
}
23+
if rl.executionsLeftInWindow != 1 {
24+
t.Errorf("expected executionsLeftInWindow to be %d, got %d", 1, rl.executionsLeftInWindow)
25+
}
26+
// Second execution: should not be rate limited
27+
if notRateLimited := rl.Try(); !notRateLimited {
28+
t.Error("expected Try to return true")
29+
}
30+
if rl.maximumExecutionsPerSecond != 2 {
31+
t.Errorf("expected maximumExecutionsPerSecond to be %d, got %d", 2, rl.maximumExecutionsPerSecond)
32+
}
33+
if rl.executionsLeftInWindow != 0 {
34+
t.Errorf("expected executionsLeftInWindow to be %d, got %d", 0, rl.executionsLeftInWindow)
35+
}
36+
// Third execution: should be rate limited
37+
if notRateLimited := rl.Try(); notRateLimited {
38+
t.Error("expected Try to return false")
39+
}
40+
if rl.maximumExecutionsPerSecond != 2 {
41+
t.Errorf("expected maximumExecutionsPerSecond to be %d, got %d", 2, rl.maximumExecutionsPerSecond)
42+
}
43+
if rl.executionsLeftInWindow != 0 {
44+
t.Errorf("expected executionsLeftInWindow to be %d, got %d", 0, rl.executionsLeftInWindow)
45+
}
46+
}
47+
48+
func TestRateLimiter_Try(t *testing.T) {
49+
rl := NewRateLimiter(5)
50+
for i := 0; i < 20; i++ {
51+
notRateLimited := rl.Try()
52+
if i < 5 {
53+
if !notRateLimited {
54+
t.Fatal("expected to not be rate limited")
55+
}
56+
} else {
57+
if notRateLimited {
58+
t.Fatal("expected to be rate limited")
59+
}
60+
}
61+
}
62+
}
63+
64+
func TestRateLimiter_TryAlwaysUnderRateLimit(t *testing.T) {
65+
rl := NewRateLimiter(20)
66+
for i := 0; i < 45; i++ {
67+
notRateLimited := rl.Try()
68+
if !notRateLimited {
69+
t.Fatal("expected to not be rate limited")
70+
}
71+
time.Sleep(51 * time.Millisecond)
72+
}
73+
}

0 commit comments

Comments
 (0)