diff --git a/conjure-go-client/httpclient/backoff_middleware.go b/conjure-go-client/httpclient/backoff_middleware.go new file mode 100644 index 00000000..44a72d79 --- /dev/null +++ b/conjure-go-client/httpclient/backoff_middleware.go @@ -0,0 +1,86 @@ +// Copyright (c) 2022 Palantir Technologies. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httpclient + +import ( + "net/http" + "net/url" + + "github.com/palantir/conjure-go-runtime/v2/conjure-go-client/httpclient/internal" + "github.com/palantir/pkg/retry" +) + +type backoffMiddleware struct { + retrier retry.Retrier + attemptedURIs map[string]struct{} + backoffFunc func() +} + +// NewBackoffMiddleware returns middleware that uses a supplied Retrier to backoff before making requests if the client +// has attempted to reach the URI before or has sent too many requests. +func NewBackoffMiddleware(retrier retry.Retrier) Middleware { + return &backoffMiddleware{ + retrier: retrier, + attemptedURIs: map[string]struct{}{}, + } +} + +func (b *backoffMiddleware) RoundTrip(req *http.Request, next http.RoundTripper) (*http.Response, error) { + b.backoffRequest(req) + resp, err := next.RoundTrip(req) + b.handleResponse(err) + return resp, err +} + +func (b *backoffMiddleware) backoffRequest(req *http.Request) { + baseURI := getBaseURI(req.URL) + defer func() { + b.attemptedURIs[baseURI] = struct{}{} + }() + // Use backoffFunc if backoff behavior was determined by previous response e.g. throttle on 429 + if b.backoffFunc != nil { + b.backoffFunc() + b.backoffFunc = nil + return + } + // Trigger retrier on first attempt so that future attempts have backoff + if len(b.attemptedURIs) == 0 { + b.retrier.Next() + } + // Trigger retrier for backoff if URI was attempted before + if _, performBackoff := b.attemptedURIs[baseURI]; performBackoff { + b.retrier.Next() + } +} + +func (b *backoffMiddleware) handleResponse(err error) { + errCode, _ := StatusCodeFromError(err) + switch errCode { + case internal.StatusCodeRetryOther, internal.StatusCodeRetryTemporaryRedirect: + b.retrier.Reset() + case internal.StatusCodeThrottle: + b.backoffFunc = func() { b.retrier.Next() } + } +} + +func getBaseURI(u *url.URL) string { + uCopy := url.URL{ + Scheme: u.Scheme, + Opaque: u.Opaque, + User: u.User, + Host: u.Host, + } + return uCopy.String() +} diff --git a/conjure-go-client/httpclient/body_handler.go b/conjure-go-client/httpclient/body_handler.go index d80bf013..9c8ad0c0 100644 --- a/conjure-go-client/httpclient/body_handler.go +++ b/conjure-go-client/httpclient/body_handler.go @@ -29,9 +29,6 @@ type bodyMiddleware struct { requestInput interface{} requestEncoder codecs.Encoder - // if rawOutput is true, the body of the response is not drained before returning -- it is the responsibility of the - // caller to read from and properly close the response body. - rawOutput bool responseOutput interface{} responseDecoder codecs.Decoder @@ -62,13 +59,12 @@ func (b *bodyMiddleware) setRequestBody(req *http.Request) (func(), error) { return cleanup, nil } - // Special case: if the requestInput is an io.ReadCloser and the requestEncoder is nil, - // use the provided input directly as the request body. - if bodyReadCloser, ok := b.requestInput.(io.ReadCloser); ok && b.requestEncoder == nil { - req.Body = bodyReadCloser - // Use the same heuristic as http.NewRequest to generate the "GetBody" function. - if newReq, err := http.NewRequest("", "", bodyReadCloser); err == nil { - req.GetBody = newReq.GetBody + // Special case: if the requestInput is a getBody function and the requestEncoder is nil, + // use the provided function to directly as the request body. + if getBody, ok := b.requestInput.(func() io.ReadCloser); ok && b.requestEncoder == nil { + req.Body = getBody() + req.GetBody = func() (io.ReadCloser, error) { + return getBody(), nil } return cleanup, nil } @@ -101,11 +97,6 @@ func (b *bodyMiddleware) setRequestBody(req *http.Request) (func(), error) { } func (b *bodyMiddleware) readResponse(resp *http.Response, respErr error) error { - // If rawOutput is true, return response directly without draining or closing body - if b.rawOutput && respErr == nil { - return nil - } - if respErr != nil { return respErr } diff --git a/conjure-go-client/httpclient/client.go b/conjure-go-client/httpclient/client.go index 136b9752..c5045927 100644 --- a/conjure-go-client/httpclient/client.go +++ b/conjure-go-client/httpclient/client.go @@ -18,14 +18,13 @@ import ( "context" "net/http" "net/url" - "strings" "github.com/palantir/conjure-go-runtime/v2/conjure-go-client/httpclient/internal" "github.com/palantir/conjure-go-runtime/v2/conjure-go-client/httpclient/internal/refreshingclient" "github.com/palantir/pkg/bytesbuffers" "github.com/palantir/pkg/refreshable" + "github.com/palantir/pkg/retry" werror "github.com/palantir/witchcraft-go-error" - "github.com/palantir/witchcraft-go-logging/wlog/svclog/svc1log" ) // A Client executes requests to a configured service. @@ -87,47 +86,59 @@ func (c *clientImpl) Do(ctx context.Context, params ...RequestParam) (*http.Resp return nil, werror.ErrorWithContextParams(ctx, "no base URIs are configured") } - attempts := 2 * len(uris) + maxAttempts := 2 * len(uris) if c.maxAttempts != nil { if confMaxAttempts := c.maxAttempts.CurrentIntPtr(); confMaxAttempts != nil { - attempts = *confMaxAttempts + maxAttempts = *confMaxAttempts } } - - var err error + b, err := applyRequestParams(c.bufferPool, params...) + if err != nil { + return nil, err + } + for _, c := range b.configureCtx { + ctx = c(ctx) + } + req, err := getRequest(ctx, b) + if err != nil { + return nil, err + } + cancelled := false + cancelFunc := func() { cancelled = true } + retrier := c.backoffOptions.CurrentRetryParams().Start(ctx) + clientCopy := c.getClientCopyWithMiddleware(b.errorDecoderMiddleware, b.bodyMiddleware, uris, retrier, cancelFunc) + attempts := 0 var resp *http.Response - - retrier := internal.NewRequestRetrier(uris, c.backoffOptions.CurrentRetryParams().Start(ctx), attempts) - for { - uri, isRelocated := retrier.GetNextURI(resp, err) - if uri == "" { - break + for !cancelled && (maxAttempts == 0 || attempts < maxAttempts) { + reqCopy := req.Clone(ctx) + resp, err = clientCopy.Do(reqCopy) + err = unwrapURLError(ctx, err) + // unless this is exactly the scenario where the caller has opted into being responsible for draining and closing + // the response body, be sure to do so here. + if !b.rawOutput { + internal.DrainBody(resp) } - if err != nil { - svc1log.FromContext(ctx).Debug("Retrying request", svc1log.Stacktrace(err)) + attempts++ + if resp != nil && isSuccessfulOrBadRequest(resp.StatusCode) { + break } - resp, err = c.doOnce(ctx, uri, isRelocated, params...) } if err != nil { return nil, err } - return resp, nil + return resp, err } -func (c *clientImpl) doOnce( - ctx context.Context, - baseURI string, - useBaseURIOnly bool, - params ...RequestParam, -) (*http.Response, error) { +func isSuccessfulOrBadRequest(statusCode int) bool { + return statusCode < 300 || (statusCode >= http.StatusBadRequest && statusCode < http.StatusInternalServerError) +} - // 1. create the request +func applyRequestParams(bufferPool bytesbuffers.Pool, params ...RequestParam) (*requestBuilder, error) { b := &requestBuilder{ headers: make(http.Header), query: make(url.Values), - bodyMiddleware: &bodyMiddleware{bufferPool: c.bufferPool}, + bodyMiddleware: &bodyMiddleware{bufferPool: bufferPool}, } - for _, p := range params { if p == nil { continue @@ -136,29 +147,25 @@ func (c *clientImpl) doOnce( return nil, err } } - if useBaseURIOnly { - b.path = "" - } - - for _, c := range b.configureCtx { - ctx = c(ctx) - } + return b, nil +} +func getRequest(ctx context.Context, b *requestBuilder) (*http.Request, error) { if b.method == "" { return nil, werror.ErrorWithContextParams(ctx, "httpclient: use WithRequestMethod() to specify HTTP method") } - reqURI := joinURIAndPath(baseURI, b.path) - req, err := http.NewRequest(b.method, reqURI, nil) + req, err := http.NewRequestWithContext(ctx, b.method, b.path, nil) if err != nil { return nil, werror.WrapWithContextParams(ctx, err, "failed to build new HTTP request") } - req = req.WithContext(ctx) req.Header = b.headers if q := b.query.Encode(); q != "" { req.URL.RawQuery = q } + return req, nil +} - // 2. create the transport and client +func (c *clientImpl) getClientCopyWithMiddleware(errorDecoderMiddleware Middleware, bodyMiddleware *bodyMiddleware, uris []string, backoffRetrier retry.Retrier, cancelFunc func()) http.Client { // shallow copy so we can overwrite the Transport with a wrapped one. clientCopy := *c.client.CurrentHTTPClient() transport := clientCopy.Transport // start with the client's transport configured with default middleware @@ -167,28 +174,22 @@ func (c *clientImpl) doOnce( transport = wrapTransport(transport, c.uriScorer.CurrentURIScoringMiddleware()) // request decoder must precede the client decoder // must precede the body middleware to read the response body - transport = wrapTransport(transport, b.errorDecoderMiddleware, c.errorDecoderMiddleware) + transport = wrapTransport(transport, errorDecoderMiddleware, c.errorDecoderMiddleware) // must precede the body middleware to read the request body transport = wrapTransport(transport, c.middlewares...) // must wrap inner middlewares to mutate the return values - transport = wrapTransport(transport, b.bodyMiddleware) + transport = wrapTransport(transport, bodyMiddleware) + // must precede URI middleware to track attempted URIs + transport = wrapTransport(transport, NewBackoffMiddleware(backoffRetrier)) + // must wrap inner middlewares to update request with resolved URL + transport = wrapTransport(transport, NewURIMiddleware(uris, cancelFunc)) // must be the outermost middleware to recover panics in the rest of the request flow // there is a second, inner recoveryMiddleware in the client's default middlewares so that panics // inside the inner-most RoundTrip benefit from traceIDs and loggers set on the context. transport = wrapTransport(transport, c.recoveryMiddleware) clientCopy.Transport = transport - - // 3. execute the request using the client to get and handle the response - resp, respErr := clientCopy.Do(req) - - // unless this is exactly the scenario where the caller has opted into being responsible for draining and closing - // the response body, be sure to do so here. - if !(respErr == nil && b.bodyMiddleware.rawOutput) { - internal.DrainBody(resp) - } - - return resp, unwrapURLError(ctx, respErr) + return clientCopy } // unwrapURLError converts a *url.Error to a werror. We need this because all @@ -205,21 +206,5 @@ func unwrapURLError(ctx context.Context, respErr error) error { // We don't recognize this as a url.Error, just return the original. return respErr } - params := []werror.Param{werror.SafeParam("requestMethod", urlErr.Op)} - - if parsedURL, _ := url.Parse(urlErr.URL); parsedURL != nil { - params = append(params, - werror.SafeParam("requestHost", parsedURL.Host), - werror.UnsafeParam("requestPath", parsedURL.Path)) - } - - return werror.WrapWithContextParams(ctx, urlErr.Err, "httpclient request failed", params...) -} - -func joinURIAndPath(baseURI, reqPath string) string { - fullURI := strings.TrimRight(baseURI, "/") - if reqPath != "" { - fullURI += "/" + strings.TrimLeft(reqPath, "/") - } - return fullURI + return urlErr.Err } diff --git a/conjure-go-client/httpclient/client_params.go b/conjure-go-client/httpclient/client_params.go index e0cef56f..2e1140ad 100644 --- a/conjure-go-client/httpclient/client_params.go +++ b/conjure-go-client/httpclient/client_params.go @@ -474,8 +474,7 @@ func WithInitialBackoff(initialBackoff time.Duration) ClientParam { }) } -// WithMaxRetries sets the maximum number of retries on transport errors for every request. Backoffs are -// also capped at this. +// WithMaxRetries sets the maximum number of retries on transport errors for every request. // If unset, the client defaults to 2 * size of URIs // TODO (#151): Rename to WithMaxAttempts and set maxAttempts directly using the argument provided to the function. func WithMaxRetries(maxTransportRetries int) ClientParam { diff --git a/conjure-go-client/httpclient/client_path_test.go b/conjure-go-client/httpclient/client_path_test.go deleted file mode 100644 index 744a8819..00000000 --- a/conjure-go-client/httpclient/client_path_test.go +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright (c) 2019 Palantir Technologies. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package httpclient - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestJoinURIandPath(t *testing.T) { - for _, test := range []struct { - baseURI string - reqPath string - expected string - }{ - { - "https://localhost", - "/api", - "https://localhost/api", - }, - { - "https://localhost:443", - "/api", - "https://localhost:443/api", - }, - { - "https://localhost:443", - "api", - "https://localhost:443/api", - }, - { - "https://localhost:443/", - "api", - "https://localhost:443/api", - }, - { - "https://localhost:443/foo/", - "/api", - "https://localhost:443/foo/api", - }, - { - "https://localhost:443/foo//////", - "////api/", - "https://localhost:443/foo/api/", - }, - { - "https://localhost:443/foo/", - "/api", - "https://localhost:443/foo/api", - }, - { - "https://localhost", - "", - "https://localhost", - }, - { - "https://localhost/api", - "", - "https://localhost/api", - }, - { - "https://localhost", - "/api/+ti%2FQojjmKJxpxmY%2FA=", - "https://localhost/api/+ti%2FQojjmKJxpxmY%2FA=", - }, - } { - t.Run("", func(t *testing.T) { - actual := joinURIAndPath(test.baseURI, test.reqPath) - assert.Equal(t, test.expected, actual) - }) - } -} diff --git a/conjure-go-client/httpclient/internal/request_retrier.go b/conjure-go-client/httpclient/internal/request_retrier.go index e7c738f7..37c5942d 100644 --- a/conjure-go-client/httpclient/internal/request_retrier.go +++ b/conjure-go-client/httpclient/internal/request_retrier.go @@ -76,13 +76,13 @@ func (r *RequestRetrier) GetNextURI(resp *http.Response, respErr error) (uri str // but ignore the returned value to ensure that the client can instrument the request even // if the context is done. r.retrier.Next() - return r.removeMeshSchemeIfPresent(r.currentURI), false + return removeMeshSchemeIfPresent(r.currentURI), false } if !r.attemptsRemaining() { // Retries exhausted return "", false } - if r.isMeshURI(r.currentURI) { + if isMeshURI(r.currentURI) { // Mesh uris don't get retried return "", false } @@ -166,14 +166,14 @@ func (r *RequestRetrier) markFailedAndMoveToNextURI() { r.offset = nextURIOffset } -func (r *RequestRetrier) removeMeshSchemeIfPresent(uri string) string { - if r.isMeshURI(uri) { +func removeMeshSchemeIfPresent(uri string) string { + if isMeshURI(uri) { return strings.Replace(uri, meshSchemePrefix, "", 1) } return uri } -func (r *RequestRetrier) isMeshURI(uri string) bool { +func isMeshURI(uri string) bool { return strings.HasPrefix(uri, meshSchemePrefix) } diff --git a/conjure-go-client/httpclient/internal/retry.go b/conjure-go-client/httpclient/internal/retry.go index f4c6b484..4253ad6a 100644 --- a/conjure-go-client/httpclient/internal/retry.go +++ b/conjure-go-client/httpclient/internal/retry.go @@ -60,24 +60,24 @@ const ( ) func isRetryOtherResponse(resp *http.Response, err error, errCode int) (bool, *url.URL) { - if errCode == StatusCodeRetryOther || errCode == StatusCodeRetryTemporaryRedirect { + if isRetryOtherStatusCode(errCode) { locationStr, ok := LocationFromError(err) if ok { return true, parseLocationURL(locationStr) } } - if resp == nil { - return false, nil - } - if resp.StatusCode != StatusCodeRetryOther && - resp.StatusCode != StatusCodeRetryTemporaryRedirect { + if resp == nil || !isRetryOtherStatusCode(resp.StatusCode) { return false, nil } locationStr := resp.Header.Get("Location") return true, parseLocationURL(locationStr) } +func isRetryOtherStatusCode(statusCode int) bool { + return statusCode == StatusCodeRetryOther || statusCode == StatusCodeRetryTemporaryRedirect +} + func parseLocationURL(locationStr string) *url.URL { if locationStr == "" { return nil diff --git a/conjure-go-client/httpclient/request_builder.go b/conjure-go-client/httpclient/request_builder.go index 88998f04..e0b0189f 100644 --- a/conjure-go-client/httpclient/request_builder.go +++ b/conjure-go-client/httpclient/request_builder.go @@ -30,6 +30,7 @@ type requestBuilder struct { query url.Values bodyMiddleware *bodyMiddleware bufferPool bytesbuffers.Pool + rawOutput bool errorDecoderMiddleware Middleware configureCtx []func(context.Context) context.Context diff --git a/conjure-go-client/httpclient/request_params.go b/conjure-go-client/httpclient/request_params.go index c2fffeb6..795365d6 100644 --- a/conjure-go-client/httpclient/request_params.go +++ b/conjure-go-client/httpclient/request_params.go @@ -125,7 +125,7 @@ func WithRawRequestBodyProvider(getBody func() io.ReadCloser) RequestParam { if getBody == nil { return werror.Error("getBody can not be nil") } - b.bodyMiddleware.requestInput = getBody() + b.bodyMiddleware.requestInput = getBody b.bodyMiddleware.requestEncoder = nil b.headers.Set("Content-Type", "application/octet-stream") return nil @@ -169,7 +169,7 @@ func WithResponseBody(output interface{}, decoder codecs.Decoder) RequestParam { // In the case of an empty response, output will be unmodified (left nil). func WithRawResponseBody() RequestParam { return requestParamFunc(func(b *requestBuilder) error { - b.bodyMiddleware.rawOutput = true + b.rawOutput = true b.bodyMiddleware.responseOutput = nil b.bodyMiddleware.responseDecoder = nil b.headers.Set("Accept", "application/octet-stream") diff --git a/conjure-go-client/httpclient/response_error_decoder_middleware_test.go b/conjure-go-client/httpclient/response_error_decoder_middleware_test.go index 846b3afb..53e1564b 100644 --- a/conjure-go-client/httpclient/response_error_decoder_middleware_test.go +++ b/conjure-go-client/httpclient/response_error_decoder_middleware_test.go @@ -94,7 +94,7 @@ func TestErrorDecoderMiddlewares(t *testing.T) { verify404(t, err) assert.EqualError(t, err, "httpclient request failed: 404 Not Found") safeParams, unsafeParams := werror.ParamsFromError(err) - assert.Equal(t, map[string]interface{}{"requestHost": u.Host, "requestMethod": "Get", "statusCode": 404}, safeParams) + assert.Equal(t, map[string]interface{}{"requestHost": u.Host, "requestMethod": "GET", "statusCode": 404}, safeParams) assert.Equal(t, map[string]interface{}{"requestPath": "/path", "responseBody": "404 page not found\n"}, unsafeParams) }, }, @@ -107,7 +107,7 @@ func TestErrorDecoderMiddlewares(t *testing.T) { verify404(t, err) assert.EqualError(t, err, "httpclient request failed: 404 Not Found") safeParams, unsafeParams := werror.ParamsFromError(err) - assert.Equal(t, map[string]interface{}{"requestHost": u.Host, "requestMethod": "Get", "statusCode": 404}, safeParams) + assert.Equal(t, map[string]interface{}{"requestHost": u.Host, "requestMethod": "GET", "statusCode": 404}, safeParams) assert.Equal(t, map[string]interface{}{"requestPath": "/path"}, unsafeParams) }, }, @@ -122,7 +122,7 @@ func TestErrorDecoderMiddlewares(t *testing.T) { verify404(t, err) assert.EqualError(t, err, "httpclient request failed: 404 Not Found") safeParams, unsafeParams := werror.ParamsFromError(err) - assert.Equal(t, map[string]interface{}{"requestHost": u.Host, "requestMethod": "Get", "statusCode": 404}, safeParams) + assert.Equal(t, map[string]interface{}{"requestHost": u.Host, "requestMethod": "GET", "statusCode": 404}, safeParams) assert.Equal(t, map[string]interface{}{"requestPath": "/path", "responseBody": "route does not exist"}, unsafeParams) }, }, @@ -137,7 +137,7 @@ func TestErrorDecoderMiddlewares(t *testing.T) { verify404(t, err) assert.EqualError(t, err, "httpclient request failed: failed to unmarshal body using registered type: errors: error name does not match regexp `^(([A-Z][a-z0-9]+)+):(([A-Z][a-z0-9]+)+)$`") safeParams, unsafeParams := werror.ParamsFromError(err) - assert.Equal(t, map[string]interface{}{"requestHost": u.Host, "requestMethod": "Get", "statusCode": 404, "type": "errors.genericError"}, safeParams) + assert.Equal(t, map[string]interface{}{"requestHost": u.Host, "requestMethod": "GET", "statusCode": 404, "type": "errors.genericError"}, safeParams) assert.Equal(t, map[string]interface{}{"requestPath": "/path", "responseBody": `{"foo":"bar"}`}, unsafeParams) }, }, @@ -159,7 +159,7 @@ func TestErrorDecoderMiddlewares(t *testing.T) { assert.Equal(t, errors.DefaultNotFound.Name(), conjureErr.Name()) safeParams, unsafeParams := werror.ParamsFromError(err) - assert.Equal(t, map[string]interface{}{"requestHost": u.Host, "requestMethod": "Get", "errorInstanceId": id, "errorName": "Default:NotFound", "statusCode": 404}, safeParams) + assert.Equal(t, map[string]interface{}{"requestHost": u.Host, "requestMethod": "GET", "errorInstanceId": id, "errorName": "Default:NotFound", "statusCode": 404}, safeParams) assert.Equal(t, map[string]interface{}{"requestPath": "/path", "stringParam": "stringValue"}, unsafeParams) }, }, @@ -170,7 +170,7 @@ func TestErrorDecoderMiddlewares(t *testing.T) { verify: func(t *testing.T, u *url.URL, err error) { assert.EqualError(t, err, "httpclient request failed: foo error") safeParams, unsafeParams := werror.ParamsFromError(err) - assert.Equal(t, map[string]interface{}{"requestHost": u.Host, "requestMethod": "Get"}, safeParams) + assert.Equal(t, map[string]interface{}{"requestHost": u.Host, "requestMethod": "GET"}, safeParams) assert.Equal(t, map[string]interface{}{"requestPath": "/path"}, unsafeParams) }, }, @@ -181,7 +181,7 @@ func TestErrorDecoderMiddlewares(t *testing.T) { verify: func(t *testing.T, u *url.URL, err error) { assert.EqualError(t, err, "httpclient request failed: error from body: 404 page not found\n") safeParams, unsafeParams := werror.ParamsFromError(err) - assert.Equal(t, map[string]interface{}{"requestHost": u.Host, "requestMethod": "Get"}, safeParams) + assert.Equal(t, map[string]interface{}{"requestHost": u.Host, "requestMethod": "GET"}, safeParams) assert.Equal(t, map[string]interface{}{"requestPath": "/path"}, unsafeParams) }, }, diff --git a/conjure-go-client/httpclient/uri_middleware.go b/conjure-go-client/httpclient/uri_middleware.go new file mode 100644 index 00000000..501ef5f3 --- /dev/null +++ b/conjure-go-client/httpclient/uri_middleware.go @@ -0,0 +1,132 @@ +// Copyright (c) 2022 Palantir Technologies. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httpclient + +import ( + "net/http" + urlpkg "net/url" + "strings" + + werror "github.com/palantir/witchcraft-go-error" +) + +const ( + meshSchemePrefix = "mesh-" +) + +type uriMiddleware struct { + uris []string + offset int + redirectURL *urlpkg.URL + cancelFunc func() +} + +func NewURIMiddleware(uris []string, cancelFunc func()) Middleware { + offset := 0 + return &uriMiddleware{ + uris: uris, + offset: offset, + redirectURL: nil, + cancelFunc: cancelFunc, + } +} + +func (u *uriMiddleware) RoundTrip(req *http.Request, next http.RoundTripper) (*http.Response, error) { + url, err := u.getURL(req) + if err != nil { + return nil, err + } + req.URL = url + req.Host = url.Host + resp, err := next.RoundTrip(req) + if _, redirectURL := isRedirectError(err); redirectURL != nil { + if !redirectURL.IsAbs() { + redirectURL = req.URL.ResolveReference(redirectURL) + } + u.redirectURL = redirectURL + } + if err != nil { + params := []werror.Param{ + werror.SafeParam("requestMethod", req.Method), + werror.SafeParam("requestHost", url.Host), + werror.UnsafeParam("requestPath", url.Path), + } + return nil, werror.Wrap(err, "httpclient request failed", params...) + } + return resp, err +} + +func (u *uriMiddleware) getURL(req *http.Request) (*urlpkg.URL, error) { + if u.redirectURL != nil { + defer func() { + u.redirectURL = nil + }() + return u.redirectURL, nil + } + uri := u.uris[u.offset] + u.offset = (u.offset + 1) % len(u.uris) + if isMeshURI(uri) { + // Mesh URIs should not be retried + u.cancelFunc() + uri = strings.Replace(uri, meshSchemePrefix, "", 1) + } + parsedURL, err := urlpkg.Parse(uri) + if err != nil { + return nil, err + } + removeEmptyPort(parsedURL) + return parsedURL.ResolveReference(req.URL), nil +} + +// removeEmptyPort() strips the empty port in ":port" to "" +// as mandated by RFC 3986 Section 6.2.3. +func removeEmptyPort(url *urlpkg.URL) { + if url.Port() == "" { + url.Host = strings.TrimSuffix(url.Host, ":") + } +} + +func isMeshURI(uri string) bool { + return strings.HasPrefix(uri, meshSchemePrefix) +} + +func isRedirectError(err error) (bool, *urlpkg.URL) { + errCode, _ := StatusCodeFromError(err) + if !isRedirectStatusCode(errCode) { + return false, nil + } + _, unsafeParams := werror.ParamsFromError(err) + locationStr, ok := unsafeParams["location"].(string) + if !ok { + return true, nil + } + return true, parseLocationURL(locationStr) +} + +func isRedirectStatusCode(statusCode int) bool { + return statusCode == http.StatusTemporaryRedirect || statusCode == http.StatusPermanentRedirect +} + +func parseLocationURL(locationStr string) *urlpkg.URL { + if locationStr == "" { + return nil + } + locationURL, err := urlpkg.Parse(locationStr) + if err != nil { + // Unable to parse location as something we recognize + return nil + } + return locationURL +}