diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2d83068 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +coverage.out diff --git a/.octocov.yml b/.octocov.yml index 94276ec..3eb5749 100644 --- a/.octocov.yml +++ b/.octocov.yml @@ -1,11 +1,12 @@ -# generated by octocov init coverage: if: true codeToTestRatio: code: - '**/*.go' + - '!testutil/**/*.go' - '!**/*_test.go' test: + - 'testutil/**/*.go' - '**/*_test.go' testExecutionTime: if: true diff --git a/Makefile b/Makefile index 42dcbc3..ca6c15e 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,9 @@ default: test ci: depsdev test test: - go test ./... -coverprofile=coverage.out -covermode=count + cp go.mod testdata/go_test.mod + go mod tidy -modfile=testdata/go_test.mod + go test ./... -modfile=testdata/go_test.mod -coverprofile=coverage.out -covermode=count lint: golangci-lint run ./... diff --git a/rl_test.go b/rl_test.go new file mode 100644 index 0000000..a74bed9 --- /dev/null +++ b/rl_test.go @@ -0,0 +1,60 @@ +package rl_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/httprate" + "github.com/k1LoW/rl" + "github.com/k1LoW/rl/testutil" +) + +func TestRL(t *testing.T) { + tests := []struct { + name string + keyFunc httprate.KeyFunc + reqLimit int + hosts []string + wantReqCount int + }{ + {"key by ip", httprate.KeyByIP, 10, []string{"a.example.com", "b.example.com"}, 10}, + {"key by host", testutil.KeyByHost, 10, []string{"a.example.com", "b.example.com"}, 20}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := http.NewServeMux() + r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("Hello, world")) + }) + l := testutil.NewLimiter(tt.reqLimit, tt.keyFunc) + m := rl.New(l) + ts := httptest.NewServer(m(r)) + t.Cleanup(func() { + ts.Close() + }) + got := 0 + L: + for { + for _, host := range tt.hosts { + req, err := http.NewRequest("GET", ts.URL, nil) + if err != nil { + t.Fatal(err) + } + req.Host = host + res, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + if res.StatusCode == http.StatusTooManyRequests { + break L + } + got++ + } + } + if got != tt.wantReqCount { + t.Errorf("got %v want %v", got, tt.wantReqCount) + } + }) + } +} diff --git a/testdata/go_test.mod b/testdata/go_test.mod new file mode 100644 index 0000000..1db5631 --- /dev/null +++ b/testdata/go_test.mod @@ -0,0 +1,10 @@ +module github.com/k1LoW/rl + +go 1.21.0 + +require ( + github.com/go-chi/httprate v0.7.4 + golang.org/x/sync v0.3.0 +) + +require github.com/cespare/xxhash/v2 v2.1.2 // indirect diff --git a/testdata/go_test.sum b/testdata/go_test.sum new file mode 100644 index 0000000..9231e1e --- /dev/null +++ b/testdata/go_test.sum @@ -0,0 +1,6 @@ +github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= +github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/go-chi/httprate v0.7.4 h1:a2GIjv8he9LRf3712zxxnRdckQCm7I8y8yQhkJ84V6M= +github.com/go-chi/httprate v0.7.4/go.mod h1:6GOYBSwnpra4CQfAKXu8sQZg+nZ0M1g9QnyFvxrAB8A= +golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= +golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= diff --git a/testutil/limiter.go b/testutil/limiter.go new file mode 100644 index 0000000..b2fb30b --- /dev/null +++ b/testutil/limiter.go @@ -0,0 +1,72 @@ +package testutil + +import ( + "net/http" + "time" + + "github.com/go-chi/httprate" +) + +type Limiter struct { + counts map[string]map[time.Time]int + reqLimit int + keyFunc httprate.KeyFunc +} + +func KeyByHost(r *http.Request) (string, error) { + return r.Host, nil +} + +func NewLimiter(reqLimit int, keyFunc httprate.KeyFunc) *Limiter { + return &Limiter{ + counts: map[string]map[time.Time]int{}, + reqLimit: reqLimit, + keyFunc: keyFunc, + } +} + +func (l *Limiter) Name() string { + return "testutil.Limiter" +} + +func (l *Limiter) KeyAndRateLimit(r *http.Request) (string, int, time.Duration, error) { + key, err := l.keyFunc(r) + if err != nil { + return "", 0, 0, err + } + return key, l.reqLimit, time.Second, nil +} + +func (l *Limiter) SetXRateLimitHeaders(err error) bool { + return true +} + +func (l *Limiter) OnRequestLimit() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusTooManyRequests) + w.Write([]byte("Too many requests")) + } +} + +func (l *Limiter) Get(key string, window time.Time) (count int, err error) { + if _, ok := l.counts[key]; !ok { + l.counts[key] = map[time.Time]int{} + return 0, nil + } + if _, ok := l.counts[key][window]; !ok { + l.counts[key][window] = 0 + return 0, nil + } + return l.counts[key][window], nil +} + +func (l *Limiter) Increment(key string, currentWindow time.Time) error { + if _, ok := l.counts[key]; !ok { + l.counts[key] = map[time.Time]int{} + } + if _, ok := l.counts[key][currentWindow]; !ok { + l.counts[key][currentWindow] = 0 + } + l.counts[key][currentWindow]++ + return nil +}