Skip to content

Commit e2b0621

Browse files
committed
feat: Deprecated DialFn
Signed-off-by: Rueian <[email protected]>
1 parent cc3f007 commit e2b0621

File tree

6 files changed

+71
-28
lines changed

6 files changed

+71
-28
lines changed

rueidis.go

+1
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ type ClientOption struct {
7070
TLSConfig *tls.Config
7171

7272
// DialFn allows for a custom function to be used to create net.Conn connections
73+
// Deprecated: use DialCtxFn instead.
7374
DialFn func(string, *net.Dialer, *tls.Config) (conn net.Conn, err error)
7475

7576
// DialCtxFn allows for a custom function to be used to create net.Conn connections

rueidis_test.go

+21
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,27 @@ func TestCustomDialFnIsCalled(t *testing.T) {
407407
}
408408
}
409409

410+
func TestCustomDialCtxFnIsCalled(t *testing.T) {
411+
defer ShouldNotLeaked(SetupLeakDetection())
412+
isFnCalled := false
413+
option := ClientOption{
414+
InitAddress: []string{"127.0.0.1:0"},
415+
DialCtxFn: func(ctx context.Context, s string, dialer *net.Dialer, config *tls.Config) (conn net.Conn, err error) {
416+
isFnCalled = true
417+
return nil, errors.New("dial error")
418+
},
419+
}
420+
421+
_, err := NewClient(option)
422+
423+
if !isFnCalled {
424+
t.Fatalf("excepted ClientOption.DialFn to be called")
425+
}
426+
if err == nil {
427+
t.Fatalf("expected dial error")
428+
}
429+
}
430+
410431
func ExampleIsRedisNil() {
411432
client, err := NewClient(ClientOption{InitAddress: []string{"127.0.0.1:6379"}})
412433
if err != nil {

rueidisotel/metrics.go

+16-10
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,13 @@ func NewClient(clientOption rueidis.ClientOption, opts ...Option) (rueidis.Clien
7070
return nil, err
7171
}
7272

73-
if clientOption.DialFn == nil {
74-
clientOption.DialFn = defaultDialFn
73+
if clientOption.DialCtxFn == nil {
74+
clientOption.DialCtxFn = defaultDialFn
75+
if clientOption.DialFn != nil {
76+
clientOption.DialCtxFn = func(_ context.Context, s string, dialer *net.Dialer, config *tls.Config) (conn net.Conn, err error) {
77+
return clientOption.DialFn(s, dialer, config)
78+
}
79+
}
7580
}
7681

7782
metrics := dialMetrics{
@@ -103,7 +108,8 @@ func NewClient(clientOption rueidis.ClientOption, opts ...Option) (rueidis.Clien
103108
return nil, err
104109
}
105110

106-
clientOption.DialFn = trackDialing(metrics, clientOption.DialFn)
111+
clientOption.DialCtxFn = trackDialing(metrics, clientOption.DialCtxFn)
112+
107113
cli, err := rueidis.NewClient(clientOption)
108114
if err != nil {
109115
return nil, err
@@ -146,14 +152,13 @@ func newClient(opts ...Option) (*otelclient, error) {
146152
return cli, nil
147153
}
148154

149-
func trackDialing(m dialMetrics, dialFn func(string, *net.Dialer, *tls.Config) (conn net.Conn, err error)) func(string, *net.Dialer, *tls.Config) (conn net.Conn, err error) {
150-
return func(network string, dialer *net.Dialer, tlsConfig *tls.Config) (conn net.Conn, err error) {
151-
ctx := context.Background()
155+
func trackDialing(m dialMetrics, dialFn func(context.Context, string, *net.Dialer, *tls.Config) (conn net.Conn, err error)) func(context.Context, string, *net.Dialer, *tls.Config) (conn net.Conn, err error) {
156+
return func(ctx context.Context, network string, dialer *net.Dialer, tlsConfig *tls.Config) (conn net.Conn, err error) {
152157
m.attempt.Add(ctx, 1, m.addOpts...)
153158

154159
start := time.Now()
155160

156-
conn, err = dialFn(network, dialer, tlsConfig)
161+
conn, err = dialFn(ctx, network, dialer, tlsConfig)
157162
if err != nil {
158163
return nil, err
159164
}
@@ -187,9 +192,10 @@ func (t *connTracker) Close() error {
187192
return t.Conn.Close()
188193
}
189194

190-
func defaultDialFn(dst string, dialer *net.Dialer, cfg *tls.Config) (conn net.Conn, err error) {
195+
func defaultDialFn(ctx context.Context, dst string, dialer *net.Dialer, cfg *tls.Config) (conn net.Conn, err error) {
191196
if cfg != nil {
192-
return tls.DialWithDialer(dialer, "tcp", dst, cfg)
197+
td := tls.Dialer{NetDialer: dialer, Config: cfg}
198+
return td.DialContext(ctx, "tcp", dst)
193199
}
194-
return dialer.Dial("tcp", dst)
200+
return dialer.DialContext(ctx, "tcp", dst)
195201
}

rueidisotel/metrics_test.go

+26-13
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import (
1515
)
1616

1717
func TestNewClient(t *testing.T) {
18-
t.Run("client option only", func(t *testing.T) {
18+
t.Run("client option only (no ctx)", func(t *testing.T) {
1919
c, err := NewClient(rueidis.ClientOption{
2020
InitAddress: []string{"127.0.0.1:6379"},
2121
DialFn: func(dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) {
@@ -28,14 +28,27 @@ func TestNewClient(t *testing.T) {
2828
defer c.Close()
2929
})
3030

31+
t.Run("client option only", func(t *testing.T) {
32+
c, err := NewClient(rueidis.ClientOption{
33+
InitAddress: []string{"127.0.0.1:6379"},
34+
DialCtxFn: func(ctx context.Context, dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) {
35+
return dialer.DialContext(ctx, "tcp", dst)
36+
},
37+
})
38+
if err != nil {
39+
t.Fatal(err)
40+
}
41+
defer c.Close()
42+
})
43+
3144
t.Run("meter provider", func(t *testing.T) {
3245
mr := metric.NewManualReader()
3346
meterProvider := metric.NewMeterProvider(metric.WithReader(mr))
3447
c, err := NewClient(
3548
rueidis.ClientOption{
3649
InitAddress: []string{"127.0.0.1:6379"},
37-
DialFn: func(dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) {
38-
return dialer.Dial("tcp", dst)
50+
DialCtxFn: func(ctx context.Context, dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) {
51+
return dialer.DialContext(ctx, "tcp", dst)
3952
},
4053
},
4154
WithMeterProvider(meterProvider),
@@ -50,8 +63,8 @@ func TestNewClient(t *testing.T) {
5063
c, err := NewClient(
5164
rueidis.ClientOption{
5265
InitAddress: []string{"127.0.0.1:6379"},
53-
DialFn: func(dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) {
54-
return dialer.Dial("tcp", dst)
66+
DialCtxFn: func(ctx context.Context, dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) {
67+
return dialer.DialContext(ctx, "tcp", dst)
5568
},
5669
},
5770
WithHistogramOption(HistogramOption{
@@ -79,8 +92,8 @@ func TestNewClientError(t *testing.T) {
7992
t.Run("invalid client option", func(t *testing.T) {
8093
_, err := NewClient(rueidis.ClientOption{
8194
InitAddress: []string{""},
82-
DialFn: func(dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) {
83-
return dialer.Dial("tcp", dst)
95+
DialCtxFn: func(ctx context.Context, dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) {
96+
return dialer.DialContext(ctx, "tcp", dst)
8497
},
8598
})
8699
if err == nil {
@@ -120,8 +133,8 @@ func TestTrackDialing(t *testing.T) {
120133
c, err := NewClient(
121134
rueidis.ClientOption{
122135
InitAddress: []string{"127.0.0.1:6379"},
123-
DialFn: func(dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) {
124-
return dialer.Dial("tcp", dst)
136+
DialCtxFn: func(ctx context.Context, dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) {
137+
return dialer.DialContext(ctx, "tcp", dst)
125138
},
126139
},
127140
WithMeterProvider(meterProvider),
@@ -169,8 +182,8 @@ func TestTrackDialing(t *testing.T) {
169182
c, err := NewClient(
170183
rueidis.ClientOption{
171184
InitAddress: []string{"127.0.0.1:6379"},
172-
DialFn: func(dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) {
173-
return dialer.Dial("tcp", dst)
185+
DialCtxFn: func(ctx context.Context, dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) {
186+
return dialer.DialContext(ctx, "tcp", dst)
174187
},
175188
},
176189
WithMeterProvider(meterProvider),
@@ -198,8 +211,8 @@ func TestTrackDialing(t *testing.T) {
198211
_, err := NewClient(
199212
rueidis.ClientOption{
200213
InitAddress: []string{""},
201-
DialFn: func(dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) {
202-
return dialer.Dial("tcp", dst)
214+
DialCtxFn: func(ctx context.Context, dst string, dialer *net.Dialer, _ *tls.Config) (conn net.Conn, err error) {
215+
return dialer.DialContext(ctx, "tcp", dst)
203216
},
204217
},
205218
WithMeterProvider(meterProvider),

url.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package rueidis
22

33
import (
4+
"context"
45
"crypto/tls"
56
"fmt"
67
"net"
@@ -37,8 +38,8 @@ func ParseURL(str string) (opt ClientOption, err error) {
3738
}
3839
switch u.Scheme {
3940
case "unix":
40-
opt.DialFn = func(s string, dialer *net.Dialer, config *tls.Config) (conn net.Conn, err error) {
41-
return dialer.Dial("unix", s)
41+
opt.DialCtxFn = func(ctx context.Context, s string, dialer *net.Dialer, config *tls.Config) (conn net.Conn, err error) {
42+
return dialer.DialContext(ctx, "unix", s)
4243
}
4344
opt.InitAddress = []string{strings.TrimSpace(u.Path)}
4445
case "rediss", "valkeys":

url_test.go

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package rueidis
22

33
import (
4+
"context"
45
"strings"
56
"testing"
67
)
@@ -18,7 +19,7 @@ func TestParseURL(t *testing.T) {
1819
if opt, err := ParseURL("valkeys://"); err != nil || opt.TLSConfig == nil {
1920
t.Fatalf("unexpected %v %v", opt, err)
2021
}
21-
if opt, err := ParseURL("unix://"); err != nil || opt.DialFn == nil {
22+
if opt, err := ParseURL("unix://"); err != nil || opt.DialCtxFn == nil {
2223
t.Fatalf("unexpected %v %v", opt, err)
2324
}
2425
if opt, err := ParseURL("valkey://"); err != nil {
@@ -84,7 +85,7 @@ func TestParseURL(t *testing.T) {
8485
if opt, err := ParseURL("rediss://myhost:6379"); err != nil || opt.TLSConfig.ServerName != "myhost" {
8586
t.Fatalf("unexpected %v %v", opt, err)
8687
}
87-
if opt, err := ParseURL("unix:///path/to/redis.sock?db=1"); opt.DialFn == nil || opt.InitAddress[0] != "/path/to/redis.sock" || opt.SelectDB != 1 {
88+
if opt, err := ParseURL("unix:///path/to/redis.sock?db=1"); opt.DialCtxFn == nil || opt.InitAddress[0] != "/path/to/redis.sock" || opt.SelectDB != 1 {
8889
t.Fatalf("unexpected %v %v", opt, err)
8990
}
9091
}
@@ -100,7 +101,7 @@ func TestMustParseURL(t *testing.T) {
100101

101102
func TestMustParseURLUnix(t *testing.T) {
102103
opt := MustParseURL("unix://")
103-
if conn, err := opt.DialFn("", &opt.Dialer, nil); !strings.Contains(err.Error(), "unix") {
104+
if conn, err := opt.DialCtxFn(context.Background(), "", &opt.Dialer, nil); !strings.Contains(err.Error(), "unix") {
104105
t.Fatalf("unexpected %v %v", conn, err) // the error should be "dial unix: missing address"
105106
}
106107
}

0 commit comments

Comments
 (0)