From cc3f007aabc427d8dd39a97273173d0b7310f0a7 Mon Sep 17 00:00:00 2001 From: Rueian Date: Thu, 13 Mar 2025 08:37:31 -0700 Subject: [PATCH 1/5] feat: respect request ctx for making tcp connections and add DialCtxFn option Signed-off-by: Rueian --- client.go | 4 +-- client_test.go | 50 ++++++++++++++-------------- cluster.go | 2 +- cluster_test.go | 86 ++++++++++++++++++++++++------------------------ mux.go | 56 +++++++++++++++---------------- mux_test.go | 34 +++++++++---------- pipe.go | 16 ++++----- pipe_test.go | 48 +++++++++++++-------------- pool.go | 9 ++--- pool_test.go | 49 +++++++++++++-------------- rueidis.go | 13 ++++++-- sentinel.go | 4 +-- sentinel_test.go | 8 ++--- 13 files changed, 194 insertions(+), 185 deletions(-) diff --git a/client.go b/client.go index ad69c360..be9230ab 100644 --- a/client.go +++ b/client.go @@ -172,7 +172,7 @@ retry: } func (c *singleClient) Dedicated(fn func(DedicatedClient) error) (err error) { - wire := c.conn.Acquire() + wire := c.conn.Acquire(context.Background()) dsc := &dedicatedSingleClient{cmd: c.cmd, conn: c.conn, wire: wire, retry: c.retry, retryHandler: c.retryHandler} err = fn(dsc) dsc.release() @@ -180,7 +180,7 @@ func (c *singleClient) Dedicated(fn func(DedicatedClient) error) (err error) { } func (c *singleClient) Dedicate() (DedicatedClient, func()) { - wire := c.conn.Acquire() + wire := c.conn.Acquire(context.Background()) dsc := &dedicatedSingleClient{cmd: c.cmd, conn: c.conn, wire: wire, retry: c.retry, retryHandler: c.retryHandler} return dsc, dsc.release } diff --git a/client_test.go b/client_test.go index 23a34aba..0258d54e 100644 --- a/client_test.go +++ b/client_test.go @@ -26,7 +26,7 @@ type mockConn struct { ErrorFn func() error CloseFn func() DialFn func() error - AcquireFn func() wire + AcquireFn func(ctx context.Context) wire StoreFn func(w wire) OverrideFn func(c conn) AddrFn func() string @@ -49,9 +49,9 @@ func (m *mockConn) Dial() error { return nil } -func (m *mockConn) Acquire() wire { +func (m *mockConn) Acquire(ctx context.Context) wire { if m.AcquireFn != nil { - return m.AcquireFn() + return m.AcquireFn(ctx) } return nil } @@ -399,7 +399,7 @@ func TestSingleClient(t *testing.T) { return e }, } - m.AcquireFn = func() wire { + m.AcquireFn = func(_ context.Context) wire { return w } if err := client.Dedicated(func(c DedicatedClient) error { @@ -434,7 +434,7 @@ func TestSingleClient(t *testing.T) { closed = true }, } - m.AcquireFn = func() wire { + m.AcquireFn = func(_ context.Context) wire { return w } stored := false @@ -500,7 +500,7 @@ func TestSingleClient(t *testing.T) { closed = true }, } - m.AcquireFn = func() wire { + m.AcquireFn = func(_ context.Context) wire { return w } stored := false @@ -544,7 +544,7 @@ func TestSingleClient(t *testing.T) { t.Run("Dedicate Delegate Release On Close", func(t *testing.T) { stored := 0 w := &mockWire{} - m.AcquireFn = func() wire { return w } + m.AcquireFn = func(_ context.Context) wire { return w } m.StoreFn = func(ww wire) { stored++ } c, _ := client.Dedicate() @@ -558,7 +558,7 @@ func TestSingleClient(t *testing.T) { t.Run("Dedicate Delegate No Duplicate Release", func(t *testing.T) { stored := 0 w := &mockWire{} - m.AcquireFn = func() wire { return w } + m.AcquireFn = func(_ context.Context) wire { return w } m.StoreFn = func(ww wire) { stored++ } c, cancel := client.Dedicate() @@ -573,7 +573,7 @@ func TestSingleClient(t *testing.T) { }) t.Run("Dedicate ErrDedicatedClientRecycled after released", func(t *testing.T) { - m.AcquireFn = func() wire { return &mockWire{} } + m.AcquireFn = func(_ context.Context) wire { return &mockWire{} } check := func(err error) { if !errors.Is(err, ErrDedicatedClientRecycled) { t.Fatalf("unexpected err %v", err) @@ -1020,7 +1020,7 @@ func SetupClientRetry(t *testing.T, fn func(mock *mockConn) Client) { newErrResult(ErrClosing), newResult(RedisMessage{typ: '+', string: "Do"}, nil), ) - m.AcquireFn = func() wire { return &mockWire{DoFn: m.DoFn} } + m.AcquireFn = func(_ context.Context) wire { return &mockWire{DoFn: m.DoFn} } if ret := c.Dedicated(func(cc DedicatedClient) error { if v, err := cc.Do(context.Background(), c.B().Get().Key("Do").Build()).ToString(); err != nil || v != "Do" { t.Fatalf("unexpected response %v %v", v, err) @@ -1035,7 +1035,7 @@ func SetupClientRetry(t *testing.T, fn func(mock *mockConn) Client) { c, m := setup() m.DoFn = makeDoFn(newErrResult(ErrClosing)) m.ErrorFn = func() error { return ErrClosing } - m.AcquireFn = func() wire { return &mockWire{DoFn: m.DoFn, ErrorFn: m.ErrorFn} } + m.AcquireFn = func(_ context.Context) wire { return &mockWire{DoFn: m.DoFn, ErrorFn: m.ErrorFn} } if ret := c.Dedicated(func(cc DedicatedClient) error { return cc.Do(context.Background(), c.B().Get().Key("Do").Build()).Error() }); ret != ErrClosing { @@ -1046,7 +1046,7 @@ func SetupClientRetry(t *testing.T, fn func(mock *mockConn) Client) { t.Run("Dedicate Delegate Do ReadOnly NoRetry - ctx done", func(t *testing.T) { c, m := setup() m.DoFn = makeDoFn(newErrResult(ErrClosing)) - m.AcquireFn = func() wire { return &mockWire{DoFn: m.DoFn} } + m.AcquireFn = func(_ context.Context) wire { return &mockWire{DoFn: m.DoFn} } ctx, cancel := context.WithCancel(context.Background()) cancel() if ret := c.Dedicated(func(cc DedicatedClient) error { @@ -1062,7 +1062,7 @@ func SetupClientRetry(t *testing.T, fn func(mock *mockConn) Client) { newErrResult(ErrClosing), newResult(RedisMessage{typ: '+', string: "Do"}, nil), ) - m.AcquireFn = func() wire { return &mockWire{DoFn: m.DoFn} } + m.AcquireFn = func(_ context.Context) wire { return &mockWire{DoFn: m.DoFn} } if ret := c.Dedicated(func(cc DedicatedClient) error { if cli, ok := cc.(*dedicatedClusterClient); ok { cli.retryHandler = &mockRetryHandler{ @@ -1091,7 +1091,7 @@ func SetupClientRetry(t *testing.T, fn func(mock *mockConn) Client) { t.Run("Dedicate Delegate Do Write NoRetry", func(t *testing.T) { c, m := setup() m.DoFn = makeDoFn(newErrResult(ErrClosing)) - m.AcquireFn = func() wire { return &mockWire{DoFn: m.DoFn} } + m.AcquireFn = func(_ context.Context) wire { return &mockWire{DoFn: m.DoFn} } if ret := c.Dedicated(func(cc DedicatedClient) error { return cc.Do(context.Background(), c.B().Set().Key("Do").Value("Do").Build()).Error() }); ret != ErrClosing { @@ -1105,7 +1105,7 @@ func SetupClientRetry(t *testing.T, fn func(mock *mockConn) Client) { []RedisResult{newErrResult(ErrClosing)}, []RedisResult{newResult(RedisMessage{typ: '+', string: "Do"}, nil)}, ) - m.AcquireFn = func() wire { return &mockWire{DoMultiFn: m.DoMultiFn} } + m.AcquireFn = func(_ context.Context) wire { return &mockWire{DoMultiFn: m.DoMultiFn} } if ret := c.Dedicated(func(cc DedicatedClient) error { if v, err := cc.DoMulti(context.Background(), c.B().Get().Key("Do").Build())[0].ToString(); err != nil || v != "Do" { t.Fatalf("unexpected response %v %v", v, err) @@ -1120,7 +1120,7 @@ func SetupClientRetry(t *testing.T, fn func(mock *mockConn) Client) { c, m := setup() m.DoMultiFn = makeDoMultiFn([]RedisResult{newErrResult(ErrClosing)}) m.ErrorFn = func() error { return ErrClosing } - m.AcquireFn = func() wire { return &mockWire{DoMultiFn: m.DoMultiFn, ErrorFn: m.ErrorFn} } + m.AcquireFn = func(_ context.Context) wire { return &mockWire{DoMultiFn: m.DoMultiFn, ErrorFn: m.ErrorFn} } if ret := c.Dedicated(func(cc DedicatedClient) error { return cc.DoMulti(context.Background(), c.B().Get().Key("Do").Build())[0].Error() }); ret != ErrClosing { @@ -1131,7 +1131,7 @@ func SetupClientRetry(t *testing.T, fn func(mock *mockConn) Client) { t.Run("Dedicate Delegate DoMulti ReadOnly NoRetry - ctx done", func(t *testing.T) { c, m := setup() m.DoMultiFn = makeDoMultiFn([]RedisResult{newErrResult(ErrClosing)}) - m.AcquireFn = func() wire { return &mockWire{DoMultiFn: m.DoMultiFn} } + m.AcquireFn = func(_ context.Context) wire { return &mockWire{DoMultiFn: m.DoMultiFn} } ctx, cancel := context.WithCancel(context.Background()) cancel() if ret := c.Dedicated(func(cc DedicatedClient) error { @@ -1147,7 +1147,7 @@ func SetupClientRetry(t *testing.T, fn func(mock *mockConn) Client) { []RedisResult{newErrResult(ErrClosing)}, []RedisResult{newResult(RedisMessage{typ: '+', string: "Do"}, nil)}, ) - m.AcquireFn = func() wire { return &mockWire{DoMultiFn: m.DoMultiFn} } + m.AcquireFn = func(_ context.Context) wire { return &mockWire{DoMultiFn: m.DoMultiFn} } if ret := c.Dedicated(func(cc DedicatedClient) error { if cli, ok := cc.(*dedicatedClusterClient); ok { cli.retryHandler = &mockRetryHandler{ @@ -1176,7 +1176,7 @@ func SetupClientRetry(t *testing.T, fn func(mock *mockConn) Client) { t.Run("Dedicate Delegate DoMulti Write NoRetry", func(t *testing.T) { c, m := setup() m.DoMultiFn = makeDoMultiFn([]RedisResult{newErrResult(ErrClosing)}) - m.AcquireFn = func() wire { return &mockWire{DoMultiFn: m.DoMultiFn} } + m.AcquireFn = func(_ context.Context) wire { return &mockWire{DoMultiFn: m.DoMultiFn} } if ret := c.Dedicated(func(cc DedicatedClient) error { return cc.DoMulti(context.Background(), c.B().Set().Key("Do").Value("Do").Build())[0].Error() }); ret != ErrClosing { @@ -1187,7 +1187,7 @@ func SetupClientRetry(t *testing.T, fn func(mock *mockConn) Client) { t.Run("Delegate Receive Retry", func(t *testing.T) { c, m := setup() m.ReceiveFn = makeReceiveFn(ErrClosing, nil) - m.AcquireFn = func() wire { return &mockWire{ReceiveFn: m.ReceiveFn} } + m.AcquireFn = func(_ context.Context) wire { return &mockWire{ReceiveFn: m.ReceiveFn} } if ret := c.Dedicated(func(cc DedicatedClient) error { if err := cc.Receive(context.Background(), c.B().Subscribe().Channel("Do").Build(), nil); err != nil { t.Fatalf("unexpected response %v", err) @@ -1202,7 +1202,7 @@ func SetupClientRetry(t *testing.T, fn func(mock *mockConn) Client) { c, m := setup() m.ReceiveFn = makeReceiveFn(ErrClosing) m.ErrorFn = func() error { return ErrClosing } - m.AcquireFn = func() wire { return &mockWire{ReceiveFn: m.ReceiveFn, ErrorFn: m.ErrorFn} } + m.AcquireFn = func(_ context.Context) wire { return &mockWire{ReceiveFn: m.ReceiveFn, ErrorFn: m.ErrorFn} } if ret := c.Dedicated(func(cc DedicatedClient) error { return cc.Receive(context.Background(), c.B().Subscribe().Channel("Do").Build(), nil) }); ret != ErrClosing { @@ -1213,7 +1213,7 @@ func SetupClientRetry(t *testing.T, fn func(mock *mockConn) Client) { t.Run("Delegate Receive NoRetry - ctx done", func(t *testing.T) { c, m := setup() m.ReceiveFn = makeReceiveFn(ErrClosing) - m.AcquireFn = func() wire { return &mockWire{ReceiveFn: m.ReceiveFn} } + m.AcquireFn = func(_ context.Context) wire { return &mockWire{ReceiveFn: m.ReceiveFn} } ctx, cancel := context.WithCancel(context.Background()) cancel() if ret := c.Dedicated(func(cc DedicatedClient) error { @@ -1226,7 +1226,7 @@ func SetupClientRetry(t *testing.T, fn func(mock *mockConn) Client) { t.Run("Delegate Receive NoRetry - not retryable", func(t *testing.T) { c, m := setup() m.ReceiveFn = makeReceiveFn(ErrClosing, nil) - m.AcquireFn = func() wire { return &mockWire{ReceiveFn: m.ReceiveFn} } + m.AcquireFn = func(_ context.Context) wire { return &mockWire{ReceiveFn: m.ReceiveFn} } if ret := c.Dedicated(func(cc DedicatedClient) error { if cli, ok := cc.(*dedicatedClusterClient); ok { cli.retryHandler = &mockRetryHandler{ @@ -1377,7 +1377,7 @@ func TestSingleClientLoadingRetry(t *testing.T) { } return newResult(RedisMessage{typ: '+', string: "OK"}, nil) } - m.AcquireFn = func() wire { return &mockWire{DoFn: m.DoFn} } + m.AcquireFn = func(_ context.Context) wire { return &mockWire{DoFn: m.DoFn} } err := client.Dedicated(func(c DedicatedClient) error { if v, err := c.Do(context.Background(), c.B().Get().Key("test").Build()).ToString(); err != nil || v != "OK" { @@ -1400,7 +1400,7 @@ func TestSingleClientLoadingRetry(t *testing.T) { } return &redisresults{s: []RedisResult{newResult(RedisMessage{typ: '+', string: "OK"}, nil)}} } - m.AcquireFn = func() wire { return &mockWire{DoMultiFn: m.DoMultiFn} } + m.AcquireFn = func(_ context.Context) wire { return &mockWire{DoMultiFn: m.DoMultiFn} } err := client.Dedicated(func(c DedicatedClient) error { resps := c.DoMulti(context.Background(), c.B().Get().Key("test").Build()) diff --git a/cluster.go b/cluster.go index 2ea2db08..84ca522b 100644 --- a/cluster.go +++ b/cluster.go @@ -1274,7 +1274,7 @@ func (c *dedicatedClusterClient) acquire(ctx context.Context, slot uint16) (wire } return nil, err } - c.wire = c.conn.Acquire() + c.wire = c.conn.Acquire(ctx) if p := c.pshks; p != nil { c.pshks = nil ch := c.wire.SetPubSubHooks(p.hooks) diff --git a/cluster_test.go b/cluster_test.go index dae2c788..9bc132e0 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -1799,7 +1799,7 @@ func TestClusterClient(t *testing.T) { t.Errorf("Dedicated should panic if cross slots is used") } }() - m.AcquireFn = func() wire { return &mockWire{} } + m.AcquireFn = func(_ context.Context) wire { return &mockWire{} } client.Dedicated(func(c DedicatedClient) error { c.Do(context.Background(), c.B().Get().Key("a").Build()).Error() return c.Do(context.Background(), c.B().Get().Key("b").Build()).Error() @@ -1812,7 +1812,7 @@ func TestClusterClient(t *testing.T) { t.Errorf("Dedicated should panic if cross slots is used") } }() - m.AcquireFn = func() wire { + m.AcquireFn = func(_ context.Context) wire { return &mockWire{ DoMultiFn: func(multi ...Completed) *redisresults { return &redisresults{s: []RedisResult{ @@ -1841,7 +1841,7 @@ func TestClusterClient(t *testing.T) { }) t.Run("Dedicated Multi Cross Slot Err", func(t *testing.T) { - m.AcquireFn = func() wire { return &mockWire{} } + m.AcquireFn = func(_ context.Context) wire { return &mockWire{} } err := client.Dedicated(func(c DedicatedClient) (err error) { defer func() { err = errors.New(recover().(string)) @@ -1865,7 +1865,7 @@ func TestClusterClient(t *testing.T) { return e }, } - m.AcquireFn = func() wire { + m.AcquireFn = func(_ context.Context) wire { return w } if err := client.Dedicated(func(c DedicatedClient) error { @@ -1914,7 +1914,7 @@ func TestClusterClient(t *testing.T) { closed = true }, } - m.AcquireFn = func() wire { + m.AcquireFn = func(_ context.Context) wire { return w } stored := false @@ -2022,7 +2022,7 @@ func TestClusterClient(t *testing.T) { closed = true }, } - m.AcquireFn = func() wire { + m.AcquireFn = func(_ context.Context) wire { return w } stored := false @@ -2092,7 +2092,7 @@ func TestClusterClient(t *testing.T) { t.Run("Dedicate Delegate Release On Close", func(t *testing.T) { stored := 0 w := &mockWire{} - m.AcquireFn = func() wire { return w } + m.AcquireFn = func(_ context.Context) wire { return w } m.StoreFn = func(ww wire) { stored++ } c, _ := client.Dedicate() c.Do(context.Background(), c.B().Get().Key("a").Build()) @@ -2107,7 +2107,7 @@ func TestClusterClient(t *testing.T) { t.Run("Dedicate Delegate No Duplicate Release", func(t *testing.T) { stored := 0 w := &mockWire{} - m.AcquireFn = func() wire { return w } + m.AcquireFn = func(_ context.Context) wire { return w } m.StoreFn = func(ww wire) { stored++ } c, cancel := client.Dedicate() c.Do(context.Background(), c.B().Get().Key("a").Build()) @@ -2399,7 +2399,7 @@ func TestClusterClient_SendToOnlyPrimaryNodes(t *testing.T) { t.Errorf("Dedicated should panic if cross slots is used") } }() - primaryNodeConn.AcquireFn = func() wire { return &mockWire{} } + primaryNodeConn.AcquireFn = func(_ context.Context) wire { return &mockWire{} } client.Dedicated(func(c DedicatedClient) error { c.Do(context.Background(), c.B().Get().Key("a").Build()).Error() return c.Do(context.Background(), c.B().Get().Key("b").Build()).Error() @@ -2412,7 +2412,7 @@ func TestClusterClient_SendToOnlyPrimaryNodes(t *testing.T) { t.Errorf("Dedicated should panic if cross slots is used") } }() - primaryNodeConn.AcquireFn = func() wire { + primaryNodeConn.AcquireFn = func(_ context.Context) wire { return &mockWire{ DoMultiFn: func(multi ...Completed) *redisresults { return &redisresults{s: []RedisResult{ @@ -2441,7 +2441,7 @@ func TestClusterClient_SendToOnlyPrimaryNodes(t *testing.T) { }) t.Run("Dedicated Multi Cross Slot Err", func(t *testing.T) { - primaryNodeConn.AcquireFn = func() wire { return &mockWire{} } + primaryNodeConn.AcquireFn = func(_ context.Context) wire { return &mockWire{} } err := client.Dedicated(func(c DedicatedClient) (err error) { defer func() { err = errors.New(recover().(string)) @@ -2465,7 +2465,7 @@ func TestClusterClient_SendToOnlyPrimaryNodes(t *testing.T) { return e }, } - primaryNodeConn.AcquireFn = func() wire { + primaryNodeConn.AcquireFn = func(_ context.Context) wire { return w } if err := client.Dedicated(func(c DedicatedClient) error { @@ -2514,7 +2514,7 @@ func TestClusterClient_SendToOnlyPrimaryNodes(t *testing.T) { closed = true }, } - primaryNodeConn.AcquireFn = func() wire { + primaryNodeConn.AcquireFn = func(_ context.Context) wire { return w } stored := false @@ -2622,7 +2622,7 @@ func TestClusterClient_SendToOnlyPrimaryNodes(t *testing.T) { closed = true }, } - primaryNodeConn.AcquireFn = func() wire { + primaryNodeConn.AcquireFn = func(_ context.Context) wire { return w } stored := false @@ -2888,7 +2888,7 @@ func TestClusterClient_SendToOnlyReplicaNodes(t *testing.T) { t.Errorf("Dedicated should panic if cross slots is used") } }() - primaryNodeConn.AcquireFn = func() wire { return &mockWire{} } + primaryNodeConn.AcquireFn = func(_ context.Context) wire { return &mockWire{} } client.Dedicated(func(c DedicatedClient) error { c.Do(context.Background(), c.B().Get().Key("a").Build()).Error() return c.Do(context.Background(), c.B().Get().Key("b").Build()).Error() @@ -2901,7 +2901,7 @@ func TestClusterClient_SendToOnlyReplicaNodes(t *testing.T) { t.Errorf("Dedicated should panic if cross slots is used") } }() - primaryNodeConn.AcquireFn = func() wire { + primaryNodeConn.AcquireFn = func(_ context.Context) wire { return &mockWire{ DoMultiFn: func(multi ...Completed) *redisresults { return &redisresults{s: []RedisResult{ @@ -2930,7 +2930,7 @@ func TestClusterClient_SendToOnlyReplicaNodes(t *testing.T) { }) t.Run("Dedicated Multi Cross Slot Err", func(t *testing.T) { - primaryNodeConn.AcquireFn = func() wire { return &mockWire{} } + primaryNodeConn.AcquireFn = func(_ context.Context) wire { return &mockWire{} } err := client.Dedicated(func(c DedicatedClient) (err error) { defer func() { err = errors.New(recover().(string)) @@ -2954,10 +2954,10 @@ func TestClusterClient_SendToOnlyReplicaNodes(t *testing.T) { return e }, } - primaryNodeConn.AcquireFn = func() wire { + primaryNodeConn.AcquireFn = func(_ context.Context) wire { return w } - replicaNodeConn.AcquireFn = func() wire { + replicaNodeConn.AcquireFn = func(_ context.Context) wire { return w } // Subscribe can work on replicas if err := client.Dedicated(func(c DedicatedClient) error { @@ -3006,7 +3006,7 @@ func TestClusterClient_SendToOnlyReplicaNodes(t *testing.T) { closed = true }, } - primaryNodeConn.AcquireFn = func() wire { + primaryNodeConn.AcquireFn = func(_ context.Context) wire { return w } stored := false @@ -3114,7 +3114,7 @@ func TestClusterClient_SendToOnlyReplicaNodes(t *testing.T) { closed = true }, } - primaryNodeConn.AcquireFn = func() wire { + primaryNodeConn.AcquireFn = func(_ context.Context) wire { return w } stored := false @@ -3478,7 +3478,7 @@ func TestClusterClient_SendReadOperationToReplicaNodesWriteOperationToPrimaryNod t.Errorf("Dedicated should panic if cross slots is used") } }() - primaryNodeConn.AcquireFn = func() wire { return &mockWire{} } + primaryNodeConn.AcquireFn = func(_ context.Context) wire { return &mockWire{} } client.Dedicated(func(c DedicatedClient) error { c.Do(context.Background(), c.B().Get().Key("a").Build()).Error() return c.Do(context.Background(), c.B().Get().Key("b").Build()).Error() @@ -3491,7 +3491,7 @@ func TestClusterClient_SendReadOperationToReplicaNodesWriteOperationToPrimaryNod t.Errorf("Dedicated should panic if cross slots is used") } }() - primaryNodeConn.AcquireFn = func() wire { + primaryNodeConn.AcquireFn = func(_ context.Context) wire { return &mockWire{ DoMultiFn: func(multi ...Completed) *redisresults { return &redisresults{s: []RedisResult{ @@ -3520,7 +3520,7 @@ func TestClusterClient_SendReadOperationToReplicaNodesWriteOperationToPrimaryNod }) t.Run("Dedicated Multi Cross Slot Err", func(t *testing.T) { - primaryNodeConn.AcquireFn = func() wire { return &mockWire{} } + primaryNodeConn.AcquireFn = func(_ context.Context) wire { return &mockWire{} } err := client.Dedicated(func(c DedicatedClient) (err error) { defer func() { err = errors.New(recover().(string)) @@ -3544,10 +3544,10 @@ func TestClusterClient_SendReadOperationToReplicaNodesWriteOperationToPrimaryNod return e }, } - primaryNodeConn.AcquireFn = func() wire { + primaryNodeConn.AcquireFn = func(_ context.Context) wire { return w } - replicaNodeConn.AcquireFn = func() wire { + replicaNodeConn.AcquireFn = func(_ context.Context) wire { return w } // Subscribe can work on replicas if err := client.Dedicated(func(c DedicatedClient) error { @@ -3596,7 +3596,7 @@ func TestClusterClient_SendReadOperationToReplicaNodesWriteOperationToPrimaryNod closed = true }, } - primaryNodeConn.AcquireFn = func() wire { + primaryNodeConn.AcquireFn = func(_ context.Context) wire { return w } stored := false @@ -3704,7 +3704,7 @@ func TestClusterClient_SendReadOperationToReplicaNodesWriteOperationToPrimaryNod closed = true }, } - primaryNodeConn.AcquireFn = func() wire { + primaryNodeConn.AcquireFn = func(_ context.Context) wire { return w } stored := false @@ -6306,7 +6306,7 @@ func TestClusterClientLoadingRetry(t *testing.T) { } return newResult(RedisMessage{typ: '+', string: "OK"}, nil) } - m.AcquireFn = func() wire { return &mockWire{DoFn: m.DoFn} } + m.AcquireFn = func(_ context.Context) wire { return &mockWire{DoFn: m.DoFn} } err := client.Dedicated(func(c DedicatedClient) error { if v, err := c.Do(context.Background(), c.B().Get().Key("test").Build()).ToString(); err != nil || v != "OK" { @@ -6329,7 +6329,7 @@ func TestClusterClientLoadingRetry(t *testing.T) { } return &redisresults{s: []RedisResult{newResult(RedisMessage{typ: '+', string: "OK"}, nil)}} } - m.AcquireFn = func() wire { return &mockWire{DoMultiFn: m.DoMultiFn} } + m.AcquireFn = func(_ context.Context) wire { return &mockWire{DoMultiFn: m.DoMultiFn} } err := client.Dedicated(func(c DedicatedClient) error { resps := c.DoMulti(context.Background(), c.B().Get().Key("test").Build()) @@ -6783,7 +6783,7 @@ func TestClusterClient_SendReadOperationToReplicaNodeWriteOperationToPrimaryNode t.Errorf("Dedicated should panic if cross slots is used") } }() - primaryNodeConn.AcquireFn = func() wire { return &mockWire{} } + primaryNodeConn.AcquireFn = func(_ context.Context) wire { return &mockWire{} } client.Dedicated(func(c DedicatedClient) error { c.Do(context.Background(), c.B().Get().Key("a").Build()).Error() return c.Do(context.Background(), c.B().Get().Key("b").Build()).Error() @@ -6796,7 +6796,7 @@ func TestClusterClient_SendReadOperationToReplicaNodeWriteOperationToPrimaryNode t.Errorf("Dedicated should panic if cross slots is used") } }() - primaryNodeConn.AcquireFn = func() wire { + primaryNodeConn.AcquireFn = func(_ context.Context) wire { return &mockWire{ DoMultiFn: func(multi ...Completed) *redisresults { return &redisresults{s: []RedisResult{ @@ -6825,7 +6825,7 @@ func TestClusterClient_SendReadOperationToReplicaNodeWriteOperationToPrimaryNode }) t.Run("Dedicated Multi Cross Slot Err", func(t *testing.T) { - primaryNodeConn.AcquireFn = func() wire { return &mockWire{} } + primaryNodeConn.AcquireFn = func(_ context.Context) wire { return &mockWire{} } err := client.Dedicated(func(c DedicatedClient) (err error) { defer func() { err = errors.New(recover().(string)) @@ -6849,8 +6849,8 @@ func TestClusterClient_SendReadOperationToReplicaNodeWriteOperationToPrimaryNode return e }, } - primaryNodeConn.AcquireFn = func() wire { return w } - replicaNodeConn.AcquireFn = func() wire { return w } // Subscribe can work on replicas + primaryNodeConn.AcquireFn = func(_ context.Context) wire { return w } + replicaNodeConn.AcquireFn = func(_ context.Context) wire { return w } // Subscribe can work on replicas if err := client.Dedicated(func(c DedicatedClient) error { return c.Receive(context.Background(), c.B().Subscribe().Channel("a").Build(), func(msg PubSubMessage) {}) }); err != e { @@ -6897,7 +6897,7 @@ func TestClusterClient_SendReadOperationToReplicaNodeWriteOperationToPrimaryNode closed = true }, } - primaryNodeConn.AcquireFn = func() wire { + primaryNodeConn.AcquireFn = func(_ context.Context) wire { return w } stored := false @@ -7005,7 +7005,7 @@ func TestClusterClient_SendReadOperationToReplicaNodeWriteOperationToPrimaryNode closed = true }, } - primaryNodeConn.AcquireFn = func() wire { + primaryNodeConn.AcquireFn = func(_ context.Context) wire { return w } stored := false @@ -7284,7 +7284,7 @@ func TestClusterClient_SendToOnlyPrimaryNodeWhenPrimaryNodeSelected(t *testing.T t.Errorf("Dedicated should panic if cross slots is used") } }() - primaryNodeConn.AcquireFn = func() wire { return &mockWire{} } + primaryNodeConn.AcquireFn = func(_ context.Context) wire { return &mockWire{} } client.Dedicated(func(c DedicatedClient) error { c.Do(context.Background(), c.B().Get().Key("a").Build()).Error() return c.Do(context.Background(), c.B().Get().Key("b").Build()).Error() @@ -7297,7 +7297,7 @@ func TestClusterClient_SendToOnlyPrimaryNodeWhenPrimaryNodeSelected(t *testing.T t.Errorf("Dedicated should panic if cross slots is used") } }() - primaryNodeConn.AcquireFn = func() wire { + primaryNodeConn.AcquireFn = func(_ context.Context) wire { return &mockWire{ DoMultiFn: func(multi ...Completed) *redisresults { return &redisresults{s: []RedisResult{ @@ -7326,7 +7326,7 @@ func TestClusterClient_SendToOnlyPrimaryNodeWhenPrimaryNodeSelected(t *testing.T }) t.Run("Dedicated Multi Cross Slot Err", func(t *testing.T) { - primaryNodeConn.AcquireFn = func() wire { return &mockWire{} } + primaryNodeConn.AcquireFn = func(_ context.Context) wire { return &mockWire{} } err := client.Dedicated(func(c DedicatedClient) (err error) { defer func() { err = errors.New(recover().(string)) @@ -7350,7 +7350,7 @@ func TestClusterClient_SendToOnlyPrimaryNodeWhenPrimaryNodeSelected(t *testing.T return e }, } - primaryNodeConn.AcquireFn = func() wire { + primaryNodeConn.AcquireFn = func(_ context.Context) wire { return w } if err := client.Dedicated(func(c DedicatedClient) error { @@ -7399,7 +7399,7 @@ func TestClusterClient_SendToOnlyPrimaryNodeWhenPrimaryNodeSelected(t *testing.T closed = true }, } - primaryNodeConn.AcquireFn = func() wire { + primaryNodeConn.AcquireFn = func(_ context.Context) wire { return w } stored := false @@ -7507,7 +7507,7 @@ func TestClusterClient_SendToOnlyPrimaryNodeWhenPrimaryNodeSelected(t *testing.T closed = true }, } - primaryNodeConn.AcquireFn = func() wire { + primaryNodeConn.AcquireFn = func(_ context.Context) wire { return w } stored := false diff --git a/mux.go b/mux.go index dcc680a4..5da2598d 100644 --- a/mux.go +++ b/mux.go @@ -13,8 +13,8 @@ import ( ) type connFn func(dst string, opt *ClientOption) conn -type dialFn func(dst string, opt *ClientOption) (net.Conn, error) -type wireFn func() wire +type dialFn func(ctx context.Context, dst string, opt *ClientOption) (net.Conn, error) +type wireFn func(ctx context.Context) wire type singleconnect struct { w wire @@ -37,7 +37,7 @@ type conn interface { Close() Dial() error Override(conn) - Acquire() wire + Acquire(ctx context.Context) wire Store(w wire) Addr() string SetOnCloseHook(func(error)) @@ -64,12 +64,12 @@ type mux struct { func makeMux(dst string, option *ClientOption, dialFn dialFn) *mux { dead := deadFn() - connFn := func() (net.Conn, error) { - return dialFn(dst, option) + connFn := func(ctx context.Context) (net.Conn, error) { + return dialFn(ctx, dst, option) } - wireFn := func(pipeFn pipeFn) func() wire { - return func() (w wire) { - w, err := pipeFn(connFn, option) + wireFn := func(pipeFn pipeFn) func(context.Context) wire { + return func(ctx context.Context) (w wire) { + w, err := pipeFn(ctx, connFn, option) if err != nil { dead.error.Store(&errs{error: err}) w = dead @@ -132,7 +132,7 @@ func (m *mux) Override(cc conn) { } } -func (m *mux) _pipe(i uint16) (w wire, err error) { +func (m *mux) _pipe(ctx context.Context, i uint16) (w wire, err error) { if w = m.wire[i].Load().(wire); w != m.init { return w, nil } @@ -151,7 +151,7 @@ func (m *mux) _pipe(i uint16) (w wire, err error) { } if w = m.wire[i].Load().(wire); w == m.init { - if w = m.wireFn(); w != m.dead { + if w = m.wireFn(ctx); w != m.dead { m.setCloseHookOnWire(i, w) m.wire[i].Store(w) } else { @@ -173,39 +173,39 @@ func (m *mux) _pipe(i uint16) (w wire, err error) { return w, err } -func (m *mux) pipe(i uint16) wire { - w, _ := m._pipe(i) +func (m *mux) pipe(ctx context.Context, i uint16) wire { + w, _ := m._pipe(ctx, i) return w // this should never be nil } func (m *mux) Dial() error { - _, err := m._pipe(0) + _, err := m._pipe(context.Background(), 0) return err } func (m *mux) Info() map[string]RedisMessage { - return m.pipe(0).Info() + return m.pipe(context.Background(), 0).Info() } func (m *mux) Version() int { - return m.pipe(0).Version() + return m.pipe(context.Background(), 0).Version() } func (m *mux) AZ() string { - return m.pipe(0).AZ() + return m.pipe(context.Background(), 0).AZ() } func (m *mux) Error() error { - return m.pipe(0).Error() + return m.pipe(context.Background(), 0).Error() } func (m *mux) DoStream(ctx context.Context, cmd Completed) RedisResultStream { - wire := m.spool.Acquire() + wire := m.spool.Acquire(ctx) return wire.DoStream(ctx, m.spool, cmd) } func (m *mux) DoMultiStream(ctx context.Context, multi ...Completed) MultiRedisResultStream { - wire := m.spool.Acquire() + wire := m.spool.Acquire(ctx) return wire.DoMultiStream(ctx, m.spool, multi...) } @@ -242,7 +242,7 @@ block: } func (m *mux) blocking(pool *pool, ctx context.Context, cmd Completed) (resp RedisResult) { - wire := pool.Acquire() + wire := pool.Acquire(ctx) resp = wire.Do(ctx, cmd) if resp.NonRedisError() != nil { // abort the wire if blocking command return early (ex. context.DeadlineExceeded) wire.Close() @@ -252,7 +252,7 @@ func (m *mux) blocking(pool *pool, ctx context.Context, cmd Completed) (resp Red } func (m *mux) blockingMulti(pool *pool, ctx context.Context, cmd []Completed) (resp *redisresults) { - wire := pool.Acquire() + wire := pool.Acquire(ctx) resp = wire.DoMulti(ctx, cmd...) for _, res := range resp.s { if res.NonRedisError() != nil { // abort the wire if blocking command return early (ex. context.DeadlineExceeded) @@ -266,7 +266,7 @@ func (m *mux) blockingMulti(pool *pool, ctx context.Context, cmd []Completed) (r func (m *mux) pipeline(ctx context.Context, cmd Completed) (resp RedisResult) { slot := slotfn(len(m.wire), cmd.Slot(), cmd.NoReply()) - wire := m.pipe(slot) + wire := m.pipe(ctx, slot) if resp = wire.Do(ctx, cmd); isBroken(resp.NonRedisError(), wire) { m.wire[slot].CompareAndSwap(wire, m.init) } @@ -275,7 +275,7 @@ func (m *mux) pipeline(ctx context.Context, cmd Completed) (resp RedisResult) { func (m *mux) pipelineMulti(ctx context.Context, cmd []Completed) (resp *redisresults) { slot := slotfn(len(m.wire), cmd[0].Slot(), cmd[0].NoReply()) - wire := m.pipe(slot) + wire := m.pipe(ctx, slot) resp = wire.DoMulti(ctx, cmd...) for _, r := range resp.s { if isBroken(r.NonRedisError(), wire) { @@ -288,7 +288,7 @@ func (m *mux) pipelineMulti(ctx context.Context, cmd []Completed) (resp *redisre func (m *mux) DoCache(ctx context.Context, cmd Cacheable, ttl time.Duration) RedisResult { slot := cmd.Slot() & uint16(len(m.wire)-1) - wire := m.pipe(slot) + wire := m.pipe(ctx, slot) resp := wire.DoCache(ctx, cmd, ttl) if isBroken(resp.NonRedisError(), wire) { m.wire[slot].CompareAndSwap(wire, m.init) @@ -346,7 +346,7 @@ func (m *mux) DoMultiCache(ctx context.Context, multi ...CacheableTTL) (results } func (m *mux) doMultiCache(ctx context.Context, slot uint16, multi []CacheableTTL) (resps *redisresults) { - wire := m.pipe(slot) + wire := m.pipe(ctx, slot) resps = wire.DoMultiCache(ctx, multi...) for _, r := range resps.s { if isBroken(r.NonRedisError(), wire) { @@ -359,7 +359,7 @@ func (m *mux) doMultiCache(ctx context.Context, slot uint16, multi []CacheableTT func (m *mux) Receive(ctx context.Context, subscribe Completed, fn func(message PubSubMessage)) error { slot := slotfn(len(m.wire), subscribe.Slot(), subscribe.NoReply()) - wire := m.pipe(slot) + wire := m.pipe(ctx, slot) err := wire.Receive(ctx, subscribe, fn) if isBroken(err, wire) { m.wire[slot].CompareAndSwap(wire, m.init) @@ -367,8 +367,8 @@ func (m *mux) Receive(ctx context.Context, subscribe Completed, fn func(message return err } -func (m *mux) Acquire() wire { - return m.dpool.Acquire() +func (m *mux) Acquire(ctx context.Context) wire { + return m.dpool.Acquire(ctx) } func (m *mux) Store(w wire) { diff --git a/mux_test.go b/mux_test.go index 4f801322..10139413 100644 --- a/mux_test.go +++ b/mux_test.go @@ -23,7 +23,7 @@ func setupMux(wires []*mockWire) (conn *mux, checkClean func(t *testing.T)) { func setupMuxWithOption(wires []*mockWire, option *ClientOption) (conn *mux, checkClean func(t *testing.T)) { var mu sync.Mutex var count = -1 - wfn := func() wire { + wfn := func(_ context.Context) wire { mu.Lock() defer mu.Unlock() count++ @@ -43,7 +43,7 @@ func TestNewMuxDailErr(t *testing.T) { defer ShouldNotLeaked(SetupLeakDetection()) c := 0 e := errors.New("any") - m := makeMux("", &ClientOption{}, func(dst string, opt *ClientOption) (net.Conn, error) { + m := makeMux("", &ClientOption{}, func(_ context.Context, dst string, opt *ClientOption) (net.Conn, error) { c++ return nil, e }) @@ -53,13 +53,13 @@ func TestNewMuxDailErr(t *testing.T) { if c != 1 { t.Fatalf("dialFn not called") } - if w := m.pipe(0); w != m.dead { // c = 2 + if w := m.pipe(context.Background(), 0); w != m.dead { // c = 2 t.Fatalf("unexpected wire %v", w) } if err := m.Dial(); err != e { // c = 3 t.Fatalf("unexpected return %v", err) } - if w := m.Acquire(); w != m.dead { + if w := m.Acquire(context.Background()); w != m.dead { t.Fatalf("unexpected wire %v", w) } if c != 4 { @@ -89,7 +89,7 @@ func TestNewMux(t *testing.T) { mock.Expect("PING").ReplyString("OK") mock.Close() }() - m := makeMux("", &ClientOption{}, func(dst string, opt *ClientOption) (net.Conn, error) { + m := makeMux("", &ClientOption{}, func(_ context.Context, dst string, opt *ClientOption) (net.Conn, error) { return n1, nil }) if err := m.Dial(); err != nil { @@ -97,7 +97,7 @@ func TestNewMux(t *testing.T) { } t.Run("Override with previous mux", func(t *testing.T) { - m2 := makeMux("", &ClientOption{}, func(dst string, opt *ClientOption) (net.Conn, error) { + m2 := makeMux("", &ClientOption{}, func(_ context.Context, dst string, opt *ClientOption) (net.Conn, error) { return n1, nil }) m2.Override(m) @@ -111,7 +111,7 @@ func TestNewMux(t *testing.T) { func TestNewMuxPipelineMultiplex(t *testing.T) { defer ShouldNotLeaked(SetupLeakDetection()) for _, v := range []int{-1, 0, 1, 2} { - m := makeMux("", &ClientOption{PipelineMultiplex: v}, func(dst string, opt *ClientOption) (net.Conn, error) { return nil, nil }) + m := makeMux("", &ClientOption{PipelineMultiplex: v}, func(_ context.Context, dst string, opt *ClientOption) (net.Conn, error) { return nil, nil }) if (v < 0 && len(m.wire) != 1) || (v >= 0 && len(m.wire) != 1< 100 { t.Fatalf("pool must not exceed the size limit") @@ -103,11 +104,11 @@ func TestPool(t *testing.T) { conn := make([]wire, 100) pool, _ := setup(len(conn)) for i := 0; i < len(conn); i++ { - w := pool.Acquire() + w := pool.Acquire(context.Background()) go pool.Store(w) } for i := 0; i < len(conn); i++ { - conn[i] = pool.Acquire() + conn[i] = pool.Acquire(context.Background()) } for i := 0; i < len(conn); i++ { for j := i + 1; j < len(conn); j++ { @@ -120,8 +121,8 @@ func TestPool(t *testing.T) { t.Run("Close", func(t *testing.T) { pool, count := setup(2) - w1 := pool.Acquire() - w2 := pool.Acquire() + w1 := pool.Acquire(context.Background()) + w2 := pool.Acquire(context.Background()) if w1.Error() != nil { t.Fatalf("unexpected err %v", w1.Error()) } @@ -137,7 +138,7 @@ func TestPool(t *testing.T) { t.Fatalf("pool does not close existing wire after Close()") } for i := 0; i < 100; i++ { - if rw := pool.Acquire(); rw != dead { + if rw := pool.Acquire(context.Background()); rw != dead { t.Fatalf("pool does not return the dead wire after Close()") } } @@ -149,12 +150,12 @@ func TestPool(t *testing.T) { t.Run("Close Empty", func(t *testing.T) { pool, count := setup(2) - w1 := pool.Acquire() + w1 := pool.Acquire(context.Background()) if w1.Error() != nil { t.Fatalf("unexpected err %v", w1.Error()) } pool.Close() - w2 := pool.Acquire() + w2 := pool.Acquire(context.Background()) if w2.Error() != ErrClosing { t.Fatalf("pool does not close wire after Close()") } @@ -162,7 +163,7 @@ func TestPool(t *testing.T) { t.Fatalf("pool should not make new wire") } for i := 0; i < 100; i++ { - if rw := pool.Acquire(); rw != dead { + if rw := pool.Acquire(context.Background()); rw != dead { t.Fatalf("pool does not return the dead wire after Close()") } } @@ -174,7 +175,7 @@ func TestPool(t *testing.T) { t.Run("Close Waiting", func(t *testing.T) { pool, count := setup(1) - w1 := pool.Acquire() + w1 := pool.Acquire(context.Background()) if w1.Error() != nil { t.Fatalf("unexpected err %v", w1.Error()) } @@ -182,7 +183,7 @@ func TestPool(t *testing.T) { for i := 0; i < 100; i++ { go func() { atomic.AddInt64(&pending, 1) - if rw := pool.Acquire(); rw != dead { + if rw := pool.Acquire(context.Background()); rw != dead { t.Errorf("pool does not return the dead wire after Close()") } atomic.AddInt64(&pending, -1) @@ -209,7 +210,7 @@ func TestPoolError(t *testing.T) { defer ShouldNotLeaked(SetupLeakDetection()) setup := func(size int) (*pool, *int32) { var count int32 - return newPool(size, dead, 0, 0, func() wire { + return newPool(size, dead, 0, 0, func(_ context.Context) wire { w := &pipe{} w.pshks.Store(emptypshks) c := atomic.AddInt32(&count, 1) @@ -224,7 +225,7 @@ func TestPoolError(t *testing.T) { conn := make([]wire, 100) pool, count := setup(len(conn)) for i := 0; i < len(conn); i++ { - conn[i] = pool.Acquire() + conn[i] = pool.Acquire(context.Background()) } if atomic.LoadInt32(count) != int32(len(conn)) { t.Fatalf("unexpected acquire count") @@ -233,7 +234,7 @@ func TestPoolError(t *testing.T) { pool.Store(conn[i]) } for i := 0; i < len(conn); i++ { - conn[i] = pool.Acquire() + conn[i] = pool.Acquire(context.Background()) } if atomic.LoadInt32(count) != int32(len(conn)+len(conn)/2) { t.Fatalf("unexpected acquire count") @@ -244,7 +245,7 @@ func TestPoolError(t *testing.T) { func TestPoolWithIdleTTL(t *testing.T) { defer ShouldNotLeaked(SetupLeakDetection()) setup := func(size int, ttl time.Duration, minSize int) *pool { - return newPool(size, dead, ttl, minSize, func() wire { + return newPool(size, dead, ttl, minSize, func(_ context.Context) wire { closed := false return &mockWire{ CloseFn: func() { @@ -267,7 +268,7 @@ func TestPoolWithIdleTTL(t *testing.T) { for i := 0; i < 2; i++ { for i := range conns { - w := p.Acquire() + w := p.Acquire(context.Background()) conns[i] = w } @@ -301,7 +302,7 @@ func TestPoolWithIdleTTL(t *testing.T) { for i := 0; i < 2; i++ { for i := range conns { - w := p.Acquire() + w := p.Acquire(context.Background()) conns[i] = w } diff --git a/rueidis.go b/rueidis.go index 5c22a07a..fb148538 100644 --- a/rueidis.go +++ b/rueidis.go @@ -72,6 +72,9 @@ type ClientOption struct { // DialFn allows for a custom function to be used to create net.Conn connections DialFn func(string, *net.Dialer, *tls.Config) (conn net.Conn, err error) + // DialCtxFn allows for a custom function to be used to create net.Conn connections + DialCtxFn func(context.Context, string, *net.Dialer, *tls.Config) (conn net.Conn, err error) + // NewCacheStoreFn allows a custom client side caching store for each connection NewCacheStoreFn NewCacheStoreFn @@ -464,14 +467,18 @@ func makeConn(dst string, opt *ClientOption) conn { return makeMux(dst, opt, dial) } -func dial(dst string, opt *ClientOption) (conn net.Conn, err error) { +func dial(ctx context.Context, dst string, opt *ClientOption) (conn net.Conn, err error) { + if opt.DialCtxFn != nil { + return opt.DialCtxFn(ctx, dst, &opt.Dialer, opt.TLSConfig) + } if opt.DialFn != nil { return opt.DialFn(dst, &opt.Dialer, opt.TLSConfig) } if opt.TLSConfig != nil { - conn, err = tls.DialWithDialer(&opt.Dialer, "tcp", dst, opt.TLSConfig) + dialer := tls.Dialer{NetDialer: &opt.Dialer, Config: opt.TLSConfig} + conn, err = dialer.DialContext(ctx, "tcp", dst) } else { - conn, err = opt.Dialer.Dial("tcp", dst) + conn, err = opt.Dialer.DialContext(ctx, "tcp", dst) } return conn, err } diff --git a/sentinel.go b/sentinel.go index 93935841..fac5a028 100644 --- a/sentinel.go +++ b/sentinel.go @@ -198,7 +198,7 @@ func (c *sentinelClient) DoMultiStream(ctx context.Context, multi ...Completed) func (c *sentinelClient) Dedicated(fn func(DedicatedClient) error) (err error) { master := c.mConn.Load().(conn) - wire := master.Acquire() + wire := master.Acquire(context.Background()) dsc := &dedicatedSingleClient{cmd: c.cmd, conn: master, wire: wire, retry: c.retry, retryHandler: c.retryHandler} err = fn(dsc) dsc.release() @@ -207,7 +207,7 @@ func (c *sentinelClient) Dedicated(fn func(DedicatedClient) error) (err error) { func (c *sentinelClient) Dedicate() (DedicatedClient, func()) { master := c.mConn.Load().(conn) - wire := master.Acquire() + wire := master.Acquire(context.Background()) dsc := &dedicatedSingleClient{cmd: c.cmd, conn: master, wire: wire, retry: c.retry, retryHandler: c.retryHandler} return dsc, dsc.release } diff --git a/sentinel_test.go b/sentinel_test.go index 9555e220..8b4d2810 100644 --- a/sentinel_test.go +++ b/sentinel_test.go @@ -927,7 +927,7 @@ func TestSentinelClientDelegate(t *testing.T) { return ErrClosing }, } - m.AcquireFn = func() wire { + m.AcquireFn = func(_ context.Context) wire { return w } stored := false @@ -976,7 +976,7 @@ func TestSentinelClientDelegate(t *testing.T) { return ErrClosing }, } - m.AcquireFn = func() wire { + m.AcquireFn = func(_ context.Context) wire { return w } stored := false @@ -1709,7 +1709,7 @@ func TestSentinelClientLoadingRetry(t *testing.T) { } return newResult(RedisMessage{typ: '+', string: "OK"}, nil) } - m1.AcquireFn = func() wire { return &mockWire{DoFn: m1.DoFn} } + m1.AcquireFn = func(_ context.Context) wire { return &mockWire{DoFn: m1.DoFn} } err := client.Dedicated(func(c DedicatedClient) error { if v, err := c.Do(context.Background(), c.B().Get().Key("test").Build()).ToString(); err != nil || v != "OK" { @@ -1732,7 +1732,7 @@ func TestSentinelClientLoadingRetry(t *testing.T) { } return &redisresults{s: []RedisResult{newResult(RedisMessage{typ: '+', string: "OK"}, nil)}} } - m1.AcquireFn = func() wire { return &mockWire{DoMultiFn: m1.DoMultiFn} } + m1.AcquireFn = func(_ context.Context) wire { return &mockWire{DoMultiFn: m1.DoMultiFn} } err := client.Dedicated(func(c DedicatedClient) error { resps := c.DoMulti(context.Background(), c.B().Get().Key("test").Build()) From 13b81bb32bf2028f56c60585bcafbb91a2952caa Mon Sep 17 00:00:00 2001 From: zhaohuiliu Date: Fri, 14 Mar 2025 09:45:10 +0800 Subject: [PATCH 2/5] add context to the dial function and respect the context deadline --- client.go | 9 +-- client_test.go | 26 ++++---- cluster.go | 29 +++++---- cluster_test.go | 160 +++++++++++++++++++++++------------------------ helper_test.go | 38 +++++------ mux.go | 80 ++++++++++++------------ mux_test.go | 99 +++++++++++++++++------------ pipe.go | 18 +++--- pipe_test.go | 48 +++++++------- pool.go | 9 +-- pool_test.go | 49 ++++++++------- rueidis.go | 16 +++-- sentinel.go | 22 ++++--- sentinel_test.go | 32 +++++----- 14 files changed, 334 insertions(+), 301 deletions(-) diff --git a/client.go b/client.go index ad69c360..e79a007f 100644 --- a/client.go +++ b/client.go @@ -27,9 +27,10 @@ func newSingleClient(opt *ClientOption, prev conn, connFn connFn, retryer retryH return nil, ErrReplicaOnlyNotSupported } - conn := connFn(opt.InitAddress[0], opt) + ctx := context.Background() + conn := connFn(ctx, opt.InitAddress[0], opt) conn.Override(prev) - if err := conn.Dial(); err != nil { + if err := conn.Dial(ctx); err != nil { return nil, err } return newSingleClientWithConn(conn, cmds.NewBuilder(cmds.NoSlot), !opt.DisableRetry, opt.DisableCache, retryer), nil @@ -172,7 +173,7 @@ retry: } func (c *singleClient) Dedicated(fn func(DedicatedClient) error) (err error) { - wire := c.conn.Acquire() + wire := c.conn.Acquire(context.Background()) dsc := &dedicatedSingleClient{cmd: c.cmd, conn: c.conn, wire: wire, retry: c.retry, retryHandler: c.retryHandler} err = fn(dsc) dsc.release() @@ -180,7 +181,7 @@ func (c *singleClient) Dedicated(fn func(DedicatedClient) error) (err error) { } func (c *singleClient) Dedicate() (DedicatedClient, func()) { - wire := c.conn.Acquire() + wire := c.conn.Acquire(context.Background()) dsc := &dedicatedSingleClient{cmd: c.cmd, conn: c.conn, wire: wire, retry: c.retry, retryHandler: c.retryHandler} return dsc, dsc.release } diff --git a/client_test.go b/client_test.go index 23a34aba..4d5adb89 100644 --- a/client_test.go +++ b/client_test.go @@ -42,14 +42,14 @@ func (m *mockConn) Override(c conn) { } } -func (m *mockConn) Dial() error { +func (m *mockConn) Dial(ctx context.Context) error { if m.DialFn != nil { return m.DialFn() } return nil } -func (m *mockConn) Acquire() wire { +func (m *mockConn) Acquire(ctx context.Context) wire { if m.AcquireFn != nil { return m.AcquireFn() } @@ -150,28 +150,28 @@ func (m *mockConn) SetOnCloseHook(func(error)) { } -func (m *mockConn) Info() map[string]RedisMessage { +func (m *mockConn) Info(ctx context.Context) map[string]RedisMessage { if m.InfoFn != nil { return m.InfoFn() } return nil } -func (m *mockConn) Version() int { +func (m *mockConn) Version(ctx context.Context) int { if m.VersionFn != nil { return m.VersionFn() } return 0 } -func (m *mockConn) AZ() string { +func (m *mockConn) AZ(ctx context.Context) string { if m.AZFn != nil { return m.AZFn() } return "" } -func (m *mockConn) Error() error { +func (m *mockConn) Error(ctx context.Context) error { if m.ErrorFn != nil { return m.ErrorFn() } @@ -194,7 +194,7 @@ func (m *mockConn) Addr() string { func TestNewSingleClientNoNode(t *testing.T) { defer ShouldNotLeaked(SetupLeakDetection()) if _, err := newSingleClient( - &ClientOption{}, nil, func(dst string, opt *ClientOption) conn { + &ClientOption{}, nil, func(ctx context.Context, dst string, opt *ClientOption) conn { return nil }, newRetryer(defaultRetryDelayFn), ); err != ErrNoAddr { @@ -205,7 +205,7 @@ func TestNewSingleClientNoNode(t *testing.T) { func TestNewSingleClientReplicaOnlyNotSupported(t *testing.T) { defer ShouldNotLeaked(SetupLeakDetection()) if _, err := newSingleClient( - &ClientOption{ReplicaOnly: true, InitAddress: []string{"localhost"}}, nil, func(dst string, opt *ClientOption) conn { return nil }, newRetryer(defaultRetryDelayFn), + &ClientOption{ReplicaOnly: true, InitAddress: []string{"localhost"}}, nil, func(ctx context.Context, dst string, opt *ClientOption) conn { return nil }, newRetryer(defaultRetryDelayFn), ); err != ErrReplicaOnlyNotSupported { t.Fatalf("unexpected err %v", err) } @@ -215,7 +215,7 @@ func TestNewSingleClientError(t *testing.T) { defer ShouldNotLeaked(SetupLeakDetection()) v := errors.New("dail err") if _, err := newSingleClient( - &ClientOption{InitAddress: []string{""}}, nil, func(dst string, opt *ClientOption) conn { return &mockConn{DialFn: func() error { return v }} }, newRetryer(defaultRetryDelayFn), + &ClientOption{InitAddress: []string{""}}, nil, func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{DialFn: func() error { return v }} }, newRetryer(defaultRetryDelayFn), ); err != v { t.Fatalf("unexpected err %v", err) } @@ -228,7 +228,7 @@ func TestNewSingleClientOverride(t *testing.T) { if _, err := newSingleClient( &ClientOption{InitAddress: []string{""}}, m1, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{OverrideFn: func(c conn) { m2 = c }} }, newRetryer(defaultRetryDelayFn), @@ -249,7 +249,7 @@ func TestSingleClient(t *testing.T) { client, err := newSingleClient( &ClientOption{InitAddress: []string{""}}, m, - func(dst string, opt *ClientOption) conn { return m }, + func(ctx context.Context, dst string, opt *ClientOption) conn { return m }, newRetryer(defaultRetryDelayFn), ) if err != nil { @@ -621,7 +621,7 @@ func TestSingleClientRetry(t *testing.T) { c, err := newSingleClient( &ClientOption{InitAddress: []string{""}}, m, - func(dst string, opt *ClientOption) conn { return m }, + func(ctx context.Context, dst string, opt *ClientOption) conn { return m }, newRetryer(defaultRetryDelayFn), ) if err != nil { @@ -1261,7 +1261,7 @@ func TestSingleClientLoadingRetry(t *testing.T) { client, err := newSingleClient( &ClientOption{InitAddress: []string{""}}, m, - func(dst string, opt *ClientOption) conn { return m }, + func(ctx context.Context, dst string, opt *ClientOption) conn { return m }, newRetryer(defaultRetryDelayFn), ) if err != nil { diff --git a/cluster.go b/cluster.go index 2ea2db08..4a01e2b5 100644 --- a/cluster.go +++ b/cluster.go @@ -79,8 +79,8 @@ func newClusterClient(opt *ClientOption, connFn connFn, retryer retryHandler) (* client.rOpt = &rOpt } - client.connFn = func(dst string, opt *ClientOption) conn { - cc := connFn(dst, opt) + client.connFn = func(ctx context.Context, dst string, opt *ClientOption) conn { + cc := connFn(ctx, dst, opt) cc.SetOnCloseHook(func(err error) { client.lazyRefresh() }) @@ -108,11 +108,12 @@ func (c *clusterClient) init() error { if len(c.opt.InitAddress) == 0 { return ErrNoAddr } + ctx := context.Background() results := make(chan error, len(c.opt.InitAddress)) for _, addr := range c.opt.InitAddress { - cc := c.connFn(addr, c.opt) + cc := c.connFn(ctx, addr, c.opt) go func(addr string, cc conn) { - if err := cc.Dial(); err == nil { + if err := cc.Dial(ctx); err == nil { c.mu.Lock() if _, ok := c.conns[addr]; ok { go cc.Close() // abort the new connection instead of closing the old one which may already been used @@ -169,7 +170,7 @@ func getClusterSlots(c conn, timeout time.Duration) clusterslots { } else { ctx = context.Background() } - v := c.Version() + v := c.Version(ctx) if v < 8 { return clusterslots{reply: c.Do(ctx, cmds.SlotCmd), addr: c.Addr(), ver: v} } @@ -185,6 +186,7 @@ func (c *clusterClient) _refresh() (err error) { } c.mu.RUnlock() + ctx := context.Background() var result clusterslots for i := 0; i < cap(results); i++ { if i&3 == 0 { // batch CLUSTER SLOTS/CLUSTER SHARDS for every 4 connections @@ -208,14 +210,14 @@ func (c *clusterClient) _refresh() (err error) { groups := result.parse(c.opt.TLSConfig != nil) conns := make(map[string]connrole, len(groups)) for master, g := range groups { - conns[master] = connrole{conn: c.connFn(master, c.opt)} + conns[master] = connrole{conn: c.connFn(ctx, master, c.opt)} if c.rOpt != nil { for _, nodeInfo := range g.nodes[1:] { - conns[nodeInfo.Addr] = connrole{conn: c.connFn(nodeInfo.Addr, c.rOpt)} + conns[nodeInfo.Addr] = connrole{conn: c.connFn(ctx, nodeInfo.Addr, c.rOpt)} } } else { for _, nodeInfo := range g.nodes[1:] { - conns[nodeInfo.Addr] = connrole{conn: c.connFn(nodeInfo.Addr, c.opt)} + conns[nodeInfo.Addr] = connrole{conn: c.connFn(ctx, nodeInfo.Addr, c.opt)} } } } @@ -223,7 +225,7 @@ func (c *clusterClient) _refresh() (err error) { for _, addr := range c.opt.InitAddress { if _, ok := conns[addr]; !ok { conns[addr] = connrole{ - conn: c.connFn(addr, c.opt), + conn: c.connFn(ctx, addr, c.opt), hidden: true, } } @@ -266,7 +268,7 @@ func (c *clusterClient) _refresh() (err error) { for j := i; j <= i+4 && j <= n; j++ { wg.Add(1) go func(wg *sync.WaitGroup, conn conn, info *ReplicaInfo) { - info.AZ = conn.AZ() + info.AZ = conn.AZ(ctx) wg.Done() }(&wg, conns[g.nodes[j].Addr].conn, &g.nodes[j]) } @@ -466,8 +468,9 @@ func (c *clusterClient) redirectOrNew(addr string, prev conn, slot uint16, mode return cc.conn } c.mu.Lock() + ctx := context.Background() if cc = c.conns[addr]; cc.conn == nil { - p := c.connFn(addr, c.opt) + p := c.connFn(ctx, addr, c.opt) cc = connrole{conn: p} c.conns[addr] = cc if mode == RedirectMove { @@ -481,7 +484,7 @@ func (c *clusterClient) redirectOrNew(addr string, prev conn, slot uint16, mode time.Sleep(time.Second * 5) prev.Close() }(prev) - p := c.connFn(addr, c.opt) + p := c.connFn(ctx, addr, c.opt) cc = connrole{conn: p} c.conns[addr] = cc if mode == RedirectMove { // MOVED should always point to the primary. @@ -1274,7 +1277,7 @@ func (c *dedicatedClusterClient) acquire(ctx context.Context, slot uint16) (wire } return nil, err } - c.wire = c.conn.Acquire() + c.wire = c.conn.Acquire(ctx) if p := c.pshks; p != nil { c.pshks = nil ch := c.wire.SetPubSubHooks(p.hooks) diff --git a/cluster_test.go b/cluster_test.go index dae2c788..88d072cb 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -701,7 +701,7 @@ func TestClusterClientInit(t *testing.T) { t.Run("Init no nodes", func(t *testing.T) { if _, err := newClusterClient( &ClientOption{InitAddress: []string{}}, - func(dst string, opt *ClientOption) conn { return nil }, + func(ctx context.Context, dst string, opt *ClientOption) conn { return nil }, newRetryer(defaultRetryDelayFn), ); err != ErrNoAddr { t.Fatalf("unexpected err %v", err) @@ -712,7 +712,7 @@ func TestClusterClientInit(t *testing.T) { v := errors.New("dial err") if _, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{DialFn: func() error { return v }} }, newRetryer(defaultRetryDelayFn), @@ -725,7 +725,7 @@ func TestClusterClientInit(t *testing.T) { v := errors.New("refresh err") if _, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{DoFn: func(cmd Completed) RedisResult { return newErrResult(v) }} }, newRetryer(defaultRetryDelayFn), @@ -738,7 +738,7 @@ func TestClusterClientInit(t *testing.T) { var first int64 if _, err := newClusterClient( &ClientOption{InitAddress: []string{"127.0.0.1:0", "127.0.1.1:1"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{ DoFn: func(cmd Completed) RedisResult { if atomic.AddInt64(&first, 1) == 1 { @@ -758,7 +758,7 @@ func TestClusterClientInit(t *testing.T) { var first int64 if _, err := newClusterClient( &ClientOption{InitAddress: []string{"127.0.0.1:0", "127.0.1.1:1"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{ DoFn: func(cmd Completed) RedisResult { if atomic.AddInt64(&first, 1) == 1 { @@ -778,7 +778,7 @@ func TestClusterClientInit(t *testing.T) { t.Run("Refresh no slots cluster", func(t *testing.T) { if _, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{ DoFn: func(cmd Completed) RedisResult { return newResult(RedisMessage{typ: '*', values: []RedisMessage{}}, nil) @@ -794,7 +794,7 @@ func TestClusterClientInit(t *testing.T) { t.Run("Refresh no shards cluster", func(t *testing.T) { if _, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{ DoFn: func(cmd Completed) RedisResult { return newResult(RedisMessage{typ: '*', values: []RedisMessage{}}, nil) @@ -812,7 +812,7 @@ func TestClusterClientInit(t *testing.T) { getClient := func(version int) (client *clusterClient, err error) { return newClusterClient( &ClientOption{InitAddress: []string{"127.0.4.1:4"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{ DoFn: func(cmd Completed) RedisResult { if strings.Join(cmd.Commands(), " ") == "CLUSTER SLOTS" { @@ -886,7 +886,7 @@ func TestClusterClientInit(t *testing.T) { var first int64 client, err := newClusterClient( &ClientOption{InitAddress: []string{"127.0.1.1:1", "127.0.2.1:2"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{ DoFn: func(cmd Completed) RedisResult { if atomic.LoadInt64(&first) == 1 { @@ -908,7 +908,7 @@ func TestClusterClientInit(t *testing.T) { var first int64 client, err := newClusterClient( &ClientOption{InitAddress: []string{"127.0.1.1:1", "127.0.2.1:2"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{ DoFn: func(cmd Completed) RedisResult { if atomic.LoadInt64(&first) == 1 { @@ -974,7 +974,7 @@ func TestClusterClientInit(t *testing.T) { var first int64 client, err := newClusterClient( &ClientOption{InitAddress: []string{"127.0.1.1:1", "127.0.2.1:2", "redis.example.com"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{ DoFn: func(cmd Completed) RedisResult { if atomic.LoadInt64(&first) == 1 { @@ -996,7 +996,7 @@ func TestClusterClientInit(t *testing.T) { var first int64 client, err := newClusterClient( &ClientOption{InitAddress: []string{"127.0.1.1:1", "127.0.2.1:2", "redis.example.com"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{ DoFn: func(cmd Completed) RedisResult { if atomic.LoadInt64(&first) == 1 { @@ -1019,7 +1019,7 @@ func TestClusterClientInit(t *testing.T) { t.Run("Shards tls", func(t *testing.T) { client, err := newClusterClient( &ClientOption{InitAddress: []string{"127.0.0.1:0"}, TLSConfig: &tls.Config{}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{ DoFn: func(cmd Completed) RedisResult { return shardsRespTls @@ -1060,7 +1060,7 @@ func TestClusterClientInit(t *testing.T) { return true }, }, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { copiedM := *m return &copiedM }, @@ -1122,7 +1122,7 @@ func TestClusterClientInit(t *testing.T) { return true }, }, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { if dst == "127.0.0.1:0" || dst == "127.0.2.1:0" { if opt.ReplicaOnly { t.Fatalf("unexpected replicaOnly option in primary node") @@ -1175,7 +1175,7 @@ func TestClusterClientInit(t *testing.T) { ShardsRefreshInterval: -1 * time.Millisecond, }, }, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{ DoFn: func(cmd Completed) RedisResult { return singleSlotResp @@ -1209,7 +1209,7 @@ func TestClusterClientInit(t *testing.T) { return 0 }, }, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { copiedM := *m return &copiedM }, @@ -1288,7 +1288,7 @@ func TestClusterClientInit(t *testing.T) { return 1 }, }, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { switch { case dst == "127.0.0.2:1" || dst == "127.0.1.2:1": return replicaNodeConn1 @@ -1375,7 +1375,7 @@ func TestClusterClientInit(t *testing.T) { return -1 }, }, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { switch { case dst == "127.0.0.2:1" || dst == "127.0.1.2:1": return replicaNodeConn1 @@ -1480,7 +1480,7 @@ func TestClusterClientInit(t *testing.T) { }, EnableReplicaAZInfo: true, }, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { switch { case dst == "127.0.0.2:1" || dst == "127.0.1.2:1": return replicaNodeConn1 @@ -1572,7 +1572,7 @@ func TestClusterClient(t *testing.T) { client, err := newClusterClient( &ClientOption{InitAddress: []string{"127.0.0.1:0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return m }, newRetryer(defaultRetryDelayFn), @@ -2248,7 +2248,7 @@ func TestClusterClient_SendToOnlyPrimaryNodes(t *testing.T) { return false }, }, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { if dst == "127.0.0.1:0" || dst == "127.0.2.1:0" { // primary nodes return primaryNodeConn } else { // replica nodes @@ -2747,7 +2747,7 @@ func TestClusterClient_SendToOnlyReplicaNodes(t *testing.T) { return true }, }, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { if dst == "127.0.0.1:0" || dst == "127.0.2.1:0" { // primary nodes return primaryNodeConn } else { // replica nodes @@ -3288,7 +3288,7 @@ func TestClusterClient_SendReadOperationToReplicaNodesWriteOperationToPrimaryNod return cmd.IsReadOnly() }, }, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { if dst == "127.0.0.1:0" || dst == "127.0.2.1:0" { // primary nodes return primaryNodeConn } else { // replica nodes @@ -3812,7 +3812,7 @@ func TestClusterClient_SendPrimaryNodeOnlyButOneSlotAssigned(t *testing.T) { return false }, }, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return primaryNodeConn }, newRetryer(defaultRetryDelayFn), @@ -3880,7 +3880,7 @@ func TestClusterClientErr(t *testing.T) { } client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { return m }, + func(ctx context.Context, dst string, opt *ClientOption) conn { return m }, newRetryer(defaultRetryDelayFn), ) if err != nil { @@ -3928,7 +3928,7 @@ func TestClusterClientErr(t *testing.T) { } client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { return m }, + func(ctx context.Context, dst string, opt *ClientOption) conn { return m }, newRetryer(defaultRetryDelayFn), ) if err != nil { @@ -3973,7 +3973,7 @@ func TestClusterClientErr(t *testing.T) { }} client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { return m }, + func(ctx context.Context, dst string, opt *ClientOption) conn { return m }, newRetryer(defaultRetryDelayFn), ) if err != nil { @@ -4006,7 +4006,7 @@ func TestClusterClientErr(t *testing.T) { }} client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { return m }, + func(ctx context.Context, dst string, opt *ClientOption) conn { return m }, newRetryer(defaultRetryDelayFn), ) if err != nil { @@ -4030,7 +4030,7 @@ func TestClusterClientErr(t *testing.T) { }} client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { return m }, + func(ctx context.Context, dst string, opt *ClientOption) conn { return m }, newRetryer(defaultRetryDelayFn), ) if err != nil { @@ -4057,7 +4057,7 @@ func TestClusterClientErr(t *testing.T) { var count, check int64 client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { atomic.AddInt64(&check, 1) return &mockConn{DoFn: func(cmd Completed) RedisResult { if strings.Join(cmd.Commands(), " ") == "CLUSTER SLOTS" { @@ -4086,7 +4086,7 @@ func TestClusterClientErr(t *testing.T) { var count int64 client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{DoFn: func(cmd Completed) RedisResult { if strings.Join(cmd.Commands(), " ") == "CLUSTER SLOTS" { return slotsMultiResp @@ -4111,7 +4111,7 @@ func TestClusterClientErr(t *testing.T) { var count int64 client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{ DoFn: func(cmd Completed) RedisResult { if strings.Join(cmd.Commands(), " ") == "CLUSTER SLOTS" { @@ -4148,7 +4148,7 @@ func TestClusterClientErr(t *testing.T) { var count int64 client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{DoFn: func(cmd Completed) RedisResult { return slotsMultiResp }, DoMultiFn: func(multi ...Completed) *redisresults { @@ -4179,7 +4179,7 @@ func TestClusterClientErr(t *testing.T) { var count int64 client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{ DoFn: func(cmd Completed) RedisResult { return slotsMultiResp @@ -4223,7 +4223,7 @@ func TestClusterClientErr(t *testing.T) { var count int64 client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{DoFn: func(cmd Completed) RedisResult { return slotsMultiResp }, DoMultiFn: func(multi ...Completed) *redisresults { @@ -4321,7 +4321,7 @@ func TestClusterClientErr(t *testing.T) { var count int64 client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{DoFn: func(cmd Completed) RedisResult { return slotsMultiResp }, DoMultiFn: func(multi ...Completed) *redisresults { @@ -4421,7 +4421,7 @@ func TestClusterClientErr(t *testing.T) { var count int64 client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{DoFn: func(cmd Completed) RedisResult { return slotsMultiResp }, DoMultiFn: func(multi ...Completed) *redisresults { @@ -4514,7 +4514,7 @@ func TestClusterClientErr(t *testing.T) { var count int64 client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{DoFn: func(cmd Completed) RedisResult { return slotsMultiResp }, DoMultiFn: func(multi ...Completed) *redisresults { @@ -4610,7 +4610,7 @@ func TestClusterClientErr(t *testing.T) { var count int64 client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{DoFn: func(cmd Completed) RedisResult { return slotsMultiResp }, DoMultiFn: func(multi ...Completed) *redisresults { @@ -4709,7 +4709,7 @@ func TestClusterClientErr(t *testing.T) { var count int64 client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{DoFn: func(cmd Completed) RedisResult { return slotsMultiResp }, DoMultiFn: func(multi ...Completed) *redisresults { @@ -4811,7 +4811,7 @@ func TestClusterClientErr(t *testing.T) { var count int64 client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{DoFn: func(cmd Completed) RedisResult { return slotsMultiResp }, DoMultiFn: func(multi ...Completed) *redisresults { @@ -4876,7 +4876,7 @@ func TestClusterClientErr(t *testing.T) { var count int64 client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{DoFn: func(cmd Completed) RedisResult { return slotsMultiResp }, DoMultiFn: func(multi ...Completed) *redisresults { @@ -4941,7 +4941,7 @@ func TestClusterClientErr(t *testing.T) { var count int64 client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{DoFn: func(cmd Completed) RedisResult { return slotsMultiResp }, DoMultiFn: func(multi ...Completed) *redisresults { @@ -4992,7 +4992,7 @@ func TestClusterClientErr(t *testing.T) { var count int64 client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{DoFn: func(cmd Completed) RedisResult { return slotsMultiResp }, DoMultiFn: func(multi ...Completed) *redisresults { @@ -5029,7 +5029,7 @@ func TestClusterClientErr(t *testing.T) { var count int64 client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{DoFn: func(cmd Completed) RedisResult { return slotsMultiResp }, DoMultiFn: func(multi ...Completed) *redisresults { @@ -5064,7 +5064,7 @@ func TestClusterClientErr(t *testing.T) { var count, check int64 client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { if dst == ":2" { atomic.AddInt64(&check, 1) } @@ -5095,7 +5095,7 @@ func TestClusterClientErr(t *testing.T) { var count, check int64 client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { if dst == ":2" { atomic.AddInt64(&check, 1) } @@ -5132,7 +5132,7 @@ func TestClusterClientErr(t *testing.T) { var count, check int64 client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { if dst == ":2" { atomic.AddInt64(&check, 1) } @@ -5175,7 +5175,7 @@ func TestClusterClientErr(t *testing.T) { var count int64 client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{DoFn: func(cmd Completed) RedisResult { return slotsMultiResp }, DoMultiFn: func(multi ...Completed) *redisresults { @@ -5210,7 +5210,7 @@ func TestClusterClientErr(t *testing.T) { var count int64 client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{ DoFn: func(cmd Completed) RedisResult { return slotsMultiResp @@ -5237,7 +5237,7 @@ func TestClusterClientErr(t *testing.T) { var count int64 client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{ DoFn: func(cmd Completed) RedisResult { return slotsMultiResp @@ -5273,7 +5273,7 @@ func TestClusterClientErr(t *testing.T) { var count int64 client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{DoFn: func(cmd Completed) RedisResult { return slotsMultiResp }, DoMultiCacheFn: func(multi ...CacheableTTL) *redisresults { @@ -5304,7 +5304,7 @@ func TestClusterClientErr(t *testing.T) { var count int64 client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{ DoFn: func(cmd Completed) RedisResult { return slotsMultiResp @@ -5348,7 +5348,7 @@ func TestClusterClientErr(t *testing.T) { var count int64 client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{DoFn: func(cmd Completed) RedisResult { return slotsMultiResp }, DoMultiCacheFn: func(multi ...CacheableTTL) *redisresults { @@ -5385,7 +5385,7 @@ func TestClusterClientErr(t *testing.T) { var count int64 client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{ DoFn: func(cmd Completed) RedisResult { if strings.Join(cmd.Commands(), " ") == "CLUSTER SLOTS" { @@ -5415,7 +5415,7 @@ func TestClusterClientErr(t *testing.T) { var count int64 client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{ DoFn: func(cmd Completed) RedisResult { return slotsMultiResp @@ -5450,7 +5450,7 @@ func TestClusterClientErr(t *testing.T) { var count int64 client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{ DoFn: func(cmd Completed) RedisResult { return slotsMultiResp @@ -5491,7 +5491,7 @@ func TestClusterClientErr(t *testing.T) { var count int64 client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{ DoFn: func(cmd Completed) RedisResult { return slotsMultiResp @@ -5521,7 +5521,7 @@ func TestClusterClientErr(t *testing.T) { var count int64 client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{ DoFn: func(cmd Completed) RedisResult { return slotsMultiResp @@ -5551,7 +5551,7 @@ func TestClusterClientErr(t *testing.T) { var count int64 client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{ DoFn: func(cmd Completed) RedisResult { return slotsMultiResp @@ -5593,7 +5593,7 @@ func TestClusterClientErr(t *testing.T) { var count int64 client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{DoFn: func(cmd Completed) RedisResult { if strings.Join(cmd.Commands(), " ") == "CLUSTER SLOTS" { return slotsMultiResp @@ -5618,7 +5618,7 @@ func TestClusterClientErr(t *testing.T) { var count int64 client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{DoFn: func(cmd Completed) RedisResult { return slotsMultiResp }, DoMultiFn: func(multi ...Completed) *redisresults { @@ -5644,7 +5644,7 @@ func TestClusterClientErr(t *testing.T) { var count int64 client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{DoFn: func(cmd Completed) RedisResult { return slotsMultiResp }, DoMultiFn: func(multi ...Completed) *redisresults { @@ -5676,7 +5676,7 @@ func TestClusterClientErr(t *testing.T) { var count int64 client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{ DoFn: func(cmd Completed) RedisResult { return slotsMultiResp @@ -5703,7 +5703,7 @@ func TestClusterClientErr(t *testing.T) { var count int64 client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{DoFn: func(cmd Completed) RedisResult { return slotsMultiResp }, DoMultiCacheFn: func(multi ...CacheableTTL) *redisresults { @@ -5727,7 +5727,7 @@ func TestClusterClientErr(t *testing.T) { var count int64 client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{DoFn: func(cmd Completed) RedisResult { if strings.Join(cmd.Commands(), " ") == "CLUSTER SLOTS" { return slotsMultiResp @@ -5765,7 +5765,7 @@ func TestClusterClientRetry(t *testing.T) { } c, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { return m }, + func(ctx context.Context, dst string, opt *ClientOption) conn { return m }, newRetryer(defaultRetryDelayFn), ) if err != nil { @@ -5788,7 +5788,7 @@ func TestClusterClientReplicaOnly_PickReplica(t *testing.T) { client, err := newClusterClient( &ClientOption{InitAddress: []string{"127.0.0.1:0"}, ReplicaOnly: true}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { copiedM := *m return &copiedM }, @@ -5827,7 +5827,7 @@ func TestClusterClientReplicaOnly_PickMasterIfNoReplica(t *testing.T) { client, err := newClusterClient( &ClientOption{InitAddress: []string{"127.0.0.1:0"}, ReplicaOnly: true}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { copiedM := *m return &copiedM }, @@ -5863,7 +5863,7 @@ func TestClusterClientReplicaOnly_PickMasterIfNoReplica(t *testing.T) { client, err := newClusterClient( &ClientOption{InitAddress: []string{"127.0.0.1:0"}, ReplicaOnly: true}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { copiedM := *m return &copiedM }, @@ -5997,7 +5997,7 @@ func TestClusterTopologyRefreshment(t *testing.T) { ShardsRefreshInterval: 0, }, }, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{ DoFn: func(cmd Completed) RedisResult { // initial call @@ -6034,7 +6034,7 @@ func TestClusterTopologyRefreshment(t *testing.T) { ShardsRefreshInterval: time.Second, }, }, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{ DoFn: func(cmd Completed) RedisResult { if atomic.AddInt64(&callCount, 1) >= 3 { @@ -6078,7 +6078,7 @@ func TestClusterTopologyRefreshment(t *testing.T) { ShardsRefreshInterval: time.Second, }, }, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{ DoFn: func(cmd Completed) RedisResult { if c := atomic.AddInt64(&callCount, 1); c >= 6 { @@ -6127,7 +6127,7 @@ func TestClusterTopologyRefreshment(t *testing.T) { ShardsRefreshInterval: time.Second, }, }, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{ DoFn: func(cmd Completed) RedisResult { if c := atomic.AddInt64(&callCount, 1); c >= 6 { @@ -6181,7 +6181,7 @@ func TestClusterClientLoadingRetry(t *testing.T) { } client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { return m }, + func(ctx context.Context, dst string, opt *ClientOption) conn { return m }, newRetryer(defaultRetryDelayFn), ) if err != nil { @@ -6361,7 +6361,7 @@ func TestClusterClientMovedRetry(t *testing.T) { } client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { return m }, + func(ctx context.Context, dst string, opt *ClientOption) conn { return m }, newRetryer(defaultRetryDelayFn), ) if err != nil { @@ -6429,7 +6429,7 @@ func TestClusterClientCacheASKRetry(t *testing.T) { } client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { return m }, + func(ctx context.Context, dst string, opt *ClientOption) conn { return m }, newRetryer(defaultRetryDelayFn), ) if err != nil { @@ -6593,7 +6593,7 @@ func TestClusterClient_SendReadOperationToReplicaNodeWriteOperationToPrimaryNode return 0 }, }, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { if dst == "127.0.0.1:0" || dst == "127.0.2.1:0" { // primary node return primaryNodeConn } else { // replica node @@ -7133,7 +7133,7 @@ func TestClusterClient_SendToOnlyPrimaryNodeWhenPrimaryNodeSelected(t *testing.T return -1 }, }, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { if dst == "127.0.0.1:0" || dst == "127.0.2.1:0" { // primary nodes return primaryNodeConn } else { // replica nodes diff --git a/helper_test.go b/helper_test.go index 3142284c..5778eea4 100644 --- a/helper_test.go +++ b/helper_test.go @@ -15,7 +15,7 @@ func TestMGetCache(t *testing.T) { client, err := newSingleClient( &ClientOption{InitAddress: []string{""}}, m, - func(dst string, opt *ClientOption) conn { return m }, + func(ctx context.Context, dst string, opt *ClientOption) conn { return m }, newRetryer(defaultRetryDelayFn), ) if err != nil { @@ -24,7 +24,7 @@ func TestMGetCache(t *testing.T) { disabledCacheClient, err := newSingleClient( &ClientOption{InitAddress: []string{""}, DisableCache: true}, m, - func(dst string, opt *ClientOption) conn { return m }, + func(ctx context.Context, dst string, opt *ClientOption) conn { return m }, newRetryer(defaultRetryDelayFn), ) if err != nil { @@ -81,7 +81,7 @@ func TestMGetCache(t *testing.T) { } client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { return m }, + func(ctx context.Context, dst string, opt *ClientOption) conn { return m }, newRetryer(defaultRetryDelayFn), ) if err != nil { @@ -89,7 +89,7 @@ func TestMGetCache(t *testing.T) { } disabledCacheClient, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}, DisableCache: true}, - func(dst string, opt *ClientOption) conn { return m }, + func(ctx context.Context, dst string, opt *ClientOption) conn { return m }, newRetryer(defaultRetryDelayFn), ) if err != nil { @@ -178,7 +178,7 @@ func TestMGet(t *testing.T) { client, err := newSingleClient( &ClientOption{InitAddress: []string{""}}, m, - func(dst string, opt *ClientOption) conn { return m }, + func(ctx context.Context, dst string, opt *ClientOption) conn { return m }, newRetryer(defaultRetryDelayFn), ) if err != nil { @@ -219,7 +219,7 @@ func TestMGet(t *testing.T) { } client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { return m }, + func(ctx context.Context, dst string, opt *ClientOption) conn { return m }, newRetryer(defaultRetryDelayFn), ) if err != nil { @@ -277,7 +277,7 @@ func TestMDel(t *testing.T) { client, err := newSingleClient( &ClientOption{InitAddress: []string{""}}, m, - func(dst string, opt *ClientOption) conn { return m }, + func(ctx context.Context, dst string, opt *ClientOption) conn { return m }, newRetryer(defaultRetryDelayFn), ) if err != nil { @@ -318,7 +318,7 @@ func TestMDel(t *testing.T) { } client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { return m }, + func(ctx context.Context, dst string, opt *ClientOption) conn { return m }, newRetryer(defaultRetryDelayFn), ) if err != nil { @@ -372,7 +372,7 @@ func TestMSet(t *testing.T) { client, err := newSingleClient( &ClientOption{InitAddress: []string{""}}, m, - func(dst string, opt *ClientOption) conn { return m }, + func(ctx context.Context, dst string, opt *ClientOption) conn { return m }, newRetryer(defaultRetryDelayFn), ) if err != nil { @@ -414,7 +414,7 @@ func TestMSet(t *testing.T) { } client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { return m }, + func(ctx context.Context, dst string, opt *ClientOption) conn { return m }, newRetryer(defaultRetryDelayFn), ) if err != nil { @@ -477,7 +477,7 @@ func TestMSetNX(t *testing.T) { client, err := newSingleClient( &ClientOption{InitAddress: []string{""}}, m, - func(dst string, opt *ClientOption) conn { return m }, + func(ctx context.Context, dst string, opt *ClientOption) conn { return m }, newRetryer(defaultRetryDelayFn), ) if err != nil { @@ -519,7 +519,7 @@ func TestMSetNX(t *testing.T) { } client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { return m }, + func(ctx context.Context, dst string, opt *ClientOption) conn { return m }, newRetryer(defaultRetryDelayFn), ) if err != nil { @@ -582,7 +582,7 @@ func TestMSetNXNotSet(t *testing.T) { client, err := newSingleClient( &ClientOption{InitAddress: []string{""}}, m, - func(dst string, opt *ClientOption) conn { return m }, + func(ctx context.Context, dst string, opt *ClientOption) conn { return m }, newRetryer(defaultRetryDelayFn), ) if err != nil { @@ -607,7 +607,7 @@ func TestJsonMGetCache(t *testing.T) { client, err := newSingleClient( &ClientOption{InitAddress: []string{""}}, m, - func(dst string, opt *ClientOption) conn { return m }, + func(ctx context.Context, dst string, opt *ClientOption) conn { return m }, newRetryer(defaultRetryDelayFn), ) if err != nil { @@ -653,7 +653,7 @@ func TestJsonMGetCache(t *testing.T) { } client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { return m }, + func(ctx context.Context, dst string, opt *ClientOption) conn { return m }, newRetryer(defaultRetryDelayFn), ) if err != nil { @@ -715,7 +715,7 @@ func TestJsonMGet(t *testing.T) { client, err := newSingleClient( &ClientOption{InitAddress: []string{""}}, m, - func(dst string, opt *ClientOption) conn { return m }, + func(ctx context.Context, dst string, opt *ClientOption) conn { return m }, newRetryer(defaultRetryDelayFn), ) if err != nil { @@ -756,7 +756,7 @@ func TestJsonMGet(t *testing.T) { } client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { return m }, + func(ctx context.Context, dst string, opt *ClientOption) conn { return m }, newRetryer(defaultRetryDelayFn), ) if err != nil { @@ -813,7 +813,7 @@ func TestJsonMSet(t *testing.T) { client, err := newSingleClient( &ClientOption{InitAddress: []string{""}}, m, - func(dst string, opt *ClientOption) conn { return m }, + func(ctx context.Context, dst string, opt *ClientOption) conn { return m }, newRetryer(defaultRetryDelayFn), ) if err != nil { @@ -855,7 +855,7 @@ func TestJsonMSet(t *testing.T) { } client, err := newClusterClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { return m }, + func(ctx context.Context, dst string, opt *ClientOption) conn { return m }, newRetryer(defaultRetryDelayFn), ) if err != nil { diff --git a/mux.go b/mux.go index dcc680a4..03c87566 100644 --- a/mux.go +++ b/mux.go @@ -12,9 +12,9 @@ import ( "github.com/redis/rueidis/internal/util" ) -type connFn func(dst string, opt *ClientOption) conn -type dialFn func(dst string, opt *ClientOption) (net.Conn, error) -type wireFn func() wire +type connFn func(ctx context.Context, dst string, opt *ClientOption) conn +type dialFn func(ctx context.Context, dst string, opt *ClientOption) (net.Conn, error) +type wireFn func(ctx context.Context) wire type singleconnect struct { w wire @@ -30,14 +30,14 @@ type conn interface { Receive(ctx context.Context, subscribe Completed, fn func(message PubSubMessage)) error DoStream(ctx context.Context, cmd Completed) RedisResultStream DoMultiStream(ctx context.Context, multi ...Completed) MultiRedisResultStream - Info() map[string]RedisMessage - Version() int - AZ() string - Error() error + Info(context.Context) map[string]RedisMessage + Version(context.Context) int + AZ(context.Context) string + Error(context.Context) error Close() - Dial() error + Dial(context.Context) error Override(conn) - Acquire() wire + Acquire(context.Context) wire Store(w wire) Addr() string SetOnCloseHook(func(error)) @@ -62,14 +62,14 @@ type mux struct { usePool bool } -func makeMux(dst string, option *ClientOption, dialFn dialFn) *mux { +func makeMux(ctx context.Context, dst string, option *ClientOption, dialFn dialFn) *mux { dead := deadFn() - connFn := func() (net.Conn, error) { - return dialFn(dst, option) + connFn := func(ctx context.Context) (net.Conn, error) { + return dialFn(ctx, dst, option) } - wireFn := func(pipeFn pipeFn) func() wire { - return func() (w wire) { - w, err := pipeFn(connFn, option) + wireFn := func(pipeFn pipeFn) func(context.Context) wire { + return func(ctx context.Context) (w wire) { + w, err := pipeFn(ctx, connFn, option) if err != nil { dead.error.Store(&errs{error: err}) w = dead @@ -132,7 +132,7 @@ func (m *mux) Override(cc conn) { } } -func (m *mux) _pipe(i uint16) (w wire, err error) { +func (m *mux) _pipe(ctx context.Context, i uint16) (w wire, err error) { if w = m.wire[i].Load().(wire); w != m.init { return w, nil } @@ -151,7 +151,7 @@ func (m *mux) _pipe(i uint16) (w wire, err error) { } if w = m.wire[i].Load().(wire); w == m.init { - if w = m.wireFn(); w != m.dead { + if w = m.wireFn(ctx); w != m.dead { m.setCloseHookOnWire(i, w) m.wire[i].Store(w) } else { @@ -173,39 +173,39 @@ func (m *mux) _pipe(i uint16) (w wire, err error) { return w, err } -func (m *mux) pipe(i uint16) wire { - w, _ := m._pipe(i) +func (m *mux) pipe(ctx context.Context, i uint16) wire { + w, _ := m._pipe(ctx, i) return w // this should never be nil } -func (m *mux) Dial() error { - _, err := m._pipe(0) +func (m *mux) Dial(ctx context.Context) error { + _, err := m._pipe(ctx, 0) return err } -func (m *mux) Info() map[string]RedisMessage { - return m.pipe(0).Info() +func (m *mux) Info(ctx context.Context) map[string]RedisMessage { + return m.pipe(ctx, 0).Info() } -func (m *mux) Version() int { - return m.pipe(0).Version() +func (m *mux) Version(ctx context.Context) int { + return m.pipe(ctx, 0).Version() } -func (m *mux) AZ() string { - return m.pipe(0).AZ() +func (m *mux) AZ(ctx context.Context) string { + return m.pipe(ctx, 0).AZ() } -func (m *mux) Error() error { - return m.pipe(0).Error() +func (m *mux) Error(ctx context.Context) error { + return m.pipe(ctx, 0).Error() } func (m *mux) DoStream(ctx context.Context, cmd Completed) RedisResultStream { - wire := m.spool.Acquire() + wire := m.spool.Acquire(ctx) return wire.DoStream(ctx, m.spool, cmd) } func (m *mux) DoMultiStream(ctx context.Context, multi ...Completed) MultiRedisResultStream { - wire := m.spool.Acquire() + wire := m.spool.Acquire(ctx) return wire.DoMultiStream(ctx, m.spool, multi...) } @@ -242,7 +242,7 @@ block: } func (m *mux) blocking(pool *pool, ctx context.Context, cmd Completed) (resp RedisResult) { - wire := pool.Acquire() + wire := pool.Acquire(ctx) resp = wire.Do(ctx, cmd) if resp.NonRedisError() != nil { // abort the wire if blocking command return early (ex. context.DeadlineExceeded) wire.Close() @@ -252,7 +252,7 @@ func (m *mux) blocking(pool *pool, ctx context.Context, cmd Completed) (resp Red } func (m *mux) blockingMulti(pool *pool, ctx context.Context, cmd []Completed) (resp *redisresults) { - wire := pool.Acquire() + wire := pool.Acquire(ctx) resp = wire.DoMulti(ctx, cmd...) for _, res := range resp.s { if res.NonRedisError() != nil { // abort the wire if blocking command return early (ex. context.DeadlineExceeded) @@ -266,7 +266,7 @@ func (m *mux) blockingMulti(pool *pool, ctx context.Context, cmd []Completed) (r func (m *mux) pipeline(ctx context.Context, cmd Completed) (resp RedisResult) { slot := slotfn(len(m.wire), cmd.Slot(), cmd.NoReply()) - wire := m.pipe(slot) + wire := m.pipe(ctx, slot) if resp = wire.Do(ctx, cmd); isBroken(resp.NonRedisError(), wire) { m.wire[slot].CompareAndSwap(wire, m.init) } @@ -275,7 +275,7 @@ func (m *mux) pipeline(ctx context.Context, cmd Completed) (resp RedisResult) { func (m *mux) pipelineMulti(ctx context.Context, cmd []Completed) (resp *redisresults) { slot := slotfn(len(m.wire), cmd[0].Slot(), cmd[0].NoReply()) - wire := m.pipe(slot) + wire := m.pipe(ctx, slot) resp = wire.DoMulti(ctx, cmd...) for _, r := range resp.s { if isBroken(r.NonRedisError(), wire) { @@ -288,7 +288,7 @@ func (m *mux) pipelineMulti(ctx context.Context, cmd []Completed) (resp *redisre func (m *mux) DoCache(ctx context.Context, cmd Cacheable, ttl time.Duration) RedisResult { slot := cmd.Slot() & uint16(len(m.wire)-1) - wire := m.pipe(slot) + wire := m.pipe(ctx, slot) resp := wire.DoCache(ctx, cmd, ttl) if isBroken(resp.NonRedisError(), wire) { m.wire[slot].CompareAndSwap(wire, m.init) @@ -346,7 +346,7 @@ func (m *mux) DoMultiCache(ctx context.Context, multi ...CacheableTTL) (results } func (m *mux) doMultiCache(ctx context.Context, slot uint16, multi []CacheableTTL) (resps *redisresults) { - wire := m.pipe(slot) + wire := m.pipe(ctx, slot) resps = wire.DoMultiCache(ctx, multi...) for _, r := range resps.s { if isBroken(r.NonRedisError(), wire) { @@ -359,7 +359,7 @@ func (m *mux) doMultiCache(ctx context.Context, slot uint16, multi []CacheableTT func (m *mux) Receive(ctx context.Context, subscribe Completed, fn func(message PubSubMessage)) error { slot := slotfn(len(m.wire), subscribe.Slot(), subscribe.NoReply()) - wire := m.pipe(slot) + wire := m.pipe(ctx, slot) err := wire.Receive(ctx, subscribe, fn) if isBroken(err, wire) { m.wire[slot].CompareAndSwap(wire, m.init) @@ -367,8 +367,8 @@ func (m *mux) Receive(ctx context.Context, subscribe Completed, fn func(message return err } -func (m *mux) Acquire() wire { - return m.dpool.Acquire() +func (m *mux) Acquire(ctx context.Context) wire { + return m.dpool.Acquire(ctx) } func (m *mux) Store(w wire) { diff --git a/mux_test.go b/mux_test.go index 4f801322..6d400b89 100644 --- a/mux_test.go +++ b/mux_test.go @@ -23,7 +23,7 @@ func setupMux(wires []*mockWire) (conn *mux, checkClean func(t *testing.T)) { func setupMuxWithOption(wires []*mockWire, option *ClientOption) (conn *mux, checkClean func(t *testing.T)) { var mu sync.Mutex var count = -1 - wfn := func() wire { + wfn := func(context.Context) wire { mu.Lock() defer mu.Unlock() count++ @@ -43,23 +43,24 @@ func TestNewMuxDailErr(t *testing.T) { defer ShouldNotLeaked(SetupLeakDetection()) c := 0 e := errors.New("any") - m := makeMux("", &ClientOption{}, func(dst string, opt *ClientOption) (net.Conn, error) { + ctx := context.Background() + m := makeMux(ctx, "", &ClientOption{}, func(ctx context.Context, dst string, opt *ClientOption) (net.Conn, error) { c++ return nil, e }) - if err := m.Dial(); err != e { + if err := m.Dial(ctx); err != e { t.Fatalf("unexpected return %v", err) } if c != 1 { t.Fatalf("dialFn not called") } - if w := m.pipe(0); w != m.dead { // c = 2 + if w := m.pipe(ctx, 0); w != m.dead { // c = 2 t.Fatalf("unexpected wire %v", w) } - if err := m.Dial(); err != e { // c = 3 + if err := m.Dial(ctx); err != e { // c = 3 t.Fatalf("unexpected return %v", err) } - if w := m.Acquire(); w != m.dead { + if w := m.Acquire(ctx); w != m.dead { t.Fatalf("unexpected wire %v", w) } if c != 4 { @@ -71,6 +72,7 @@ func TestNewMux(t *testing.T) { defer ShouldNotLeaked(SetupLeakDetection()) n1, n2 := net.Pipe() mock := &redisMock{t: t, buf: bufio.NewReader(n2), conn: n2} + ctx := context.Background() go func() { mock.Expect("HELLO", "3"). Reply(RedisMessage{ @@ -89,19 +91,19 @@ func TestNewMux(t *testing.T) { mock.Expect("PING").ReplyString("OK") mock.Close() }() - m := makeMux("", &ClientOption{}, func(dst string, opt *ClientOption) (net.Conn, error) { + m := makeMux(ctx, "", &ClientOption{}, func(ctx context.Context, dst string, opt *ClientOption) (net.Conn, error) { return n1, nil }) - if err := m.Dial(); err != nil { + if err := m.Dial(ctx); err != nil { t.Fatalf("unexpected error %v", err) } t.Run("Override with previous mux", func(t *testing.T) { - m2 := makeMux("", &ClientOption{}, func(dst string, opt *ClientOption) (net.Conn, error) { + m2 := makeMux(ctx, "", &ClientOption{}, func(ctx context.Context, dst string, opt *ClientOption) (net.Conn, error) { return n1, nil }) m2.Override(m) - if err := m2.Dial(); err != nil { + if err := m2.Dial(ctx); err != nil { t.Fatalf("unexpected error %v", err) } m2.Close() @@ -110,8 +112,9 @@ func TestNewMux(t *testing.T) { func TestNewMuxPipelineMultiplex(t *testing.T) { defer ShouldNotLeaked(SetupLeakDetection()) + ctx := context.Background() for _, v := range []int{-1, 0, 1, 2} { - m := makeMux("", &ClientOption{PipelineMultiplex: v}, func(dst string, opt *ClientOption) (net.Conn, error) { return nil, nil }) + m := makeMux(ctx, "", &ClientOption{PipelineMultiplex: v}, func(ctx context.Context, dst string, opt *ClientOption) (net.Conn, error) { return nil, nil }) if (v < 0 && len(m.wire) != 1) || (v >= 0 && len(m.wire) != 1< 100 { t.Fatalf("pool must not exceed the size limit") @@ -103,11 +104,11 @@ func TestPool(t *testing.T) { conn := make([]wire, 100) pool, _ := setup(len(conn)) for i := 0; i < len(conn); i++ { - w := pool.Acquire() + w := pool.Acquire(context.Background()) go pool.Store(w) } for i := 0; i < len(conn); i++ { - conn[i] = pool.Acquire() + conn[i] = pool.Acquire(context.Background()) } for i := 0; i < len(conn); i++ { for j := i + 1; j < len(conn); j++ { @@ -120,8 +121,8 @@ func TestPool(t *testing.T) { t.Run("Close", func(t *testing.T) { pool, count := setup(2) - w1 := pool.Acquire() - w2 := pool.Acquire() + w1 := pool.Acquire(context.Background()) + w2 := pool.Acquire(context.Background()) if w1.Error() != nil { t.Fatalf("unexpected err %v", w1.Error()) } @@ -137,7 +138,7 @@ func TestPool(t *testing.T) { t.Fatalf("pool does not close existing wire after Close()") } for i := 0; i < 100; i++ { - if rw := pool.Acquire(); rw != dead { + if rw := pool.Acquire(context.Background()); rw != dead { t.Fatalf("pool does not return the dead wire after Close()") } } @@ -149,12 +150,12 @@ func TestPool(t *testing.T) { t.Run("Close Empty", func(t *testing.T) { pool, count := setup(2) - w1 := pool.Acquire() + w1 := pool.Acquire(context.Background()) if w1.Error() != nil { t.Fatalf("unexpected err %v", w1.Error()) } pool.Close() - w2 := pool.Acquire() + w2 := pool.Acquire(context.Background()) if w2.Error() != ErrClosing { t.Fatalf("pool does not close wire after Close()") } @@ -162,7 +163,7 @@ func TestPool(t *testing.T) { t.Fatalf("pool should not make new wire") } for i := 0; i < 100; i++ { - if rw := pool.Acquire(); rw != dead { + if rw := pool.Acquire(context.Background()); rw != dead { t.Fatalf("pool does not return the dead wire after Close()") } } @@ -174,7 +175,7 @@ func TestPool(t *testing.T) { t.Run("Close Waiting", func(t *testing.T) { pool, count := setup(1) - w1 := pool.Acquire() + w1 := pool.Acquire(context.Background()) if w1.Error() != nil { t.Fatalf("unexpected err %v", w1.Error()) } @@ -182,7 +183,7 @@ func TestPool(t *testing.T) { for i := 0; i < 100; i++ { go func() { atomic.AddInt64(&pending, 1) - if rw := pool.Acquire(); rw != dead { + if rw := pool.Acquire(context.Background()); rw != dead { t.Errorf("pool does not return the dead wire after Close()") } atomic.AddInt64(&pending, -1) @@ -209,7 +210,7 @@ func TestPoolError(t *testing.T) { defer ShouldNotLeaked(SetupLeakDetection()) setup := func(size int) (*pool, *int32) { var count int32 - return newPool(size, dead, 0, 0, func() wire { + return newPool(size, dead, 0, 0, func(context.Context) wire { w := &pipe{} w.pshks.Store(emptypshks) c := atomic.AddInt32(&count, 1) @@ -224,7 +225,7 @@ func TestPoolError(t *testing.T) { conn := make([]wire, 100) pool, count := setup(len(conn)) for i := 0; i < len(conn); i++ { - conn[i] = pool.Acquire() + conn[i] = pool.Acquire(context.Background()) } if atomic.LoadInt32(count) != int32(len(conn)) { t.Fatalf("unexpected acquire count") @@ -233,7 +234,7 @@ func TestPoolError(t *testing.T) { pool.Store(conn[i]) } for i := 0; i < len(conn); i++ { - conn[i] = pool.Acquire() + conn[i] = pool.Acquire(context.Background()) } if atomic.LoadInt32(count) != int32(len(conn)+len(conn)/2) { t.Fatalf("unexpected acquire count") @@ -244,7 +245,7 @@ func TestPoolError(t *testing.T) { func TestPoolWithIdleTTL(t *testing.T) { defer ShouldNotLeaked(SetupLeakDetection()) setup := func(size int, ttl time.Duration, minSize int) *pool { - return newPool(size, dead, ttl, minSize, func() wire { + return newPool(size, dead, ttl, minSize, func(context.Context) wire { closed := false return &mockWire{ CloseFn: func() { @@ -267,7 +268,7 @@ func TestPoolWithIdleTTL(t *testing.T) { for i := 0; i < 2; i++ { for i := range conns { - w := p.Acquire() + w := p.Acquire(context.Background()) conns[i] = w } @@ -301,7 +302,7 @@ func TestPoolWithIdleTTL(t *testing.T) { for i := 0; i < 2; i++ { for i := range conns { - w := p.Acquire() + w := p.Acquire(context.Background()) conns[i] = w } diff --git a/rueidis.go b/rueidis.go index 5c22a07a..f2866147 100644 --- a/rueidis.go +++ b/rueidis.go @@ -71,6 +71,7 @@ type ClientOption struct { // DialFn allows for a custom function to be used to create net.Conn connections DialFn func(string, *net.Dialer, *tls.Config) (conn net.Conn, err error) + DialFnCtx func(context.Context, string, *net.Dialer, *tls.Config) (conn net.Conn, err error) // NewCacheStoreFn allows a custom client side caching store for each connection NewCacheStoreFn NewCacheStoreFn @@ -460,18 +461,21 @@ func singleClientMultiplex(multiplex int) int { return multiplex } -func makeConn(dst string, opt *ClientOption) conn { - return makeMux(dst, opt, dial) +func makeConn(ctx context.Context, dst string, opt *ClientOption) conn { + return makeMux(ctx, dst, opt, dial) } -func dial(dst string, opt *ClientOption) (conn net.Conn, err error) { - if opt.DialFn != nil { +func dial(ctx context.Context, dst string, opt *ClientOption) (conn net.Conn, err error) { + if opt.DialFnCtx != nil { + return opt.DialFnCtx(ctx, dst, &opt.Dialer, opt.TLSConfig) + } else if opt.DialFn != nil { // maintain it for compatability reason return opt.DialFn(dst, &opt.Dialer, opt.TLSConfig) } if opt.TLSConfig != nil { - conn, err = tls.DialWithDialer(&opt.Dialer, "tcp", dst, opt.TLSConfig) + tlsDialer := &tls.Dialer{NetDialer: &opt.Dialer, Config: opt.TLSConfig} + conn, err = tlsDialer.DialContext(ctx, "tcp", dst) } else { - conn, err = opt.Dialer.Dial("tcp", dst) + conn, err = opt.Dialer.DialContext(ctx, "tcp", dst) } return conn, err } diff --git a/sentinel.go b/sentinel.go index 93935841..ceb8c5cd 100644 --- a/sentinel.go +++ b/sentinel.go @@ -198,7 +198,7 @@ func (c *sentinelClient) DoMultiStream(ctx context.Context, multi ...Completed) func (c *sentinelClient) Dedicated(fn func(DedicatedClient) error) (err error) { master := c.mConn.Load().(conn) - wire := master.Acquire() + wire := master.Acquire(context.Background()) dsc := &dedicatedSingleClient{cmd: c.cmd, conn: master, wire: wire, retry: c.retry, retryHandler: c.retryHandler} err = fn(dsc) dsc.release() @@ -207,7 +207,7 @@ func (c *sentinelClient) Dedicated(fn func(DedicatedClient) error) (err error) { func (c *sentinelClient) Dedicate() (DedicatedClient, func()) { master := c.mConn.Load().(conn) - wire := master.Acquire() + wire := master.Acquire(context.Background()) dsc := &dedicatedSingleClient{cmd: c.cmd, conn: master, wire: wire, retry: c.retry, retryHandler: c.retryHandler} return dsc, dsc.release } @@ -273,20 +273,21 @@ func (c *sentinelClient) _switchTarget(addr string) (err error) { if atomic.LoadUint32(&c.stop) == 1 { return nil } + ctx := context.Background() if c.mAddr == addr { target = c.mConn.Load().(conn) - if target.Error() != nil { + if target.Error(ctx) != nil { target = nil } } if target == nil { - target = c.connFn(addr, c.mOpt) - if err = target.Dial(); err != nil { + target = c.connFn(ctx, addr, c.mOpt) + if err = target.Dial(ctx); err != nil { return err } } - resp, err := target.Do(context.Background(), cmds.RoleCmd).ToArray() + resp, err := target.Do(ctx, cmds.RoleCmd).ToArray() if err != nil { target.Close() return err @@ -326,6 +327,7 @@ func (c *sentinelClient) _refresh() (err error) { c.mu.Lock() head := c.sentinels.Front() + ctx := context.Background() for e := head; e != nil; { if atomic.LoadUint32(&c.stop) == 1 { c.mu.Unlock() @@ -333,13 +335,13 @@ func (c *sentinelClient) _refresh() (err error) { } addr := e.Value.(string) - if c.sAddr != addr || c.sConn == nil || c.sConn.Error() != nil { + if c.sAddr != addr || c.sConn == nil || c.sConn.Error(ctx) != nil { if c.sConn != nil { c.sConn.Close() } c.sAddr = addr - c.sConn = c.connFn(addr, c.sOpt) - err = c.sConn.Dial() + c.sConn = c.connFn(ctx, addr, c.sOpt) + err = c.sConn.Dial(ctx) } if err == nil { // listWatch returns server address with sentinels. @@ -367,7 +369,7 @@ func (c *sentinelClient) _refresh() (err error) { if master := c.mConn.Load(); master == nil { err = ErrNoAddr } else { - err = master.(conn).Error() + err = master.(conn).Error(ctx) } } return err diff --git a/sentinel_test.go b/sentinel_test.go index 9555e220..2f8e569a 100644 --- a/sentinel_test.go +++ b/sentinel_test.go @@ -18,7 +18,7 @@ func TestSentinelClientInit(t *testing.T) { t.Run("Init no nodes", func(t *testing.T) { if _, err := newSentinelClient( &ClientOption{InitAddress: []string{}}, - func(dst string, opt *ClientOption) conn { return nil }, + func(ctx context.Context, dst string, opt *ClientOption) conn { return nil }, newRetryer(defaultRetryDelayFn), ); err != ErrNoAddr { t.Fatalf("unexpected err %v", err) @@ -29,7 +29,7 @@ func TestSentinelClientInit(t *testing.T) { v := errors.New("dial err") if _, err := newSentinelClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { return &mockConn{DialFn: func() error { return v }} }, + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{DialFn: func() error { return v }} }, newRetryer(defaultRetryDelayFn), ); err != v { t.Fatalf("unexpected err %v", err) @@ -40,7 +40,7 @@ func TestSentinelClientInit(t *testing.T) { v := errors.New("refresh err") if _, err := newSentinelClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return &mockConn{ DoMultiFn: func(cmd ...Completed) *redisresults { return &redisresults{s: []RedisResult{newErrResult(v)}} }, } @@ -126,7 +126,7 @@ func TestSentinelClientInit(t *testing.T) { } client, err := newSentinelClient( &ClientOption{InitAddress: []string{":0", ":1", ":2"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { if dst == ":0" { return s0 } @@ -334,7 +334,7 @@ func TestSentinelClientInit(t *testing.T) { client, err := newSentinelClient( &ClientOption{InitAddress: []string{":0", ":1", ":2"}, ReplicaOnly: true}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { if dst == ":0" { return slaveWithMultiError } @@ -436,7 +436,7 @@ func TestSentinelClientInit(t *testing.T) { } client, err := newSentinelClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { if dst == ":0" { return s0 } @@ -562,7 +562,7 @@ func TestSentinelClientInit(t *testing.T) { } client, err := newSentinelClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { if dst == ":0" { return s0 } @@ -629,7 +629,7 @@ func TestSentinelRefreshAfterClose(t *testing.T) { } client, err := newSentinelClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { if dst == ":0" { return s0 } @@ -674,7 +674,7 @@ func TestSentinelSwitchAfterClose(t *testing.T) { } client, err := newSentinelClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { if dst == ":0" { return s0 } @@ -716,7 +716,7 @@ func TestSentinelClientDelegate(t *testing.T) { } client, err := newSentinelClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { if dst == ":0" { return s0 } @@ -734,7 +734,7 @@ func TestSentinelClientDelegate(t *testing.T) { disabledCacheClient, err := newSentinelClient( &ClientOption{InitAddress: []string{":0"}, DisableCache: true}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { if dst == ":0" { return s0 } @@ -1082,7 +1082,7 @@ func TestSentinelClientDelegateRetry(t *testing.T) { } client, err := newSentinelClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { if dst == ":0" { return s0 } @@ -1242,7 +1242,7 @@ func TestSentinelClientPubSub(t *testing.T) { MasterSet: "test", }, }, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { if dst == ":0" { return s0 } @@ -1411,7 +1411,7 @@ func TestSentinelReplicaOnlyClientPubSub(t *testing.T) { }, ReplicaOnly: true, }, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { if dst == ":0" { return s0 } @@ -1539,7 +1539,7 @@ func TestSentinelClientRetry(t *testing.T) { InitAddress: []string{":0"}, Sentinel: SentinelOption{MasterSet: "masters"}, }, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { return m }, newRetryer(defaultRetryDelayFn), @@ -1576,7 +1576,7 @@ func TestSentinelClientLoadingRetry(t *testing.T) { } client, err := newSentinelClient( &ClientOption{InitAddress: []string{":0"}}, - func(dst string, opt *ClientOption) conn { + func(ctx context.Context, dst string, opt *ClientOption) conn { if dst == ":0" { return s0 } From e2b062191e657370091d23fde5aeb236b03e14fd Mon Sep 17 00:00:00 2001 From: Rueian Date: Fri, 14 Mar 2025 10:46:48 -0700 Subject: [PATCH 3/5] feat: Deprecated DialFn Signed-off-by: Rueian --- rueidis.go | 1 + rueidis_test.go | 21 ++++++++++++++++++++ rueidisotel/metrics.go | 26 +++++++++++++++---------- rueidisotel/metrics_test.go | 39 ++++++++++++++++++++++++------------- url.go | 5 +++-- url_test.go | 7 ++++--- 6 files changed, 71 insertions(+), 28 deletions(-) diff --git a/rueidis.go b/rueidis.go index fb148538..de0cf623 100644 --- a/rueidis.go +++ b/rueidis.go @@ -70,6 +70,7 @@ type ClientOption struct { TLSConfig *tls.Config // DialFn allows for a custom function to be used to create net.Conn connections + // Deprecated: use DialCtxFn instead. DialFn func(string, *net.Dialer, *tls.Config) (conn net.Conn, err error) // DialCtxFn allows for a custom function to be used to create net.Conn connections diff --git a/rueidis_test.go b/rueidis_test.go index 4da05a95..2d0e7582 100644 --- a/rueidis_test.go +++ b/rueidis_test.go @@ -407,6 +407,27 @@ func TestCustomDialFnIsCalled(t *testing.T) { } } +func TestCustomDialCtxFnIsCalled(t *testing.T) { + defer ShouldNotLeaked(SetupLeakDetection()) + isFnCalled := false + option := ClientOption{ + InitAddress: []string{"127.0.0.1:0"}, + DialCtxFn: func(ctx context.Context, s string, dialer *net.Dialer, config *tls.Config) (conn net.Conn, err error) { + isFnCalled = true + return nil, errors.New("dial error") + }, + } + + _, err := NewClient(option) + + if !isFnCalled { + t.Fatalf("excepted ClientOption.DialFn to be called") + } + if err == nil { + t.Fatalf("expected dial error") + } +} + func ExampleIsRedisNil() { client, err := NewClient(ClientOption{InitAddress: []string{"127.0.0.1:6379"}}) if err != nil { diff --git a/rueidisotel/metrics.go b/rueidisotel/metrics.go index 9176162e..75cfcd5e 100644 --- a/rueidisotel/metrics.go +++ b/rueidisotel/metrics.go @@ -70,8 +70,13 @@ func NewClient(clientOption rueidis.ClientOption, opts ...Option) (rueidis.Clien return nil, err } - if clientOption.DialFn == nil { - clientOption.DialFn = defaultDialFn + if clientOption.DialCtxFn == nil { + clientOption.DialCtxFn = defaultDialFn + if clientOption.DialFn != nil { + clientOption.DialCtxFn = func(_ context.Context, s string, dialer *net.Dialer, config *tls.Config) (conn net.Conn, err error) { + return clientOption.DialFn(s, dialer, config) + } + } } metrics := dialMetrics{ @@ -103,7 +108,8 @@ func NewClient(clientOption rueidis.ClientOption, opts ...Option) (rueidis.Clien return nil, err } - clientOption.DialFn = trackDialing(metrics, clientOption.DialFn) + clientOption.DialCtxFn = trackDialing(metrics, clientOption.DialCtxFn) + cli, err := rueidis.NewClient(clientOption) if err != nil { return nil, err @@ -146,14 +152,13 @@ func newClient(opts ...Option) (*otelclient, error) { return cli, nil } -func trackDialing(m dialMetrics, dialFn func(string, *net.Dialer, *tls.Config) (conn net.Conn, err error)) func(string, *net.Dialer, *tls.Config) (conn net.Conn, err error) { - return func(network string, dialer *net.Dialer, tlsConfig *tls.Config) (conn net.Conn, err error) { - ctx := context.Background() +func trackDialing(m dialMetrics, dialFn func(context.Context, string, *net.Dialer, *tls.Config) (conn net.Conn, err error)) func(context.Context, string, *net.Dialer, *tls.Config) (conn net.Conn, err error) { + return func(ctx context.Context, network string, dialer *net.Dialer, tlsConfig *tls.Config) (conn net.Conn, err error) { m.attempt.Add(ctx, 1, m.addOpts...) start := time.Now() - conn, err = dialFn(network, dialer, tlsConfig) + conn, err = dialFn(ctx, network, dialer, tlsConfig) if err != nil { return nil, err } @@ -187,9 +192,10 @@ func (t *connTracker) Close() error { return t.Conn.Close() } -func defaultDialFn(dst string, dialer *net.Dialer, cfg *tls.Config) (conn net.Conn, err error) { +func defaultDialFn(ctx context.Context, dst string, dialer *net.Dialer, cfg *tls.Config) (conn net.Conn, err error) { if cfg != nil { - return tls.DialWithDialer(dialer, "tcp", dst, cfg) + td := tls.Dialer{NetDialer: dialer, Config: cfg} + return td.DialContext(ctx, "tcp", dst) } - return dialer.Dial("tcp", dst) + return dialer.DialContext(ctx, "tcp", dst) } diff --git a/rueidisotel/metrics_test.go b/rueidisotel/metrics_test.go index c2be606c..3766ce5c 100644 --- a/rueidisotel/metrics_test.go +++ b/rueidisotel/metrics_test.go @@ -15,7 +15,7 @@ import ( ) func TestNewClient(t *testing.T) { - t.Run("client option only", func(t *testing.T) { + t.Run("client option only (no ctx)", func(t *testing.T) { c, err := NewClient(rueidis.ClientOption{ InitAddress: []string{"127.0.0.1:6379"}, DialFn: func(dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) { @@ -28,14 +28,27 @@ func TestNewClient(t *testing.T) { defer c.Close() }) + t.Run("client option only", func(t *testing.T) { + c, err := NewClient(rueidis.ClientOption{ + InitAddress: []string{"127.0.0.1:6379"}, + DialCtxFn: func(ctx context.Context, dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) { + return dialer.DialContext(ctx, "tcp", dst) + }, + }) + if err != nil { + t.Fatal(err) + } + defer c.Close() + }) + t.Run("meter provider", func(t *testing.T) { mr := metric.NewManualReader() meterProvider := metric.NewMeterProvider(metric.WithReader(mr)) c, err := NewClient( rueidis.ClientOption{ InitAddress: []string{"127.0.0.1:6379"}, - DialFn: func(dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) { - return dialer.Dial("tcp", dst) + DialCtxFn: func(ctx context.Context, dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) { + return dialer.DialContext(ctx, "tcp", dst) }, }, WithMeterProvider(meterProvider), @@ -50,8 +63,8 @@ func TestNewClient(t *testing.T) { c, err := NewClient( rueidis.ClientOption{ InitAddress: []string{"127.0.0.1:6379"}, - DialFn: func(dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) { - return dialer.Dial("tcp", dst) + DialCtxFn: func(ctx context.Context, dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) { + return dialer.DialContext(ctx, "tcp", dst) }, }, WithHistogramOption(HistogramOption{ @@ -79,8 +92,8 @@ func TestNewClientError(t *testing.T) { t.Run("invalid client option", func(t *testing.T) { _, err := NewClient(rueidis.ClientOption{ InitAddress: []string{""}, - DialFn: func(dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) { - return dialer.Dial("tcp", dst) + DialCtxFn: func(ctx context.Context, dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) { + return dialer.DialContext(ctx, "tcp", dst) }, }) if err == nil { @@ -120,8 +133,8 @@ func TestTrackDialing(t *testing.T) { c, err := NewClient( rueidis.ClientOption{ InitAddress: []string{"127.0.0.1:6379"}, - DialFn: func(dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) { - return dialer.Dial("tcp", dst) + DialCtxFn: func(ctx context.Context, dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) { + return dialer.DialContext(ctx, "tcp", dst) }, }, WithMeterProvider(meterProvider), @@ -169,8 +182,8 @@ func TestTrackDialing(t *testing.T) { c, err := NewClient( rueidis.ClientOption{ InitAddress: []string{"127.0.0.1:6379"}, - DialFn: func(dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) { - return dialer.Dial("tcp", dst) + DialCtxFn: func(ctx context.Context, dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) { + return dialer.DialContext(ctx, "tcp", dst) }, }, WithMeterProvider(meterProvider), @@ -198,8 +211,8 @@ func TestTrackDialing(t *testing.T) { _, err := NewClient( rueidis.ClientOption{ InitAddress: []string{""}, - DialFn: func(dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) { - return dialer.Dial("tcp", dst) + DialCtxFn: func(ctx context.Context, dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) { + return dialer.DialContext(ctx, "tcp", dst) }, }, WithMeterProvider(meterProvider), diff --git a/url.go b/url.go index dad0fc00..416402bb 100644 --- a/url.go +++ b/url.go @@ -1,6 +1,7 @@ package rueidis import ( + "context" "crypto/tls" "fmt" "net" @@ -37,8 +38,8 @@ func ParseURL(str string) (opt ClientOption, err error) { } switch u.Scheme { case "unix": - opt.DialFn = func(s string, dialer *net.Dialer, config *tls.Config) (conn net.Conn, err error) { - return dialer.Dial("unix", s) + opt.DialCtxFn = func(ctx context.Context, s string, dialer *net.Dialer, config *tls.Config) (conn net.Conn, err error) { + return dialer.DialContext(ctx, "unix", s) } opt.InitAddress = []string{strings.TrimSpace(u.Path)} case "rediss", "valkeys": diff --git a/url_test.go b/url_test.go index 45e3eed8..aa207b3a 100644 --- a/url_test.go +++ b/url_test.go @@ -1,6 +1,7 @@ package rueidis import ( + "context" "strings" "testing" ) @@ -18,7 +19,7 @@ func TestParseURL(t *testing.T) { if opt, err := ParseURL("valkeys://"); err != nil || opt.TLSConfig == nil { t.Fatalf("unexpected %v %v", opt, err) } - if opt, err := ParseURL("unix://"); err != nil || opt.DialFn == nil { + if opt, err := ParseURL("unix://"); err != nil || opt.DialCtxFn == nil { t.Fatalf("unexpected %v %v", opt, err) } if opt, err := ParseURL("valkey://"); err != nil { @@ -84,7 +85,7 @@ func TestParseURL(t *testing.T) { if opt, err := ParseURL("rediss://myhost:6379"); err != nil || opt.TLSConfig.ServerName != "myhost" { t.Fatalf("unexpected %v %v", opt, err) } - if opt, err := ParseURL("unix:///path/to/redis.sock?db=1"); opt.DialFn == nil || opt.InitAddress[0] != "/path/to/redis.sock" || opt.SelectDB != 1 { + if opt, err := ParseURL("unix:///path/to/redis.sock?db=1"); opt.DialCtxFn == nil || opt.InitAddress[0] != "/path/to/redis.sock" || opt.SelectDB != 1 { t.Fatalf("unexpected %v %v", opt, err) } } @@ -100,7 +101,7 @@ func TestMustParseURL(t *testing.T) { func TestMustParseURLUnix(t *testing.T) { opt := MustParseURL("unix://") - if conn, err := opt.DialFn("", &opt.Dialer, nil); !strings.Contains(err.Error(), "unix") { + if conn, err := opt.DialCtxFn(context.Background(), "", &opt.Dialer, nil); !strings.Contains(err.Error(), "unix") { t.Fatalf("unexpected %v %v", conn, err) // the error should be "dial unix: missing address" } } From 4ed2376408cf77e1645ace83f029e8c7c39d217d Mon Sep 17 00:00:00 2001 From: zhaohuiliu Date: Fri, 14 Mar 2025 23:46:31 +0800 Subject: [PATCH 4/5] add testcase for conn ctx --- client_test.go | 68 +++++++++++++++++++++++++++- mux_test.go | 32 ++++++++++++- pool_test.go | 119 +++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 215 insertions(+), 4 deletions(-) diff --git a/client_test.go b/client_test.go index 0258d54e..2a3c0b71 100644 --- a/client_test.go +++ b/client_test.go @@ -14,8 +14,10 @@ import ( type mockConn struct { DoFn func(cmd Completed) RedisResult + DoCtxFn func(ctx context.Context, cmd Completed) RedisResult DoCacheFn func(cmd Cacheable, ttl time.Duration) RedisResult DoMultiFn func(multi ...Completed) *redisresults + DoMultiCtxFn func(ctx context.Context, multi ...Completed) *redisresults DoMultiCacheFn func(multi ...CacheableTTL) *redisresults ReceiveFn func(ctx context.Context, subscribe Completed, fn func(message PubSubMessage)) error DoStreamFn func(cmd Completed) RedisResultStream @@ -66,7 +68,9 @@ func (m *mockConn) Do(ctx context.Context, cmd Completed) RedisResult { if fn := m.DoOverride[strings.Join(cmd.Commands(), " ")]; fn != nil { return fn(cmd) } - if m.DoFn != nil { + if m.DoCtxFn != nil { + return m.DoCtxFn(ctx, cmd) + } else if m.DoFn != nil { return m.DoFn(cmd) } return RedisResult{} @@ -108,7 +112,9 @@ func (m *mockConn) DoMulti(ctx context.Context, multi ...Completed) *redisresult if len(overrides) == len(multi) { return &redisresults{s: overrides} } - if m.DoMultiFn != nil { + if m.DoMultiCtxFn != nil { + return m.DoMultiCtxFn(ctx, multi...) + } else if m.DoMultiFn != nil { return m.DoMultiFn(multi...) } return nil @@ -613,6 +619,64 @@ func TestSingleClient(t *testing.T) { } } }) + + t.Run("Acquire Exceed Context Deadline", func(t *testing.T) { + w := &mockWire{} + m.AcquireFn = func(ctx context.Context) wire { + timer := time.NewTimer(time.Millisecond*10) + defer timer.Stop() + select { + case <-ctx.Done(): + return epipeFn(ctx.Err()) + case <-timer.C: + // noop + } + return w + } + m.DoCtxFn = func(ctx context.Context, cmd Completed) RedisResult { + if ww := m.AcquireFn(ctx); ww != w { + return newErrResult(ww.Error()) + } + return newResult(RedisMessage{typ: '+', string: "Acquire"}, nil) + } + m.DoMultiCtxFn = func(ctx context.Context, cmd ...Completed) *redisresults { + if ww := m.AcquireFn(ctx); ww != w { + return &redisresults{s: []RedisResult{newErrResult(ww.Error())}} + } + return &redisresults{s: []RedisResult{newResult(RedisMessage{typ: '+', string: "Acquire"}, nil)}} + } + + m.StoreFn = func(ww wire) { + if (err == nil && ww != dead) || (err != nil && ww != w) { + t.Fatalf("received unexpected wire %v", ww) + } + err = nil + } + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + if v, err := client.Do(ctx, client.B().Get().Key("a").Build()).ToString(); err == nil || v == "Acquire" { + t.Fatalf("unexpected response %v %v", v, err) + } + cancel() + ctx, cancel = context.WithTimeout(context.Background(), time.Millisecond*20) + if v, err := client.Do(ctx, client.B().Get().Key("a").Build()).ToString(); err != nil || v != "Acquire" { + t.Fatalf("unexpected response %v %v", v, err) + } + cancel() + ctx, cancel = context.WithTimeout(context.Background(), time.Millisecond) + for _, resp := range client.DoMulti(ctx, client.B().Get().Key("a").Build()) { + if v, err := resp.ToString(); err == nil || v == "Acquire" { + t.Fatalf("unexpected response %v %v", v, err) + } + } + cancel() + ctx, cancel = context.WithTimeout(context.Background(), time.Millisecond*20) + for _, resp := range client.DoMulti(ctx, client.B().Get().Key("a").Build()) { + if v, err := resp.ToString(); err != nil || v != "Acquire" { + t.Fatalf("unexpected response %v %v", v, err) + } + } + cancel() + }) } func TestSingleClientRetry(t *testing.T) { diff --git a/mux_test.go b/mux_test.go index 10139413..0df7113f 100644 --- a/mux_test.go +++ b/mux_test.go @@ -43,26 +43,54 @@ func TestNewMuxDailErr(t *testing.T) { defer ShouldNotLeaked(SetupLeakDetection()) c := 0 e := errors.New("any") - m := makeMux("", &ClientOption{}, func(_ context.Context, dst string, opt *ClientOption) (net.Conn, error) { + m := makeMux("", &ClientOption{}, func(ctx context.Context, dst string, opt *ClientOption) (net.Conn, error) { + timer := time.NewTimer(time.Millisecond*10) // delay time + defer timer.Stop() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-timer.C: + // noop + } c++ return nil, e }) if err := m.Dial(); err != e { t.Fatalf("unexpected return %v", err) } + ctx1, cancel1 := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel1() + if _, err := m._pipe(ctx1, 0); err != context.DeadlineExceeded { + t.Fatalf("unexpected return %v", err) + } if c != 1 { t.Fatalf("dialFn not called") } if w := m.pipe(context.Background(), 0); w != m.dead { // c = 2 t.Fatalf("unexpected wire %v", w) } + ctx2, cancel2 := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel2() + if w := m.pipe(ctx2, 0); w != m.dead { + t.Fatalf("unexpected wire %v", w) + } if err := m.Dial(); err != e { // c = 3 t.Fatalf("unexpected return %v", err) } if w := m.Acquire(context.Background()); w != m.dead { t.Fatalf("unexpected wire %v", w) } - if c != 4 { + ctx3, cancel3 := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel3() + if w := m.Acquire(ctx3); w != m.dead { + t.Fatalf("unexpected wire %v", w) + } + ctx4, cancel4 := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancel4() + if w := m.Acquire(ctx4); w != m.dead { + t.Fatalf("unexpected wire %v", w) + } + if c != 5 { t.Fatalf("dialFn not called %v", c) } } diff --git a/pool_test.go b/pool_test.go index 4053381f..8851cb36 100644 --- a/pool_test.go +++ b/pool_test.go @@ -330,3 +330,122 @@ func TestPoolWithIdleTTL(t *testing.T) { p.Close() }) } + +func TestPoolWithAcquireCtx(t *testing.T) { + defer ShouldNotLeaked(SetupLeakDetection()) + setup := func(size int, delay time.Duration) *pool { + return newPool(size, dead, 0, 0, func(ctx context.Context) wire { + var err error + closed := false + timer := time.NewTimer(delay) + defer timer.Stop() + select { + case <-ctx.Done(): + err = ctx.Err() + closed = true + case <-timer.C: + // noop + } + + return &mockWire{ + CloseFn: func() { + closed = true + }, + ErrorFn: func() error { + if err != nil { + return err + } else if closed { + return ErrClosing + } + return nil + }, + } + }) + } + t.Run("Acquire connections, all exceed context deadline", func(t *testing.T) { + p := setup(10, time.Millisecond*5) + conns := make([]wire, 10) + + for i := 0; i < 2; i++ { + for i := range conns { + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + w := p.Acquire(ctx) + conns[i] = w + cancel() + } + + for _, w := range conns { + p.Store(w) + } + + p.cond.L.Lock() + if p.size != 0 { + defer p.cond.L.Unlock() + t.Fatalf("size must be equal to 0, actual: %d", p.size) + } + + if len(p.list) != 0 { + defer p.cond.L.Unlock() + t.Fatalf("pool len must equal to 0, actual: %d", len(p.list)) + } + p.cond.L.Unlock() + } + + p.Close() + }) + + t.Run("Acquire connections, some exceed context deadline", func(t *testing.T) { + p := setup(10, time.Millisecond*5) + conns := make([]wire, 10) + + // size = 5 + for i := range conns { + d := time.Millisecond + if i % 2 == 0 { + d = time.Millisecond * 8 + } + ctx, cancel := context.WithTimeout(context.Background(), d) + w := p.Acquire(ctx) + conns[i] = w + cancel() + } + for _, w := range conns { + p.Store(w) + } + p.cond.L.Lock() + if p.size != len(conns)/2 { + defer p.cond.L.Unlock() + t.Fatalf("size must be equal to %d, actual: %d", len(conns)/2, p.size) + } + + if len(p.list) != len(conns)/2 { + defer p.cond.L.Unlock() + t.Fatalf("pool len must equal to %d, actual: %d", len(conns)/2, len(p.list)) + } + p.cond.L.Unlock() + + // size = 10 + for i := range conns { + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond * 8) + w := p.Acquire(ctx) + conns[i] = w + cancel() + } + for _, w := range conns { + p.Store(w) + } + p.cond.L.Lock() + if p.size != len(conns) { + defer p.cond.L.Unlock() + t.Fatalf("size must be equal to %d, actual: %d", len(conns), p.size) + } + + if len(p.list) != len(conns) { + defer p.cond.L.Unlock() + t.Fatalf("pool len must equal to %d, actual: %d", len(conns), len(p.list)) + } + p.cond.L.Unlock() + + p.Close() + }) +} \ No newline at end of file From ee9d8edf94579dc4c01085e1f3629ff767c77a29 Mon Sep 17 00:00:00 2001 From: Rueian Date: Sat, 15 Mar 2025 09:22:19 -0700 Subject: [PATCH 5/5] chore: remove unnecessary changes Signed-off-by: Rueian --- client_test.go | 48 +++++++++++++-------------- cluster_test.go | 86 ++++++++++++++++++++++++------------------------ sentinel_test.go | 8 ++--- 3 files changed, 71 insertions(+), 71 deletions(-) diff --git a/client_test.go b/client_test.go index 0258d54e..6a9559f8 100644 --- a/client_test.go +++ b/client_test.go @@ -26,7 +26,7 @@ type mockConn struct { ErrorFn func() error CloseFn func() DialFn func() error - AcquireFn func(ctx context.Context) wire + AcquireFn func() wire StoreFn func(w wire) OverrideFn func(c conn) AddrFn func() string @@ -51,7 +51,7 @@ func (m *mockConn) Dial() error { func (m *mockConn) Acquire(ctx context.Context) wire { if m.AcquireFn != nil { - return m.AcquireFn(ctx) + return m.AcquireFn() } return nil } @@ -399,7 +399,7 @@ func TestSingleClient(t *testing.T) { return e }, } - m.AcquireFn = func(_ context.Context) wire { + m.AcquireFn = func() wire { return w } if err := client.Dedicated(func(c DedicatedClient) error { @@ -434,7 +434,7 @@ func TestSingleClient(t *testing.T) { closed = true }, } - m.AcquireFn = func(_ context.Context) wire { + m.AcquireFn = func() wire { return w } stored := false @@ -500,7 +500,7 @@ func TestSingleClient(t *testing.T) { closed = true }, } - m.AcquireFn = func(_ context.Context) wire { + m.AcquireFn = func() wire { return w } stored := false @@ -544,7 +544,7 @@ func TestSingleClient(t *testing.T) { t.Run("Dedicate Delegate Release On Close", func(t *testing.T) { stored := 0 w := &mockWire{} - m.AcquireFn = func(_ context.Context) wire { return w } + m.AcquireFn = func() wire { return w } m.StoreFn = func(ww wire) { stored++ } c, _ := client.Dedicate() @@ -558,7 +558,7 @@ func TestSingleClient(t *testing.T) { t.Run("Dedicate Delegate No Duplicate Release", func(t *testing.T) { stored := 0 w := &mockWire{} - m.AcquireFn = func(_ context.Context) wire { return w } + m.AcquireFn = func() wire { return w } m.StoreFn = func(ww wire) { stored++ } c, cancel := client.Dedicate() @@ -573,7 +573,7 @@ func TestSingleClient(t *testing.T) { }) t.Run("Dedicate ErrDedicatedClientRecycled after released", func(t *testing.T) { - m.AcquireFn = func(_ context.Context) wire { return &mockWire{} } + m.AcquireFn = func() wire { return &mockWire{} } check := func(err error) { if !errors.Is(err, ErrDedicatedClientRecycled) { t.Fatalf("unexpected err %v", err) @@ -1020,7 +1020,7 @@ func SetupClientRetry(t *testing.T, fn func(mock *mockConn) Client) { newErrResult(ErrClosing), newResult(RedisMessage{typ: '+', string: "Do"}, nil), ) - m.AcquireFn = func(_ context.Context) wire { return &mockWire{DoFn: m.DoFn} } + m.AcquireFn = func() wire { return &mockWire{DoFn: m.DoFn} } if ret := c.Dedicated(func(cc DedicatedClient) error { if v, err := cc.Do(context.Background(), c.B().Get().Key("Do").Build()).ToString(); err != nil || v != "Do" { t.Fatalf("unexpected response %v %v", v, err) @@ -1035,7 +1035,7 @@ func SetupClientRetry(t *testing.T, fn func(mock *mockConn) Client) { c, m := setup() m.DoFn = makeDoFn(newErrResult(ErrClosing)) m.ErrorFn = func() error { return ErrClosing } - m.AcquireFn = func(_ context.Context) wire { return &mockWire{DoFn: m.DoFn, ErrorFn: m.ErrorFn} } + m.AcquireFn = func() wire { return &mockWire{DoFn: m.DoFn, ErrorFn: m.ErrorFn} } if ret := c.Dedicated(func(cc DedicatedClient) error { return cc.Do(context.Background(), c.B().Get().Key("Do").Build()).Error() }); ret != ErrClosing { @@ -1046,7 +1046,7 @@ func SetupClientRetry(t *testing.T, fn func(mock *mockConn) Client) { t.Run("Dedicate Delegate Do ReadOnly NoRetry - ctx done", func(t *testing.T) { c, m := setup() m.DoFn = makeDoFn(newErrResult(ErrClosing)) - m.AcquireFn = func(_ context.Context) wire { return &mockWire{DoFn: m.DoFn} } + m.AcquireFn = func() wire { return &mockWire{DoFn: m.DoFn} } ctx, cancel := context.WithCancel(context.Background()) cancel() if ret := c.Dedicated(func(cc DedicatedClient) error { @@ -1062,7 +1062,7 @@ func SetupClientRetry(t *testing.T, fn func(mock *mockConn) Client) { newErrResult(ErrClosing), newResult(RedisMessage{typ: '+', string: "Do"}, nil), ) - m.AcquireFn = func(_ context.Context) wire { return &mockWire{DoFn: m.DoFn} } + m.AcquireFn = func() wire { return &mockWire{DoFn: m.DoFn} } if ret := c.Dedicated(func(cc DedicatedClient) error { if cli, ok := cc.(*dedicatedClusterClient); ok { cli.retryHandler = &mockRetryHandler{ @@ -1091,7 +1091,7 @@ func SetupClientRetry(t *testing.T, fn func(mock *mockConn) Client) { t.Run("Dedicate Delegate Do Write NoRetry", func(t *testing.T) { c, m := setup() m.DoFn = makeDoFn(newErrResult(ErrClosing)) - m.AcquireFn = func(_ context.Context) wire { return &mockWire{DoFn: m.DoFn} } + m.AcquireFn = func() wire { return &mockWire{DoFn: m.DoFn} } if ret := c.Dedicated(func(cc DedicatedClient) error { return cc.Do(context.Background(), c.B().Set().Key("Do").Value("Do").Build()).Error() }); ret != ErrClosing { @@ -1105,7 +1105,7 @@ func SetupClientRetry(t *testing.T, fn func(mock *mockConn) Client) { []RedisResult{newErrResult(ErrClosing)}, []RedisResult{newResult(RedisMessage{typ: '+', string: "Do"}, nil)}, ) - m.AcquireFn = func(_ context.Context) wire { return &mockWire{DoMultiFn: m.DoMultiFn} } + m.AcquireFn = func() wire { return &mockWire{DoMultiFn: m.DoMultiFn} } if ret := c.Dedicated(func(cc DedicatedClient) error { if v, err := cc.DoMulti(context.Background(), c.B().Get().Key("Do").Build())[0].ToString(); err != nil || v != "Do" { t.Fatalf("unexpected response %v %v", v, err) @@ -1120,7 +1120,7 @@ func SetupClientRetry(t *testing.T, fn func(mock *mockConn) Client) { c, m := setup() m.DoMultiFn = makeDoMultiFn([]RedisResult{newErrResult(ErrClosing)}) m.ErrorFn = func() error { return ErrClosing } - m.AcquireFn = func(_ context.Context) wire { return &mockWire{DoMultiFn: m.DoMultiFn, ErrorFn: m.ErrorFn} } + m.AcquireFn = func() wire { return &mockWire{DoMultiFn: m.DoMultiFn, ErrorFn: m.ErrorFn} } if ret := c.Dedicated(func(cc DedicatedClient) error { return cc.DoMulti(context.Background(), c.B().Get().Key("Do").Build())[0].Error() }); ret != ErrClosing { @@ -1131,7 +1131,7 @@ func SetupClientRetry(t *testing.T, fn func(mock *mockConn) Client) { t.Run("Dedicate Delegate DoMulti ReadOnly NoRetry - ctx done", func(t *testing.T) { c, m := setup() m.DoMultiFn = makeDoMultiFn([]RedisResult{newErrResult(ErrClosing)}) - m.AcquireFn = func(_ context.Context) wire { return &mockWire{DoMultiFn: m.DoMultiFn} } + m.AcquireFn = func() wire { return &mockWire{DoMultiFn: m.DoMultiFn} } ctx, cancel := context.WithCancel(context.Background()) cancel() if ret := c.Dedicated(func(cc DedicatedClient) error { @@ -1147,7 +1147,7 @@ func SetupClientRetry(t *testing.T, fn func(mock *mockConn) Client) { []RedisResult{newErrResult(ErrClosing)}, []RedisResult{newResult(RedisMessage{typ: '+', string: "Do"}, nil)}, ) - m.AcquireFn = func(_ context.Context) wire { return &mockWire{DoMultiFn: m.DoMultiFn} } + m.AcquireFn = func() wire { return &mockWire{DoMultiFn: m.DoMultiFn} } if ret := c.Dedicated(func(cc DedicatedClient) error { if cli, ok := cc.(*dedicatedClusterClient); ok { cli.retryHandler = &mockRetryHandler{ @@ -1176,7 +1176,7 @@ func SetupClientRetry(t *testing.T, fn func(mock *mockConn) Client) { t.Run("Dedicate Delegate DoMulti Write NoRetry", func(t *testing.T) { c, m := setup() m.DoMultiFn = makeDoMultiFn([]RedisResult{newErrResult(ErrClosing)}) - m.AcquireFn = func(_ context.Context) wire { return &mockWire{DoMultiFn: m.DoMultiFn} } + m.AcquireFn = func() wire { return &mockWire{DoMultiFn: m.DoMultiFn} } if ret := c.Dedicated(func(cc DedicatedClient) error { return cc.DoMulti(context.Background(), c.B().Set().Key("Do").Value("Do").Build())[0].Error() }); ret != ErrClosing { @@ -1187,7 +1187,7 @@ func SetupClientRetry(t *testing.T, fn func(mock *mockConn) Client) { t.Run("Delegate Receive Retry", func(t *testing.T) { c, m := setup() m.ReceiveFn = makeReceiveFn(ErrClosing, nil) - m.AcquireFn = func(_ context.Context) wire { return &mockWire{ReceiveFn: m.ReceiveFn} } + m.AcquireFn = func() wire { return &mockWire{ReceiveFn: m.ReceiveFn} } if ret := c.Dedicated(func(cc DedicatedClient) error { if err := cc.Receive(context.Background(), c.B().Subscribe().Channel("Do").Build(), nil); err != nil { t.Fatalf("unexpected response %v", err) @@ -1202,7 +1202,7 @@ func SetupClientRetry(t *testing.T, fn func(mock *mockConn) Client) { c, m := setup() m.ReceiveFn = makeReceiveFn(ErrClosing) m.ErrorFn = func() error { return ErrClosing } - m.AcquireFn = func(_ context.Context) wire { return &mockWire{ReceiveFn: m.ReceiveFn, ErrorFn: m.ErrorFn} } + m.AcquireFn = func() wire { return &mockWire{ReceiveFn: m.ReceiveFn, ErrorFn: m.ErrorFn} } if ret := c.Dedicated(func(cc DedicatedClient) error { return cc.Receive(context.Background(), c.B().Subscribe().Channel("Do").Build(), nil) }); ret != ErrClosing { @@ -1213,7 +1213,7 @@ func SetupClientRetry(t *testing.T, fn func(mock *mockConn) Client) { t.Run("Delegate Receive NoRetry - ctx done", func(t *testing.T) { c, m := setup() m.ReceiveFn = makeReceiveFn(ErrClosing) - m.AcquireFn = func(_ context.Context) wire { return &mockWire{ReceiveFn: m.ReceiveFn} } + m.AcquireFn = func() wire { return &mockWire{ReceiveFn: m.ReceiveFn} } ctx, cancel := context.WithCancel(context.Background()) cancel() if ret := c.Dedicated(func(cc DedicatedClient) error { @@ -1226,7 +1226,7 @@ func SetupClientRetry(t *testing.T, fn func(mock *mockConn) Client) { t.Run("Delegate Receive NoRetry - not retryable", func(t *testing.T) { c, m := setup() m.ReceiveFn = makeReceiveFn(ErrClosing, nil) - m.AcquireFn = func(_ context.Context) wire { return &mockWire{ReceiveFn: m.ReceiveFn} } + m.AcquireFn = func() wire { return &mockWire{ReceiveFn: m.ReceiveFn} } if ret := c.Dedicated(func(cc DedicatedClient) error { if cli, ok := cc.(*dedicatedClusterClient); ok { cli.retryHandler = &mockRetryHandler{ @@ -1377,7 +1377,7 @@ func TestSingleClientLoadingRetry(t *testing.T) { } return newResult(RedisMessage{typ: '+', string: "OK"}, nil) } - m.AcquireFn = func(_ context.Context) wire { return &mockWire{DoFn: m.DoFn} } + m.AcquireFn = func() wire { return &mockWire{DoFn: m.DoFn} } err := client.Dedicated(func(c DedicatedClient) error { if v, err := c.Do(context.Background(), c.B().Get().Key("test").Build()).ToString(); err != nil || v != "OK" { @@ -1400,7 +1400,7 @@ func TestSingleClientLoadingRetry(t *testing.T) { } return &redisresults{s: []RedisResult{newResult(RedisMessage{typ: '+', string: "OK"}, nil)}} } - m.AcquireFn = func(_ context.Context) wire { return &mockWire{DoMultiFn: m.DoMultiFn} } + m.AcquireFn = func() wire { return &mockWire{DoMultiFn: m.DoMultiFn} } err := client.Dedicated(func(c DedicatedClient) error { resps := c.DoMulti(context.Background(), c.B().Get().Key("test").Build()) diff --git a/cluster_test.go b/cluster_test.go index 9bc132e0..dae2c788 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -1799,7 +1799,7 @@ func TestClusterClient(t *testing.T) { t.Errorf("Dedicated should panic if cross slots is used") } }() - m.AcquireFn = func(_ context.Context) wire { return &mockWire{} } + m.AcquireFn = func() wire { return &mockWire{} } client.Dedicated(func(c DedicatedClient) error { c.Do(context.Background(), c.B().Get().Key("a").Build()).Error() return c.Do(context.Background(), c.B().Get().Key("b").Build()).Error() @@ -1812,7 +1812,7 @@ func TestClusterClient(t *testing.T) { t.Errorf("Dedicated should panic if cross slots is used") } }() - m.AcquireFn = func(_ context.Context) wire { + m.AcquireFn = func() wire { return &mockWire{ DoMultiFn: func(multi ...Completed) *redisresults { return &redisresults{s: []RedisResult{ @@ -1841,7 +1841,7 @@ func TestClusterClient(t *testing.T) { }) t.Run("Dedicated Multi Cross Slot Err", func(t *testing.T) { - m.AcquireFn = func(_ context.Context) wire { return &mockWire{} } + m.AcquireFn = func() wire { return &mockWire{} } err := client.Dedicated(func(c DedicatedClient) (err error) { defer func() { err = errors.New(recover().(string)) @@ -1865,7 +1865,7 @@ func TestClusterClient(t *testing.T) { return e }, } - m.AcquireFn = func(_ context.Context) wire { + m.AcquireFn = func() wire { return w } if err := client.Dedicated(func(c DedicatedClient) error { @@ -1914,7 +1914,7 @@ func TestClusterClient(t *testing.T) { closed = true }, } - m.AcquireFn = func(_ context.Context) wire { + m.AcquireFn = func() wire { return w } stored := false @@ -2022,7 +2022,7 @@ func TestClusterClient(t *testing.T) { closed = true }, } - m.AcquireFn = func(_ context.Context) wire { + m.AcquireFn = func() wire { return w } stored := false @@ -2092,7 +2092,7 @@ func TestClusterClient(t *testing.T) { t.Run("Dedicate Delegate Release On Close", func(t *testing.T) { stored := 0 w := &mockWire{} - m.AcquireFn = func(_ context.Context) wire { return w } + m.AcquireFn = func() wire { return w } m.StoreFn = func(ww wire) { stored++ } c, _ := client.Dedicate() c.Do(context.Background(), c.B().Get().Key("a").Build()) @@ -2107,7 +2107,7 @@ func TestClusterClient(t *testing.T) { t.Run("Dedicate Delegate No Duplicate Release", func(t *testing.T) { stored := 0 w := &mockWire{} - m.AcquireFn = func(_ context.Context) wire { return w } + m.AcquireFn = func() wire { return w } m.StoreFn = func(ww wire) { stored++ } c, cancel := client.Dedicate() c.Do(context.Background(), c.B().Get().Key("a").Build()) @@ -2399,7 +2399,7 @@ func TestClusterClient_SendToOnlyPrimaryNodes(t *testing.T) { t.Errorf("Dedicated should panic if cross slots is used") } }() - primaryNodeConn.AcquireFn = func(_ context.Context) wire { return &mockWire{} } + primaryNodeConn.AcquireFn = func() wire { return &mockWire{} } client.Dedicated(func(c DedicatedClient) error { c.Do(context.Background(), c.B().Get().Key("a").Build()).Error() return c.Do(context.Background(), c.B().Get().Key("b").Build()).Error() @@ -2412,7 +2412,7 @@ func TestClusterClient_SendToOnlyPrimaryNodes(t *testing.T) { t.Errorf("Dedicated should panic if cross slots is used") } }() - primaryNodeConn.AcquireFn = func(_ context.Context) wire { + primaryNodeConn.AcquireFn = func() wire { return &mockWire{ DoMultiFn: func(multi ...Completed) *redisresults { return &redisresults{s: []RedisResult{ @@ -2441,7 +2441,7 @@ func TestClusterClient_SendToOnlyPrimaryNodes(t *testing.T) { }) t.Run("Dedicated Multi Cross Slot Err", func(t *testing.T) { - primaryNodeConn.AcquireFn = func(_ context.Context) wire { return &mockWire{} } + primaryNodeConn.AcquireFn = func() wire { return &mockWire{} } err := client.Dedicated(func(c DedicatedClient) (err error) { defer func() { err = errors.New(recover().(string)) @@ -2465,7 +2465,7 @@ func TestClusterClient_SendToOnlyPrimaryNodes(t *testing.T) { return e }, } - primaryNodeConn.AcquireFn = func(_ context.Context) wire { + primaryNodeConn.AcquireFn = func() wire { return w } if err := client.Dedicated(func(c DedicatedClient) error { @@ -2514,7 +2514,7 @@ func TestClusterClient_SendToOnlyPrimaryNodes(t *testing.T) { closed = true }, } - primaryNodeConn.AcquireFn = func(_ context.Context) wire { + primaryNodeConn.AcquireFn = func() wire { return w } stored := false @@ -2622,7 +2622,7 @@ func TestClusterClient_SendToOnlyPrimaryNodes(t *testing.T) { closed = true }, } - primaryNodeConn.AcquireFn = func(_ context.Context) wire { + primaryNodeConn.AcquireFn = func() wire { return w } stored := false @@ -2888,7 +2888,7 @@ func TestClusterClient_SendToOnlyReplicaNodes(t *testing.T) { t.Errorf("Dedicated should panic if cross slots is used") } }() - primaryNodeConn.AcquireFn = func(_ context.Context) wire { return &mockWire{} } + primaryNodeConn.AcquireFn = func() wire { return &mockWire{} } client.Dedicated(func(c DedicatedClient) error { c.Do(context.Background(), c.B().Get().Key("a").Build()).Error() return c.Do(context.Background(), c.B().Get().Key("b").Build()).Error() @@ -2901,7 +2901,7 @@ func TestClusterClient_SendToOnlyReplicaNodes(t *testing.T) { t.Errorf("Dedicated should panic if cross slots is used") } }() - primaryNodeConn.AcquireFn = func(_ context.Context) wire { + primaryNodeConn.AcquireFn = func() wire { return &mockWire{ DoMultiFn: func(multi ...Completed) *redisresults { return &redisresults{s: []RedisResult{ @@ -2930,7 +2930,7 @@ func TestClusterClient_SendToOnlyReplicaNodes(t *testing.T) { }) t.Run("Dedicated Multi Cross Slot Err", func(t *testing.T) { - primaryNodeConn.AcquireFn = func(_ context.Context) wire { return &mockWire{} } + primaryNodeConn.AcquireFn = func() wire { return &mockWire{} } err := client.Dedicated(func(c DedicatedClient) (err error) { defer func() { err = errors.New(recover().(string)) @@ -2954,10 +2954,10 @@ func TestClusterClient_SendToOnlyReplicaNodes(t *testing.T) { return e }, } - primaryNodeConn.AcquireFn = func(_ context.Context) wire { + primaryNodeConn.AcquireFn = func() wire { return w } - replicaNodeConn.AcquireFn = func(_ context.Context) wire { + replicaNodeConn.AcquireFn = func() wire { return w } // Subscribe can work on replicas if err := client.Dedicated(func(c DedicatedClient) error { @@ -3006,7 +3006,7 @@ func TestClusterClient_SendToOnlyReplicaNodes(t *testing.T) { closed = true }, } - primaryNodeConn.AcquireFn = func(_ context.Context) wire { + primaryNodeConn.AcquireFn = func() wire { return w } stored := false @@ -3114,7 +3114,7 @@ func TestClusterClient_SendToOnlyReplicaNodes(t *testing.T) { closed = true }, } - primaryNodeConn.AcquireFn = func(_ context.Context) wire { + primaryNodeConn.AcquireFn = func() wire { return w } stored := false @@ -3478,7 +3478,7 @@ func TestClusterClient_SendReadOperationToReplicaNodesWriteOperationToPrimaryNod t.Errorf("Dedicated should panic if cross slots is used") } }() - primaryNodeConn.AcquireFn = func(_ context.Context) wire { return &mockWire{} } + primaryNodeConn.AcquireFn = func() wire { return &mockWire{} } client.Dedicated(func(c DedicatedClient) error { c.Do(context.Background(), c.B().Get().Key("a").Build()).Error() return c.Do(context.Background(), c.B().Get().Key("b").Build()).Error() @@ -3491,7 +3491,7 @@ func TestClusterClient_SendReadOperationToReplicaNodesWriteOperationToPrimaryNod t.Errorf("Dedicated should panic if cross slots is used") } }() - primaryNodeConn.AcquireFn = func(_ context.Context) wire { + primaryNodeConn.AcquireFn = func() wire { return &mockWire{ DoMultiFn: func(multi ...Completed) *redisresults { return &redisresults{s: []RedisResult{ @@ -3520,7 +3520,7 @@ func TestClusterClient_SendReadOperationToReplicaNodesWriteOperationToPrimaryNod }) t.Run("Dedicated Multi Cross Slot Err", func(t *testing.T) { - primaryNodeConn.AcquireFn = func(_ context.Context) wire { return &mockWire{} } + primaryNodeConn.AcquireFn = func() wire { return &mockWire{} } err := client.Dedicated(func(c DedicatedClient) (err error) { defer func() { err = errors.New(recover().(string)) @@ -3544,10 +3544,10 @@ func TestClusterClient_SendReadOperationToReplicaNodesWriteOperationToPrimaryNod return e }, } - primaryNodeConn.AcquireFn = func(_ context.Context) wire { + primaryNodeConn.AcquireFn = func() wire { return w } - replicaNodeConn.AcquireFn = func(_ context.Context) wire { + replicaNodeConn.AcquireFn = func() wire { return w } // Subscribe can work on replicas if err := client.Dedicated(func(c DedicatedClient) error { @@ -3596,7 +3596,7 @@ func TestClusterClient_SendReadOperationToReplicaNodesWriteOperationToPrimaryNod closed = true }, } - primaryNodeConn.AcquireFn = func(_ context.Context) wire { + primaryNodeConn.AcquireFn = func() wire { return w } stored := false @@ -3704,7 +3704,7 @@ func TestClusterClient_SendReadOperationToReplicaNodesWriteOperationToPrimaryNod closed = true }, } - primaryNodeConn.AcquireFn = func(_ context.Context) wire { + primaryNodeConn.AcquireFn = func() wire { return w } stored := false @@ -6306,7 +6306,7 @@ func TestClusterClientLoadingRetry(t *testing.T) { } return newResult(RedisMessage{typ: '+', string: "OK"}, nil) } - m.AcquireFn = func(_ context.Context) wire { return &mockWire{DoFn: m.DoFn} } + m.AcquireFn = func() wire { return &mockWire{DoFn: m.DoFn} } err := client.Dedicated(func(c DedicatedClient) error { if v, err := c.Do(context.Background(), c.B().Get().Key("test").Build()).ToString(); err != nil || v != "OK" { @@ -6329,7 +6329,7 @@ func TestClusterClientLoadingRetry(t *testing.T) { } return &redisresults{s: []RedisResult{newResult(RedisMessage{typ: '+', string: "OK"}, nil)}} } - m.AcquireFn = func(_ context.Context) wire { return &mockWire{DoMultiFn: m.DoMultiFn} } + m.AcquireFn = func() wire { return &mockWire{DoMultiFn: m.DoMultiFn} } err := client.Dedicated(func(c DedicatedClient) error { resps := c.DoMulti(context.Background(), c.B().Get().Key("test").Build()) @@ -6783,7 +6783,7 @@ func TestClusterClient_SendReadOperationToReplicaNodeWriteOperationToPrimaryNode t.Errorf("Dedicated should panic if cross slots is used") } }() - primaryNodeConn.AcquireFn = func(_ context.Context) wire { return &mockWire{} } + primaryNodeConn.AcquireFn = func() wire { return &mockWire{} } client.Dedicated(func(c DedicatedClient) error { c.Do(context.Background(), c.B().Get().Key("a").Build()).Error() return c.Do(context.Background(), c.B().Get().Key("b").Build()).Error() @@ -6796,7 +6796,7 @@ func TestClusterClient_SendReadOperationToReplicaNodeWriteOperationToPrimaryNode t.Errorf("Dedicated should panic if cross slots is used") } }() - primaryNodeConn.AcquireFn = func(_ context.Context) wire { + primaryNodeConn.AcquireFn = func() wire { return &mockWire{ DoMultiFn: func(multi ...Completed) *redisresults { return &redisresults{s: []RedisResult{ @@ -6825,7 +6825,7 @@ func TestClusterClient_SendReadOperationToReplicaNodeWriteOperationToPrimaryNode }) t.Run("Dedicated Multi Cross Slot Err", func(t *testing.T) { - primaryNodeConn.AcquireFn = func(_ context.Context) wire { return &mockWire{} } + primaryNodeConn.AcquireFn = func() wire { return &mockWire{} } err := client.Dedicated(func(c DedicatedClient) (err error) { defer func() { err = errors.New(recover().(string)) @@ -6849,8 +6849,8 @@ func TestClusterClient_SendReadOperationToReplicaNodeWriteOperationToPrimaryNode return e }, } - primaryNodeConn.AcquireFn = func(_ context.Context) wire { return w } - replicaNodeConn.AcquireFn = func(_ context.Context) wire { return w } // Subscribe can work on replicas + primaryNodeConn.AcquireFn = func() wire { return w } + replicaNodeConn.AcquireFn = func() wire { return w } // Subscribe can work on replicas if err := client.Dedicated(func(c DedicatedClient) error { return c.Receive(context.Background(), c.B().Subscribe().Channel("a").Build(), func(msg PubSubMessage) {}) }); err != e { @@ -6897,7 +6897,7 @@ func TestClusterClient_SendReadOperationToReplicaNodeWriteOperationToPrimaryNode closed = true }, } - primaryNodeConn.AcquireFn = func(_ context.Context) wire { + primaryNodeConn.AcquireFn = func() wire { return w } stored := false @@ -7005,7 +7005,7 @@ func TestClusterClient_SendReadOperationToReplicaNodeWriteOperationToPrimaryNode closed = true }, } - primaryNodeConn.AcquireFn = func(_ context.Context) wire { + primaryNodeConn.AcquireFn = func() wire { return w } stored := false @@ -7284,7 +7284,7 @@ func TestClusterClient_SendToOnlyPrimaryNodeWhenPrimaryNodeSelected(t *testing.T t.Errorf("Dedicated should panic if cross slots is used") } }() - primaryNodeConn.AcquireFn = func(_ context.Context) wire { return &mockWire{} } + primaryNodeConn.AcquireFn = func() wire { return &mockWire{} } client.Dedicated(func(c DedicatedClient) error { c.Do(context.Background(), c.B().Get().Key("a").Build()).Error() return c.Do(context.Background(), c.B().Get().Key("b").Build()).Error() @@ -7297,7 +7297,7 @@ func TestClusterClient_SendToOnlyPrimaryNodeWhenPrimaryNodeSelected(t *testing.T t.Errorf("Dedicated should panic if cross slots is used") } }() - primaryNodeConn.AcquireFn = func(_ context.Context) wire { + primaryNodeConn.AcquireFn = func() wire { return &mockWire{ DoMultiFn: func(multi ...Completed) *redisresults { return &redisresults{s: []RedisResult{ @@ -7326,7 +7326,7 @@ func TestClusterClient_SendToOnlyPrimaryNodeWhenPrimaryNodeSelected(t *testing.T }) t.Run("Dedicated Multi Cross Slot Err", func(t *testing.T) { - primaryNodeConn.AcquireFn = func(_ context.Context) wire { return &mockWire{} } + primaryNodeConn.AcquireFn = func() wire { return &mockWire{} } err := client.Dedicated(func(c DedicatedClient) (err error) { defer func() { err = errors.New(recover().(string)) @@ -7350,7 +7350,7 @@ func TestClusterClient_SendToOnlyPrimaryNodeWhenPrimaryNodeSelected(t *testing.T return e }, } - primaryNodeConn.AcquireFn = func(_ context.Context) wire { + primaryNodeConn.AcquireFn = func() wire { return w } if err := client.Dedicated(func(c DedicatedClient) error { @@ -7399,7 +7399,7 @@ func TestClusterClient_SendToOnlyPrimaryNodeWhenPrimaryNodeSelected(t *testing.T closed = true }, } - primaryNodeConn.AcquireFn = func(_ context.Context) wire { + primaryNodeConn.AcquireFn = func() wire { return w } stored := false @@ -7507,7 +7507,7 @@ func TestClusterClient_SendToOnlyPrimaryNodeWhenPrimaryNodeSelected(t *testing.T closed = true }, } - primaryNodeConn.AcquireFn = func(_ context.Context) wire { + primaryNodeConn.AcquireFn = func() wire { return w } stored := false diff --git a/sentinel_test.go b/sentinel_test.go index 8b4d2810..9555e220 100644 --- a/sentinel_test.go +++ b/sentinel_test.go @@ -927,7 +927,7 @@ func TestSentinelClientDelegate(t *testing.T) { return ErrClosing }, } - m.AcquireFn = func(_ context.Context) wire { + m.AcquireFn = func() wire { return w } stored := false @@ -976,7 +976,7 @@ func TestSentinelClientDelegate(t *testing.T) { return ErrClosing }, } - m.AcquireFn = func(_ context.Context) wire { + m.AcquireFn = func() wire { return w } stored := false @@ -1709,7 +1709,7 @@ func TestSentinelClientLoadingRetry(t *testing.T) { } return newResult(RedisMessage{typ: '+', string: "OK"}, nil) } - m1.AcquireFn = func(_ context.Context) wire { return &mockWire{DoFn: m1.DoFn} } + m1.AcquireFn = func() wire { return &mockWire{DoFn: m1.DoFn} } err := client.Dedicated(func(c DedicatedClient) error { if v, err := c.Do(context.Background(), c.B().Get().Key("test").Build()).ToString(); err != nil || v != "OK" { @@ -1732,7 +1732,7 @@ func TestSentinelClientLoadingRetry(t *testing.T) { } return &redisresults{s: []RedisResult{newResult(RedisMessage{typ: '+', string: "OK"}, nil)}} } - m1.AcquireFn = func(_ context.Context) wire { return &mockWire{DoMultiFn: m1.DoMultiFn} } + m1.AcquireFn = func() wire { return &mockWire{DoMultiFn: m1.DoMultiFn} } err := client.Dedicated(func(c DedicatedClient) error { resps := c.DoMulti(context.Background(), c.B().Get().Key("test").Build())