@@ -20,6 +20,9 @@ import (
2020//
2121// See: https://en.wikipedia.org/wiki/Leaky_bucket
2222type LeakyBucket interface {
23+ // Inspect atomically inspects the leaky bucket and returns the capacity available. It does not take any tokens.
24+ Inspect (ctx context.Context , bucket * LeakyBucketOptions ) (* InspectLeakyBucketResponse , error )
25+
2326 // Use atomically attempts to use the leaky bucket. Use takeAmount to set how many tokens should be attempted to be removed
2427 // from the bucket: they are atomic, either all tokens are taken, or the ratelimit is unsuccessful.
2528 Use (ctx context.Context , bucket * LeakyBucketOptions , takeAmount int ) (* UseLeakyBucketResponse , error )
@@ -50,18 +53,6 @@ type LeakyBucketOptions struct {
5053 WindowSeconds int
5154}
5255
53- // UseLeakyBucketResponse defines the response parameters for LeakyBucket.Use()
54- type UseLeakyBucketResponse struct {
55- // Success is true when we were successfully able to take tokens from the bucket.
56- Success bool
57-
58- // RemainingTokens defines hwo many tokens are left in the bucket
59- RemainingTokens int
60-
61- // ResetAt is the time at which the bucket will be fully refilled
62- ResetAt time.Time
63- }
64-
6556// LeakyBucketImpl implements a leaky bucket ratelimiter in Redis with Lua. This struct is compatible with the LeakyBucket interface
6657//
6758// See the LeakyBucket interface for more information about leaky bucket ratelimiters.
@@ -70,6 +61,8 @@ type LeakyBucketImpl struct {
7061 Adapter adapters.Adapter
7162
7263 // nowFunc is a private helper used to mock out time changes in unit testing
64+ //
65+ // if this is not defined, it falls back to time.Now()
7366 nowFunc func () time.Time
7467}
7568
@@ -88,6 +81,80 @@ func (r *LeakyBucketImpl) now() time.Time {
8881 return r .nowFunc ()
8982}
9083
84+ // InspectLeakyBucketResponse defines the response parameters for LeakyBucket.Inspect()
85+ type InspectLeakyBucketResponse struct {
86+ // RemainingTokens defines hwo many tokens are left in the bucket
87+ RemainingTokens int
88+
89+ // ResetAt is the time at which the bucket will be fully refilled
90+ ResetAt time.Time
91+ }
92+
93+ // Inspect atomically inspects the leaky bucket and returns the capacity available. It does not take any tokens.
94+ func (r * LeakyBucketImpl ) Inspect (ctx context.Context , bucket * LeakyBucketOptions ) (* InspectLeakyBucketResponse , error ) {
95+ const script = `
96+ local tokensKey = KEYS[1]
97+ local lastFillKey = KEYS[2]
98+ local capacity = tonumber(ARGV[1])
99+ local rate = tonumber(ARGV[2])
100+ local now = tonumber(ARGV[3])
101+
102+ local tokens = tonumber(redis.call("get", tokensKey))
103+ local lastFilled = tonumber(redis.call("get", lastFillKey))
104+
105+ if (tokens == nil) then
106+ tokens = 0 -- default empty buckets to 0
107+ end
108+
109+ if (tokens > capacity) then
110+ tokens = capacity -- shrink buckets if the capacity is reduced
111+ end
112+
113+ if (lastFilled == nil) then
114+ lastFilled = 0
115+ end
116+
117+ if (tokens < capacity) then
118+ local tokensToFill = math.floor((now - lastFilled) * rate)
119+ if (tokensToFill > 0) then
120+ tokens = math.min(capacity, tokens + tokensToFill)
121+ lastFilled = now
122+ end
123+ end
124+
125+ return {tokens, lastFilled}
126+ `
127+ refillRate := getRefillRate (bucket .MaximumCapacity , bucket .WindowSeconds )
128+ now := r .now ().UTC ().Unix ()
129+
130+ resp , err := r .Adapter .Eval (ctx , script , []string {tokensKey (bucket .KeyPrefix ), lastFillKey (bucket .KeyPrefix )}, []interface {}{bucket .MaximumCapacity , refillRate , now })
131+ if err != nil {
132+ return nil , fmt .Errorf ("failed to query redis adapter: %w" , err )
133+ }
134+
135+ output , err := parseInspectLeakyBucketResponse (resp )
136+ if err != nil {
137+ return nil , fmt .Errorf ("parsing redis response: %w" , err )
138+ }
139+
140+ return & InspectLeakyBucketResponse {
141+ RemainingTokens : output .remaining ,
142+ ResetAt : calculateLeakyBucketFillTime (output .lastFilled , output .remaining , bucket .MaximumCapacity , bucket .WindowSeconds ),
143+ }, nil
144+ }
145+
146+ // UseLeakyBucketResponse defines the response parameters for LeakyBucket.Use()
147+ type UseLeakyBucketResponse struct {
148+ // Success is true when we were successfully able to take tokens from the bucket.
149+ Success bool
150+
151+ // RemainingTokens defines hwo many tokens are left in the bucket
152+ RemainingTokens int
153+
154+ // ResetAt is the time at which the bucket will be fully refilled
155+ ResetAt time.Time
156+ }
157+
91158// Use atomically attempts to use the leaky bucket. Use takeAmount to set how many tokens should be attempted to be removed
92159// from the bucket: they are atomic, either all tokens are taken, or the ratelimit is unsuccessful.
93160func (r * LeakyBucketImpl ) Use (ctx context.Context , bucket * LeakyBucketOptions , takeAmount int ) (* UseLeakyBucketResponse , error ) {
@@ -139,15 +206,14 @@ return {success, tokens, lastFilled}
139206 refillRate := getRefillRate (bucket .MaximumCapacity , bucket .WindowSeconds )
140207 now := r .now ().UTC ().Unix ()
141208
142- tokensKey := bucket .KeyPrefix + "::tokens"
143- lastFillKey := bucket .KeyPrefix + "::last_fill"
144-
145- resp , err := r .Adapter .Eval (ctx , script , []string {tokensKey , lastFillKey }, []interface {}{bucket .MaximumCapacity , refillRate , now , takeAmount , bucket .WindowSeconds })
209+ resp , err := r .Adapter .Eval (ctx , script , []string {tokensKey (bucket .KeyPrefix ), lastFillKey (bucket .KeyPrefix )}, []interface {}{
210+ bucket .MaximumCapacity , refillRate , now , takeAmount , bucket .WindowSeconds ,
211+ })
146212 if err != nil {
147213 return nil , fmt .Errorf ("failed to query redis adapter: %w" , err )
148214 }
149215
150- output , err := parseLeakyBucketResponse (resp )
216+ output , err := parseUseLeakyBucketResponse (resp )
151217 if err != nil {
152218 return nil , fmt .Errorf ("parsing redis response: %w" , err )
153219 }
@@ -159,6 +225,14 @@ return {success, tokens, lastFilled}
159225 }, nil
160226}
161227
228+ func tokensKey (prefix string ) string {
229+ return prefix + "::tokens"
230+ }
231+
232+ func lastFillKey (prefix string ) string {
233+ return prefix + "::last_fill"
234+ }
235+
162236func calculateLeakyBucketFillTime (lastFillUnix , currentTokens , maxCapacity , windowSeconds int ) time.Time {
163237 resetAt := lastFillUnix // if delta is 0 (thus, all tokens are filled), then the bucket is already reset
164238 if delta := maxCapacity - currentTokens ; delta > 0 {
@@ -182,35 +256,46 @@ func getRefillRate(maxCapacity, windowSeconds int) float64 {
182256 return float64 (maxCapacity ) / float64 (windowSeconds )
183257}
184258
185- type leakyBucketOutput struct {
259+ type useLeakyBucketOutput struct {
186260 success bool
187261 remaining int
188262 lastFilled int
189263}
190264
191- func parseLeakyBucketResponse (v interface {}) (* leakyBucketOutput , error ) {
192- args , ok := v .([] interface {} )
193- if ! ok {
194- return nil , fmt . Errorf ( "expected []interface{} but got %T" , v )
265+ func parseUseLeakyBucketResponse (v interface {}) (* useLeakyBucketOutput , error ) {
266+ ints , err := parseRedisInt64Slice ( v )
267+ if err != nil {
268+ return nil , err
195269 }
196270
197- if len (args ) != 3 {
198- return nil , fmt .Errorf ("expected 3 args but got %d" , len (args ))
271+ if len (ints ) != 3 {
272+ return nil , fmt .Errorf ("expected 3 args but got %d" , len (ints ))
199273 }
200274
201- argInts := make ([]int64 , len (args ))
202- for i , argValue := range args {
203- intValue , ok := argValue .(int64 )
204- if ! ok {
205- return nil , fmt .Errorf ("expected int64 in arg[%d] but got %T" , i , argValue )
206- }
275+ return & useLeakyBucketOutput {
276+ success : ints [0 ] == 1 ,
277+ remaining : int (ints [1 ]),
278+ lastFilled : int (ints [2 ]),
279+ }, nil
280+ }
281+
282+ type inspectLeakyBucketOutput struct {
283+ remaining int
284+ lastFilled int
285+ }
286+
287+ func parseInspectLeakyBucketResponse (v interface {}) (* inspectLeakyBucketOutput , error ) {
288+ ints , err := parseRedisInt64Slice (v )
289+ if err != nil {
290+ return nil , err
291+ }
207292
208- argInts [i ] = intValue
293+ if len (ints ) != 2 {
294+ return nil , fmt .Errorf ("expected 2 args but got %d" , len (ints ))
209295 }
210296
211- return & leakyBucketOutput {
212- success : argInts [0 ] == 1 ,
213- remaining : int (argInts [1 ]),
214- lastFilled : int (argInts [2 ]),
297+ return & inspectLeakyBucketOutput {
298+ remaining : int (ints [0 ]),
299+ lastFilled : int (ints [1 ]),
215300 }, nil
216301}
0 commit comments