Skip to content

Commit

Permalink
Rename LimitError to Context
Browse files Browse the repository at this point in the history
  • Loading branch information
k1LoW committed Oct 31, 2023
1 parent dbd6772 commit d2295db
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 15 deletions.
12 changes: 6 additions & 6 deletions error.go → context.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ import (
"time"
)

var _ error = &LimitError{}
var _ error = &Context{}
var ErrRateLimitExceeded error = errors.New("rate limit exceeded")

// LimitError is the error returned by the middleware.
type LimitError struct {
// Context is the error returned by the middleware.
type Context struct {
StatusCode int
Err error
Limiter Limiter
Expand All @@ -20,8 +20,8 @@ type LimitError struct {
lh *limitHandler
}

func newLimitError(statusCode int, err error, lh *limitHandler) *LimitError {
return &LimitError{
func newContext(statusCode int, err error, lh *limitHandler) *Context {
return &Context{
StatusCode: statusCode,
Err: err,
Limiter: lh.limiter,
Expand All @@ -34,6 +34,6 @@ func newLimitError(statusCode int, err error, lh *limitHandler) *LimitError {
}

// Error returns the error message.
func (e *LimitError) Error() string {
func (e *Context) Error() string {
return e.Err.Error()
}
14 changes: 7 additions & 7 deletions rl.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ type Limiter interface {
// Rule returns the key and rate limit rule for the request
Rule(r *http.Request) (rule *Rule, err error)
// ShouldSetXRateLimitHeaders returns true if the X-RateLimit-* headers should be set
ShouldSetXRateLimitHeaders(*LimitError) bool
ShouldSetXRateLimitHeaders(*Context) bool
// OnRequestLimit returns the handler to be called when the rate limit is exceeded
OnRequestLimit(*LimitError) http.HandlerFunc
OnRequestLimit(*Context) http.HandlerFunc

// Get returns the current count for the key and window
Get(key string, window time.Time) (count int, err error) //nostyle:getters
Expand Down Expand Up @@ -115,23 +115,23 @@ func (lm *limitMw) Handler(next http.Handler) http.Handler {
case <-ctx.Done():
// Increment must be called even if the request limit is already exceeded
if err := lh.limiter.Increment(lh.key, currWindow); err != nil {
return newLimitError(http.StatusInternalServerError, err, lh)
return newContext(http.StatusInternalServerError, err, lh)
}
return nil
default:
}
rate, err := lh.status(now, currWindow)
if err != nil {
return newLimitError(http.StatusPreconditionRequired, err, lh)
return newContext(http.StatusPreconditionRequired, err, lh)
}
nrate := int(math.Round(rate))
if nrate >= lh.reqLimit {
return newLimitError(http.StatusTooManyRequests, ErrRateLimitExceeded, lh)
return newContext(http.StatusTooManyRequests, ErrRateLimitExceeded, lh)
}

lh.rateLimitRemaining = lh.reqLimit - nrate
if err := lh.limiter.Increment(lh.key, currWindow); err != nil {
return newLimitError(http.StatusInternalServerError, err, lh)
return newContext(http.StatusInternalServerError, err, lh)
}
return nil
})
Expand All @@ -144,7 +144,7 @@ func (lm *limitMw) Handler(next http.Handler) http.Handler {
// Wait for all limiters to finish
if err := eg.Wait(); err != nil {
// Handle first error
if e, ok := err.(*LimitError); ok {
if e, ok := err.(*Context); ok {
if e.lh.limiter.ShouldSetXRateLimitHeaders(e) {
w.Header().Set("X-RateLimit-Limit", fmt.Sprintf("%d", e.lh.reqLimit))
w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", e.lh.rateLimitRemaining))
Expand Down
4 changes: 2 additions & 2 deletions testutil/limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@ func (l *Limiter) Rule(r *http.Request) (*rl.Rule, error) {
}, nil
}

func (l *Limiter) ShouldSetXRateLimitHeaders(le *rl.LimitError) bool {
func (l *Limiter) ShouldSetXRateLimitHeaders(le *rl.Context) bool {
return true
}

func (l *Limiter) OnRequestLimit(le *rl.LimitError) http.HandlerFunc {
func (l *Limiter) OnRequestLimit(le *rl.Context) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if l.statusCode != 0 {
w.WriteHeader(l.statusCode)
Expand Down

0 comments on commit d2295db

Please sign in to comment.