Skip to content

Commit 736baa7

Browse files
authored
Add ability to inspect redis ratelimiters (#7)
1 parent 8c9e41a commit 736baa7

File tree

6 files changed

+414
-65
lines changed

6 files changed

+414
-65
lines changed

redis/leaky_bucket.go

Lines changed: 120 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ import (
2020
//
2121
// See: https://en.wikipedia.org/wiki/Leaky_bucket
2222
type LeakyBucket interface {
23+
// Inspect atomically inspects the leaky bucket and returns the capacity available. It does not take any tokens.
24+
Inspect(ctx context.Context, bucket *LeakyBucketOptions) (*InspectLeakyBucketResponse, error)
25+
2326
// Use atomically attempts to use the leaky bucket. Use takeAmount to set how many tokens should be attempted to be removed
2427
// from the bucket: they are atomic, either all tokens are taken, or the ratelimit is unsuccessful.
2528
Use(ctx context.Context, bucket *LeakyBucketOptions, takeAmount int) (*UseLeakyBucketResponse, error)
@@ -50,18 +53,6 @@ type LeakyBucketOptions struct {
5053
WindowSeconds int
5154
}
5255

53-
// UseLeakyBucketResponse defines the response parameters for LeakyBucket.Use()
54-
type UseLeakyBucketResponse struct {
55-
// Success is true when we were successfully able to take tokens from the bucket.
56-
Success bool
57-
58-
// RemainingTokens defines hwo many tokens are left in the bucket
59-
RemainingTokens int
60-
61-
// ResetAt is the time at which the bucket will be fully refilled
62-
ResetAt time.Time
63-
}
64-
6556
// LeakyBucketImpl implements a leaky bucket ratelimiter in Redis with Lua. This struct is compatible with the LeakyBucket interface
6657
//
6758
// See the LeakyBucket interface for more information about leaky bucket ratelimiters.
@@ -70,6 +61,8 @@ type LeakyBucketImpl struct {
7061
Adapter adapters.Adapter
7162

7263
// nowFunc is a private helper used to mock out time changes in unit testing
64+
//
65+
// if this is not defined, it falls back to time.Now()
7366
nowFunc func() time.Time
7467
}
7568

@@ -88,6 +81,80 @@ func (r *LeakyBucketImpl) now() time.Time {
8881
return r.nowFunc()
8982
}
9083

84+
// InspectLeakyBucketResponse defines the response parameters for LeakyBucket.Inspect()
85+
type InspectLeakyBucketResponse struct {
86+
// RemainingTokens defines hwo many tokens are left in the bucket
87+
RemainingTokens int
88+
89+
// ResetAt is the time at which the bucket will be fully refilled
90+
ResetAt time.Time
91+
}
92+
93+
// Inspect atomically inspects the leaky bucket and returns the capacity available. It does not take any tokens.
94+
func (r *LeakyBucketImpl) Inspect(ctx context.Context, bucket *LeakyBucketOptions) (*InspectLeakyBucketResponse, error) {
95+
const script = `
96+
local tokensKey = KEYS[1]
97+
local lastFillKey = KEYS[2]
98+
local capacity = tonumber(ARGV[1])
99+
local rate = tonumber(ARGV[2])
100+
local now = tonumber(ARGV[3])
101+
102+
local tokens = tonumber(redis.call("get", tokensKey))
103+
local lastFilled = tonumber(redis.call("get", lastFillKey))
104+
105+
if (tokens == nil) then
106+
tokens = 0 -- default empty buckets to 0
107+
end
108+
109+
if (tokens > capacity) then
110+
tokens = capacity -- shrink buckets if the capacity is reduced
111+
end
112+
113+
if (lastFilled == nil) then
114+
lastFilled = 0
115+
end
116+
117+
if (tokens < capacity) then
118+
local tokensToFill = math.floor((now - lastFilled) * rate)
119+
if (tokensToFill > 0) then
120+
tokens = math.min(capacity, tokens + tokensToFill)
121+
lastFilled = now
122+
end
123+
end
124+
125+
return {tokens, lastFilled}
126+
`
127+
refillRate := getRefillRate(bucket.MaximumCapacity, bucket.WindowSeconds)
128+
now := r.now().UTC().Unix()
129+
130+
resp, err := r.Adapter.Eval(ctx, script, []string{tokensKey(bucket.KeyPrefix), lastFillKey(bucket.KeyPrefix)}, []interface{}{bucket.MaximumCapacity, refillRate, now})
131+
if err != nil {
132+
return nil, fmt.Errorf("failed to query redis adapter: %w", err)
133+
}
134+
135+
output, err := parseInspectLeakyBucketResponse(resp)
136+
if err != nil {
137+
return nil, fmt.Errorf("parsing redis response: %w", err)
138+
}
139+
140+
return &InspectLeakyBucketResponse{
141+
RemainingTokens: output.remaining,
142+
ResetAt: calculateLeakyBucketFillTime(output.lastFilled, output.remaining, bucket.MaximumCapacity, bucket.WindowSeconds),
143+
}, nil
144+
}
145+
146+
// UseLeakyBucketResponse defines the response parameters for LeakyBucket.Use()
147+
type UseLeakyBucketResponse struct {
148+
// Success is true when we were successfully able to take tokens from the bucket.
149+
Success bool
150+
151+
// RemainingTokens defines hwo many tokens are left in the bucket
152+
RemainingTokens int
153+
154+
// ResetAt is the time at which the bucket will be fully refilled
155+
ResetAt time.Time
156+
}
157+
91158
// Use atomically attempts to use the leaky bucket. Use takeAmount to set how many tokens should be attempted to be removed
92159
// from the bucket: they are atomic, either all tokens are taken, or the ratelimit is unsuccessful.
93160
func (r *LeakyBucketImpl) Use(ctx context.Context, bucket *LeakyBucketOptions, takeAmount int) (*UseLeakyBucketResponse, error) {
@@ -139,15 +206,14 @@ return {success, tokens, lastFilled}
139206
refillRate := getRefillRate(bucket.MaximumCapacity, bucket.WindowSeconds)
140207
now := r.now().UTC().Unix()
141208

142-
tokensKey := bucket.KeyPrefix + "::tokens"
143-
lastFillKey := bucket.KeyPrefix + "::last_fill"
144-
145-
resp, err := r.Adapter.Eval(ctx, script, []string{tokensKey, lastFillKey}, []interface{}{bucket.MaximumCapacity, refillRate, now, takeAmount, bucket.WindowSeconds})
209+
resp, err := r.Adapter.Eval(ctx, script, []string{tokensKey(bucket.KeyPrefix), lastFillKey(bucket.KeyPrefix)}, []interface{}{
210+
bucket.MaximumCapacity, refillRate, now, takeAmount, bucket.WindowSeconds,
211+
})
146212
if err != nil {
147213
return nil, fmt.Errorf("failed to query redis adapter: %w", err)
148214
}
149215

150-
output, err := parseLeakyBucketResponse(resp)
216+
output, err := parseUseLeakyBucketResponse(resp)
151217
if err != nil {
152218
return nil, fmt.Errorf("parsing redis response: %w", err)
153219
}
@@ -159,6 +225,14 @@ return {success, tokens, lastFilled}
159225
}, nil
160226
}
161227

228+
func tokensKey(prefix string) string {
229+
return prefix + "::tokens"
230+
}
231+
232+
func lastFillKey(prefix string) string {
233+
return prefix + "::last_fill"
234+
}
235+
162236
func calculateLeakyBucketFillTime(lastFillUnix, currentTokens, maxCapacity, windowSeconds int) time.Time {
163237
resetAt := lastFillUnix // if delta is 0 (thus, all tokens are filled), then the bucket is already reset
164238
if delta := maxCapacity - currentTokens; delta > 0 {
@@ -182,35 +256,46 @@ func getRefillRate(maxCapacity, windowSeconds int) float64 {
182256
return float64(maxCapacity) / float64(windowSeconds)
183257
}
184258

185-
type leakyBucketOutput struct {
259+
type useLeakyBucketOutput struct {
186260
success bool
187261
remaining int
188262
lastFilled int
189263
}
190264

191-
func parseLeakyBucketResponse(v interface{}) (*leakyBucketOutput, error) {
192-
args, ok := v.([]interface{})
193-
if !ok {
194-
return nil, fmt.Errorf("expected []interface{} but got %T", v)
265+
func parseUseLeakyBucketResponse(v interface{}) (*useLeakyBucketOutput, error) {
266+
ints, err := parseRedisInt64Slice(v)
267+
if err != nil {
268+
return nil, err
195269
}
196270

197-
if len(args) != 3 {
198-
return nil, fmt.Errorf("expected 3 args but got %d", len(args))
271+
if len(ints) != 3 {
272+
return nil, fmt.Errorf("expected 3 args but got %d", len(ints))
199273
}
200274

201-
argInts := make([]int64, len(args))
202-
for i, argValue := range args {
203-
intValue, ok := argValue.(int64)
204-
if !ok {
205-
return nil, fmt.Errorf("expected int64 in arg[%d] but got %T", i, argValue)
206-
}
275+
return &useLeakyBucketOutput{
276+
success: ints[0] == 1,
277+
remaining: int(ints[1]),
278+
lastFilled: int(ints[2]),
279+
}, nil
280+
}
281+
282+
type inspectLeakyBucketOutput struct {
283+
remaining int
284+
lastFilled int
285+
}
286+
287+
func parseInspectLeakyBucketResponse(v interface{}) (*inspectLeakyBucketOutput, error) {
288+
ints, err := parseRedisInt64Slice(v)
289+
if err != nil {
290+
return nil, err
291+
}
207292

208-
argInts[i] = intValue
293+
if len(ints) != 2 {
294+
return nil, fmt.Errorf("expected 2 args but got %d", len(ints))
209295
}
210296

211-
return &leakyBucketOutput{
212-
success: argInts[0] == 1,
213-
remaining: int(argInts[1]),
214-
lastFilled: int(argInts[2]),
297+
return &inspectLeakyBucketOutput{
298+
remaining: int(ints[0]),
299+
lastFilled: int(ints[1]),
215300
}, nil
216301
}

redis/leaky_bucket_test.go

Lines changed: 105 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,84 @@ import (
1414
"github.com/stretchr/testify/assert"
1515
)
1616

17-
func TestUseLeakyBucket(t *testing.T) {
18-
t.Parallel()
17+
func TestInspectLeakyBucket(t *testing.T) {
18+
testCases := map[string]func(*miniredis.Miniredis) adapters.Adapter{
19+
"go-redis": func(t *miniredis.Miniredis) adapters.Adapter {
20+
return goredisadapter.NewAdapter(goredis.NewClient(&goredis.Options{Addr: t.Addr()}))
21+
},
22+
"redigo": func(t *miniredis.Miniredis) adapters.Adapter {
23+
conn, err := redigo.Dial("tcp", t.Addr())
24+
if err != nil {
25+
panic(err)
26+
}
27+
return redigoadapter.NewAdapter(conn)
28+
},
29+
}
30+
31+
for name, testCase := range testCases {
32+
testCase := testCase
33+
34+
t.Run(name, func(t *testing.T) {
35+
ctx := context.Background()
36+
now := time.Now().UTC()
37+
limiter := NewLeakyBucket(testCase(miniredis.RunT(t)))
38+
limiter.nowFunc = func() time.Time { return now }
39+
40+
{
41+
resp, err := limiter.Inspect(ctx, leakyBucketOptions())
42+
assert.NoError(t, err)
43+
assert.Equal(t, leakyBucketOptions().MaximumCapacity, resp.RemainingTokens)
44+
assert.Equal(t, now.Unix(), resp.ResetAt.Unix())
45+
}
46+
47+
{
48+
resp, err := useLeakyBucket(ctx, limiter)
49+
assert.NoError(t, err)
50+
assert.Equal(t, leakyBucketOptions().MaximumCapacity-1, resp.RemainingTokens)
51+
assert.Equal(t, now.Add(time.Second*1).Unix(), resp.ResetAt.Unix())
52+
}
53+
54+
{
55+
resp, err := limiter.Inspect(ctx, leakyBucketOptions())
56+
assert.NoError(t, err)
57+
assert.Equal(t, leakyBucketOptions().MaximumCapacity-1, resp.RemainingTokens)
58+
assert.Equal(t, now.Add(time.Second*1).Unix(), resp.ResetAt.Unix())
59+
}
60+
})
61+
}
62+
}
63+
64+
func TestInspectLeakyBucket_Errors(t *testing.T) {
65+
testCases := map[string]struct {
66+
errorMessage string
67+
mockAdapter adapters.Adapter
68+
}{
69+
"redis error": {
70+
errorMessage: "failed to query redis adapter: " + assert.AnError.Error(),
71+
mockAdapter: &mockAdapter{
72+
returnError: assert.AnError,
73+
},
74+
},
75+
"parsing error": {
76+
errorMessage: "parsing redis response: expected []interface{} but got string",
77+
mockAdapter: &mockAdapter{
78+
returnValue: "foo",
79+
},
80+
},
81+
}
82+
83+
for name, testCase := range testCases {
84+
testCase := testCase
85+
86+
t.Run(name, func(t *testing.T) {
87+
out, err := NewLeakyBucket(testCase.mockAdapter).Inspect(context.Background(), leakyBucketOptions())
88+
assert.Nil(t, out)
89+
assert.EqualError(t, err, testCase.errorMessage)
90+
})
91+
}
92+
}
1993

94+
func TestUseLeakyBucket(t *testing.T) {
2095
testCases := map[string]func(*miniredis.Miniredis) adapters.Adapter{
2196
"go-redis": func(t *miniredis.Miniredis) adapters.Adapter {
2297
return goredisadapter.NewAdapter(goredis.NewClient(&goredis.Options{Addr: t.Addr()}))
@@ -108,7 +183,7 @@ func TestRefillRate(t *testing.T) {
108183
assert.EqualValues(t, 5, getRefillRate(300, 60))
109184
}
110185

111-
func TestParseLeakyBucketResponse_Errors(t *testing.T) {
186+
func TestParseUseLeakyBucketResponse_Errors(t *testing.T) {
112187
testCases := map[string]struct {
113188
errorMessage string
114189
in interface{}
@@ -119,19 +194,41 @@ func TestParseLeakyBucketResponse_Errors(t *testing.T) {
119194
},
120195
"invalid length": {
121196
errorMessage: "expected 3 args but got 2",
122-
in: []interface{}{1, 2},
197+
in: []interface{}{int64(1), int64(2)},
198+
},
199+
}
200+
201+
for name, testCase := range testCases {
202+
testCase := testCase
203+
204+
t.Run(name, func(t *testing.T) {
205+
out, err := parseUseLeakyBucketResponse(testCase.in)
206+
assert.Nil(t, out)
207+
assert.EqualError(t, err, testCase.errorMessage)
208+
})
209+
}
210+
}
211+
212+
func TestParseInspectLeakyBucketResponse_Errors(t *testing.T) {
213+
testCases := map[string]struct {
214+
errorMessage string
215+
in interface{}
216+
}{
217+
"invalid type": {
218+
errorMessage: "expected []interface{} but got string",
219+
in: "foo",
123220
},
124-
"invalid item type": {
125-
errorMessage: "expected int64 in arg[1] but got float64",
126-
in: []interface{}{int64(1), float64(2), "three"},
221+
"invalid length": {
222+
errorMessage: "expected 2 args but got 3",
223+
in: []interface{}{int64(1), int64(2), int64(3)},
127224
},
128225
}
129226

130227
for name, testCase := range testCases {
131228
testCase := testCase
132229

133230
t.Run(name, func(t *testing.T) {
134-
out, err := parseLeakyBucketResponse(testCase.in)
231+
out, err := parseInspectLeakyBucketResponse(testCase.in)
135232
assert.Nil(t, out)
136233
assert.EqualError(t, err, testCase.errorMessage)
137234
})

redis/redis.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package redis
2+
3+
import "fmt"
4+
5+
func parseRedisInt64Slice(v interface{}) ([]int64, error) {
6+
args, ok := v.([]interface{})
7+
if !ok {
8+
return nil, fmt.Errorf("expected []interface{} but got %T", v)
9+
}
10+
11+
out := make([]int64, len(args))
12+
for i, arg := range args {
13+
value, ok := arg.(int64)
14+
if !ok {
15+
return nil, fmt.Errorf("expected int64 in args[%d] but got %T", i, arg)
16+
}
17+
18+
out[i] = value
19+
}
20+
21+
return out, nil
22+
}

0 commit comments

Comments
 (0)