From f75c3c9cdfd895b0d84ef22b521fe9a8369a736e Mon Sep 17 00:00:00 2001 From: Eric Chlebek Date: Mon, 6 Jan 2025 07:25:55 -0800 Subject: [PATCH] Add CheckRedirect callback (#269) This commit adds a CheckRedirect callback that opamp-go will call before following a redirect from the server it's trying to connect to. Like in net/http, CheckRedirect can be used to observe the request chain that the client is taking while attempting to make a connection. The user can optionally terminate redirect following by returning an error from CheckRedirect. Unlike in net/http, the via parameter for CheckRedirect is a slice of responses. Since the user would have no other way to access these in the context of opamp-go, CheckRedirect makes them available so that users can know exactly what status codes and headers are set in the response. Another small improvement is that the error callback is no longer called when redirecting. This should help to prevent undue error logging by opamp-go consumers. Since the CheckRedirect callback is now available, it also doesn't represent any loss in functionality to opamp-go consumers. --- client/httpclient_test.go | 88 ++++++++++++++++++++++++ client/internal/httpsender.go | 8 +++ client/types/callbacks.go | 14 ++++ client/wsclient.go | 88 ++++++++++++++++++++---- client/wsclient_test.go | 123 +++++++++++++++++++++++++++++++++- go.mod | 3 +- go.sum | 2 + 7 files changed, 308 insertions(+), 18 deletions(-) diff --git a/client/httpclient_test.go b/client/httpclient_test.go index a3845c45..fc670411 100644 --- a/client/httpclient_test.go +++ b/client/httpclient_test.go @@ -3,13 +3,17 @@ package client import ( "compress/gzip" "context" + "errors" "io" "net/http" + "net/http/httptest" + "net/url" "sync/atomic" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "google.golang.org/protobuf/proto" "github.com/open-telemetry/opamp-go/client/internal" @@ -223,3 +227,87 @@ func TestHTTPClientStartWithZeroHeartbeatInterval(t *testing.T) { // Shutdown the Server. srv.Close() } + +func mockRedirectHTTP(t testing.TB, viaLen int, err error) *checkRedirectMock { + m := &checkRedirectMock{ + t: t, + viaLen: viaLen, + http: true, + } + m.On("CheckRedirect", mock.Anything, mock.Anything, mock.Anything).Return(err) + return m +} + +func TestRedirectHTTP(t *testing.T) { + redirectee := internal.StartMockServer(t) + tests := []struct { + Name string + Redirector *httptest.Server + ExpError bool + MockRedirect *checkRedirectMock + }{ + { + Name: "simple redirect", + Redirector: redirectServer("http://"+redirectee.Endpoint, 302), + }, + { + Name: "check redirect", + Redirector: redirectServer("http://"+redirectee.Endpoint, 302), + MockRedirect: mockRedirectHTTP(t, 1, nil), + }, + { + Name: "check redirect returns error", + Redirector: redirectServer("http://"+redirectee.Endpoint, 302), + MockRedirect: mockRedirectHTTP(t, 1, errors.New("hello")), + ExpError: true, + }, + } + + for _, test := range tests { + t.Run(test.Name, func(t *testing.T) { + var connectErr atomic.Value + var connected atomic.Value + + settings := &types.StartSettings{ + Callbacks: types.Callbacks{ + OnConnect: func(ctx context.Context) { + connected.Store(1) + }, + OnConnectFailed: func(ctx context.Context, err error) { + connectErr.Store(err) + }, + }, + } + if test.MockRedirect != nil { + settings.Callbacks = types.Callbacks{ + OnConnect: func(ctx context.Context) { + connected.Store(1) + }, + OnConnectFailed: func(ctx context.Context, err error) { + connectErr.Store(err) + }, + CheckRedirect: test.MockRedirect.CheckRedirect, + } + } + reURL, _ := url.Parse(test.Redirector.URL) // err can't be non-nil + settings.OpAMPServerURL = reURL.String() + client := NewHTTP(nil) + prepareClient(t, settings, client) + + err := client.Start(context.Background(), *settings) + if err != nil { + t.Fatal(err) + } + defer client.Stop(context.Background()) + // Wait for connection to be established. + eventually(t, func() bool { + return connected.Load() != nil || connectErr.Load() != nil + }) + if test.ExpError && connectErr.Load() == nil { + t.Error("expected non-nil error") + } else if err := connectErr.Load(); !test.ExpError && err != nil { + t.Fatal(err) + } + }) + } +} diff --git a/client/internal/httpsender.go b/client/internal/httpsender.go index 502bf7e4..a97e1311 100644 --- a/client/internal/httpsender.go +++ b/client/internal/httpsender.go @@ -98,6 +98,14 @@ func (h *HTTPSender) Run( h.callbacks = callbacks h.receiveProcessor = newReceivedProcessor(h.logger, callbacks, h, clientSyncedState, packagesStateProvider, capabilities, packageSyncMutex) + // we need to detect if the redirect was ever set, if not, we want default behaviour + if callbacks.CheckRedirect != nil { + h.client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + // viaResp only non-nil for ws client + return callbacks.CheckRedirect(req, via, nil) + } + } + for { pollingTimer := time.NewTimer(time.Millisecond * time.Duration(atomic.LoadInt64(&h.pollingIntervalMs))) select { diff --git a/client/types/callbacks.go b/client/types/callbacks.go index a5dc02ce..48d5f832 100644 --- a/client/types/callbacks.go +++ b/client/types/callbacks.go @@ -2,6 +2,7 @@ package types import ( "context" + "net/http" "github.com/open-telemetry/opamp-go/protobufs" ) @@ -116,6 +117,19 @@ type Callbacks struct { // OnCommand is called when the Server requests that the connected Agent perform a command. OnCommand func(ctx context.Context, command *protobufs.ServerToAgentCommand) error + + // CheckRedirect is called before following a redirect, allowing the client + // the opportunity to observe the redirect chain, and optionally terminate + // following redirects early. + // + // CheckRedirect is intended to be similar, although not exactly equivalent, + // to net/http.Client's CheckRedirect feature. Unlike in net/http, the via + // parameter is a slice of HTTP responses, instead of requests. This gives + // an opportunity to users to know what the exact response headers and + // status were. The request itself can be obtained from the response. + // + // The responses in the via parameter are passed with their bodies closed. + CheckRedirect func(req *http.Request, viaReq []*http.Request, via []*http.Response) error } func (c *Callbacks) SetDefaults() { diff --git a/client/wsclient.go b/client/wsclient.go index 6219a28f..f19d8ab4 100644 --- a/client/wsclient.go +++ b/client/wsclient.go @@ -48,6 +48,12 @@ type wsClient struct { // Network connection timeout used for the WebSocket closing handshake. // This field is currently only modified during testing. connShutdownTimeout time.Duration + + // responseChain is used for the "via" argument in CheckRedirect. + // It is appended to with every redirect followed, and zeroed on a succesful + // connection. responseChain should only be referred to by the goroutine that + // runs tryConnectOnce and its synchronous callees. + responseChain []*http.Response } // NewWebSocket creates a new OpAMP Client that uses WebSocket transport. @@ -151,11 +157,77 @@ func (c *wsClient) SendCustomMessage(message *protobufs.CustomMessage) (messageS return c.common.SendCustomMessage(message) } +func viaReq(resps []*http.Response) []*http.Request { + reqs := make([]*http.Request, 0, len(resps)) + for _, resp := range resps { + reqs = append(reqs, resp.Request) + } + return reqs +} + +// handleRedirect checks a failed websocket upgrade response for a 3xx response +// and a Location header. If found, it sets the URL to the location found in the +// header so that it is tried on the next retry, instead of the current URL. +func (c *wsClient) handleRedirect(ctx context.Context, resp *http.Response) error { + // append to the responseChain so that subsequent redirects will have access + c.responseChain = append(c.responseChain, resp) + + // very liberal handling of 3xx that largely ignores HTTP semantics + redirect, err := resp.Location() + if err != nil { + c.common.Logger.Errorf(ctx, "%d redirect, but no valid location: %s", resp.StatusCode, err) + return err + } + + // It's slightly tricky to make CheckRedirect work. The WS HTTP request is + // formed within the websocket library. To work around that, copy the + // previous request, available in the response, and set the URL to the new + // location. It should then result in the same URL that the websocket + // library will form. + nextRequest := resp.Request.Clone(ctx) + nextRequest.URL = redirect + + // if CheckRedirect results in an error, it gets returned, terminating + // redirection. As with stdlib, the error is wrapped in url.Error. + if c.common.Callbacks.CheckRedirect != nil { + if err := c.common.Callbacks.CheckRedirect(nextRequest, viaReq(c.responseChain), c.responseChain); err != nil { + return &url.Error{ + Op: "Get", + URL: nextRequest.URL.String(), + Err: err, + } + } + } + + // rewrite the scheme for the sake of tolerance + if redirect.Scheme == "http" { + redirect.Scheme = "ws" + } else if redirect.Scheme == "https" { + redirect.Scheme = "wss" + } + c.common.Logger.Debugf(ctx, "%d redirect to %s", resp.StatusCode, redirect) + + // Set the URL to the redirect, so that it connects to it on the + // next cycle. + c.url = redirect + + return nil +} + // Try to connect once. Returns an error if connection fails and optional retryAfter // duration to indicate to the caller to retry after the specified time as instructed // by the Server. func (c *wsClient) tryConnectOnce(ctx context.Context) (retryAfter sharedinternal.OptionalDuration, err error) { var resp *http.Response + var redirecting bool + defer func() { + if err != nil && !redirecting { + c.responseChain = nil + if !c.common.IsStopping() { + c.common.Callbacks.OnConnectFailed(ctx, err) + } + } + }() conn, resp, err := c.dialer.DialContext(ctx, c.url.String(), c.getHeader()) if err != nil { if !c.common.IsStopping() { @@ -164,22 +236,10 @@ func (c *wsClient) tryConnectOnce(ctx context.Context) (retryAfter sharedinterna if resp != nil { duration := sharedinternal.ExtractRetryAfterHeader(resp) if resp.StatusCode >= 300 && resp.StatusCode < 400 { - // very liberal handling of 3xx that largely ignores HTTP semantics - redirect, err := resp.Location() - if err != nil { - c.common.Logger.Errorf(ctx, "%d redirect, but no valid location: %s", resp.StatusCode, err) + redirecting = true + if err := c.handleRedirect(ctx, resp); err != nil { return duration, err } - // rewrite the scheme for the sake of tolerance - if redirect.Scheme == "http" { - redirect.Scheme = "ws" - } else if redirect.Scheme == "https" { - redirect.Scheme = "wss" - } - c.common.Logger.Debugf(ctx, "%d redirect to %s", resp.StatusCode, redirect) - // Set the URL to the redirect, so that it connects to it on the - // next cycle. - c.url = redirect } else { c.common.Logger.Errorf(ctx, "Server responded with status=%v", resp.Status) } diff --git a/client/wsclient_test.go b/client/wsclient_test.go index cc9fd87d..436ceb55 100644 --- a/client/wsclient_test.go +++ b/client/wsclient_test.go @@ -2,6 +2,7 @@ package client import ( "context" + "errors" "fmt" "net/http" "net/http/httptest" @@ -13,6 +14,7 @@ import ( "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "google.golang.org/protobuf/proto" @@ -322,12 +324,54 @@ func errServer() *httptest.Server { })) } +type checkRedirectMock struct { + mock.Mock + t testing.TB + viaLen int + http bool +} + +func (c *checkRedirectMock) CheckRedirect(req *http.Request, viaReq []*http.Request, via []*http.Response) error { + if req == nil { + c.t.Error("nil request in CheckRedirect") + return errors.New("nil request in CheckRedirect") + } + if len(viaReq) > c.viaLen { + c.t.Error("viaReq should be shorter than viaLen") + } + if !c.http { + // websocket transport + if len(via) > c.viaLen { + c.t.Error("via should be shorter than viaLen") + } + } + if !c.http && len(via) > 0 { + location, err := via[len(via)-1].Location() + if err != nil { + c.t.Error(err) + } + // the URL of the request should match the location header of the last response + assert.Equal(c.t, req.URL, location, "request URL should equal the location in the response") + } + return c.Called(req, via).Error(0) +} + +func mockRedirect(t testing.TB, viaLen int, err error) *checkRedirectMock { + m := &checkRedirectMock{ + t: t, + viaLen: viaLen, + } + m.On("CheckRedirect", mock.Anything, mock.Anything, mock.Anything).Return(err) + return m +} + func TestRedirectWS(t *testing.T) { redirectee := internal.StartMockServer(t) tests := []struct { - Name string - Redirector *httptest.Server - ExpError bool + Name string + Redirector *httptest.Server + ExpError bool + MockRedirect *checkRedirectMock }{ { Name: "redirect ws scheme", @@ -342,6 +386,17 @@ func TestRedirectWS(t *testing.T) { Redirector: errServer(), ExpError: true, }, + { + Name: "check redirect", + Redirector: redirectServer("ws://"+redirectee.Endpoint, 302), + MockRedirect: mockRedirect(t, 1, nil), + }, + { + Name: "check redirect returns error", + Redirector: redirectServer("ws://"+redirectee.Endpoint, 302), + MockRedirect: mockRedirect(t, 1, errors.New("hello")), + ExpError: true, + }, } for _, test := range tests { @@ -366,6 +421,9 @@ func TestRedirectWS(t *testing.T) { }, }, } + if test.MockRedirect != nil { + settings.Callbacks.CheckRedirect = test.MockRedirect.CheckRedirect + } reURL, err := url.Parse(test.Redirector.URL) assert.NoError(t, err) reURL.Scheme = "ws" @@ -388,10 +446,69 @@ func TestRedirectWS(t *testing.T) { // Stop the client. err = client.Stop(context.Background()) assert.NoError(t, err) + + if test.MockRedirect != nil { + test.MockRedirect.AssertCalled(t, "CheckRedirect", mock.Anything, mock.Anything) + } }) } } +func TestRedirectWSFollowChain(t *testing.T) { + // test that redirect following is recursive + redirectee := internal.StartMockServer(t) + middle := redirectServer("http://"+redirectee.Endpoint, 302) + middleURL, err := url.Parse(middle.URL) + if err != nil { + // unlikely + t.Fatal(err) + } + redirector := redirectServer("http://"+middleURL.Host, 302) + + var conn atomic.Value + redirectee.OnWSConnect = func(c *websocket.Conn) { + conn.Store(c) + } + + // Start an OpAMP/WebSocket client. + var connected int64 + var connectErr atomic.Value + mr := mockRedirect(t, 2, nil) + settings := types.StartSettings{ + Callbacks: types.Callbacks{ + OnConnect: func(ctx context.Context) { + atomic.StoreInt64(&connected, 1) + }, + OnConnectFailed: func(ctx context.Context, err error) { + if err != websocket.ErrBadHandshake { + connectErr.Store(err) + } + }, + CheckRedirect: mr.CheckRedirect, + }, + } + reURL, err := url.Parse(redirector.URL) + if err != nil { + // unlikely + t.Fatal(err) + } + reURL.Scheme = "ws" + settings.OpAMPServerURL = reURL.String() + client := NewWebSocket(nil) + startClient(t, settings, client) + + // Wait for connection to be established. + eventually(t, func() bool { + return conn.Load() != nil || connectErr.Load() != nil || client.lastInternalErr.Load() != nil + }) + + assert.True(t, connectErr.Load() == nil) + + // Stop the client. + err = client.Stop(context.Background()) + assert.NoError(t, err) +} + func TestHandlesStopBeforeStart(t *testing.T) { client := NewWebSocket(nil) require.Error(t, client.Stop(context.Background())) diff --git a/go.mod b/go.mod index 2742c8a9..4b9746d1 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.22 require ( github.com/cenkalti/backoff/v4 v4.3.0 + github.com/google/go-cmp v0.5.6 github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.3 github.com/stretchr/testify v1.10.0 @@ -12,8 +13,8 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/google/go-cmp v0.5.6 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/objx v0.5.2 // indirect golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 3390120c..ea122d10 100644 --- a/go.sum +++ b/go.sum @@ -12,6 +12,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=