diff --git a/transport/client.go b/transport/client.go index b6f1d8c8..7170fe86 100644 --- a/transport/client.go +++ b/transport/client.go @@ -41,6 +41,13 @@ type Client struct { mutex sync.Mutex } +type ClientOptions struct { + Network string + Host string + DefaultPort int + Logger *logp.Logger +} + type Config struct { Proxy *ProxyConfig TLS *tlscommon.TLSConfig @@ -48,10 +55,20 @@ type Config struct { Stats IOStatser } +// Deprecated: use NewClientWithOptions func NewClient(c Config, network, host string, defaultPort int) (*Client, error) { + return NewClientWithOptions(c, ClientOptions{ + Network: network, + Host: host, + DefaultPort: defaultPort, + Logger: logp.NewLogger(""), + }) +} + +func NewClientWithOptions(c Config, opts ClientOptions) (*Client, error) { // do some sanity checks regarding network and Config matching + // address being parseable - switch network { + switch opts.Network { case "tcp", "tcp4", "tcp6": case "udp", "udp4", "udp6": if c.TLS == nil && c.Proxy == nil { @@ -59,18 +76,22 @@ func NewClient(c Config, network, host string, defaultPort int) (*Client, error) } fallthrough default: - return nil, fmt.Errorf("unsupported network type %v", network) + return nil, fmt.Errorf("unsupported network type %v", opts.Network) + } + + if opts.Logger == nil { + opts.Logger = logp.NewNopLogger() } - dialer, err := MakeDialer(c) + dialer, err := MakeDialer(c, opts.Logger) if err != nil { return nil, err } - return NewClientWithDialer(dialer, c, network, host, defaultPort) + return NewClientWithDialer(dialer, c, opts.Network, opts.Host, opts.DefaultPort, opts.Logger) } -func NewClientWithDialer(d Dialer, c Config, network, host string, defaultPort int) (*Client, error) { +func NewClientWithDialer(d Dialer, c Config, network, host string, defaultPort int, logger *logp.Logger) (*Client, error) { // check address being parseable host = fullAddress(host, defaultPort) _, _, err := net.SplitHostPort(host) @@ -79,7 +100,7 @@ func NewClientWithDialer(d Dialer, c Config, network, host string, defaultPort i } client := &Client{ - log: logp.NewLogger(logSelector), + log: logger.Named(logSelector), dialer: d, network: network, host: host, @@ -231,7 +252,7 @@ func (c *Client) Test(d testing.Driver) { } else { d.Run("TLS", func(d testing.Driver) { netDialer := NetDialer(c.config.Timeout) - tlsDialer := TestTLSDialer(d, netDialer, c.config.TLS, c.config.Timeout) + tlsDialer := TestTLSDialer(d, netDialer, c.config.TLS, c.config.Timeout, c.log) _, err := tlsDialer.DialContext(context.Background(), "tcp", c.host) d.Fatal("dial up", err) }) diff --git a/transport/httpcommon/httpcommon.go b/transport/httpcommon/httpcommon.go index 3ef5498b..c703e962 100644 --- a/transport/httpcommon/httpcommon.go +++ b/transport/httpcommon/httpcommon.go @@ -185,8 +185,9 @@ func (settings *HTTPTransportSettings) Unpack(cfg *config.C) error { return err } - // TODO: use local logger here - _, err := tlscommon.LoadTLSConfig(tmp.TLS, logp.NewLogger("")) + // we use no-op logger here because we are only testing for errors + // if any while loading ssl config + _, err := tlscommon.LoadTLSConfig(tmp.TLS, logp.NewNopLogger()) if err != nil { return err } @@ -214,9 +215,8 @@ func (settings *HTTPTransportSettings) RoundTripper(opts ...TransportOption) (ht } } - logger := logp.NewLogger("") - if log := extra.logger; log != nil { - logger = log + if extra.logger == nil { + extra.logger = logp.NewLogger("") } for _, opt := range opts { @@ -229,12 +229,12 @@ func (settings *HTTPTransportSettings) RoundTripper(opts ...TransportOption) (ht dialer = transport.NetDialer(settings.Timeout) } - tls, err := tlscommon.LoadTLSConfig(settings.TLS, logger) + tls, err := tlscommon.LoadTLSConfig(settings.TLS, extra.logger) if err != nil { return nil, err } - tlsDialer := transport.TLSDialer(dialer, tls, settings.Timeout) + tlsDialer := transport.TLSDialer(dialer, tls, settings.Timeout, extra.logger) for _, opt := range opts { if dialOpt, ok := opt.(dialerModOption); ok { dialer = dialOpt.applyDialer(settings, dialer) diff --git a/transport/tls.go b/transport/tls.go index 24b3311f..d56c273e 100644 --- a/transport/tls.go +++ b/transport/tls.go @@ -26,12 +26,13 @@ import ( "sync" "time" + "github.com/elastic/elastic-agent-libs/logp" "github.com/elastic/elastic-agent-libs/testing" "github.com/elastic/elastic-agent-libs/transport/tlscommon" ) -func TLSDialer(forward Dialer, config *tlscommon.TLSConfig, timeout time.Duration) Dialer { - return TestTLSDialer(testing.NullDriver, forward, config, timeout) +func TLSDialer(forward Dialer, config *tlscommon.TLSConfig, timeout time.Duration, logger *logp.Logger) Dialer { + return TestTLSDialer(testing.NullDriver, forward, config, timeout, logger) } func TestTLSDialer( @@ -39,6 +40,7 @@ func TestTLSDialer( forward Dialer, config *tlscommon.TLSConfig, timeout time.Duration, + logger *logp.Logger, ) Dialer { var lastTLSConfig *tls.Config var lastNetwork string @@ -63,7 +65,12 @@ func TestTLSDialer( tlsConfig = lastTLSConfig } if tlsConfig == nil { - tlsConfig = config.BuildModuleClientConfig(host) + // if tlsconfig is nil, set provided logger + if config == nil { + tlsConfig = config.BuildModuleClientConfig(host, tlscommon.WithLogger(logger)) + } else { + tlsConfig = config.BuildModuleClientConfig(host) + } lastNetwork = network lastAddress = address lastTLSConfig = tlsConfig diff --git a/transport/tlscommon/tls_config.go b/transport/tlscommon/tls_config.go index 647bc93c..12c1d942 100644 --- a/transport/tlscommon/tls_config.go +++ b/transport/tlscommon/tls_config.go @@ -94,6 +94,26 @@ var ( ErrMissingPeerCertificate = errors.New("missing peer certificates") ) +type tlsOptFunc func(t *TLSSettings) + +func (t tlsOptFunc) apply(c *TLSSettings) { + t(c) +} + +type TLSOption interface { + apply(t *TLSSettings) +} + +type TLSSettings struct { + logger *logp.Logger +} + +func WithLogger(logger *logp.Logger) TLSOption { + return tlsOptFunc(func(t *TLSSettings) { + t.logger = logger + }) +} + // ToConfig generates a tls.Config object. Note, you must use BuildModuleClientConfig to generate a config with // ServerName set, use that method for servers with SNI. // By default VerifyConnection is set to client mode. @@ -126,7 +146,16 @@ func (c *TLSConfig) ToConfig() *tls.Config { } // BuildModuleClientConfig takes the TLSConfig and transform it into a `tls.Config`. -func (c *TLSConfig) BuildModuleClientConfig(host string) *tls.Config { +func (c *TLSConfig) BuildModuleClientConfig(host string, options ...TLSOption) *tls.Config { + var settings TLSSettings + for _, opt := range options { + opt.apply(&settings) + } + + if settings.logger == nil { + settings.logger = logp.NewLogger("") + } + if c == nil { // use default TLS settings, if config is empty. return &tls.Config{ @@ -135,7 +164,7 @@ func (c *TLSConfig) BuildModuleClientConfig(host string) *tls.Config { VerifyConnection: makeVerifyConnection(&TLSConfig{ Verification: VerifyFull, ServerName: host, - }, logp.NewLogger("tls")), + }, settings.logger.Named("tls")), } } diff --git a/transport/transport.go b/transport/transport.go index 41037af4..f6a49497 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -45,17 +45,17 @@ func (d DialerFunc) DialContext(ctx context.Context, network, address string) (n } func DialContext(ctx context.Context, c Config, network, address string) (net.Conn, error) { - d, err := MakeDialer(c) + d, err := MakeDialer(c, logp.NewLogger("")) if err != nil { return nil, err } return d.DialContext(ctx, network, address) } -func MakeDialer(c Config) (Dialer, error) { +func MakeDialer(c Config, logger *logp.Logger) (Dialer, error) { var err error dialer := NetDialer(c.Timeout) - dialer, err = ProxyDialer(logp.NewLogger(logSelector), c.Proxy, dialer) + dialer, err = ProxyDialer(logger.Named(logSelector), c.Proxy, dialer) if err != nil { return nil, err } @@ -64,7 +64,7 @@ func MakeDialer(c Config) (Dialer, error) { } if c.TLS != nil { - return TLSDialer(dialer, c.TLS, c.Timeout), nil + return TLSDialer(dialer, c.TLS, c.Timeout, logger), nil } return dialer, nil }