Skip to content

Commit ee62451

Browse files
author
Your Name
committed
implement total claimable limit
1 parent 28182f6 commit ee62451

File tree

5 files changed

+143
-64
lines changed

5 files changed

+143
-64
lines changed

contribs/gnofaucet/cooldown.go

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package main
22

33
import (
44
"context"
5+
"encoding/json"
56
"fmt"
67
"time"
78

@@ -12,52 +13,71 @@ import (
1213
// CooldownLimiter limits a specific user to one claim per cooldown period
1314
// this limiter keeps track of which keys are on cooldown using a badger database (written to a local file)
1415
type CooldownLimiter struct {
15-
redis *redis.Client
16-
cooldownTime time.Duration
16+
redis *redis.Client
17+
cooldownTime time.Duration
18+
maxlifeTimeAmount int64
1719
}
1820

1921
// NewCooldownLimiter initializes a Cooldown Limiter with a given duration
20-
func NewCooldownLimiter(cooldown time.Duration, redis *redis.Client) *CooldownLimiter {
22+
func NewCooldownLimiter(cooldown time.Duration, redis *redis.Client, maxlifeTimeAmount int64) *CooldownLimiter {
2123
return &CooldownLimiter{
22-
redis: redis,
23-
cooldownTime: cooldown,
24+
redis: redis,
25+
cooldownTime: cooldown,
26+
maxlifeTimeAmount: maxlifeTimeAmount,
2427
}
2528
}
2629

2730
// CheckCooldown checks if a key can make a claim or if it is still within the cooldown period
31+
// also checks that the user will not exceed the max lifetime allowed amount
2832
// Returns true if the key is not on cooldown, and marks the key as on cooldown
2933
// Returns false if the key is on cooldown or if an error occurs
30-
func (rl *CooldownLimiter) CheckCooldown(ctx context.Context, key string) (bool, error) {
31-
isOnCooldown, err := rl.isOnCooldown(ctx, key)
34+
func (rl *CooldownLimiter) CheckCooldown(ctx context.Context, key string, amountClaimed int64) (bool, error) {
35+
claimData, err := rl.getClaimsData(ctx, key)
3236
if err != nil {
3337
return false, fmt.Errorf("unable to check if key is on cooldown, %w", err)
3438
}
35-
if isOnCooldown {
36-
return false, nil // Deny claim if within cooldown period
39+
// Deny claim if within cooldown period
40+
if claimData.LastClaimed.Add(rl.cooldownTime).After(time.Now()) {
41+
return false, nil
42+
}
43+
// check that user will not exceed max lifetime allowed amount
44+
if claimData.TotalClaimed+amountClaimed > rl.maxlifeTimeAmount {
45+
return false, nil
3746
}
3847

39-
return true, rl.markOnCooldown(ctx, key)
48+
return true, rl.declareClaimedValue(ctx, key, amountClaimed, claimData)
4049
}
4150

42-
func (rl *CooldownLimiter) isOnCooldown(ctx context.Context, key string) (bool, error) {
43-
_, err := rl.redis.Get(ctx, key).Result()
51+
func (rl *CooldownLimiter) getClaimsData(ctx context.Context, key string) (*claimData, error) {
52+
storedData, err := rl.redis.Get(ctx, key).Result()
4453
if err != nil {
45-
// Since we use redis's TTL feature to manage cooldown periods,
46-
// an error redis.Nil simply indicates that the key is not on cooldown.
54+
// Here we return an empty claimData because is the first time the user is making a claim
55+
// the total amount claimed is 0 and the lastClaimed is the default time value
4756
if errors.Is(err, redis.Nil) {
48-
return false, nil
57+
return &claimData{}, nil
4958
}
5059
// Any other unexpected error is returned.
51-
return false, err
60+
return nil, err
5261
}
5362

54-
// Key found: it is on cooldown
55-
return true, nil
63+
claimData := &claimData{}
64+
err = json.Unmarshal([]byte(storedData), claimData)
65+
return claimData, err
5666
}
5767

58-
func (rl *CooldownLimiter) markOnCooldown(ctx context.Context, key string) error {
59-
// The value set here does not matter, as we only rely on
60-
// redis's TTL feature to check if a key is still on cooldown
61-
return rl.redis.Set(ctx, key, "claimed", rl.cooldownTime).Err()
68+
func (rl *CooldownLimiter) declareClaimedValue(ctx context.Context, key string, amountClaimed int64, currentData *claimData) error {
69+
currentData.LastClaimed = time.Now()
70+
currentData.TotalClaimed += amountClaimed
71+
72+
data, err := json.Marshal(currentData)
73+
if err != nil {
74+
return fmt.Errorf("unable to marshal claim data, %w", err)
75+
}
76+
77+
return rl.redis.Set(ctx, key, data, 0).Err()
78+
}
6279

80+
type claimData struct {
81+
LastClaimed time.Time
82+
TotalClaimed int64
6383
}

contribs/gnofaucet/cooldown_test.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package main
22

33
import (
44
"context"
5+
"math"
56
"testing"
67
"time"
78

@@ -11,33 +12,34 @@ import (
1112
)
1213

1314
func TestCooldownLimiter(t *testing.T) {
15+
var tenGnots int64 = 10_000_000
1416
redisServer := miniredis.RunT(t)
1517
rdb := redis.NewClient(&redis.Options{
1618
Addr: redisServer.Addr(),
1719
})
1820

1921
cooldownDuration := time.Second
20-
limiter := NewCooldownLimiter(cooldownDuration, rdb)
22+
limiter := NewCooldownLimiter(cooldownDuration, rdb, math.MaxInt64)
2123
ctx := context.Background()
2224
user := "testUser"
2325

2426
// First check should be allowed
25-
allowed, err := limiter.CheckCooldown(ctx, user)
27+
allowed, err := limiter.CheckCooldown(ctx, user, tenGnots)
2628
require.NoError(t, err)
2729

2830
if !allowed {
2931
t.Errorf("Expected first CheckCooldown to return true, but got false")
3032
}
3133

32-
allowed, err = limiter.CheckCooldown(ctx, user)
34+
allowed, err = limiter.CheckCooldown(ctx, user, tenGnots)
3335
require.NoError(t, err)
3436
// Second check immediately should be denied
3537
if allowed {
3638
t.Errorf("Expected second CheckCooldown to return false, but got true")
3739
}
3840

3941
require.Eventually(t, func() bool {
40-
allowed, err := limiter.CheckCooldown(ctx, user)
42+
allowed, err := limiter.CheckCooldown(ctx, user, tenGnots)
4143
return err == nil && !allowed
4244
}, 2*cooldownDuration, 10*time.Millisecond, "Expected CheckCooldown to return true after cooldown period")
4345
}

contribs/gnofaucet/gh.go

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
package main
22

33
import (
4+
"bytes"
45
"context"
56
"encoding/json"
67
"fmt"
8+
"io"
79
"net/http"
810
"strings"
911

@@ -25,14 +27,6 @@ func getGithubMiddleware(clientID, secret string, coolDownLimiter *CooldownLimit
2527
return func(next http.Handler) http.Handler {
2628
return http.HandlerFunc(
2729
func(w http.ResponseWriter, r *http.Request) {
28-
// github Oauth flow is enabled
29-
if secret == "" || clientID == "" {
30-
// Continue with serving the faucet request
31-
next.ServeHTTP(w, r)
32-
33-
return
34-
}
35-
3630
// Extracts the authorization code returned by the GitHub OAuth flow.
3731
//
3832
// When a user successfully authenticates via GitHub OAuth, GitHub redirects them
@@ -52,8 +46,14 @@ func getGithubMiddleware(clientID, secret string, coolDownLimiter *CooldownLimit
5246
return
5347
}
5448

49+
claimAmount, err := getClaimAmount(r)
50+
if err != nil {
51+
http.Error(w, err.Error(), http.StatusBadRequest)
52+
return
53+
}
54+
5555
// Just check if given account have asked for faucet before the cooldown period
56-
allowedToClaim, err := coolDownLimiter.CheckCooldown(r.Context(), user.GetLogin())
56+
allowedToClaim, err := coolDownLimiter.CheckCooldown(r.Context(), user.GetLogin(), claimAmount)
5757
if err != nil {
5858
http.Error(w, err.Error(), http.StatusInternalServerError)
5959
return
@@ -71,6 +71,25 @@ func getGithubMiddleware(clientID, secret string, coolDownLimiter *CooldownLimit
7171
}
7272
}
7373

74+
type request struct {
75+
Amount int64 `json:"amount"`
76+
}
77+
78+
func getClaimAmount(r *http.Request) (int64, error) {
79+
body, err := io.ReadAll(r.Body)
80+
if err != nil {
81+
return 0, err
82+
}
83+
84+
var data request
85+
err = json.Unmarshal(body, &data)
86+
if err != nil {
87+
return 0, err
88+
}
89+
r.Body = io.NopCloser(bytes.NewBuffer(body))
90+
return data.Amount, nil
91+
}
92+
7493
type gitHubTokenResponse struct {
7594
AccessToken string `json:"access_token"`
7695
}

contribs/gnofaucet/gh_test.go

Lines changed: 54 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
package main
22

33
import (
4+
"bytes"
45
"context"
56
"errors"
7+
"fmt"
8+
"math"
69
"net/http"
710
"net/http/httptest"
811
"testing"
@@ -25,10 +28,11 @@ func mockExchangeCodeForToken(ctx context.Context, secret, clientID, code string
2528
func TestGitHubMiddleware(t *testing.T) {
2629
cooldown := 2 * time.Minute
2730
exchangeCodeForUser = mockExchangeCodeForToken
28-
t.Run("Midleware without credentials", func(t *testing.T) {
29-
middleware := getGithubMiddleware("", "", getCooldownLimiter(t, cooldown))
30-
// Test missing clientID and secret, middleware does nothing
31-
req := httptest.NewRequest("GET", "http://localhost", nil)
31+
var tenGnots int64 = 10000000
32+
claimBody := fmt.Sprintf(`{"amount": %d}`, tenGnots)
33+
t.Run("request without code", func(t *testing.T) {
34+
middleware := getGithubMiddleware("mockClientID", "mockSecret", getCooldownLimiter(t, cooldown, math.MaxInt64))
35+
req := httptest.NewRequest("GET", "http://localhost?code=", bytes.NewBufferString(claimBody))
3236
rec := httptest.NewRecorder()
3337

3438
handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -37,13 +41,14 @@ func TestGitHubMiddleware(t *testing.T) {
3741

3842
handler.ServeHTTP(rec, req)
3943

40-
if rec.Code != http.StatusOK {
41-
t.Errorf("Expected status OK, got %d", rec.Code)
44+
if rec.Code != http.StatusBadRequest {
45+
t.Errorf("Expected status BadRequest, got %d", rec.Code)
4246
}
4347
})
44-
t.Run("request without code", func(t *testing.T) {
45-
middleware := getGithubMiddleware("mockClientID", "mockSecret", getCooldownLimiter(t, cooldown))
46-
req := httptest.NewRequest("GET", "http://localhost?code=", nil)
48+
49+
t.Run("request invalid code", func(t *testing.T) {
50+
middleware := getGithubMiddleware("mockClientID", "mockSecret", getCooldownLimiter(t, cooldown, math.MaxInt64))
51+
req := httptest.NewRequest("GET", "http://localhost?code=invalid", bytes.NewBufferString(claimBody))
4752
rec := httptest.NewRecorder()
4853

4954
handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -57,9 +62,9 @@ func TestGitHubMiddleware(t *testing.T) {
5762
}
5863
})
5964

60-
t.Run("request invalid code", func(t *testing.T) {
61-
middleware := getGithubMiddleware("mockClientID", "mockSecret", getCooldownLimiter(t, cooldown))
62-
req := httptest.NewRequest("GET", "http://localhost?code=invalid", nil)
65+
t.Run("OK", func(t *testing.T) {
66+
middleware := getGithubMiddleware("mockClientID", "mockSecret", getCooldownLimiter(t, cooldown, math.MaxInt64))
67+
req := httptest.NewRequest("GET", "http://localhost?code=valid", bytes.NewBufferString(claimBody))
6368
rec := httptest.NewRecorder()
6469

6570
handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -68,14 +73,14 @@ func TestGitHubMiddleware(t *testing.T) {
6873

6974
handler.ServeHTTP(rec, req)
7075

71-
if rec.Code != http.StatusBadRequest {
72-
t.Errorf("Expected status BadRequest, got %d", rec.Code)
76+
if rec.Code != http.StatusOK {
77+
t.Errorf("Expected status OK, got %d", rec.Code)
7378
}
7479
})
7580

76-
t.Run("OK", func(t *testing.T) {
77-
middleware := getGithubMiddleware("mockClientID", "mockSecret", getCooldownLimiter(t, cooldown))
78-
req := httptest.NewRequest("GET", "http://localhost?code=valid", nil)
81+
t.Run("Cooldown active", func(t *testing.T) {
82+
middleware := getGithubMiddleware("mockClientID", "mockSecret", getCooldownLimiter(t, cooldown, math.MaxInt64))
83+
req := httptest.NewRequest("GET", "http://localhost?code=valid", bytes.NewBufferString(claimBody))
7984
rec := httptest.NewRecorder()
8085

8186
handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -87,41 +92,66 @@ func TestGitHubMiddleware(t *testing.T) {
8792
if rec.Code != http.StatusOK {
8893
t.Errorf("Expected status OK, got %d", rec.Code)
8994
}
95+
96+
req = httptest.NewRequest("GET", "http://localhost?code=valid", bytes.NewBufferString(claimBody))
97+
rec = httptest.NewRecorder()
98+
99+
handler.ServeHTTP(rec, req)
100+
if rec.Code != http.StatusTooManyRequests {
101+
t.Errorf("Expected status TooManyRequest, got %d", rec.Code)
102+
}
90103
})
91104

92-
t.Run("Cooldown active", func(t *testing.T) {
93-
middleware := getGithubMiddleware("mockClientID", "mockSecret", getCooldownLimiter(t, cooldown))
94-
req := httptest.NewRequest("GET", "http://localhost?code=valid", nil)
105+
t.Run("User exceeded lifetime limit", func(t *testing.T) {
106+
cooldown = time.Millisecond
107+
// Max lifetime amount is 20 Gnots so we should be able to make 2 claims
108+
middleware := getGithubMiddleware("mockClientID", "mockSecret", getCooldownLimiter(t, cooldown, tenGnots*2))
109+
req := httptest.NewRequest("GET", "http://localhost?code=valid", bytes.NewBufferString(claimBody))
95110
rec := httptest.NewRecorder()
96111

97112
handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
98113
w.WriteHeader(http.StatusOK)
99114
}))
100115

101116
handler.ServeHTTP(rec, req)
117+
// First claim ok
118+
if rec.Code != http.StatusOK {
119+
t.Errorf("Expected status OK, got %d", rec.Code)
120+
}
121+
// Wait 2 times the cooldown
122+
time.Sleep(2 * cooldown)
123+
124+
req = httptest.NewRequest("GET", "http://localhost?code=valid", bytes.NewBufferString(claimBody))
125+
rec = httptest.NewRecorder()
102126

127+
handler.ServeHTTP(rec, req)
128+
//Second claim should also be ok
103129
if rec.Code != http.StatusOK {
104130
t.Errorf("Expected status OK, got %d", rec.Code)
105131
}
106132

107-
req = httptest.NewRequest("GET", "http://localhost?code=valid", nil)
133+
// Third one should fail
134+
time.Sleep(2 * cooldown)
135+
136+
req = httptest.NewRequest("GET", "http://localhost?code=valid", bytes.NewBufferString(claimBody))
108137
rec = httptest.NewRecorder()
109138

110139
handler.ServeHTTP(rec, req)
140+
//Second claim should also be ok
111141
if rec.Code != http.StatusTooManyRequests {
112-
t.Errorf("Expected status TooManyRequest, got %d", rec.Code)
142+
t.Errorf("Expected status OK, got %d", rec.Code)
113143
}
114144
})
115145
}
116146

117-
func getCooldownLimiter(t *testing.T, duration time.Duration) *CooldownLimiter {
147+
func getCooldownLimiter(t *testing.T, duration time.Duration, maxlifeTimeAmount int64) *CooldownLimiter {
118148
t.Helper()
119149
redisServer := miniredis.RunT(t)
120150
rdb := redis.NewClient(&redis.Options{
121151
Addr: redisServer.Addr(),
122152
})
123153

124-
limiter := NewCooldownLimiter(duration, rdb)
154+
limiter := NewCooldownLimiter(duration, rdb, maxlifeTimeAmount)
125155

126156
return limiter
127157
}

0 commit comments

Comments
 (0)