Skip to content

Commit

Permalink
Add CheckRedirect callback (#269)
Browse files Browse the repository at this point in the history
This commit adds a CheckRedirect callback that opamp-go will call before
following a redirect from the server it's trying to connect to. Like in
net/http, CheckRedirect can be used to observe the request chain that
the client is taking while attempting to make a connection.

The user can optionally terminate redirect following by returning an
error from CheckRedirect.

Unlike in net/http, the via parameter for CheckRedirect is a slice of
responses. Since the user would have no other way to access these in the
context of opamp-go, CheckRedirect makes them available so that users
can know exactly what status codes and headers are set in the response.

Another small improvement is that the error callback is no longer called
when redirecting. This should help to prevent undue error logging by
opamp-go consumers. Since the CheckRedirect callback is now available,
it also doesn't represent any loss in functionality to opamp-go
consumers.
  • Loading branch information
echlebek authored Jan 6, 2025
1 parent bfdb952 commit f75c3c9
Show file tree
Hide file tree
Showing 7 changed files with 308 additions and 18 deletions.
88 changes: 88 additions & 0 deletions client/httpclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@ package client
import (
"compress/gzip"
"context"
"errors"
"io"
"net/http"
"net/http/httptest"
"net/url"
"sync/atomic"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"google.golang.org/protobuf/proto"

"github.com/open-telemetry/opamp-go/client/internal"
Expand Down Expand Up @@ -223,3 +227,87 @@ func TestHTTPClientStartWithZeroHeartbeatInterval(t *testing.T) {
// Shutdown the Server.
srv.Close()
}

func mockRedirectHTTP(t testing.TB, viaLen int, err error) *checkRedirectMock {
m := &checkRedirectMock{
t: t,
viaLen: viaLen,
http: true,
}
m.On("CheckRedirect", mock.Anything, mock.Anything, mock.Anything).Return(err)
return m
}

func TestRedirectHTTP(t *testing.T) {
redirectee := internal.StartMockServer(t)
tests := []struct {
Name string
Redirector *httptest.Server
ExpError bool
MockRedirect *checkRedirectMock
}{
{
Name: "simple redirect",
Redirector: redirectServer("http://"+redirectee.Endpoint, 302),
},
{
Name: "check redirect",
Redirector: redirectServer("http://"+redirectee.Endpoint, 302),
MockRedirect: mockRedirectHTTP(t, 1, nil),
},
{
Name: "check redirect returns error",
Redirector: redirectServer("http://"+redirectee.Endpoint, 302),
MockRedirect: mockRedirectHTTP(t, 1, errors.New("hello")),
ExpError: true,
},
}

for _, test := range tests {
t.Run(test.Name, func(t *testing.T) {
var connectErr atomic.Value
var connected atomic.Value

settings := &types.StartSettings{
Callbacks: types.Callbacks{
OnConnect: func(ctx context.Context) {
connected.Store(1)
},
OnConnectFailed: func(ctx context.Context, err error) {
connectErr.Store(err)
},
},
}
if test.MockRedirect != nil {
settings.Callbacks = types.Callbacks{
OnConnect: func(ctx context.Context) {
connected.Store(1)
},
OnConnectFailed: func(ctx context.Context, err error) {
connectErr.Store(err)
},
CheckRedirect: test.MockRedirect.CheckRedirect,
}
}
reURL, _ := url.Parse(test.Redirector.URL) // err can't be non-nil
settings.OpAMPServerURL = reURL.String()
client := NewHTTP(nil)
prepareClient(t, settings, client)

err := client.Start(context.Background(), *settings)
if err != nil {
t.Fatal(err)
}
defer client.Stop(context.Background())
// Wait for connection to be established.
eventually(t, func() bool {
return connected.Load() != nil || connectErr.Load() != nil
})
if test.ExpError && connectErr.Load() == nil {
t.Error("expected non-nil error")
} else if err := connectErr.Load(); !test.ExpError && err != nil {
t.Fatal(err)
}
})
}
}
8 changes: 8 additions & 0 deletions client/internal/httpsender.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@ func (h *HTTPSender) Run(
h.callbacks = callbacks
h.receiveProcessor = newReceivedProcessor(h.logger, callbacks, h, clientSyncedState, packagesStateProvider, capabilities, packageSyncMutex)

// we need to detect if the redirect was ever set, if not, we want default behaviour
if callbacks.CheckRedirect != nil {
h.client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
// viaResp only non-nil for ws client
return callbacks.CheckRedirect(req, via, nil)
}
}

for {
pollingTimer := time.NewTimer(time.Millisecond * time.Duration(atomic.LoadInt64(&h.pollingIntervalMs)))
select {
Expand Down
14 changes: 14 additions & 0 deletions client/types/callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package types

import (
"context"
"net/http"

"github.com/open-telemetry/opamp-go/protobufs"
)
Expand Down Expand Up @@ -116,6 +117,19 @@ type Callbacks struct {

// OnCommand is called when the Server requests that the connected Agent perform a command.
OnCommand func(ctx context.Context, command *protobufs.ServerToAgentCommand) error

// CheckRedirect is called before following a redirect, allowing the client
// the opportunity to observe the redirect chain, and optionally terminate
// following redirects early.
//
// CheckRedirect is intended to be similar, although not exactly equivalent,
// to net/http.Client's CheckRedirect feature. Unlike in net/http, the via
// parameter is a slice of HTTP responses, instead of requests. This gives
// an opportunity to users to know what the exact response headers and
// status were. The request itself can be obtained from the response.
//
// The responses in the via parameter are passed with their bodies closed.
CheckRedirect func(req *http.Request, viaReq []*http.Request, via []*http.Response) error
}

func (c *Callbacks) SetDefaults() {
Expand Down
88 changes: 74 additions & 14 deletions client/wsclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ type wsClient struct {
// Network connection timeout used for the WebSocket closing handshake.
// This field is currently only modified during testing.
connShutdownTimeout time.Duration

// responseChain is used for the "via" argument in CheckRedirect.
// It is appended to with every redirect followed, and zeroed on a succesful
// connection. responseChain should only be referred to by the goroutine that
// runs tryConnectOnce and its synchronous callees.
responseChain []*http.Response
}

// NewWebSocket creates a new OpAMP Client that uses WebSocket transport.
Expand Down Expand Up @@ -151,11 +157,77 @@ func (c *wsClient) SendCustomMessage(message *protobufs.CustomMessage) (messageS
return c.common.SendCustomMessage(message)
}

func viaReq(resps []*http.Response) []*http.Request {
reqs := make([]*http.Request, 0, len(resps))
for _, resp := range resps {
reqs = append(reqs, resp.Request)
}
return reqs
}

// handleRedirect checks a failed websocket upgrade response for a 3xx response
// and a Location header. If found, it sets the URL to the location found in the
// header so that it is tried on the next retry, instead of the current URL.
func (c *wsClient) handleRedirect(ctx context.Context, resp *http.Response) error {
// append to the responseChain so that subsequent redirects will have access
c.responseChain = append(c.responseChain, resp)

// very liberal handling of 3xx that largely ignores HTTP semantics
redirect, err := resp.Location()
if err != nil {
c.common.Logger.Errorf(ctx, "%d redirect, but no valid location: %s", resp.StatusCode, err)
return err
}

// It's slightly tricky to make CheckRedirect work. The WS HTTP request is
// formed within the websocket library. To work around that, copy the
// previous request, available in the response, and set the URL to the new
// location. It should then result in the same URL that the websocket
// library will form.
nextRequest := resp.Request.Clone(ctx)
nextRequest.URL = redirect

// if CheckRedirect results in an error, it gets returned, terminating
// redirection. As with stdlib, the error is wrapped in url.Error.
if c.common.Callbacks.CheckRedirect != nil {
if err := c.common.Callbacks.CheckRedirect(nextRequest, viaReq(c.responseChain), c.responseChain); err != nil {
return &url.Error{
Op: "Get",
URL: nextRequest.URL.String(),
Err: err,
}
}
}

// rewrite the scheme for the sake of tolerance
if redirect.Scheme == "http" {
redirect.Scheme = "ws"
} else if redirect.Scheme == "https" {
redirect.Scheme = "wss"
}
c.common.Logger.Debugf(ctx, "%d redirect to %s", resp.StatusCode, redirect)

// Set the URL to the redirect, so that it connects to it on the
// next cycle.
c.url = redirect

return nil
}

// Try to connect once. Returns an error if connection fails and optional retryAfter
// duration to indicate to the caller to retry after the specified time as instructed
// by the Server.
func (c *wsClient) tryConnectOnce(ctx context.Context) (retryAfter sharedinternal.OptionalDuration, err error) {
var resp *http.Response
var redirecting bool
defer func() {
if err != nil && !redirecting {
c.responseChain = nil
if !c.common.IsStopping() {
c.common.Callbacks.OnConnectFailed(ctx, err)
}
}
}()
conn, resp, err := c.dialer.DialContext(ctx, c.url.String(), c.getHeader())
if err != nil {
if !c.common.IsStopping() {
Expand All @@ -164,22 +236,10 @@ func (c *wsClient) tryConnectOnce(ctx context.Context) (retryAfter sharedinterna
if resp != nil {
duration := sharedinternal.ExtractRetryAfterHeader(resp)
if resp.StatusCode >= 300 && resp.StatusCode < 400 {
// very liberal handling of 3xx that largely ignores HTTP semantics
redirect, err := resp.Location()
if err != nil {
c.common.Logger.Errorf(ctx, "%d redirect, but no valid location: %s", resp.StatusCode, err)
redirecting = true
if err := c.handleRedirect(ctx, resp); err != nil {
return duration, err
}
// rewrite the scheme for the sake of tolerance
if redirect.Scheme == "http" {
redirect.Scheme = "ws"
} else if redirect.Scheme == "https" {
redirect.Scheme = "wss"
}
c.common.Logger.Debugf(ctx, "%d redirect to %s", resp.StatusCode, redirect)
// Set the URL to the redirect, so that it connects to it on the
// next cycle.
c.url = redirect
} else {
c.common.Logger.Errorf(ctx, "Server responded with status=%v", resp.Status)
}
Expand Down
Loading

0 comments on commit f75c3c9

Please sign in to comment.