Skip to content

Commit 5608ee7

Browse files
committed
feature: httpclient supporting retry
Signed-off-by: Sandor Szücs <[email protected]>
1 parent 994004b commit 5608ee7

File tree

2 files changed

+250
-8
lines changed

2 files changed

+250
-8
lines changed

net/httpclient.go

+81-8
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package net
22

33
import (
4+
"bytes"
45
"crypto/tls"
56
"fmt"
67
"io"
@@ -23,15 +24,54 @@ const (
2324
defaultRefreshInterval = 5 * time.Minute
2425
)
2526

27+
type mybuf struct{ *bytes.Buffer }
28+
29+
func (buf *mybuf) Close() error {
30+
return nil
31+
}
32+
33+
type copyBodyStream struct {
34+
left int
35+
buf *mybuf
36+
input io.ReadCloser
37+
}
38+
39+
func newCopyBodyStream(left int, buf *bytes.Buffer, rc io.ReadCloser) *copyBodyStream {
40+
return &copyBodyStream{
41+
left: left,
42+
buf: &mybuf{Buffer: buf},
43+
input: rc,
44+
}
45+
}
46+
47+
func (cb *copyBodyStream) Read(p []byte) (n int, err error) {
48+
n, err = cb.input.Read(p)
49+
if cb.left > 0 && n > 0 {
50+
m := min(n, cb.left)
51+
cb.buf.Write(p[:m])
52+
cb.left -= m
53+
}
54+
return n, err
55+
}
56+
57+
func (cb *copyBodyStream) Close() error {
58+
return cb.input.Close()
59+
}
60+
61+
func (cb *copyBodyStream) GetBody() io.ReadCloser {
62+
return cb.buf
63+
}
64+
2665
// Client adds additional features like Bearer token injection, and
2766
// opentracing to the wrapped http.Client with the same interface as
2867
// http.Client from the stdlib.
2968
type Client struct {
30-
once sync.Once
31-
client http.Client
32-
tr *Transport
33-
log logging.Logger
34-
sr secrets.SecretsReader
69+
once sync.Once
70+
client http.Client
71+
tr *Transport
72+
log logging.Logger
73+
sr secrets.SecretsReader
74+
retryBuffers *sync.Map
3575
}
3676

3777
// NewClient creates a wrapped http.Client and uses Transport to
@@ -67,9 +107,10 @@ func NewClient(o Options) *Client {
67107
Transport: tr,
68108
CheckRedirect: o.CheckRedirect,
69109
},
70-
tr: tr,
71-
log: o.Log,
72-
sr: sr,
110+
tr: tr,
111+
log: o.Log,
112+
sr: sr,
113+
retryBuffers: &sync.Map{},
73114
}
74115

75116
return c
@@ -125,9 +166,41 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
125166
req.Header.Set("Authorization", "Bearer "+string(b))
126167
}
127168
}
169+
if req.Body != nil && req.Body != http.NoBody {
170+
retryBuffer := newCopyBodyStream(int(req.ContentLength), &bytes.Buffer{}, req.Body)
171+
c.retryBuffers.Store(req, retryBuffer)
172+
req.Body = retryBuffer
173+
}
128174
return c.client.Do(req)
129175
}
130176

177+
func (c *Client) Retry(req *http.Request) (*http.Response, error) {
178+
if req.Body == nil || req.Body == http.NoBody {
179+
return c.Do(req)
180+
}
181+
182+
if rc, err := req.GetBody(); err == nil {
183+
println("req.GetBody() case")
184+
c.retryBuffers.Delete(req)
185+
req.Body = rc
186+
return c.Do(req)
187+
}
188+
189+
println("our own retry buffer impl")
190+
buf, ok := c.retryBuffers.Load(req)
191+
if !ok {
192+
return nil, fmt.Errorf("no retry possible, request not found: %s %s", req.Method, req.URL)
193+
}
194+
195+
retryBuffer, ok := buf.(*copyBodyStream)
196+
if !ok {
197+
return nil, fmt.Errorf("no retry possible, no retry buffer for request: %s %s", req.Method, req.URL)
198+
}
199+
req.Body = retryBuffer.GetBody()
200+
201+
return c.Do(req)
202+
}
203+
131204
// CloseIdleConnections delegates the call to the underlying
132205
// http.Client.
133206
func (c *Client) CloseIdleConnections() {

net/httpclient_test.go

+169
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package net
22

33
import (
4+
"bytes"
5+
"io"
46
"net/http"
57
"net/http/httptest"
68
"net/url"
@@ -324,3 +326,170 @@ func TestClientClosesIdleConnections(t *testing.T) {
324326
}
325327
rsp.Body.Close()
326328
}
329+
330+
func TestTestClientRetry(t *testing.T) {
331+
for _, tt := range []struct {
332+
name string
333+
method string
334+
body string
335+
}{
336+
{
337+
name: "test GET",
338+
method: "GET",
339+
},
340+
{
341+
name: "test POST",
342+
method: "POST",
343+
body: "hello POST",
344+
},
345+
{
346+
name: "test PATCH",
347+
method: "PATCH",
348+
body: "hello PATCH",
349+
},
350+
{
351+
name: "test PUT",
352+
method: "PUT",
353+
body: "hello PUT",
354+
}} {
355+
t.Run(tt.name, func(t *testing.T) {
356+
i := 0
357+
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
358+
if i == 0 {
359+
i++
360+
w.WriteHeader(http.StatusBadGateway)
361+
}
362+
363+
got, err := io.ReadAll(r.Body)
364+
if err != nil {
365+
t.Fatalf("got no data")
366+
}
367+
s := string(got)
368+
if tt.body != s {
369+
t.Fatalf("Failed to get the right data want: %q, got: %q", tt.body, s)
370+
}
371+
372+
w.WriteHeader(http.StatusOK)
373+
}))
374+
defer backend.Close()
375+
376+
noleak.Check(t)
377+
378+
cli := NewClient(Options{})
379+
defer cli.Close()
380+
381+
buf := bytes.NewBufferString(tt.body)
382+
req, err := http.NewRequest(tt.method, backend.URL, buf)
383+
if err != nil {
384+
t.Fatal(err)
385+
}
386+
rsp, err := cli.Do(req)
387+
if err != nil {
388+
t.Fatal(err)
389+
}
390+
if rsp.StatusCode != http.StatusBadGateway {
391+
t.Fatalf("unexpected status code: %s", rsp.Status)
392+
}
393+
394+
rsp, err = cli.Retry(req)
395+
if rsp.StatusCode != http.StatusOK {
396+
t.Fatalf("unexpected status code: %s", rsp.Status)
397+
}
398+
rsp.Body.Close()
399+
})
400+
}
401+
}
402+
403+
func TestTestClientRetryConcurrentRequests(t *testing.T) {
404+
for _, tt := range []struct {
405+
name string
406+
method string
407+
body string
408+
}{
409+
{
410+
name: "test GET",
411+
method: "GET",
412+
},
413+
{
414+
name: "test POST",
415+
method: "POST",
416+
body: "hello POST",
417+
},
418+
{
419+
name: "test PATCH",
420+
method: "PATCH",
421+
body: "hello PATCH",
422+
},
423+
{
424+
name: "test PUT",
425+
method: "PUT",
426+
body: "hello PUT",
427+
}} {
428+
t.Run(tt.name, func(t *testing.T) {
429+
i := 0
430+
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
431+
if r.URL.Path == "/ignore" {
432+
w.WriteHeader(http.StatusOK)
433+
return
434+
}
435+
436+
if i == 0 {
437+
i++
438+
io.ReadAll(r.Body)
439+
w.WriteHeader(http.StatusBadGateway)
440+
return
441+
}
442+
443+
got, err := io.ReadAll(r.Body)
444+
if err != nil {
445+
t.Fatalf("got no data")
446+
}
447+
s := string(got)
448+
if tt.body != s {
449+
t.Fatalf("Failed to get the right data want: %q, got: %q", tt.body, s)
450+
}
451+
452+
w.WriteHeader(http.StatusOK)
453+
}))
454+
defer backend.Close()
455+
456+
noleak.Check(t)
457+
458+
cli := NewClient(Options{})
459+
defer cli.Close()
460+
461+
quit := make(chan struct{})
462+
go func() {
463+
for {
464+
select {
465+
case <-quit:
466+
return
467+
default:
468+
}
469+
cli.Get(backend.URL + "/ignore")
470+
}
471+
}()
472+
473+
buf := bytes.NewBufferString(tt.body)
474+
req, err := http.NewRequest(tt.method, backend.URL, buf)
475+
if err != nil {
476+
t.Fatal(err)
477+
}
478+
rsp, err := cli.Do(req)
479+
if err != nil {
480+
t.Fatal(err)
481+
}
482+
if rsp.StatusCode != http.StatusBadGateway {
483+
t.Fatalf("unexpected status code: %s", rsp.Status)
484+
}
485+
486+
rsp, err = cli.Retry(req)
487+
if rsp.StatusCode != http.StatusOK {
488+
t.Fatalf("unexpected status code: %s", rsp.Status)
489+
}
490+
rsp.Body.Close()
491+
492+
close(quit)
493+
})
494+
}
495+
}

0 commit comments

Comments
 (0)