Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: respect request ctx for making tcp connections and add DialCtxFn option #803

Merged
merged 7 commits into from
Mar 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,15 +172,15 @@ retry:
}

func (c *singleClient) Dedicated(fn func(DedicatedClient) error) (err error) {
wire := c.conn.Acquire()
wire := c.conn.Acquire(context.Background())
dsc := &dedicatedSingleClient{cmd: c.cmd, conn: c.conn, wire: wire, retry: c.retry, retryHandler: c.retryHandler}
err = fn(dsc)
dsc.release()
return err
}

func (c *singleClient) Dedicate() (DedicatedClient, func()) {
wire := c.conn.Acquire()
wire := c.conn.Acquire(context.Background())
dsc := &dedicatedSingleClient{cmd: c.cmd, conn: c.conn, wire: wire, retry: c.retry, retryHandler: c.retryHandler}
return dsc, dsc.release
}
Expand Down
2 changes: 1 addition & 1 deletion client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (m *mockConn) Dial() error {
return nil
}

func (m *mockConn) Acquire() wire {
func (m *mockConn) Acquire(ctx context.Context) wire {
if m.AcquireFn != nil {
return m.AcquireFn()
}
Expand Down
2 changes: 1 addition & 1 deletion cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -1274,7 +1274,7 @@ func (c *dedicatedClusterClient) acquire(ctx context.Context, slot uint16) (wire
}
return nil, err
}
c.wire = c.conn.Acquire()
c.wire = c.conn.Acquire(ctx)
if p := c.pshks; p != nil {
c.pshks = nil
ch := c.wire.SetPubSubHooks(p.hooks)
Expand Down
56 changes: 28 additions & 28 deletions mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ import (
)

type connFn func(dst string, opt *ClientOption) conn
type dialFn func(dst string, opt *ClientOption) (net.Conn, error)
type wireFn func() wire
type dialFn func(ctx context.Context, dst string, opt *ClientOption) (net.Conn, error)
type wireFn func(ctx context.Context) wire

type singleconnect struct {
w wire
Expand All @@ -37,7 +37,7 @@ type conn interface {
Close()
Dial() error
Override(conn)
Acquire() wire
Acquire(ctx context.Context) wire
Store(w wire)
Addr() string
SetOnCloseHook(func(error))
Expand All @@ -64,12 +64,12 @@ type mux struct {

func makeMux(dst string, option *ClientOption, dialFn dialFn) *mux {
dead := deadFn()
connFn := func() (net.Conn, error) {
return dialFn(dst, option)
connFn := func(ctx context.Context) (net.Conn, error) {
return dialFn(ctx, dst, option)
}
wireFn := func(pipeFn pipeFn) func() wire {
return func() (w wire) {
w, err := pipeFn(connFn, option)
wireFn := func(pipeFn pipeFn) func(context.Context) wire {
return func(ctx context.Context) (w wire) {
w, err := pipeFn(ctx, connFn, option)
if err != nil {
dead.error.Store(&errs{error: err})
w = dead
Expand Down Expand Up @@ -132,7 +132,7 @@ func (m *mux) Override(cc conn) {
}
}

func (m *mux) _pipe(i uint16) (w wire, err error) {
func (m *mux) _pipe(ctx context.Context, i uint16) (w wire, err error) {
if w = m.wire[i].Load().(wire); w != m.init {
return w, nil
}
Expand All @@ -151,7 +151,7 @@ func (m *mux) _pipe(i uint16) (w wire, err error) {
}

if w = m.wire[i].Load().(wire); w == m.init {
if w = m.wireFn(); w != m.dead {
if w = m.wireFn(ctx); w != m.dead {
m.setCloseHookOnWire(i, w)
m.wire[i].Store(w)
} else {
Expand All @@ -173,39 +173,39 @@ func (m *mux) _pipe(i uint16) (w wire, err error) {
return w, err
}

func (m *mux) pipe(i uint16) wire {
w, _ := m._pipe(i)
func (m *mux) pipe(ctx context.Context, i uint16) wire {
w, _ := m._pipe(ctx, i)
return w // this should never be nil
}

func (m *mux) Dial() error {
_, err := m._pipe(0)
_, err := m._pipe(context.Background(), 0)
return err
}

func (m *mux) Info() map[string]RedisMessage {
return m.pipe(0).Info()
return m.pipe(context.Background(), 0).Info()
}

func (m *mux) Version() int {
return m.pipe(0).Version()
return m.pipe(context.Background(), 0).Version()
}

func (m *mux) AZ() string {
return m.pipe(0).AZ()
return m.pipe(context.Background(), 0).AZ()
}

func (m *mux) Error() error {
return m.pipe(0).Error()
return m.pipe(context.Background(), 0).Error()
}

func (m *mux) DoStream(ctx context.Context, cmd Completed) RedisResultStream {
wire := m.spool.Acquire()
wire := m.spool.Acquire(ctx)
return wire.DoStream(ctx, m.spool, cmd)
}

func (m *mux) DoMultiStream(ctx context.Context, multi ...Completed) MultiRedisResultStream {
wire := m.spool.Acquire()
wire := m.spool.Acquire(ctx)
return wire.DoMultiStream(ctx, m.spool, multi...)
}

Expand Down Expand Up @@ -242,7 +242,7 @@ block:
}

func (m *mux) blocking(pool *pool, ctx context.Context, cmd Completed) (resp RedisResult) {
wire := pool.Acquire()
wire := pool.Acquire(ctx)
resp = wire.Do(ctx, cmd)
if resp.NonRedisError() != nil { // abort the wire if blocking command return early (ex. context.DeadlineExceeded)
wire.Close()
Expand All @@ -252,7 +252,7 @@ func (m *mux) blocking(pool *pool, ctx context.Context, cmd Completed) (resp Red
}

func (m *mux) blockingMulti(pool *pool, ctx context.Context, cmd []Completed) (resp *redisresults) {
wire := pool.Acquire()
wire := pool.Acquire(ctx)
resp = wire.DoMulti(ctx, cmd...)
for _, res := range resp.s {
if res.NonRedisError() != nil { // abort the wire if blocking command return early (ex. context.DeadlineExceeded)
Expand All @@ -266,7 +266,7 @@ func (m *mux) blockingMulti(pool *pool, ctx context.Context, cmd []Completed) (r

func (m *mux) pipeline(ctx context.Context, cmd Completed) (resp RedisResult) {
slot := slotfn(len(m.wire), cmd.Slot(), cmd.NoReply())
wire := m.pipe(slot)
wire := m.pipe(ctx, slot)
if resp = wire.Do(ctx, cmd); isBroken(resp.NonRedisError(), wire) {
m.wire[slot].CompareAndSwap(wire, m.init)
}
Expand All @@ -275,7 +275,7 @@ func (m *mux) pipeline(ctx context.Context, cmd Completed) (resp RedisResult) {

func (m *mux) pipelineMulti(ctx context.Context, cmd []Completed) (resp *redisresults) {
slot := slotfn(len(m.wire), cmd[0].Slot(), cmd[0].NoReply())
wire := m.pipe(slot)
wire := m.pipe(ctx, slot)
resp = wire.DoMulti(ctx, cmd...)
for _, r := range resp.s {
if isBroken(r.NonRedisError(), wire) {
Expand All @@ -288,7 +288,7 @@ func (m *mux) pipelineMulti(ctx context.Context, cmd []Completed) (resp *redisre

func (m *mux) DoCache(ctx context.Context, cmd Cacheable, ttl time.Duration) RedisResult {
slot := cmd.Slot() & uint16(len(m.wire)-1)
wire := m.pipe(slot)
wire := m.pipe(ctx, slot)
resp := wire.DoCache(ctx, cmd, ttl)
if isBroken(resp.NonRedisError(), wire) {
m.wire[slot].CompareAndSwap(wire, m.init)
Expand Down Expand Up @@ -346,7 +346,7 @@ func (m *mux) DoMultiCache(ctx context.Context, multi ...CacheableTTL) (results
}

func (m *mux) doMultiCache(ctx context.Context, slot uint16, multi []CacheableTTL) (resps *redisresults) {
wire := m.pipe(slot)
wire := m.pipe(ctx, slot)
resps = wire.DoMultiCache(ctx, multi...)
for _, r := range resps.s {
if isBroken(r.NonRedisError(), wire) {
Expand All @@ -359,16 +359,16 @@ func (m *mux) doMultiCache(ctx context.Context, slot uint16, multi []CacheableTT

func (m *mux) Receive(ctx context.Context, subscribe Completed, fn func(message PubSubMessage)) error {
slot := slotfn(len(m.wire), subscribe.Slot(), subscribe.NoReply())
wire := m.pipe(slot)
wire := m.pipe(ctx, slot)
err := wire.Receive(ctx, subscribe, fn)
if isBroken(err, wire) {
m.wire[slot].CompareAndSwap(wire, m.init)
}
return err
}

func (m *mux) Acquire() wire {
return m.dpool.Acquire()
func (m *mux) Acquire(ctx context.Context) wire {
return m.dpool.Acquire(ctx)
}

func (m *mux) Store(w wire) {
Expand Down
64 changes: 46 additions & 18 deletions mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func setupMux(wires []*mockWire) (conn *mux, checkClean func(t *testing.T)) {
func setupMuxWithOption(wires []*mockWire, option *ClientOption) (conn *mux, checkClean func(t *testing.T)) {
var mu sync.Mutex
var count = -1
wfn := func() wire {
wfn := func(_ context.Context) wire {
mu.Lock()
defer mu.Unlock()
count++
Expand All @@ -43,26 +43,54 @@ func TestNewMuxDailErr(t *testing.T) {
defer ShouldNotLeaked(SetupLeakDetection())
c := 0
e := errors.New("any")
m := makeMux("", &ClientOption{}, func(dst string, opt *ClientOption) (net.Conn, error) {
m := makeMux("", &ClientOption{}, func(ctx context.Context, dst string, opt *ClientOption) (net.Conn, error) {
timer := time.NewTimer(time.Millisecond*10) // delay time
defer timer.Stop()
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-timer.C:
// noop
}
c++
return nil, e
})
if err := m.Dial(); err != e {
t.Fatalf("unexpected return %v", err)
}
ctx1, cancel1 := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel1()
if _, err := m._pipe(ctx1, 0); err != context.DeadlineExceeded {
t.Fatalf("unexpected return %v", err)
}
if c != 1 {
t.Fatalf("dialFn not called")
}
if w := m.pipe(0); w != m.dead { // c = 2
if w := m.pipe(context.Background(), 0); w != m.dead { // c = 2
t.Fatalf("unexpected wire %v", w)
}
ctx2, cancel2 := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel2()
if w := m.pipe(ctx2, 0); w != m.dead {
t.Fatalf("unexpected wire %v", w)
}
if err := m.Dial(); err != e { // c = 3
t.Fatalf("unexpected return %v", err)
}
if w := m.Acquire(); w != m.dead {
if w := m.Acquire(context.Background()); w != m.dead {
t.Fatalf("unexpected wire %v", w)
}
ctx3, cancel3 := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel3()
if w := m.Acquire(ctx3); w != m.dead {
t.Fatalf("unexpected wire %v", w)
}
ctx4, cancel4 := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer cancel4()
if w := m.Acquire(ctx4); w != m.dead {
t.Fatalf("unexpected wire %v", w)
}
if c != 4 {
if c != 5 {
t.Fatalf("dialFn not called %v", c)
}
}
Expand All @@ -89,15 +117,15 @@ func TestNewMux(t *testing.T) {
mock.Expect("PING").ReplyString("OK")
mock.Close()
}()
m := makeMux("", &ClientOption{}, func(dst string, opt *ClientOption) (net.Conn, error) {
m := makeMux("", &ClientOption{}, func(_ context.Context, dst string, opt *ClientOption) (net.Conn, error) {
return n1, nil
})
if err := m.Dial(); err != nil {
t.Fatalf("unexpected error %v", err)
}

t.Run("Override with previous mux", func(t *testing.T) {
m2 := makeMux("", &ClientOption{}, func(dst string, opt *ClientOption) (net.Conn, error) {
m2 := makeMux("", &ClientOption{}, func(_ context.Context, dst string, opt *ClientOption) (net.Conn, error) {
return n1, nil
})
m2.Override(m)
Expand All @@ -111,7 +139,7 @@ func TestNewMux(t *testing.T) {
func TestNewMuxPipelineMultiplex(t *testing.T) {
defer ShouldNotLeaked(SetupLeakDetection())
for _, v := range []int{-1, 0, 1, 2} {
m := makeMux("", &ClientOption{PipelineMultiplex: v}, func(dst string, opt *ClientOption) (net.Conn, error) { return nil, nil })
m := makeMux("", &ClientOption{PipelineMultiplex: v}, func(_ context.Context, dst string, opt *ClientOption) (net.Conn, error) { return nil, nil })
if (v < 0 && len(m.wire) != 1) || (v >= 0 && len(m.wire) != 1<<v) {
t.Fatalf("unexpected len(m.wire): %v", len(m.wire))
}
Expand All @@ -130,11 +158,11 @@ func TestMuxDialSuppress(t *testing.T) {
defer ShouldNotLeaked(SetupLeakDetection())
var wires, waits, done int64
blocking := make(chan struct{})
m := newMux("", &ClientOption{}, (*mockWire)(nil), (*mockWire)(nil), func() wire {
m := newMux("", &ClientOption{}, (*mockWire)(nil), (*mockWire)(nil), func(_ context.Context) wire {
atomic.AddInt64(&wires, 1)
<-blocking
return &mockWire{}
}, func() wire {
}, func(_ context.Context) wire {
return &mockWire{}
})
for i := 0; i < 1000; i++ {
Expand Down Expand Up @@ -202,7 +230,7 @@ func TestMuxReuseWire(t *testing.T) {
t.Fatalf("unexpected dial error %v", err)
}

wire1 := m.dpool.Acquire()
wire1 := m.dpool.Acquire(context.Background())

go func() {
// this should use the second wire
Expand Down Expand Up @@ -256,7 +284,7 @@ func TestMuxReuseWire(t *testing.T) {
t.Fatalf("unexpected dial error %v", err)
}

wire1 := m.spool.Acquire()
wire1 := m.spool.Acquire(context.Background())

go func() {
// this should use the second wire
Expand Down Expand Up @@ -317,7 +345,7 @@ func TestMuxReuseWire(t *testing.T) {
t.Fatalf("unexpected dial error %v", err)
}

wire1 := m.spool.Acquire()
wire1 := m.spool.Acquire(context.Background())

go func() {
// this should use the second wire
Expand Down Expand Up @@ -374,7 +402,7 @@ func TestMuxReuseWire(t *testing.T) {
t.Fatalf("unexpected dial error %v", err)
}

wire1 := m.dpool.Acquire()
wire1 := m.dpool.Acquire(context.Background())

go func() {
// this should use the second wire
Expand Down Expand Up @@ -418,7 +446,7 @@ func TestMuxReuseWire(t *testing.T) {
t.Fatalf("unexpected dial error %v", err)
}

wire1 := m.Acquire()
wire1 := m.Acquire(context.Background())
m.Store(wire1)

if !cleaned {
Expand Down Expand Up @@ -660,7 +688,7 @@ func TestMuxDelegation(t *testing.T) {
defer m.Close()

for i := range wires {
m._pipe(uint16(i))
m._pipe(context.Background(), uint16(i))
}

builder := cmds.NewBuilder(cmds.NoSlot)
Expand Down Expand Up @@ -702,7 +730,7 @@ func TestMuxDelegation(t *testing.T) {
defer m.Close()

for i := range wires {
m._pipe(uint16(i))
m._pipe(context.Background(), uint16(i))
}

builder := cmds.NewBuilder(cmds.NoSlot)
Expand Down Expand Up @@ -1038,7 +1066,7 @@ func TestMuxRegisterCloseHook(t *testing.T) {

func BenchmarkClientSideCaching(b *testing.B) {
setup := func(b *testing.B) *mux {
c := makeMux("127.0.0.1:6379", &ClientOption{CacheSizeEachConn: DefaultCacheBytes}, func(dst string, opt *ClientOption) (conn net.Conn, err error) {
c := makeMux("127.0.0.1:6379", &ClientOption{CacheSizeEachConn: DefaultCacheBytes}, func(_ context.Context, dst string, opt *ClientOption) (conn net.Conn, err error) {
return net.Dial("tcp", dst)
})
if err := c.Dial(); err != nil {
Expand Down
Loading
Loading