Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
k1LoW committed Aug 27, 2023
1 parent 9df0b99 commit f4e4696
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 2 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
coverage.out
3 changes: 2 additions & 1 deletion .octocov.yml
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# generated by octocov init
coverage:
if: true
codeToTestRatio:
code:
- '**/*.go'
- '!testutil/**/*.go'
- '!**/*_test.go'
test:
- 'testutil/**/*.go'
- '**/*_test.go'
testExecutionTime:
if: true
Expand Down
4 changes: 3 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ default: test
ci: depsdev test

test:
go test ./... -coverprofile=coverage.out -covermode=count
cp go.mod testdata/go_test.mod
go mod tidy -modfile=testdata/go_test.mod
go test ./... -modfile=testdata/go_test.mod -coverprofile=coverage.out -covermode=count

lint:
golangci-lint run ./...
Expand Down
60 changes: 60 additions & 0 deletions rl_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package rl_test

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/go-chi/httprate"
"github.com/k1LoW/rl"
"github.com/k1LoW/rl/testutil"
)

func TestRL(t *testing.T) {
tests := []struct {
name string
keyFunc httprate.KeyFunc
reqLimit int
hosts []string
wantReqCount int
}{
{"key by ip", httprate.KeyByIP, 10, []string{"a.example.com", "b.example.com"}, 10},
{"key by host", testutil.KeyByHost, 10, []string{"a.example.com", "b.example.com"}, 20},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := http.NewServeMux()
r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Hello, world"))
})
l := testutil.NewLimiter(tt.reqLimit, tt.keyFunc)
m := rl.New(l)
ts := httptest.NewServer(m(r))
t.Cleanup(func() {
ts.Close()
})
got := 0
L:
for {
for _, host := range tt.hosts {
req, err := http.NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatal(err)
}
req.Host = host
res, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
if res.StatusCode == http.StatusTooManyRequests {
break L
}
got++
}
}
if got != tt.wantReqCount {
t.Errorf("got %v want %v", got, tt.wantReqCount)
}
})
}
}
10 changes: 10 additions & 0 deletions testdata/go_test.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module github.com/k1LoW/rl

go 1.21.0

require (
github.com/go-chi/httprate v0.7.4
golang.org/x/sync v0.3.0
)

require github.com/cespare/xxhash/v2 v2.1.2 // indirect
6 changes: 6 additions & 0 deletions testdata/go_test.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/go-chi/httprate v0.7.4 h1:a2GIjv8he9LRf3712zxxnRdckQCm7I8y8yQhkJ84V6M=
github.com/go-chi/httprate v0.7.4/go.mod h1:6GOYBSwnpra4CQfAKXu8sQZg+nZ0M1g9QnyFvxrAB8A=
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
72 changes: 72 additions & 0 deletions testutil/limiter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package testutil

import (
"net/http"
"time"

"github.com/go-chi/httprate"
)

type Limiter struct {
counts map[string]map[time.Time]int
reqLimit int
keyFunc httprate.KeyFunc
}

func KeyByHost(r *http.Request) (string, error) {
return r.Host, nil
}

func NewLimiter(reqLimit int, keyFunc httprate.KeyFunc) *Limiter {
return &Limiter{
counts: map[string]map[time.Time]int{},
reqLimit: reqLimit,
keyFunc: keyFunc,
}
}

func (l *Limiter) Name() string {
return "testutil.Limiter"
}

func (l *Limiter) KeyAndRateLimit(r *http.Request) (string, int, time.Duration, error) {
key, err := l.keyFunc(r)
if err != nil {
return "", 0, 0, err
}
return key, l.reqLimit, time.Second, nil
}

func (l *Limiter) SetXRateLimitHeaders(err error) bool {
return true
}

func (l *Limiter) OnRequestLimit() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTooManyRequests)
w.Write([]byte("Too many requests"))
}
}

func (l *Limiter) Get(key string, window time.Time) (count int, err error) {
if _, ok := l.counts[key]; !ok {
l.counts[key] = map[time.Time]int{}
return 0, nil
}
if _, ok := l.counts[key][window]; !ok {
l.counts[key][window] = 0
return 0, nil
}
return l.counts[key][window], nil
}

func (l *Limiter) Increment(key string, currentWindow time.Time) error {
if _, ok := l.counts[key]; !ok {
l.counts[key] = map[time.Time]int{}
}
if _, ok := l.counts[key][currentWindow]; !ok {
l.counts[key][currentWindow] = 0
}
l.counts[key][currentWindow]++
return nil
}

0 comments on commit f4e4696

Please sign in to comment.