diff --git a/client.go b/client.go index ad69c360..be9230ab 100644 --- a/client.go +++ b/client.go @@ -172,7 +172,7 @@ retry: } func (c *singleClient) Dedicated(fn func(DedicatedClient) error) (err error) { - wire := c.conn.Acquire() + wire := c.conn.Acquire(context.Background()) dsc := &dedicatedSingleClient{cmd: c.cmd, conn: c.conn, wire: wire, retry: c.retry, retryHandler: c.retryHandler} err = fn(dsc) dsc.release() @@ -180,7 +180,7 @@ func (c *singleClient) Dedicated(fn func(DedicatedClient) error) (err error) { } func (c *singleClient) Dedicate() (DedicatedClient, func()) { - wire := c.conn.Acquire() + wire := c.conn.Acquire(context.Background()) dsc := &dedicatedSingleClient{cmd: c.cmd, conn: c.conn, wire: wire, retry: c.retry, retryHandler: c.retryHandler} return dsc, dsc.release } diff --git a/client_test.go b/client_test.go index 23a34aba..6a9559f8 100644 --- a/client_test.go +++ b/client_test.go @@ -49,7 +49,7 @@ func (m *mockConn) Dial() error { return nil } -func (m *mockConn) Acquire() wire { +func (m *mockConn) Acquire(ctx context.Context) wire { if m.AcquireFn != nil { return m.AcquireFn() } diff --git a/cluster.go b/cluster.go index 2ea2db08..84ca522b 100644 --- a/cluster.go +++ b/cluster.go @@ -1274,7 +1274,7 @@ func (c *dedicatedClusterClient) acquire(ctx context.Context, slot uint16) (wire } return nil, err } - c.wire = c.conn.Acquire() + c.wire = c.conn.Acquire(ctx) if p := c.pshks; p != nil { c.pshks = nil ch := c.wire.SetPubSubHooks(p.hooks) diff --git a/mux.go b/mux.go index dcc680a4..5da2598d 100644 --- a/mux.go +++ b/mux.go @@ -13,8 +13,8 @@ import ( ) type connFn func(dst string, opt *ClientOption) conn -type dialFn func(dst string, opt *ClientOption) (net.Conn, error) -type wireFn func() wire +type dialFn func(ctx context.Context, dst string, opt *ClientOption) (net.Conn, error) +type wireFn func(ctx context.Context) wire type singleconnect struct { w wire @@ -37,7 +37,7 @@ type conn interface { Close() Dial() error Override(conn) - Acquire() wire + Acquire(ctx context.Context) wire Store(w wire) Addr() string SetOnCloseHook(func(error)) @@ -64,12 +64,12 @@ type mux struct { func makeMux(dst string, option *ClientOption, dialFn dialFn) *mux { dead := deadFn() - connFn := func() (net.Conn, error) { - return dialFn(dst, option) + connFn := func(ctx context.Context) (net.Conn, error) { + return dialFn(ctx, dst, option) } - wireFn := func(pipeFn pipeFn) func() wire { - return func() (w wire) { - w, err := pipeFn(connFn, option) + wireFn := func(pipeFn pipeFn) func(context.Context) wire { + return func(ctx context.Context) (w wire) { + w, err := pipeFn(ctx, connFn, option) if err != nil { dead.error.Store(&errs{error: err}) w = dead @@ -132,7 +132,7 @@ func (m *mux) Override(cc conn) { } } -func (m *mux) _pipe(i uint16) (w wire, err error) { +func (m *mux) _pipe(ctx context.Context, i uint16) (w wire, err error) { if w = m.wire[i].Load().(wire); w != m.init { return w, nil } @@ -151,7 +151,7 @@ func (m *mux) _pipe(i uint16) (w wire, err error) { } if w = m.wire[i].Load().(wire); w == m.init { - if w = m.wireFn(); w != m.dead { + if w = m.wireFn(ctx); w != m.dead { m.setCloseHookOnWire(i, w) m.wire[i].Store(w) } else { @@ -173,39 +173,39 @@ func (m *mux) _pipe(i uint16) (w wire, err error) { return w, err } -func (m *mux) pipe(i uint16) wire { - w, _ := m._pipe(i) +func (m *mux) pipe(ctx context.Context, i uint16) wire { + w, _ := m._pipe(ctx, i) return w // this should never be nil } func (m *mux) Dial() error { - _, err := m._pipe(0) + _, err := m._pipe(context.Background(), 0) return err } func (m *mux) Info() map[string]RedisMessage { - return m.pipe(0).Info() + return m.pipe(context.Background(), 0).Info() } func (m *mux) Version() int { - return m.pipe(0).Version() + return m.pipe(context.Background(), 0).Version() } func (m *mux) AZ() string { - return m.pipe(0).AZ() + return m.pipe(context.Background(), 0).AZ() } func (m *mux) Error() error { - return m.pipe(0).Error() + return m.pipe(context.Background(), 0).Error() } func (m *mux) DoStream(ctx context.Context, cmd Completed) RedisResultStream { - wire := m.spool.Acquire() + wire := m.spool.Acquire(ctx) return wire.DoStream(ctx, m.spool, cmd) } func (m *mux) DoMultiStream(ctx context.Context, multi ...Completed) MultiRedisResultStream { - wire := m.spool.Acquire() + wire := m.spool.Acquire(ctx) return wire.DoMultiStream(ctx, m.spool, multi...) } @@ -242,7 +242,7 @@ block: } func (m *mux) blocking(pool *pool, ctx context.Context, cmd Completed) (resp RedisResult) { - wire := pool.Acquire() + wire := pool.Acquire(ctx) resp = wire.Do(ctx, cmd) if resp.NonRedisError() != nil { // abort the wire if blocking command return early (ex. context.DeadlineExceeded) wire.Close() @@ -252,7 +252,7 @@ func (m *mux) blocking(pool *pool, ctx context.Context, cmd Completed) (resp Red } func (m *mux) blockingMulti(pool *pool, ctx context.Context, cmd []Completed) (resp *redisresults) { - wire := pool.Acquire() + wire := pool.Acquire(ctx) resp = wire.DoMulti(ctx, cmd...) for _, res := range resp.s { if res.NonRedisError() != nil { // abort the wire if blocking command return early (ex. context.DeadlineExceeded) @@ -266,7 +266,7 @@ func (m *mux) blockingMulti(pool *pool, ctx context.Context, cmd []Completed) (r func (m *mux) pipeline(ctx context.Context, cmd Completed) (resp RedisResult) { slot := slotfn(len(m.wire), cmd.Slot(), cmd.NoReply()) - wire := m.pipe(slot) + wire := m.pipe(ctx, slot) if resp = wire.Do(ctx, cmd); isBroken(resp.NonRedisError(), wire) { m.wire[slot].CompareAndSwap(wire, m.init) } @@ -275,7 +275,7 @@ func (m *mux) pipeline(ctx context.Context, cmd Completed) (resp RedisResult) { func (m *mux) pipelineMulti(ctx context.Context, cmd []Completed) (resp *redisresults) { slot := slotfn(len(m.wire), cmd[0].Slot(), cmd[0].NoReply()) - wire := m.pipe(slot) + wire := m.pipe(ctx, slot) resp = wire.DoMulti(ctx, cmd...) for _, r := range resp.s { if isBroken(r.NonRedisError(), wire) { @@ -288,7 +288,7 @@ func (m *mux) pipelineMulti(ctx context.Context, cmd []Completed) (resp *redisre func (m *mux) DoCache(ctx context.Context, cmd Cacheable, ttl time.Duration) RedisResult { slot := cmd.Slot() & uint16(len(m.wire)-1) - wire := m.pipe(slot) + wire := m.pipe(ctx, slot) resp := wire.DoCache(ctx, cmd, ttl) if isBroken(resp.NonRedisError(), wire) { m.wire[slot].CompareAndSwap(wire, m.init) @@ -346,7 +346,7 @@ func (m *mux) DoMultiCache(ctx context.Context, multi ...CacheableTTL) (results } func (m *mux) doMultiCache(ctx context.Context, slot uint16, multi []CacheableTTL) (resps *redisresults) { - wire := m.pipe(slot) + wire := m.pipe(ctx, slot) resps = wire.DoMultiCache(ctx, multi...) for _, r := range resps.s { if isBroken(r.NonRedisError(), wire) { @@ -359,7 +359,7 @@ func (m *mux) doMultiCache(ctx context.Context, slot uint16, multi []CacheableTT func (m *mux) Receive(ctx context.Context, subscribe Completed, fn func(message PubSubMessage)) error { slot := slotfn(len(m.wire), subscribe.Slot(), subscribe.NoReply()) - wire := m.pipe(slot) + wire := m.pipe(ctx, slot) err := wire.Receive(ctx, subscribe, fn) if isBroken(err, wire) { m.wire[slot].CompareAndSwap(wire, m.init) @@ -367,8 +367,8 @@ func (m *mux) Receive(ctx context.Context, subscribe Completed, fn func(message return err } -func (m *mux) Acquire() wire { - return m.dpool.Acquire() +func (m *mux) Acquire(ctx context.Context) wire { + return m.dpool.Acquire(ctx) } func (m *mux) Store(w wire) { diff --git a/mux_test.go b/mux_test.go index 4f801322..0df7113f 100644 --- a/mux_test.go +++ b/mux_test.go @@ -23,7 +23,7 @@ func setupMux(wires []*mockWire) (conn *mux, checkClean func(t *testing.T)) { func setupMuxWithOption(wires []*mockWire, option *ClientOption) (conn *mux, checkClean func(t *testing.T)) { var mu sync.Mutex var count = -1 - wfn := func() wire { + wfn := func(_ context.Context) wire { mu.Lock() defer mu.Unlock() count++ @@ -43,26 +43,54 @@ func TestNewMuxDailErr(t *testing.T) { defer ShouldNotLeaked(SetupLeakDetection()) c := 0 e := errors.New("any") - m := makeMux("", &ClientOption{}, func(dst string, opt *ClientOption) (net.Conn, error) { + m := makeMux("", &ClientOption{}, func(ctx context.Context, dst string, opt *ClientOption) (net.Conn, error) { + timer := time.NewTimer(time.Millisecond*10) // delay time + defer timer.Stop() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-timer.C: + // noop + } c++ return nil, e }) if err := m.Dial(); err != e { t.Fatalf("unexpected return %v", err) } + ctx1, cancel1 := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel1() + if _, err := m._pipe(ctx1, 0); err != context.DeadlineExceeded { + t.Fatalf("unexpected return %v", err) + } if c != 1 { t.Fatalf("dialFn not called") } - if w := m.pipe(0); w != m.dead { // c = 2 + if w := m.pipe(context.Background(), 0); w != m.dead { // c = 2 + t.Fatalf("unexpected wire %v", w) + } + ctx2, cancel2 := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel2() + if w := m.pipe(ctx2, 0); w != m.dead { t.Fatalf("unexpected wire %v", w) } if err := m.Dial(); err != e { // c = 3 t.Fatalf("unexpected return %v", err) } - if w := m.Acquire(); w != m.dead { + if w := m.Acquire(context.Background()); w != m.dead { + t.Fatalf("unexpected wire %v", w) + } + ctx3, cancel3 := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel3() + if w := m.Acquire(ctx3); w != m.dead { + t.Fatalf("unexpected wire %v", w) + } + ctx4, cancel4 := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancel4() + if w := m.Acquire(ctx4); w != m.dead { t.Fatalf("unexpected wire %v", w) } - if c != 4 { + if c != 5 { t.Fatalf("dialFn not called %v", c) } } @@ -89,7 +117,7 @@ func TestNewMux(t *testing.T) { mock.Expect("PING").ReplyString("OK") mock.Close() }() - m := makeMux("", &ClientOption{}, func(dst string, opt *ClientOption) (net.Conn, error) { + m := makeMux("", &ClientOption{}, func(_ context.Context, dst string, opt *ClientOption) (net.Conn, error) { return n1, nil }) if err := m.Dial(); err != nil { @@ -97,7 +125,7 @@ func TestNewMux(t *testing.T) { } t.Run("Override with previous mux", func(t *testing.T) { - m2 := makeMux("", &ClientOption{}, func(dst string, opt *ClientOption) (net.Conn, error) { + m2 := makeMux("", &ClientOption{}, func(_ context.Context, dst string, opt *ClientOption) (net.Conn, error) { return n1, nil }) m2.Override(m) @@ -111,7 +139,7 @@ func TestNewMux(t *testing.T) { func TestNewMuxPipelineMultiplex(t *testing.T) { defer ShouldNotLeaked(SetupLeakDetection()) for _, v := range []int{-1, 0, 1, 2} { - m := makeMux("", &ClientOption{PipelineMultiplex: v}, func(dst string, opt *ClientOption) (net.Conn, error) { return nil, nil }) + m := makeMux("", &ClientOption{PipelineMultiplex: v}, func(_ context.Context, dst string, opt *ClientOption) (net.Conn, error) { return nil, nil }) if (v < 0 && len(m.wire) != 1) || (v >= 0 && len(m.wire) != 1< 100 { t.Fatalf("pool must not exceed the size limit") @@ -103,11 +104,11 @@ func TestPool(t *testing.T) { conn := make([]wire, 100) pool, _ := setup(len(conn)) for i := 0; i < len(conn); i++ { - w := pool.Acquire() + w := pool.Acquire(context.Background()) go pool.Store(w) } for i := 0; i < len(conn); i++ { - conn[i] = pool.Acquire() + conn[i] = pool.Acquire(context.Background()) } for i := 0; i < len(conn); i++ { for j := i + 1; j < len(conn); j++ { @@ -120,8 +121,8 @@ func TestPool(t *testing.T) { t.Run("Close", func(t *testing.T) { pool, count := setup(2) - w1 := pool.Acquire() - w2 := pool.Acquire() + w1 := pool.Acquire(context.Background()) + w2 := pool.Acquire(context.Background()) if w1.Error() != nil { t.Fatalf("unexpected err %v", w1.Error()) } @@ -137,7 +138,7 @@ func TestPool(t *testing.T) { t.Fatalf("pool does not close existing wire after Close()") } for i := 0; i < 100; i++ { - if rw := pool.Acquire(); rw != dead { + if rw := pool.Acquire(context.Background()); rw != dead { t.Fatalf("pool does not return the dead wire after Close()") } } @@ -149,12 +150,12 @@ func TestPool(t *testing.T) { t.Run("Close Empty", func(t *testing.T) { pool, count := setup(2) - w1 := pool.Acquire() + w1 := pool.Acquire(context.Background()) if w1.Error() != nil { t.Fatalf("unexpected err %v", w1.Error()) } pool.Close() - w2 := pool.Acquire() + w2 := pool.Acquire(context.Background()) if w2.Error() != ErrClosing { t.Fatalf("pool does not close wire after Close()") } @@ -162,7 +163,7 @@ func TestPool(t *testing.T) { t.Fatalf("pool should not make new wire") } for i := 0; i < 100; i++ { - if rw := pool.Acquire(); rw != dead { + if rw := pool.Acquire(context.Background()); rw != dead { t.Fatalf("pool does not return the dead wire after Close()") } } @@ -174,7 +175,7 @@ func TestPool(t *testing.T) { t.Run("Close Waiting", func(t *testing.T) { pool, count := setup(1) - w1 := pool.Acquire() + w1 := pool.Acquire(context.Background()) if w1.Error() != nil { t.Fatalf("unexpected err %v", w1.Error()) } @@ -182,7 +183,7 @@ func TestPool(t *testing.T) { for i := 0; i < 100; i++ { go func() { atomic.AddInt64(&pending, 1) - if rw := pool.Acquire(); rw != dead { + if rw := pool.Acquire(context.Background()); rw != dead { t.Errorf("pool does not return the dead wire after Close()") } atomic.AddInt64(&pending, -1) @@ -209,7 +210,7 @@ func TestPoolError(t *testing.T) { defer ShouldNotLeaked(SetupLeakDetection()) setup := func(size int) (*pool, *int32) { var count int32 - return newPool(size, dead, 0, 0, func() wire { + return newPool(size, dead, 0, 0, func(_ context.Context) wire { w := &pipe{} w.pshks.Store(emptypshks) c := atomic.AddInt32(&count, 1) @@ -224,7 +225,7 @@ func TestPoolError(t *testing.T) { conn := make([]wire, 100) pool, count := setup(len(conn)) for i := 0; i < len(conn); i++ { - conn[i] = pool.Acquire() + conn[i] = pool.Acquire(context.Background()) } if atomic.LoadInt32(count) != int32(len(conn)) { t.Fatalf("unexpected acquire count") @@ -233,7 +234,7 @@ func TestPoolError(t *testing.T) { pool.Store(conn[i]) } for i := 0; i < len(conn); i++ { - conn[i] = pool.Acquire() + conn[i] = pool.Acquire(context.Background()) } if atomic.LoadInt32(count) != int32(len(conn)+len(conn)/2) { t.Fatalf("unexpected acquire count") @@ -244,7 +245,7 @@ func TestPoolError(t *testing.T) { func TestPoolWithIdleTTL(t *testing.T) { defer ShouldNotLeaked(SetupLeakDetection()) setup := func(size int, ttl time.Duration, minSize int) *pool { - return newPool(size, dead, ttl, minSize, func() wire { + return newPool(size, dead, ttl, minSize, func(_ context.Context) wire { closed := false return &mockWire{ CloseFn: func() { @@ -267,7 +268,7 @@ func TestPoolWithIdleTTL(t *testing.T) { for i := 0; i < 2; i++ { for i := range conns { - w := p.Acquire() + w := p.Acquire(context.Background()) conns[i] = w } @@ -301,7 +302,7 @@ func TestPoolWithIdleTTL(t *testing.T) { for i := 0; i < 2; i++ { for i := range conns { - w := p.Acquire() + w := p.Acquire(context.Background()) conns[i] = w } @@ -329,3 +330,122 @@ func TestPoolWithIdleTTL(t *testing.T) { p.Close() }) } + +func TestPoolWithAcquireCtx(t *testing.T) { + defer ShouldNotLeaked(SetupLeakDetection()) + setup := func(size int, delay time.Duration) *pool { + return newPool(size, dead, 0, 0, func(ctx context.Context) wire { + var err error + closed := false + timer := time.NewTimer(delay) + defer timer.Stop() + select { + case <-ctx.Done(): + err = ctx.Err() + closed = true + case <-timer.C: + // noop + } + + return &mockWire{ + CloseFn: func() { + closed = true + }, + ErrorFn: func() error { + if err != nil { + return err + } else if closed { + return ErrClosing + } + return nil + }, + } + }) + } + t.Run("Acquire connections, all exceed context deadline", func(t *testing.T) { + p := setup(10, time.Millisecond*5) + conns := make([]wire, 10) + + for i := 0; i < 2; i++ { + for i := range conns { + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + w := p.Acquire(ctx) + conns[i] = w + cancel() + } + + for _, w := range conns { + p.Store(w) + } + + p.cond.L.Lock() + if p.size != 0 { + defer p.cond.L.Unlock() + t.Fatalf("size must be equal to 0, actual: %d", p.size) + } + + if len(p.list) != 0 { + defer p.cond.L.Unlock() + t.Fatalf("pool len must equal to 0, actual: %d", len(p.list)) + } + p.cond.L.Unlock() + } + + p.Close() + }) + + t.Run("Acquire connections, some exceed context deadline", func(t *testing.T) { + p := setup(10, time.Millisecond*5) + conns := make([]wire, 10) + + // size = 5 + for i := range conns { + d := time.Millisecond + if i % 2 == 0 { + d = time.Millisecond * 8 + } + ctx, cancel := context.WithTimeout(context.Background(), d) + w := p.Acquire(ctx) + conns[i] = w + cancel() + } + for _, w := range conns { + p.Store(w) + } + p.cond.L.Lock() + if p.size != len(conns)/2 { + defer p.cond.L.Unlock() + t.Fatalf("size must be equal to %d, actual: %d", len(conns)/2, p.size) + } + + if len(p.list) != len(conns)/2 { + defer p.cond.L.Unlock() + t.Fatalf("pool len must equal to %d, actual: %d", len(conns)/2, len(p.list)) + } + p.cond.L.Unlock() + + // size = 10 + for i := range conns { + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond * 8) + w := p.Acquire(ctx) + conns[i] = w + cancel() + } + for _, w := range conns { + p.Store(w) + } + p.cond.L.Lock() + if p.size != len(conns) { + defer p.cond.L.Unlock() + t.Fatalf("size must be equal to %d, actual: %d", len(conns), p.size) + } + + if len(p.list) != len(conns) { + defer p.cond.L.Unlock() + t.Fatalf("pool len must equal to %d, actual: %d", len(conns), len(p.list)) + } + p.cond.L.Unlock() + + p.Close() + }) +} \ No newline at end of file diff --git a/rueidis.go b/rueidis.go index 5c22a07a..de0cf623 100644 --- a/rueidis.go +++ b/rueidis.go @@ -70,8 +70,12 @@ type ClientOption struct { TLSConfig *tls.Config // DialFn allows for a custom function to be used to create net.Conn connections + // Deprecated: use DialCtxFn instead. DialFn func(string, *net.Dialer, *tls.Config) (conn net.Conn, err error) + // DialCtxFn allows for a custom function to be used to create net.Conn connections + DialCtxFn func(context.Context, string, *net.Dialer, *tls.Config) (conn net.Conn, err error) + // NewCacheStoreFn allows a custom client side caching store for each connection NewCacheStoreFn NewCacheStoreFn @@ -464,14 +468,18 @@ func makeConn(dst string, opt *ClientOption) conn { return makeMux(dst, opt, dial) } -func dial(dst string, opt *ClientOption) (conn net.Conn, err error) { +func dial(ctx context.Context, dst string, opt *ClientOption) (conn net.Conn, err error) { + if opt.DialCtxFn != nil { + return opt.DialCtxFn(ctx, dst, &opt.Dialer, opt.TLSConfig) + } if opt.DialFn != nil { return opt.DialFn(dst, &opt.Dialer, opt.TLSConfig) } if opt.TLSConfig != nil { - conn, err = tls.DialWithDialer(&opt.Dialer, "tcp", dst, opt.TLSConfig) + dialer := tls.Dialer{NetDialer: &opt.Dialer, Config: opt.TLSConfig} + conn, err = dialer.DialContext(ctx, "tcp", dst) } else { - conn, err = opt.Dialer.Dial("tcp", dst) + conn, err = opt.Dialer.DialContext(ctx, "tcp", dst) } return conn, err } diff --git a/rueidis_test.go b/rueidis_test.go index 4da05a95..2d0e7582 100644 --- a/rueidis_test.go +++ b/rueidis_test.go @@ -407,6 +407,27 @@ func TestCustomDialFnIsCalled(t *testing.T) { } } +func TestCustomDialCtxFnIsCalled(t *testing.T) { + defer ShouldNotLeaked(SetupLeakDetection()) + isFnCalled := false + option := ClientOption{ + InitAddress: []string{"127.0.0.1:0"}, + DialCtxFn: func(ctx context.Context, s string, dialer *net.Dialer, config *tls.Config) (conn net.Conn, err error) { + isFnCalled = true + return nil, errors.New("dial error") + }, + } + + _, err := NewClient(option) + + if !isFnCalled { + t.Fatalf("excepted ClientOption.DialFn to be called") + } + if err == nil { + t.Fatalf("expected dial error") + } +} + func ExampleIsRedisNil() { client, err := NewClient(ClientOption{InitAddress: []string{"127.0.0.1:6379"}}) if err != nil { diff --git a/rueidisotel/metrics.go b/rueidisotel/metrics.go index 9176162e..75cfcd5e 100644 --- a/rueidisotel/metrics.go +++ b/rueidisotel/metrics.go @@ -70,8 +70,13 @@ func NewClient(clientOption rueidis.ClientOption, opts ...Option) (rueidis.Clien return nil, err } - if clientOption.DialFn == nil { - clientOption.DialFn = defaultDialFn + if clientOption.DialCtxFn == nil { + clientOption.DialCtxFn = defaultDialFn + if clientOption.DialFn != nil { + clientOption.DialCtxFn = func(_ context.Context, s string, dialer *net.Dialer, config *tls.Config) (conn net.Conn, err error) { + return clientOption.DialFn(s, dialer, config) + } + } } metrics := dialMetrics{ @@ -103,7 +108,8 @@ func NewClient(clientOption rueidis.ClientOption, opts ...Option) (rueidis.Clien return nil, err } - clientOption.DialFn = trackDialing(metrics, clientOption.DialFn) + clientOption.DialCtxFn = trackDialing(metrics, clientOption.DialCtxFn) + cli, err := rueidis.NewClient(clientOption) if err != nil { return nil, err @@ -146,14 +152,13 @@ func newClient(opts ...Option) (*otelclient, error) { return cli, nil } -func trackDialing(m dialMetrics, dialFn func(string, *net.Dialer, *tls.Config) (conn net.Conn, err error)) func(string, *net.Dialer, *tls.Config) (conn net.Conn, err error) { - return func(network string, dialer *net.Dialer, tlsConfig *tls.Config) (conn net.Conn, err error) { - ctx := context.Background() +func trackDialing(m dialMetrics, dialFn func(context.Context, string, *net.Dialer, *tls.Config) (conn net.Conn, err error)) func(context.Context, string, *net.Dialer, *tls.Config) (conn net.Conn, err error) { + return func(ctx context.Context, network string, dialer *net.Dialer, tlsConfig *tls.Config) (conn net.Conn, err error) { m.attempt.Add(ctx, 1, m.addOpts...) start := time.Now() - conn, err = dialFn(network, dialer, tlsConfig) + conn, err = dialFn(ctx, network, dialer, tlsConfig) if err != nil { return nil, err } @@ -187,9 +192,10 @@ func (t *connTracker) Close() error { return t.Conn.Close() } -func defaultDialFn(dst string, dialer *net.Dialer, cfg *tls.Config) (conn net.Conn, err error) { +func defaultDialFn(ctx context.Context, dst string, dialer *net.Dialer, cfg *tls.Config) (conn net.Conn, err error) { if cfg != nil { - return tls.DialWithDialer(dialer, "tcp", dst, cfg) + td := tls.Dialer{NetDialer: dialer, Config: cfg} + return td.DialContext(ctx, "tcp", dst) } - return dialer.Dial("tcp", dst) + return dialer.DialContext(ctx, "tcp", dst) } diff --git a/rueidisotel/metrics_test.go b/rueidisotel/metrics_test.go index c2be606c..3766ce5c 100644 --- a/rueidisotel/metrics_test.go +++ b/rueidisotel/metrics_test.go @@ -15,7 +15,7 @@ import ( ) func TestNewClient(t *testing.T) { - t.Run("client option only", func(t *testing.T) { + t.Run("client option only (no ctx)", func(t *testing.T) { c, err := NewClient(rueidis.ClientOption{ InitAddress: []string{"127.0.0.1:6379"}, DialFn: func(dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) { @@ -28,14 +28,27 @@ func TestNewClient(t *testing.T) { defer c.Close() }) + t.Run("client option only", func(t *testing.T) { + c, err := NewClient(rueidis.ClientOption{ + InitAddress: []string{"127.0.0.1:6379"}, + DialCtxFn: func(ctx context.Context, dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) { + return dialer.DialContext(ctx, "tcp", dst) + }, + }) + if err != nil { + t.Fatal(err) + } + defer c.Close() + }) + t.Run("meter provider", func(t *testing.T) { mr := metric.NewManualReader() meterProvider := metric.NewMeterProvider(metric.WithReader(mr)) c, err := NewClient( rueidis.ClientOption{ InitAddress: []string{"127.0.0.1:6379"}, - DialFn: func(dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) { - return dialer.Dial("tcp", dst) + DialCtxFn: func(ctx context.Context, dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) { + return dialer.DialContext(ctx, "tcp", dst) }, }, WithMeterProvider(meterProvider), @@ -50,8 +63,8 @@ func TestNewClient(t *testing.T) { c, err := NewClient( rueidis.ClientOption{ InitAddress: []string{"127.0.0.1:6379"}, - DialFn: func(dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) { - return dialer.Dial("tcp", dst) + DialCtxFn: func(ctx context.Context, dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) { + return dialer.DialContext(ctx, "tcp", dst) }, }, WithHistogramOption(HistogramOption{ @@ -79,8 +92,8 @@ func TestNewClientError(t *testing.T) { t.Run("invalid client option", func(t *testing.T) { _, err := NewClient(rueidis.ClientOption{ InitAddress: []string{""}, - DialFn: func(dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) { - return dialer.Dial("tcp", dst) + DialCtxFn: func(ctx context.Context, dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) { + return dialer.DialContext(ctx, "tcp", dst) }, }) if err == nil { @@ -120,8 +133,8 @@ func TestTrackDialing(t *testing.T) { c, err := NewClient( rueidis.ClientOption{ InitAddress: []string{"127.0.0.1:6379"}, - DialFn: func(dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) { - return dialer.Dial("tcp", dst) + DialCtxFn: func(ctx context.Context, dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) { + return dialer.DialContext(ctx, "tcp", dst) }, }, WithMeterProvider(meterProvider), @@ -169,8 +182,8 @@ func TestTrackDialing(t *testing.T) { c, err := NewClient( rueidis.ClientOption{ InitAddress: []string{"127.0.0.1:6379"}, - DialFn: func(dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) { - return dialer.Dial("tcp", dst) + DialCtxFn: func(ctx context.Context, dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) { + return dialer.DialContext(ctx, "tcp", dst) }, }, WithMeterProvider(meterProvider), @@ -198,8 +211,8 @@ func TestTrackDialing(t *testing.T) { _, err := NewClient( rueidis.ClientOption{ InitAddress: []string{""}, - DialFn: func(dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) { - return dialer.Dial("tcp", dst) + DialCtxFn: func(ctx context.Context, dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) { + return dialer.DialContext(ctx, "tcp", dst) }, }, WithMeterProvider(meterProvider), diff --git a/sentinel.go b/sentinel.go index 93935841..fac5a028 100644 --- a/sentinel.go +++ b/sentinel.go @@ -198,7 +198,7 @@ func (c *sentinelClient) DoMultiStream(ctx context.Context, multi ...Completed) func (c *sentinelClient) Dedicated(fn func(DedicatedClient) error) (err error) { master := c.mConn.Load().(conn) - wire := master.Acquire() + wire := master.Acquire(context.Background()) dsc := &dedicatedSingleClient{cmd: c.cmd, conn: master, wire: wire, retry: c.retry, retryHandler: c.retryHandler} err = fn(dsc) dsc.release() @@ -207,7 +207,7 @@ func (c *sentinelClient) Dedicated(fn func(DedicatedClient) error) (err error) { func (c *sentinelClient) Dedicate() (DedicatedClient, func()) { master := c.mConn.Load().(conn) - wire := master.Acquire() + wire := master.Acquire(context.Background()) dsc := &dedicatedSingleClient{cmd: c.cmd, conn: master, wire: wire, retry: c.retry, retryHandler: c.retryHandler} return dsc, dsc.release } diff --git a/url.go b/url.go index dad0fc00..416402bb 100644 --- a/url.go +++ b/url.go @@ -1,6 +1,7 @@ package rueidis import ( + "context" "crypto/tls" "fmt" "net" @@ -37,8 +38,8 @@ func ParseURL(str string) (opt ClientOption, err error) { } switch u.Scheme { case "unix": - opt.DialFn = func(s string, dialer *net.Dialer, config *tls.Config) (conn net.Conn, err error) { - return dialer.Dial("unix", s) + opt.DialCtxFn = func(ctx context.Context, s string, dialer *net.Dialer, config *tls.Config) (conn net.Conn, err error) { + return dialer.DialContext(ctx, "unix", s) } opt.InitAddress = []string{strings.TrimSpace(u.Path)} case "rediss", "valkeys": diff --git a/url_test.go b/url_test.go index 45e3eed8..aa207b3a 100644 --- a/url_test.go +++ b/url_test.go @@ -1,6 +1,7 @@ package rueidis import ( + "context" "strings" "testing" ) @@ -18,7 +19,7 @@ func TestParseURL(t *testing.T) { if opt, err := ParseURL("valkeys://"); err != nil || opt.TLSConfig == nil { t.Fatalf("unexpected %v %v", opt, err) } - if opt, err := ParseURL("unix://"); err != nil || opt.DialFn == nil { + if opt, err := ParseURL("unix://"); err != nil || opt.DialCtxFn == nil { t.Fatalf("unexpected %v %v", opt, err) } if opt, err := ParseURL("valkey://"); err != nil { @@ -84,7 +85,7 @@ func TestParseURL(t *testing.T) { if opt, err := ParseURL("rediss://myhost:6379"); err != nil || opt.TLSConfig.ServerName != "myhost" { t.Fatalf("unexpected %v %v", opt, err) } - if opt, err := ParseURL("unix:///path/to/redis.sock?db=1"); opt.DialFn == nil || opt.InitAddress[0] != "/path/to/redis.sock" || opt.SelectDB != 1 { + if opt, err := ParseURL("unix:///path/to/redis.sock?db=1"); opt.DialCtxFn == nil || opt.InitAddress[0] != "/path/to/redis.sock" || opt.SelectDB != 1 { t.Fatalf("unexpected %v %v", opt, err) } } @@ -100,7 +101,7 @@ func TestMustParseURL(t *testing.T) { func TestMustParseURLUnix(t *testing.T) { opt := MustParseURL("unix://") - if conn, err := opt.DialFn("", &opt.Dialer, nil); !strings.Contains(err.Error(), "unix") { + if conn, err := opt.DialCtxFn(context.Background(), "", &opt.Dialer, nil); !strings.Contains(err.Error(), "unix") { t.Fatalf("unexpected %v %v", conn, err) // the error should be "dial unix: missing address" } }