Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 28 additions & 7 deletions transport/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,36 +41,57 @@ type Client struct {
mutex sync.Mutex
}

type ClientOptions struct {
Network string
Host string
DefaultPort int
Logger *logp.Logger
}

type Config struct {
Proxy *ProxyConfig
TLS *tlscommon.TLSConfig
Timeout time.Duration
Stats IOStatser
}

// Deprecated: use NewClientWithOptions
func NewClient(c Config, network, host string, defaultPort int) (*Client, error) {
return NewClientWithOptions(c, ClientOptions{
Network: network,
Host: host,
DefaultPort: defaultPort,
Logger: logp.NewLogger(""),
})
}

func NewClientWithOptions(c Config, opts ClientOptions) (*Client, error) {
// do some sanity checks regarding network and Config matching +
// address being parseable
switch network {
switch opts.Network {
case "tcp", "tcp4", "tcp6":
case "udp", "udp4", "udp6":
if c.TLS == nil && c.Proxy == nil {
break
}
fallthrough
default:
return nil, fmt.Errorf("unsupported network type %v", network)
return nil, fmt.Errorf("unsupported network type %v", opts.Network)
}

if opts.Logger == nil {
opts.Logger = logp.NewNopLogger()
}

dialer, err := MakeDialer(c)
dialer, err := MakeDialer(c, opts.Logger)
if err != nil {
return nil, err
}

return NewClientWithDialer(dialer, c, network, host, defaultPort)
return NewClientWithDialer(dialer, c, opts.Network, opts.Host, opts.DefaultPort, opts.Logger)
}

func NewClientWithDialer(d Dialer, c Config, network, host string, defaultPort int) (*Client, error) {
func NewClientWithDialer(d Dialer, c Config, network, host string, defaultPort int, logger *logp.Logger) (*Client, error) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this makes it a breaking change 😅

// check address being parseable
host = fullAddress(host, defaultPort)
_, _, err := net.SplitHostPort(host)
Expand All @@ -79,7 +100,7 @@ func NewClientWithDialer(d Dialer, c Config, network, host string, defaultPort i
}

client := &Client{
log: logp.NewLogger(logSelector),
log: logger.Named(logSelector),
dialer: d,
network: network,
host: host,
Expand Down Expand Up @@ -231,7 +252,7 @@ func (c *Client) Test(d testing.Driver) {
} else {
d.Run("TLS", func(d testing.Driver) {
netDialer := NetDialer(c.config.Timeout)
tlsDialer := TestTLSDialer(d, netDialer, c.config.TLS, c.config.Timeout)
tlsDialer := TestTLSDialer(d, netDialer, c.config.TLS, c.config.Timeout, c.log)
_, err := tlsDialer.DialContext(context.Background(), "tcp", c.host)
d.Fatal("dial up", err)
})
Expand Down
14 changes: 7 additions & 7 deletions transport/httpcommon/httpcommon.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,9 @@
return err
}

// TODO: use local logger here
_, err := tlscommon.LoadTLSConfig(tmp.TLS, logp.NewLogger(""))
// we use no-op logger here because we are only testing for errors
// if any while loading ssl config
_, err := tlscommon.LoadTLSConfig(tmp.TLS, logp.NewNopLogger())
if err != nil {
return err
}
Expand Down Expand Up @@ -214,9 +215,8 @@
}
}

logger := logp.NewLogger("")
if log := extra.logger; log != nil {
logger = log
if extra.logger == nil {
extra.logger = logp.NewLogger("")
}

for _, opt := range opts {
Expand All @@ -229,12 +229,12 @@
dialer = transport.NetDialer(settings.Timeout)
}

tls, err := tlscommon.LoadTLSConfig(settings.TLS, logger)
tls, err := tlscommon.LoadTLSConfig(settings.TLS, extra.logger)
if err != nil {
return nil, err
}

tlsDialer := transport.TLSDialer(dialer, tls, settings.Timeout)
tlsDialer := transport.TLSDialer(dialer, tls, settings.Timeout, extra.logger)
for _, opt := range opts {
if dialOpt, ok := opt.(dialerModOption); ok {
dialer = dialOpt.applyDialer(settings, dialer)
Expand Down Expand Up @@ -270,7 +270,7 @@
dialer, tlsDialer transport.Dialer,
opts ...TransportOption,
) *http.Transport {
t := http.DefaultTransport.(*http.Transport).Clone()

Check failure on line 273 in transport/httpcommon/httpcommon.go

View workflow job for this annotation

GitHub Actions / lint (ubuntu-latest)

Error return value is not checked (errcheck)

Check failure on line 273 in transport/httpcommon/httpcommon.go

View workflow job for this annotation

GitHub Actions / lint (macos-latest)

Error return value is not checked (errcheck)
t.DialContext = dialer.DialContext
t.DialTLSContext = tlsDialer.DialContext
t.TLSClientConfig = tls.ToConfig()
Expand Down
13 changes: 10 additions & 3 deletions transport/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,21 @@ import (
"sync"
"time"

"github.com/elastic/elastic-agent-libs/logp"
"github.com/elastic/elastic-agent-libs/testing"
"github.com/elastic/elastic-agent-libs/transport/tlscommon"
)

func TLSDialer(forward Dialer, config *tlscommon.TLSConfig, timeout time.Duration) Dialer {
return TestTLSDialer(testing.NullDriver, forward, config, timeout)
func TLSDialer(forward Dialer, config *tlscommon.TLSConfig, timeout time.Duration, logger *logp.Logger) Dialer {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

return TestTLSDialer(testing.NullDriver, forward, config, timeout, logger)
}

func TestTLSDialer(
d testing.Driver,
forward Dialer,
config *tlscommon.TLSConfig,
timeout time.Duration,
logger *logp.Logger,
) Dialer {
var lastTLSConfig *tls.Config
var lastNetwork string
Expand All @@ -63,7 +65,12 @@ func TestTLSDialer(
tlsConfig = lastTLSConfig
}
if tlsConfig == nil {
tlsConfig = config.BuildModuleClientConfig(host)
// if tlsconfig is nil, set provided logger
if config == nil {
tlsConfig = config.BuildModuleClientConfig(host, tlscommon.WithLogger(logger))
} else {
tlsConfig = config.BuildModuleClientConfig(host)
}
lastNetwork = network
lastAddress = address
lastTLSConfig = tlsConfig
Expand Down
33 changes: 31 additions & 2 deletions transport/tlscommon/tls_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,26 @@ var (
ErrMissingPeerCertificate = errors.New("missing peer certificates")
)

type tlsOptFunc func(t *TLSSettings)

func (t tlsOptFunc) apply(c *TLSSettings) {
t(c)
}

type TLSOption interface {
apply(t *TLSSettings)
}

type TLSSettings struct {
logger *logp.Logger
}

func WithLogger(logger *logp.Logger) TLSOption {
return tlsOptFunc(func(t *TLSSettings) {
t.logger = logger
})
}

// ToConfig generates a tls.Config object. Note, you must use BuildModuleClientConfig to generate a config with
// ServerName set, use that method for servers with SNI.
// By default VerifyConnection is set to client mode.
Expand Down Expand Up @@ -126,7 +146,16 @@ func (c *TLSConfig) ToConfig() *tls.Config {
}

// BuildModuleClientConfig takes the TLSConfig and transform it into a `tls.Config`.
func (c *TLSConfig) BuildModuleClientConfig(host string) *tls.Config {
func (c *TLSConfig) BuildModuleClientConfig(host string, options ...TLSOption) *tls.Config {
var settings TLSSettings
for _, opt := range options {
opt.apply(&settings)
}

if settings.logger == nil {
settings.logger = logp.NewLogger("")
}

if c == nil {
// use default TLS settings, if config is empty.
return &tls.Config{
Expand All @@ -135,7 +164,7 @@ func (c *TLSConfig) BuildModuleClientConfig(host string) *tls.Config {
VerifyConnection: makeVerifyConnection(&TLSConfig{
Verification: VerifyFull,
ServerName: host,
}, logp.NewLogger("tls")),
}, settings.logger.Named("tls")),
}
}

Expand Down
8 changes: 4 additions & 4 deletions transport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,17 @@ func (d DialerFunc) DialContext(ctx context.Context, network, address string) (n
}

func DialContext(ctx context.Context, c Config, network, address string) (net.Conn, error) {
d, err := MakeDialer(c)
d, err := MakeDialer(c, logp.NewLogger(""))
if err != nil {
return nil, err
}
return d.DialContext(ctx, network, address)
}

func MakeDialer(c Config) (Dialer, error) {
func MakeDialer(c Config, logger *logp.Logger) (Dialer, error) {
var err error
dialer := NetDialer(c.Timeout)
dialer, err = ProxyDialer(logp.NewLogger(logSelector), c.Proxy, dialer)
dialer, err = ProxyDialer(logger.Named(logSelector), c.Proxy, dialer)
if err != nil {
return nil, err
}
Expand All @@ -64,7 +64,7 @@ func MakeDialer(c Config) (Dialer, error) {
}

if c.TLS != nil {
return TLSDialer(dialer, c.TLS, c.Timeout), nil
return TLSDialer(dialer, c.TLS, c.Timeout, logger), nil
}
return dialer, nil
}
Loading