|
| 1 | +package local |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "math" |
| 6 | + "sync" |
| 7 | + "time" |
| 8 | +) |
| 9 | + |
| 10 | +// LeakyBucket is a ratelimiter that fills a given bucket at a constant rate you define (calculated based on your window duration, and the max tokens) |
| 11 | +// that may exist in the window at any given time. |
| 12 | +// |
| 13 | +// Leaky buckets have the advantage of being able to burst up to the max tokens you define, and then slowly leak out tokens at a constant rate. This makes |
| 14 | +// it a good fit for situations where you want caller buckets to slowly fill if they decide to burst your service, whereas a sliding window ratelimiter will |
| 15 | +// free all tokens at once. |
| 16 | +// |
| 17 | +// Leaky buckets slowly fill your window over time, and will not fill above the size of the window. For example, if you allow 10 tokens per a window of 1 second, |
| 18 | +// your bucket fills at a fixed rate of 100ms. |
| 19 | +// |
| 20 | +// See: https://en.wikipedia.org/wiki/Leaky_bucket |
| 21 | +type LeakyBucket interface { |
| 22 | + // Wait will block the goroutine til a ratelimit token is available. You can use context to cancel the ratelimiter. |
| 23 | + Wait(ctx context.Context) |
| 24 | + |
| 25 | + // WaitFunc is equivalent to Wait except it calls a callback when it's able to accquire a token. Iif you cancel the context, cb is not called. This |
| 26 | + // function does spawn a goroutine per invocation. If you want something more efficient, consider writing your own implementation using TryTakeWithDuration() |
| 27 | + WaitFunc(ctx context.Context, cb func()) |
| 28 | + |
| 29 | + // Size will return how many tokens are currently available |
| 30 | + Size() int |
| 31 | + |
| 32 | + // Take will attempt to accquire a token, it will return a boolean indicating whether it was able to accquire a token or not. |
| 33 | + TryTake() bool |
| 34 | + |
| 35 | + // Take will attempt to accquire a token, it will return a boolean indicating whether it was able to accquire a token or not, |
| 36 | + // and a duration for when you should next try. |
| 37 | + TryTakeWithDuration() (bool, time.Duration) |
| 38 | +} |
| 39 | + |
| 40 | +type leakyBucket struct { |
| 41 | + max int |
| 42 | + tokens int |
| 43 | + rate time.Duration |
| 44 | + lastFill time.Time |
| 45 | + m sync.Mutex |
| 46 | +} |
| 47 | + |
| 48 | +// NewLeakyBucket creates a new leaky bucket ratelimiter. See the LeakyBucket interface for more info about what this ratelimiter does. |
| 49 | +func NewLeakyBucket(tokensPerWindow int, window time.Duration) LeakyBucket { |
| 50 | + tokenRate := window / time.Duration(tokensPerWindow) |
| 51 | + |
| 52 | + return &leakyBucket{ |
| 53 | + tokens: tokensPerWindow, |
| 54 | + lastFill: time.Now().UTC(), |
| 55 | + max: tokensPerWindow, |
| 56 | + rate: tokenRate, |
| 57 | + } |
| 58 | +} |
| 59 | + |
| 60 | +// TryTakeWithDuration will attempt to accquire a ratelimit window, it will return a boolean indicating whether it was able to accquire a token or not, |
| 61 | +// and a duration for when you should next try. |
| 62 | +func (r *leakyBucket) TryTakeWithDuration() (bool, time.Duration) { |
| 63 | + r.m.Lock() |
| 64 | + defer r.m.Unlock() |
| 65 | + |
| 66 | + r.unsafeFill() |
| 67 | + |
| 68 | + if r.tokens < 1 { |
| 69 | + // there isn't at least 1 oken, so nothing is available |
| 70 | + return false, time.Until(r.lastFill.Add(r.rate)) |
| 71 | + } |
| 72 | + |
| 73 | + // take a token if there is one available |
| 74 | + r.tokens-- |
| 75 | + |
| 76 | + return true, 0 |
| 77 | +} |
| 78 | + |
| 79 | +// Take will attempt to accquire a ratelimit window, it will return a boolean indicating whether it was able to accquire a token or not. |
| 80 | +func (r *leakyBucket) TryTake() bool { |
| 81 | + resp, _ := r.TryTakeWithDuration() |
| 82 | + return resp |
| 83 | +} |
| 84 | + |
| 85 | +// Wait will block the goroutine til a ratelimit token is available. You can use context to cancel the ratelimiter. |
| 86 | +func (r *leakyBucket) Wait(ctx context.Context) { |
| 87 | + _ = r.wait(ctx) |
| 88 | +} |
| 89 | + |
| 90 | +// wait keeps trying to take a token, while also sleeping the goroutine while it waits for the next attempt. The wait functions just call this |
| 91 | +// under the hood. |
| 92 | +func (r *leakyBucket) wait(ctx context.Context) bool { |
| 93 | + for { |
| 94 | + available, duration := r.TryTakeWithDuration() |
| 95 | + if available { |
| 96 | + return true |
| 97 | + } |
| 98 | + if !r.awaitNextToken(ctx, duration) { |
| 99 | + return false |
| 100 | + } |
| 101 | + } |
| 102 | +} |
| 103 | + |
| 104 | +// Size will return how many tokens are currently available |
| 105 | +func (r *leakyBucket) Size() int { |
| 106 | + r.m.Lock() |
| 107 | + defer r.m.Unlock() |
| 108 | + r.unsafeFill() |
| 109 | + return r.tokens |
| 110 | +} |
| 111 | + |
| 112 | +// WaitFunc is equivalent to Wait except it calls a callback when it's able to accquire a token. Iif you cancel the context, cb is not called. This |
| 113 | +// function does spawn a goroutine per invocation. If you want something more efficient, consider writing your own implementation using TryTakeWithDuration() |
| 114 | +func (r *leakyBucket) WaitFunc(ctx context.Context, cb func()) { |
| 115 | + go func(ctx context.Context, cb func()) { |
| 116 | + if r.wait(ctx) { |
| 117 | + cb() |
| 118 | + } |
| 119 | + }(ctx, cb) |
| 120 | +} |
| 121 | + |
| 122 | +func (r *leakyBucket) awaitNextToken(ctx context.Context, duration time.Duration) bool { |
| 123 | + timer := time.NewTimer(duration) |
| 124 | + defer timer.Stop() |
| 125 | + |
| 126 | + select { |
| 127 | + case <-ctx.Done(): |
| 128 | + return false |
| 129 | + case <-timer.C: |
| 130 | + return true |
| 131 | + } |
| 132 | +} |
| 133 | + |
| 134 | +// unsafeFill attempts to fill the leaky bucket with tokens, but is not thread safe. |
| 135 | +// |
| 136 | +// Ensure you have locked the mutex outside of this function before calling it. |
| 137 | +func (r *leakyBucket) unsafeFill() { |
| 138 | + if r.tokens >= r.max { |
| 139 | + // bucket is already full |
| 140 | + return |
| 141 | + } |
| 142 | + |
| 143 | + tokensToFill := int(time.Since(r.lastFill) / r.rate) |
| 144 | + r.tokens = int(math.Min(float64(r.tokens+tokensToFill), float64(r.max))) |
| 145 | + r.lastFill = time.Now().UTC() |
| 146 | +} |
0 commit comments