Skip to content

wip: add refreshable *tls.Config param support #761

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions conjure-go-client/httpclient/client_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ package httpclient

import (
"context"
"crypto/tls"
"fmt"
"net/http"
"time"
Expand Down Expand Up @@ -70,10 +69,12 @@ type clientBuilder struct {
}

type httpClientBuilder struct {
ServiceName refreshable.String
Timeout refreshable.Duration
DialerParams refreshingclient.RefreshableDialerParams
TLSConfig *tls.Config // If unset, config in TransportParams will be used.
ServiceName refreshable.String
Timeout refreshable.Duration
DialerParams refreshingclient.RefreshableDialerParams
// TLSConfig supplies the *tls.Config for the underlying transport to use.
// If unset, config in TransportParams will be used.
TLSConfig refreshingclient.RefreshableTLSConfig
TransportParams refreshingclient.RefreshableTransportParams
Middlewares []Middleware

Expand All @@ -97,11 +98,11 @@ func (b *httpClientBuilder) Build(ctx context.Context, params ...HTTPClientParam
}
}

var tlsProvider refreshingclient.TLSProvider
var tlsProvider refreshingclient.RefreshableTLSConfig
if b.TLSConfig != nil {
tlsProvider = refreshingclient.NewStaticTLSConfigProvider(b.TLSConfig)
tlsProvider = b.TLSConfig
} else {
refreshableProvider, err := refreshingclient.NewRefreshableTLSConfig(ctx, b.TransportParams.TLS())
refreshableProvider, err := refreshingclient.NewRefreshableTLSConfigFromParams(ctx, b.TransportParams.TLS())
if err != nil {
return nil, err
}
Expand Down
21 changes: 19 additions & 2 deletions conjure-go-client/httpclient/client_params.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,19 +354,36 @@ func WithTLSConfig(conf *tls.Config) ClientOrHTTPClientParam {
if conf == nil {
b.TLSConfig = nil
} else {
b.TLSConfig = conf.Clone()
b.TLSConfig = refreshingclient.NewStaticTLSConfigProvider(conf.Clone())
}
return nil
})
}

// WithRefreshableTLSConfig sets the SSL/TLS configuration for the HTTP client's Transport.
// Clients can update the TLS configuration of the underlying HTTP Transport using the 'updateFn'.
// This function does not accept a refreshable because 'reflect.DeepEqual' (which underpins refreshables)
// does not work for structs with functional fields, which *tls.Config uses extensively.
// The palantir/pkg/tlsconfig package is recommended to build a tls.Config from sane defaults.
func WithRefreshableTLSConfig(conf *tls.Config) (param ClientOrHTTPClientParam, updateFn func(*tls.Config)) {
m := refreshingclient.NewMappedRefreshableTLSConfig(conf)
return clientOrHTTPClientParamFunc(func(b *httpClientBuilder) error {
b.TLSConfig = m
return nil
}), m.Update
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for this, the clients would be responsible for calling updateFn whenever they want to refresh the underlying *tls.Config of their clients.

}

// WithTLSInsecureSkipVerify sets the InsecureSkipVerify field for the HTTP client's tls config.
// This option should only be used in clients that have way to establish trust with servers.
// If WithTLSConfig is used, the config's InsecureSkipVerify is set to true.
func WithTLSInsecureSkipVerify() ClientOrHTTPClientParam {
return clientOrHTTPClientParamFunc(func(b *httpClientBuilder) error {
if b.TLSConfig != nil {
b.TLSConfig.InsecureSkipVerify = true
b.TLSConfig = refreshingclient.ConfigureTLSConfig(b.TLSConfig, func(conf *tls.Config) *tls.Config {
conf = conf.Clone()
conf.InsecureSkipVerify = true
return conf
})
}
b.TransportParams = refreshingclient.ConfigureTransport(b.TransportParams, func(p refreshingclient.TransportParams) refreshingclient.TransportParams {
p.TLS.InsecureSkipVerify = true
Expand Down
2 changes: 1 addition & 1 deletion conjure-go-client/httpclient/client_params_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ func unwrapTransport(rt http.RoundTripper) (*http.Transport, []Middleware) {
for {
switch v := unwrapped.(type) {
case *refreshingclient.RefreshableTransport:
unwrapped = v.Current().(http.RoundTripper)
unwrapped = v.Load()
case *wrappedClient:
unwrapped = v.baseTransport
middlewares = append(middlewares, v.middleware)
Expand Down
100 changes: 94 additions & 6 deletions conjure-go-client/httpclient/internal/refreshingclient/tlsconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,87 @@ package refreshingclient
import (
"context"
"crypto/tls"
"sync"
"sync/atomic"

"github.com/palantir/pkg/refreshable"
"github.com/palantir/pkg/tlsconfig"
werror "github.com/palantir/witchcraft-go-error"
"github.com/palantir/witchcraft-go-logging/wlog/svclog/svc1log"
)

type RefreshableTLSConfig interface {
GetTLSConfig(ctx context.Context) *tls.Config
SubscribeToTLSConfig(consumer func(*tls.Config)) (unsubscribe func())
}

var _ RefreshableTLSConfig = (*MappedRefreshableTLSConfig)(nil)

func ConfigureTLSConfig(r RefreshableTLSConfig, mapFn func(conf *tls.Config) *tls.Config) RefreshableTLSConfig {
var m MappedRefreshableTLSConfig
r.SubscribeToTLSConfig(func(c *tls.Config) {
m.Update(mapFn(c))
})
return &m
}

type MappedRefreshableTLSConfig struct {
conf atomic.Pointer[tls.Config]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we reimplementing refreshable?


mu sync.Mutex // protects subscribers
subscribers []*func(*tls.Config)
}

// NewMappedRefreshableTLSConfig returns a new *MappedRefreshableTLSConfig.
func NewMappedRefreshableTLSConfig(conf *tls.Config) *MappedRefreshableTLSConfig {
var m MappedRefreshableTLSConfig
m.conf.Store(conf)
return &m
}

// GetTLSConfig implements RefreshableTLSConf.
func (m *MappedRefreshableTLSConfig) GetTLSConfig(ctx context.Context) *tls.Config {
return m.conf.Load()
}

func (m *MappedRefreshableTLSConfig) Update(conf *tls.Config) {
m.conf.Store(conf)

m.mu.Lock()
defer m.mu.Unlock()
for _, sub := range m.subscribers {
(*sub)(conf)
}
}

// SubscribeToTLSConfig implements RefreshableTLSConf.
func (m *MappedRefreshableTLSConfig) SubscribeToTLSConfig(consumer func(*tls.Config)) (unsubscribe func()) {
m.mu.Lock()
defer m.mu.Unlock()

consumerFnPtr := &consumer
m.subscribers = append(m.subscribers, consumerFnPtr)
return func() {
m.unsubscribe(consumerFnPtr)
}
}

func (m *MappedRefreshableTLSConfig) unsubscribe(consumerFnPtr *func(*tls.Config)) {
m.mu.Lock()
defer m.mu.Unlock()

matchIdx := -1
for idx, currSub := range m.subscribers {
if currSub == consumerFnPtr {
matchIdx = idx
break
}
}
if matchIdx != -1 {
m.subscribers = append(m.subscribers[:matchIdx], m.subscribers[matchIdx+1:]...)
}
}

// TLSParams contains the parameters needed to build a *tls.Config.
// Its fields must all be compatible with reflect.DeepEqual.
type TLSParams struct {
Expand All @@ -37,6 +111,8 @@ type TLSProvider interface {
GetTLSConfig(ctx context.Context) *tls.Config
}

var _ RefreshableTLSConfig = (*StaticTLSConfigProvider)(nil)

// StaticTLSConfigProvider is a TLSProvider that always returns the same *tls.Config.
type StaticTLSConfigProvider tls.Config

Expand All @@ -48,37 +124,49 @@ func (p *StaticTLSConfigProvider) GetTLSConfig(context.Context) *tls.Config {
return (*tls.Config)(p)
}

type RefreshableTLSConfig struct {
// SubscribeToTLSConfig implements RefreshableTLSConf.
func (p *StaticTLSConfigProvider) SubscribeToTLSConfig(consumer func(*tls.Config)) (unsubscribe func()) {
return nil
}

type WrappedRefreshableTLSConfig struct {
r *refreshable.ValidatingRefreshable // contains *tls.Config
}

// NewRefreshableTLSConfig evaluates the provided TLSParams and returns a RefreshableTLSConfig that will update the
// NewRefreshableTLSConfigFromParams evaluates the provided TLSParams and returns a RefreshableTLSConfig that will update the
// underlying *tls.Config when the TLSParams change.
// IF the initial TLSParams are invalid, NewRefreshableTLSConfig will return an error.
// IF the initial TLSParams are invalid, NewRefreshableTLSConfigFromParams will return an error.
// If the updated TLSParams are invalid, the RefreshableTLSConfig will continue to use the previous value and log the error.
//
// N.B. This subscription only fires when the paths are updated, not when the contents of the files are updated.
// We could consider adding a file refreshable to watch the key and cert files.
func NewRefreshableTLSConfig(ctx context.Context, params RefreshableTLSParams) (TLSProvider, error) {
func NewRefreshableTLSConfigFromParams(ctx context.Context, params RefreshableTLSParams) (RefreshableTLSConfig, error) {
r, err := refreshable.NewMapValidatingRefreshable(params, func(i interface{}) (interface{}, error) {
return NewTLSConfig(ctx, i.(TLSParams))
})
if err != nil {
return nil, werror.WrapWithContextParams(ctx, err, "failed to build RefreshableTLSConfig")
}
return RefreshableTLSConfig{r: r}, nil
return WrappedRefreshableTLSConfig{r: r}, nil
}

// GetTLSConfig returns the most recent valid *tls.Config.
// If the last refreshable update resulted in an error, that error is logged and
// the previous value is returned.
func (r RefreshableTLSConfig) GetTLSConfig(ctx context.Context) *tls.Config {
func (r WrappedRefreshableTLSConfig) GetTLSConfig(ctx context.Context) *tls.Config {
if err := r.r.LastValidateErr(); err != nil {
svc1log.FromContext(ctx).Warn("Invalid TLS config. Using previous value.", svc1log.Stacktrace(err))
}
return r.r.Current().(*tls.Config)
}

// SubscribeToTLSConfig implements RefreshableTLSConf.
func (r WrappedRefreshableTLSConfig) SubscribeToTLSConfig(consumer func(*tls.Config)) (unsubscribe func()) {
return r.r.Subscribe(func(i interface{}) {
consumer(i.(*tls.Config))
})
}

// NewTLSConfig returns a *tls.Config built from the provided TLSParams.
func NewTLSConfig(ctx context.Context, p TLSParams) (*tls.Config, error) {
var tlsParams []tlsconfig.ClientParam
Expand Down
34 changes: 21 additions & 13 deletions conjure-go-client/httpclient/internal/refreshingclient/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@ package refreshingclient

import (
"context"
"crypto/tls"
"net/http"
"net/url"
"sync/atomic"
"time"

"github.com/palantir/pkg/refreshable"
"github.com/palantir/witchcraft-go-logging/wlog/svclog/svc1log"
"golang.org/x/net/http2"
)
Expand All @@ -42,12 +43,21 @@ type TransportParams struct {
TLS TLSParams
}

func NewRefreshableTransport(ctx context.Context, p RefreshableTransportParams, tlsProvider TLSProvider, dialer ContextDialer) http.RoundTripper {
return &RefreshableTransport{
Refreshable: p.MapTransportParams(func(p TransportParams) interface{} {
return newTransport(ctx, p, tlsProvider, dialer)
}),
}
func NewRefreshableTransport(ctx context.Context, p RefreshableTransportParams, t RefreshableTLSConfig, dialer ContextDialer) http.RoundTripper {
var refreshingTransport RefreshableTransport

// initialize the transport the first time.
refreshingTransport.Update(ctx, p.CurrentTransportParams(), t.GetTLSConfig(ctx), dialer)

// also subscribe to updates on transport params and the tls provider.
p.SubscribeToTransportParams(func(tp TransportParams) {
refreshingTransport.Update(ctx, tp, t.GetTLSConfig(ctx), dialer)
})
t.SubscribeToTLSConfig(func(conf *tls.Config) {
refreshingTransport.Update(ctx, p.CurrentTransportParams(), conf, dialer)
})

return &refreshingTransport
}

// ConfigureTransport accepts a mapping function which will be applied to the params value as it is evaluated.
Expand All @@ -61,14 +71,14 @@ func ConfigureTransport(r RefreshableTransportParams, mapFn func(p TransportPara
// RefreshableTransport implements http.RoundTripper backed by a refreshable *http.Transport.
// The transport and internal dialer are each rebuilt when any of their respective parameters are updated.
type RefreshableTransport struct {
refreshable.Refreshable // contains *http.Transport
atomic.Pointer[http.Transport]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why change this?

}

func (r *RefreshableTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return r.Current().(*http.Transport).RoundTrip(req)
return r.Load().RoundTrip(req)
}

func newTransport(ctx context.Context, p TransportParams, tlsProvider TLSProvider, dialer ContextDialer) *http.Transport {
func (r *RefreshableTransport) Update(ctx context.Context, p TransportParams, tlsConfig *tls.Config, dialer ContextDialer) {
svc1log.FromContext(ctx).Debug("Reconstructing HTTP Transport")

var transportProxy func(*http.Request) (*url.URL, error)
Expand All @@ -78,7 +88,6 @@ func newTransport(ctx context.Context, p TransportParams, tlsProvider TLSProvide
transportProxy = http.ProxyFromEnvironment
}

tlsConfig := tlsProvider.GetTLSConfig(ctx)
transport := &http.Transport{
Proxy: transportProxy,
DialContext: dialer.DialContext,
Expand Down Expand Up @@ -115,6 +124,5 @@ func newTransport(ctx context.Context, p TransportParams, tlsProvider TLSProvide
http2Transport.PingTimeout = p.HTTP2PingTimeout
}
}

return transport
r.Store(transport)
}