Skip to content

Commit 5250697

Browse files
liuzhaohuirueianzhaohuiliu
authored
feat: respect request ctx for making tcp connections and add DialCtxFn option (#803)
* feat: respect request ctx for making tcp connections and add DialCtxFn option Signed-off-by: Rueian <[email protected]> * add context to the dial function and respect the context deadline * feat: Deprecated DialFn Signed-off-by: Rueian <[email protected]> * add testcase for conn ctx * chore: remove unnecessary changes Signed-off-by: Rueian <[email protected]> --------- Signed-off-by: Rueian <[email protected]> Co-authored-by: Rueian <[email protected]> Co-authored-by: zhaohuiliu <[email protected]>
1 parent 83e70d9 commit 5250697

16 files changed

+352
-153
lines changed

client.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -172,15 +172,15 @@ retry:
172172
}
173173

174174
func (c *singleClient) Dedicated(fn func(DedicatedClient) error) (err error) {
175-
wire := c.conn.Acquire()
175+
wire := c.conn.Acquire(context.Background())
176176
dsc := &dedicatedSingleClient{cmd: c.cmd, conn: c.conn, wire: wire, retry: c.retry, retryHandler: c.retryHandler}
177177
err = fn(dsc)
178178
dsc.release()
179179
return err
180180
}
181181

182182
func (c *singleClient) Dedicate() (DedicatedClient, func()) {
183-
wire := c.conn.Acquire()
183+
wire := c.conn.Acquire(context.Background())
184184
dsc := &dedicatedSingleClient{cmd: c.cmd, conn: c.conn, wire: wire, retry: c.retry, retryHandler: c.retryHandler}
185185
return dsc, dsc.release
186186
}

client_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ func (m *mockConn) Dial() error {
5151
return nil
5252
}
5353

54-
func (m *mockConn) Acquire() wire {
54+
func (m *mockConn) Acquire(ctx context.Context) wire {
5555
if m.AcquireFn != nil {
5656
return m.AcquireFn()
5757
}

cluster.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -1274,7 +1274,7 @@ func (c *dedicatedClusterClient) acquire(ctx context.Context, slot uint16) (wire
12741274
}
12751275
return nil, err
12761276
}
1277-
c.wire = c.conn.Acquire()
1277+
c.wire = c.conn.Acquire(ctx)
12781278
if p := c.pshks; p != nil {
12791279
c.pshks = nil
12801280
ch := c.wire.SetPubSubHooks(p.hooks)

mux.go

+28-28
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ import (
1414
)
1515

1616
type connFn func(dst string, opt *ClientOption) conn
17-
type dialFn func(dst string, opt *ClientOption) (net.Conn, error)
18-
type wireFn func() wire
17+
type dialFn func(ctx context.Context, dst string, opt *ClientOption) (net.Conn, error)
18+
type wireFn func(ctx context.Context) wire
1919

2020
type singleconnect struct {
2121
w wire
@@ -38,7 +38,7 @@ type conn interface {
3838
Close()
3939
Dial() error
4040
Override(conn)
41-
Acquire() wire
41+
Acquire(ctx context.Context) wire
4242
Store(w wire)
4343
Addr() string
4444
SetOnCloseHook(func(error))
@@ -67,12 +67,12 @@ type mux struct {
6767

6868
func makeMux(dst string, option *ClientOption, dialFn dialFn) *mux {
6969
dead := deadFn()
70-
connFn := func() (net.Conn, error) {
71-
return dialFn(dst, option)
70+
connFn := func(ctx context.Context) (net.Conn, error) {
71+
return dialFn(ctx, dst, option)
7272
}
73-
wireFn := func(pipeFn pipeFn) func() wire {
74-
return func() (w wire) {
75-
w, err := pipeFn(connFn, option)
73+
wireFn := func(pipeFn pipeFn) func(context.Context) wire {
74+
return func(ctx context.Context) (w wire) {
75+
w, err := pipeFn(ctx, connFn, option)
7676
if err != nil {
7777
dead.error.Store(&errs{error: err})
7878
w = dead
@@ -152,7 +152,7 @@ func (m *mux) Override(cc conn) {
152152
}
153153
}
154154

155-
func (m *mux) _pipe(i uint16) (w wire, err error) {
155+
func (m *mux) _pipe(ctx context.Context, i uint16) (w wire, err error) {
156156
if w = m.wire[i].Load().(wire); w != m.init {
157157
return w, nil
158158
}
@@ -171,7 +171,7 @@ func (m *mux) _pipe(i uint16) (w wire, err error) {
171171
}
172172

173173
if w = m.wire[i].Load().(wire); w == m.init {
174-
if w = m.wireFn(); w != m.dead {
174+
if w = m.wireFn(ctx); w != m.dead {
175175
m.setCloseHookOnWire(i, w)
176176
m.wire[i].Store(w)
177177
} else {
@@ -193,39 +193,39 @@ func (m *mux) _pipe(i uint16) (w wire, err error) {
193193
return w, err
194194
}
195195

196-
func (m *mux) pipe(i uint16) wire {
197-
w, _ := m._pipe(i)
196+
func (m *mux) pipe(ctx context.Context, i uint16) wire {
197+
w, _ := m._pipe(ctx, i)
198198
return w // this should never be nil
199199
}
200200

201201
func (m *mux) Dial() error {
202-
_, err := m._pipe(0)
202+
_, err := m._pipe(context.Background(), 0)
203203
return err
204204
}
205205

206206
func (m *mux) Info() map[string]RedisMessage {
207-
return m.pipe(0).Info()
207+
return m.pipe(context.Background(), 0).Info()
208208
}
209209

210210
func (m *mux) Version() int {
211-
return m.pipe(0).Version()
211+
return m.pipe(context.Background(), 0).Version()
212212
}
213213

214214
func (m *mux) AZ() string {
215-
return m.pipe(0).AZ()
215+
return m.pipe(context.Background(), 0).AZ()
216216
}
217217

218218
func (m *mux) Error() error {
219-
return m.pipe(0).Error()
219+
return m.pipe(context.Background(), 0).Error()
220220
}
221221

222222
func (m *mux) DoStream(ctx context.Context, cmd Completed) RedisResultStream {
223-
wire := m.spool.Acquire()
223+
wire := m.spool.Acquire(ctx)
224224
return wire.DoStream(ctx, m.spool, cmd)
225225
}
226226

227227
func (m *mux) DoMultiStream(ctx context.Context, multi ...Completed) MultiRedisResultStream {
228-
wire := m.spool.Acquire()
228+
wire := m.spool.Acquire(ctx)
229229
return wire.DoMultiStream(ctx, m.spool, multi...)
230230
}
231231

@@ -262,7 +262,7 @@ block:
262262
}
263263

264264
func (m *mux) blocking(pool *pool, ctx context.Context, cmd Completed) (resp RedisResult) {
265-
wire := pool.Acquire()
265+
wire := pool.Acquire(ctx)
266266
resp = wire.Do(ctx, cmd)
267267
if resp.NonRedisError() != nil { // abort the wire if blocking command return early (ex. context.DeadlineExceeded)
268268
wire.Close()
@@ -272,7 +272,7 @@ func (m *mux) blocking(pool *pool, ctx context.Context, cmd Completed) (resp Red
272272
}
273273

274274
func (m *mux) blockingMulti(pool *pool, ctx context.Context, cmd []Completed) (resp *redisresults) {
275-
wire := pool.Acquire()
275+
wire := pool.Acquire(ctx)
276276
resp = wire.DoMulti(ctx, cmd...)
277277
for _, res := range resp.s {
278278
if res.NonRedisError() != nil { // abort the wire if blocking command return early (ex. context.DeadlineExceeded)
@@ -286,7 +286,7 @@ func (m *mux) blockingMulti(pool *pool, ctx context.Context, cmd []Completed) (r
286286

287287
func (m *mux) pipeline(ctx context.Context, cmd Completed) (resp RedisResult) {
288288
slot := slotfn(len(m.wire), cmd.Slot(), cmd.NoReply())
289-
wire := m.pipe(slot)
289+
wire := m.pipe(ctx, slot)
290290
if resp = wire.Do(ctx, cmd); isBroken(resp.NonRedisError(), wire) {
291291
m.wire[slot].CompareAndSwap(wire, m.init)
292292
}
@@ -295,7 +295,7 @@ func (m *mux) pipeline(ctx context.Context, cmd Completed) (resp RedisResult) {
295295

296296
func (m *mux) pipelineMulti(ctx context.Context, cmd []Completed) (resp *redisresults) {
297297
slot := slotfn(len(m.wire), cmd[0].Slot(), cmd[0].NoReply())
298-
wire := m.pipe(slot)
298+
wire := m.pipe(ctx, slot)
299299
resp = wire.DoMulti(ctx, cmd...)
300300
for _, r := range resp.s {
301301
if isBroken(r.NonRedisError(), wire) {
@@ -308,7 +308,7 @@ func (m *mux) pipelineMulti(ctx context.Context, cmd []Completed) (resp *redisre
308308

309309
func (m *mux) DoCache(ctx context.Context, cmd Cacheable, ttl time.Duration) RedisResult {
310310
slot := cmd.Slot() & uint16(len(m.wire)-1)
311-
wire := m.pipe(slot)
311+
wire := m.pipe(ctx, slot)
312312
resp := wire.DoCache(ctx, cmd, ttl)
313313
if isBroken(resp.NonRedisError(), wire) {
314314
m.wire[slot].CompareAndSwap(wire, m.init)
@@ -366,7 +366,7 @@ func (m *mux) DoMultiCache(ctx context.Context, multi ...CacheableTTL) (results
366366
}
367367

368368
func (m *mux) doMultiCache(ctx context.Context, slot uint16, multi []CacheableTTL) (resps *redisresults) {
369-
wire := m.pipe(slot)
369+
wire := m.pipe(ctx, slot)
370370
resps = wire.DoMultiCache(ctx, multi...)
371371
for _, r := range resps.s {
372372
if isBroken(r.NonRedisError(), wire) {
@@ -379,16 +379,16 @@ func (m *mux) doMultiCache(ctx context.Context, slot uint16, multi []CacheableTT
379379

380380
func (m *mux) Receive(ctx context.Context, subscribe Completed, fn func(message PubSubMessage)) error {
381381
slot := slotfn(len(m.wire), subscribe.Slot(), subscribe.NoReply())
382-
wire := m.pipe(slot)
382+
wire := m.pipe(ctx, slot)
383383
err := wire.Receive(ctx, subscribe, fn)
384384
if isBroken(err, wire) {
385385
m.wire[slot].CompareAndSwap(wire, m.init)
386386
}
387387
return err
388388
}
389389

390-
func (m *mux) Acquire() wire {
391-
return m.dpool.Acquire()
390+
func (m *mux) Acquire(ctx context.Context) wire {
391+
return m.dpool.Acquire(ctx)
392392
}
393393

394394
func (m *mux) Store(w wire) {

mux_test.go

+46-18
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ func setupMux(wires []*mockWire) (conn *mux, checkClean func(t *testing.T)) {
2323
func setupMuxWithOption(wires []*mockWire, option *ClientOption) (conn *mux, checkClean func(t *testing.T)) {
2424
var mu sync.Mutex
2525
var count = -1
26-
wfn := func() wire {
26+
wfn := func(_ context.Context) wire {
2727
mu.Lock()
2828
defer mu.Unlock()
2929
count++
@@ -43,26 +43,54 @@ func TestNewMuxDailErr(t *testing.T) {
4343
defer ShouldNotLeaked(SetupLeakDetection())
4444
c := 0
4545
e := errors.New("any")
46-
m := makeMux("", &ClientOption{}, func(dst string, opt *ClientOption) (net.Conn, error) {
46+
m := makeMux("", &ClientOption{}, func(ctx context.Context, dst string, opt *ClientOption) (net.Conn, error) {
47+
timer := time.NewTimer(time.Millisecond*10) // delay time
48+
defer timer.Stop()
49+
select {
50+
case <-ctx.Done():
51+
return nil, ctx.Err()
52+
case <-timer.C:
53+
// noop
54+
}
4755
c++
4856
return nil, e
4957
})
5058
if err := m.Dial(); err != e {
5159
t.Fatalf("unexpected return %v", err)
5260
}
61+
ctx1, cancel1 := context.WithTimeout(context.Background(), time.Millisecond)
62+
defer cancel1()
63+
if _, err := m._pipe(ctx1, 0); err != context.DeadlineExceeded {
64+
t.Fatalf("unexpected return %v", err)
65+
}
5366
if c != 1 {
5467
t.Fatalf("dialFn not called")
5568
}
56-
if w := m.pipe(0); w != m.dead { // c = 2
69+
if w := m.pipe(context.Background(), 0); w != m.dead { // c = 2
70+
t.Fatalf("unexpected wire %v", w)
71+
}
72+
ctx2, cancel2 := context.WithTimeout(context.Background(), time.Millisecond)
73+
defer cancel2()
74+
if w := m.pipe(ctx2, 0); w != m.dead {
5775
t.Fatalf("unexpected wire %v", w)
5876
}
5977
if err := m.Dial(); err != e { // c = 3
6078
t.Fatalf("unexpected return %v", err)
6179
}
62-
if w := m.Acquire(); w != m.dead {
80+
if w := m.Acquire(context.Background()); w != m.dead {
81+
t.Fatalf("unexpected wire %v", w)
82+
}
83+
ctx3, cancel3 := context.WithTimeout(context.Background(), time.Millisecond)
84+
defer cancel3()
85+
if w := m.Acquire(ctx3); w != m.dead {
86+
t.Fatalf("unexpected wire %v", w)
87+
}
88+
ctx4, cancel4 := context.WithTimeout(context.Background(), 20*time.Millisecond)
89+
defer cancel4()
90+
if w := m.Acquire(ctx4); w != m.dead {
6391
t.Fatalf("unexpected wire %v", w)
6492
}
65-
if c != 4 {
93+
if c != 5 {
6694
t.Fatalf("dialFn not called %v", c)
6795
}
6896
}
@@ -89,15 +117,15 @@ func TestNewMux(t *testing.T) {
89117
mock.Expect("PING").ReplyString("OK")
90118
mock.Close()
91119
}()
92-
m := makeMux("", &ClientOption{}, func(dst string, opt *ClientOption) (net.Conn, error) {
120+
m := makeMux("", &ClientOption{}, func(_ context.Context, dst string, opt *ClientOption) (net.Conn, error) {
93121
return n1, nil
94122
})
95123
if err := m.Dial(); err != nil {
96124
t.Fatalf("unexpected error %v", err)
97125
}
98126

99127
t.Run("Override with previous mux", func(t *testing.T) {
100-
m2 := makeMux("", &ClientOption{}, func(dst string, opt *ClientOption) (net.Conn, error) {
128+
m2 := makeMux("", &ClientOption{}, func(_ context.Context, dst string, opt *ClientOption) (net.Conn, error) {
101129
return n1, nil
102130
})
103131
m2.Override(m)
@@ -111,7 +139,7 @@ func TestNewMux(t *testing.T) {
111139
func TestNewMuxPipelineMultiplex(t *testing.T) {
112140
defer ShouldNotLeaked(SetupLeakDetection())
113141
for _, v := range []int{-1, 0, 1, 2} {
114-
m := makeMux("", &ClientOption{PipelineMultiplex: v}, func(dst string, opt *ClientOption) (net.Conn, error) { return nil, nil })
142+
m := makeMux("", &ClientOption{PipelineMultiplex: v}, func(_ context.Context, dst string, opt *ClientOption) (net.Conn, error) { return nil, nil })
115143
if (v < 0 && len(m.wire) != 1) || (v >= 0 && len(m.wire) != 1<<v) {
116144
t.Fatalf("unexpected len(m.wire): %v", len(m.wire))
117145
}
@@ -150,11 +178,11 @@ func TestMuxDialSuppress(t *testing.T) {
150178
defer ShouldNotLeaked(SetupLeakDetection())
151179
var wires, waits, done int64
152180
blocking := make(chan struct{})
153-
m := newMux("", &ClientOption{}, (*mockWire)(nil), (*mockWire)(nil), func() wire {
181+
m := newMux("", &ClientOption{}, (*mockWire)(nil), (*mockWire)(nil), func(_ context.Context) wire {
154182
atomic.AddInt64(&wires, 1)
155183
<-blocking
156184
return &mockWire{}
157-
}, func() wire {
185+
}, func(_ context.Context) wire {
158186
return &mockWire{}
159187
})
160188
for i := 0; i < 1000; i++ {
@@ -222,7 +250,7 @@ func TestMuxReuseWire(t *testing.T) {
222250
t.Fatalf("unexpected dial error %v", err)
223251
}
224252

225-
wire1 := m.dpool.Acquire()
253+
wire1 := m.dpool.Acquire(context.Background())
226254

227255
go func() {
228256
// this should use the second wire
@@ -276,7 +304,7 @@ func TestMuxReuseWire(t *testing.T) {
276304
t.Fatalf("unexpected dial error %v", err)
277305
}
278306

279-
wire1 := m.spool.Acquire()
307+
wire1 := m.spool.Acquire(context.Background())
280308

281309
go func() {
282310
// this should use the second wire
@@ -337,7 +365,7 @@ func TestMuxReuseWire(t *testing.T) {
337365
t.Fatalf("unexpected dial error %v", err)
338366
}
339367

340-
wire1 := m.spool.Acquire()
368+
wire1 := m.spool.Acquire(context.Background())
341369

342370
go func() {
343371
// this should use the second wire
@@ -394,7 +422,7 @@ func TestMuxReuseWire(t *testing.T) {
394422
t.Fatalf("unexpected dial error %v", err)
395423
}
396424

397-
wire1 := m.dpool.Acquire()
425+
wire1 := m.dpool.Acquire(context.Background())
398426

399427
go func() {
400428
// this should use the second wire
@@ -438,7 +466,7 @@ func TestMuxReuseWire(t *testing.T) {
438466
t.Fatalf("unexpected dial error %v", err)
439467
}
440468

441-
wire1 := m.Acquire()
469+
wire1 := m.Acquire(context.Background())
442470
m.Store(wire1)
443471

444472
if !cleaned {
@@ -680,7 +708,7 @@ func TestMuxDelegation(t *testing.T) {
680708
defer m.Close()
681709

682710
for i := range wires {
683-
m._pipe(uint16(i))
711+
m._pipe(context.Background(), uint16(i))
684712
}
685713

686714
builder := cmds.NewBuilder(cmds.NoSlot)
@@ -722,7 +750,7 @@ func TestMuxDelegation(t *testing.T) {
722750
defer m.Close()
723751

724752
for i := range wires {
725-
m._pipe(uint16(i))
753+
m._pipe(context.Background(), uint16(i))
726754
}
727755

728756
builder := cmds.NewBuilder(cmds.NoSlot)
@@ -1058,7 +1086,7 @@ func TestMuxRegisterCloseHook(t *testing.T) {
10581086

10591087
func BenchmarkClientSideCaching(b *testing.B) {
10601088
setup := func(b *testing.B) *mux {
1061-
c := makeMux("127.0.0.1:6379", &ClientOption{CacheSizeEachConn: DefaultCacheBytes}, func(dst string, opt *ClientOption) (conn net.Conn, err error) {
1089+
c := makeMux("127.0.0.1:6379", &ClientOption{CacheSizeEachConn: DefaultCacheBytes}, func(_ context.Context, dst string, opt *ClientOption) (conn net.Conn, err error) {
10621090
return net.Dial("tcp", dst)
10631091
})
10641092
if err := c.Dial(); err != nil {

0 commit comments

Comments
 (0)