Skip to content

Commit f4e4696

Browse files
committed
Add test
1 parent 9df0b99 commit f4e4696

File tree

7 files changed

+154
-2
lines changed

7 files changed

+154
-2
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
coverage.out

.octocov.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
# generated by octocov init
21
coverage:
32
if: true
43
codeToTestRatio:
54
code:
65
- '**/*.go'
6+
- '!testutil/**/*.go'
77
- '!**/*_test.go'
88
test:
9+
- 'testutil/**/*.go'
910
- '**/*_test.go'
1011
testExecutionTime:
1112
if: true

Makefile

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ default: test
55
ci: depsdev test
66

77
test:
8-
go test ./... -coverprofile=coverage.out -covermode=count
8+
cp go.mod testdata/go_test.mod
9+
go mod tidy -modfile=testdata/go_test.mod
10+
go test ./... -modfile=testdata/go_test.mod -coverprofile=coverage.out -covermode=count
911

1012
lint:
1113
golangci-lint run ./...

rl_test.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
package rl_test
2+
3+
import (
4+
"net/http"
5+
"net/http/httptest"
6+
"testing"
7+
8+
"github.com/go-chi/httprate"
9+
"github.com/k1LoW/rl"
10+
"github.com/k1LoW/rl/testutil"
11+
)
12+
13+
func TestRL(t *testing.T) {
14+
tests := []struct {
15+
name string
16+
keyFunc httprate.KeyFunc
17+
reqLimit int
18+
hosts []string
19+
wantReqCount int
20+
}{
21+
{"key by ip", httprate.KeyByIP, 10, []string{"a.example.com", "b.example.com"}, 10},
22+
{"key by host", testutil.KeyByHost, 10, []string{"a.example.com", "b.example.com"}, 20},
23+
}
24+
for _, tt := range tests {
25+
t.Run(tt.name, func(t *testing.T) {
26+
r := http.NewServeMux()
27+
r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
28+
w.Write([]byte("Hello, world"))
29+
})
30+
l := testutil.NewLimiter(tt.reqLimit, tt.keyFunc)
31+
m := rl.New(l)
32+
ts := httptest.NewServer(m(r))
33+
t.Cleanup(func() {
34+
ts.Close()
35+
})
36+
got := 0
37+
L:
38+
for {
39+
for _, host := range tt.hosts {
40+
req, err := http.NewRequest("GET", ts.URL, nil)
41+
if err != nil {
42+
t.Fatal(err)
43+
}
44+
req.Host = host
45+
res, err := http.DefaultClient.Do(req)
46+
if err != nil {
47+
t.Fatal(err)
48+
}
49+
if res.StatusCode == http.StatusTooManyRequests {
50+
break L
51+
}
52+
got++
53+
}
54+
}
55+
if got != tt.wantReqCount {
56+
t.Errorf("got %v want %v", got, tt.wantReqCount)
57+
}
58+
})
59+
}
60+
}

testdata/go_test.mod

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
module github.com/k1LoW/rl
2+
3+
go 1.21.0
4+
5+
require (
6+
github.com/go-chi/httprate v0.7.4
7+
golang.org/x/sync v0.3.0
8+
)
9+
10+
require github.com/cespare/xxhash/v2 v2.1.2 // indirect

testdata/go_test.sum

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
2+
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
3+
github.com/go-chi/httprate v0.7.4 h1:a2GIjv8he9LRf3712zxxnRdckQCm7I8y8yQhkJ84V6M=
4+
github.com/go-chi/httprate v0.7.4/go.mod h1:6GOYBSwnpra4CQfAKXu8sQZg+nZ0M1g9QnyFvxrAB8A=
5+
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
6+
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=

testutil/limiter.go

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
package testutil
2+
3+
import (
4+
"net/http"
5+
"time"
6+
7+
"github.com/go-chi/httprate"
8+
)
9+
10+
type Limiter struct {
11+
counts map[string]map[time.Time]int
12+
reqLimit int
13+
keyFunc httprate.KeyFunc
14+
}
15+
16+
func KeyByHost(r *http.Request) (string, error) {
17+
return r.Host, nil
18+
}
19+
20+
func NewLimiter(reqLimit int, keyFunc httprate.KeyFunc) *Limiter {
21+
return &Limiter{
22+
counts: map[string]map[time.Time]int{},
23+
reqLimit: reqLimit,
24+
keyFunc: keyFunc,
25+
}
26+
}
27+
28+
func (l *Limiter) Name() string {
29+
return "testutil.Limiter"
30+
}
31+
32+
func (l *Limiter) KeyAndRateLimit(r *http.Request) (string, int, time.Duration, error) {
33+
key, err := l.keyFunc(r)
34+
if err != nil {
35+
return "", 0, 0, err
36+
}
37+
return key, l.reqLimit, time.Second, nil
38+
}
39+
40+
func (l *Limiter) SetXRateLimitHeaders(err error) bool {
41+
return true
42+
}
43+
44+
func (l *Limiter) OnRequestLimit() http.HandlerFunc {
45+
return func(w http.ResponseWriter, r *http.Request) {
46+
w.WriteHeader(http.StatusTooManyRequests)
47+
w.Write([]byte("Too many requests"))
48+
}
49+
}
50+
51+
func (l *Limiter) Get(key string, window time.Time) (count int, err error) {
52+
if _, ok := l.counts[key]; !ok {
53+
l.counts[key] = map[time.Time]int{}
54+
return 0, nil
55+
}
56+
if _, ok := l.counts[key][window]; !ok {
57+
l.counts[key][window] = 0
58+
return 0, nil
59+
}
60+
return l.counts[key][window], nil
61+
}
62+
63+
func (l *Limiter) Increment(key string, currentWindow time.Time) error {
64+
if _, ok := l.counts[key]; !ok {
65+
l.counts[key] = map[time.Time]int{}
66+
}
67+
if _, ok := l.counts[key][currentWindow]; !ok {
68+
l.counts[key][currentWindow] = 0
69+
}
70+
l.counts[key][currentWindow]++
71+
return nil
72+
}

0 commit comments

Comments
 (0)