11package main
22
33import (
4+ "bytes"
45 "context"
56 "errors"
7+ "fmt"
8+ "math"
69 "net/http"
710 "net/http/httptest"
811 "testing"
@@ -25,10 +28,11 @@ func mockExchangeCodeForToken(ctx context.Context, secret, clientID, code string
2528func TestGitHubMiddleware (t * testing.T ) {
2629 cooldown := 2 * time .Minute
2730 exchangeCodeForUser = mockExchangeCodeForToken
28- t .Run ("Midleware without credentials" , func (t * testing.T ) {
29- middleware := getGithubMiddleware ("" , "" , getCooldownLimiter (t , cooldown ))
30- // Test missing clientID and secret, middleware does nothing
31- req := httptest .NewRequest ("GET" , "http://localhost" , nil )
31+ var tenGnots int64 = 10000000
32+ claimBody := fmt .Sprintf (`{"amount": %d}` , tenGnots )
33+ t .Run ("request without code" , func (t * testing.T ) {
34+ middleware := getGithubMiddleware ("mockClientID" , "mockSecret" , getCooldownLimiter (t , cooldown , math .MaxInt64 ))
35+ req := httptest .NewRequest ("GET" , "http://localhost?code=" , bytes .NewBufferString (claimBody ))
3236 rec := httptest .NewRecorder ()
3337
3438 handler := middleware (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
@@ -37,13 +41,14 @@ func TestGitHubMiddleware(t *testing.T) {
3741
3842 handler .ServeHTTP (rec , req )
3943
40- if rec .Code != http .StatusOK {
41- t .Errorf ("Expected status OK , got %d" , rec .Code )
44+ if rec .Code != http .StatusBadRequest {
45+ t .Errorf ("Expected status BadRequest , got %d" , rec .Code )
4246 }
4347 })
44- t .Run ("request without code" , func (t * testing.T ) {
45- middleware := getGithubMiddleware ("mockClientID" , "mockSecret" , getCooldownLimiter (t , cooldown ))
46- req := httptest .NewRequest ("GET" , "http://localhost?code=" , nil )
48+
49+ t .Run ("request invalid code" , func (t * testing.T ) {
50+ middleware := getGithubMiddleware ("mockClientID" , "mockSecret" , getCooldownLimiter (t , cooldown , math .MaxInt64 ))
51+ req := httptest .NewRequest ("GET" , "http://localhost?code=invalid" , bytes .NewBufferString (claimBody ))
4752 rec := httptest .NewRecorder ()
4853
4954 handler := middleware (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
@@ -57,9 +62,9 @@ func TestGitHubMiddleware(t *testing.T) {
5762 }
5863 })
5964
60- t .Run ("request invalid code " , func (t * testing.T ) {
61- middleware := getGithubMiddleware ("mockClientID" , "mockSecret" , getCooldownLimiter (t , cooldown ))
62- req := httptest .NewRequest ("GET" , "http://localhost?code=invalid " , nil )
65+ t .Run ("OK " , func (t * testing.T ) {
66+ middleware := getGithubMiddleware ("mockClientID" , "mockSecret" , getCooldownLimiter (t , cooldown , math . MaxInt64 ))
67+ req := httptest .NewRequest ("GET" , "http://localhost?code=valid " , bytes . NewBufferString ( claimBody ) )
6368 rec := httptest .NewRecorder ()
6469
6570 handler := middleware (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
@@ -68,14 +73,14 @@ func TestGitHubMiddleware(t *testing.T) {
6873
6974 handler .ServeHTTP (rec , req )
7075
71- if rec .Code != http .StatusBadRequest {
72- t .Errorf ("Expected status BadRequest , got %d" , rec .Code )
76+ if rec .Code != http .StatusOK {
77+ t .Errorf ("Expected status OK , got %d" , rec .Code )
7378 }
7479 })
7580
76- t .Run ("OK " , func (t * testing.T ) {
77- middleware := getGithubMiddleware ("mockClientID" , "mockSecret" , getCooldownLimiter (t , cooldown ))
78- req := httptest .NewRequest ("GET" , "http://localhost?code=valid" , nil )
81+ t .Run ("Cooldown active " , func (t * testing.T ) {
82+ middleware := getGithubMiddleware ("mockClientID" , "mockSecret" , getCooldownLimiter (t , cooldown , math . MaxInt64 ))
83+ req := httptest .NewRequest ("GET" , "http://localhost?code=valid" , bytes . NewBufferString ( claimBody ) )
7984 rec := httptest .NewRecorder ()
8085
8186 handler := middleware (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
@@ -87,41 +92,66 @@ func TestGitHubMiddleware(t *testing.T) {
8792 if rec .Code != http .StatusOK {
8893 t .Errorf ("Expected status OK, got %d" , rec .Code )
8994 }
95+
96+ req = httptest .NewRequest ("GET" , "http://localhost?code=valid" , bytes .NewBufferString (claimBody ))
97+ rec = httptest .NewRecorder ()
98+
99+ handler .ServeHTTP (rec , req )
100+ if rec .Code != http .StatusTooManyRequests {
101+ t .Errorf ("Expected status TooManyRequest, got %d" , rec .Code )
102+ }
90103 })
91104
92- t .Run ("Cooldown active" , func (t * testing.T ) {
93- middleware := getGithubMiddleware ("mockClientID" , "mockSecret" , getCooldownLimiter (t , cooldown ))
94- req := httptest .NewRequest ("GET" , "http://localhost?code=valid" , nil )
105+ t .Run ("User exceeded lifetime limit" , func (t * testing.T ) {
106+ cooldown = time .Millisecond
107+ // Max lifetime amount is 20 Gnots so we should be able to make 2 claims
108+ middleware := getGithubMiddleware ("mockClientID" , "mockSecret" , getCooldownLimiter (t , cooldown , tenGnots * 2 ))
109+ req := httptest .NewRequest ("GET" , "http://localhost?code=valid" , bytes .NewBufferString (claimBody ))
95110 rec := httptest .NewRecorder ()
96111
97112 handler := middleware (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
98113 w .WriteHeader (http .StatusOK )
99114 }))
100115
101116 handler .ServeHTTP (rec , req )
117+ // First claim ok
118+ if rec .Code != http .StatusOK {
119+ t .Errorf ("Expected status OK, got %d" , rec .Code )
120+ }
121+ // Wait 2 times the cooldown
122+ time .Sleep (2 * cooldown )
123+
124+ req = httptest .NewRequest ("GET" , "http://localhost?code=valid" , bytes .NewBufferString (claimBody ))
125+ rec = httptest .NewRecorder ()
102126
127+ handler .ServeHTTP (rec , req )
128+ //Second claim should also be ok
103129 if rec .Code != http .StatusOK {
104130 t .Errorf ("Expected status OK, got %d" , rec .Code )
105131 }
106132
107- req = httptest .NewRequest ("GET" , "http://localhost?code=valid" , nil )
133+ // Third one should fail
134+ time .Sleep (2 * cooldown )
135+
136+ req = httptest .NewRequest ("GET" , "http://localhost?code=valid" , bytes .NewBufferString (claimBody ))
108137 rec = httptest .NewRecorder ()
109138
110139 handler .ServeHTTP (rec , req )
140+ //Second claim should also be ok
111141 if rec .Code != http .StatusTooManyRequests {
112- t .Errorf ("Expected status TooManyRequest , got %d" , rec .Code )
142+ t .Errorf ("Expected status OK , got %d" , rec .Code )
113143 }
114144 })
115145}
116146
117- func getCooldownLimiter (t * testing.T , duration time.Duration ) * CooldownLimiter {
147+ func getCooldownLimiter (t * testing.T , duration time.Duration , maxlifeTimeAmount int64 ) * CooldownLimiter {
118148 t .Helper ()
119149 redisServer := miniredis .RunT (t )
120150 rdb := redis .NewClient (& redis.Options {
121151 Addr : redisServer .Addr (),
122152 })
123153
124- limiter := NewCooldownLimiter (duration , rdb )
154+ limiter := NewCooldownLimiter (duration , rdb , maxlifeTimeAmount )
125155
126156 return limiter
127157}
0 commit comments