Skip to content

Commit 1100e4e

Browse files
ItalyPaleAledapr-bot
andauthoredJul 17, 2023
Streaming: convert HTTP service invocation handler to net/http (dapr#6594)
* Streaming: convert HTTP service invocation handler to net/http Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> * Remove extra slashes in service invocation when using dapr-app-id header Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> * Re-enable test Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> --------- Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> Signed-off-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com> Co-authored-by: Dapr Bot <56698301+dapr-bot@users.noreply.github.com>
1 parent 7d26a5f commit 1100e4e

File tree

8 files changed

+125
-80
lines changed

8 files changed

+125
-80
lines changed
 

‎pkg/grpc/api.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ var (
249249
// Deprecated: Use proxy mode service invocation instead.
250250
func (a *api) InvokeService(ctx context.Context, in *runtimev1pb.InvokeServiceRequest) (*commonv1pb.InvokeResponse, error) {
251251
if a.directMessaging == nil {
252-
return nil, status.Errorf(codes.Internal, messages.ErrDirectInvokeNotReady)
252+
return nil, messages.ErrDirectInvokeNotReady
253253
}
254254

255255
if invokeServiceDeprecationNoticeShown.CompareAndSwap(false, true) {
@@ -285,7 +285,7 @@ func (a *api) InvokeService(ctx context.Context, in *runtimev1pb.InvokeServiceRe
285285
}
286286
}
287287
if rErr != nil {
288-
return rResp, status.Errorf(codes.Internal, messages.ErrDirectInvoke, in.Id, rErr)
288+
return rResp, messages.ErrDirectInvoke.WithFormat(in.Id, rErr)
289289
}
290290

291291
rResp.headers = invokev1.InternalMetadataToGrpcMetadata(ctx, imr.Headers(), true)

‎pkg/http/api.go

+45-40
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"io"
2424
nethttp "net/http"
2525
"net/url"
26+
"path"
2627
"strconv"
2728
"strings"
2829
"time"
@@ -57,6 +58,7 @@ import (
5758
"github.com/dapr/dapr/pkg/runtime/compstore"
5859
runtimePubsub "github.com/dapr/dapr/pkg/runtime/pubsub"
5960
"github.com/dapr/dapr/utils"
61+
"github.com/dapr/dapr/utils/responsewriter"
6062
)
6163

6264
// API returns a list of HTTP endpoints for Dapr.
@@ -278,7 +280,7 @@ func (a *api) constructDirectMessagingEndpoints() []Endpoint {
278280
IsFallback: true,
279281
Version: apiVersionV1,
280282
KeepWildcardUnescaped: true,
281-
FastHTTPHandler: a.onDirectMessage,
283+
Handler: a.onDirectMessage,
282284
},
283285
}
284286
}
@@ -1077,11 +1079,11 @@ func (a *api) getStateStoreName(reqCtx *fasthttp.RequestCtx) string {
10771079

10781080
type invokeError struct {
10791081
statusCode int
1080-
msg ErrorResponse
1082+
msg []byte
10811083
}
10821084

10831085
func (ie invokeError) Error() string {
1084-
return fmt.Sprintf("invokeError (statusCode='%d') msg.errorCode='%s' msg.message='%s'", ie.statusCode, ie.msg.ErrorCode, ie.msg.Message)
1086+
return fmt.Sprintf("invokeError (statusCode='%d') msg='%v'", ie.statusCode, string(ie.msg))
10851087
}
10861088

10871089
func (a *api) isHTTPEndpoint(appID string) bool {
@@ -1099,22 +1101,21 @@ func (a *api) getBaseURL(targetAppID string) string {
10991101
return ""
11001102
}
11011103

1102-
func (a *api) onDirectMessage(reqCtx *fasthttp.RequestCtx) {
1103-
targetID, invokeMethodName := findTargetIDAndMethod(string(reqCtx.URI().PathOriginal()), reqCtx.Request.Header.Peek)
1104+
func (a *api) onDirectMessage(w nethttp.ResponseWriter, r *nethttp.Request) {
1105+
targetID, invokeMethodName := findTargetIDAndMethod(r.URL.String(), r.Header)
11041106
if targetID == "" {
1105-
msg := NewErrorResponse("ERR_DIRECT_INVOKE", messages.ErrDirectInvokeNoAppID)
1106-
fasthttpRespond(reqCtx, fasthttpResponseWithError(nethttp.StatusNotFound, msg))
1107+
respondWithError(w, messages.ErrDirectInvokeNoAppID)
11071108
return
11081109
}
11091110

1110-
// Store target and method as user values so they can be picked up by the tracing library
1111-
reqCtx.SetUserValue("id", targetID)
1112-
reqCtx.SetUserValue("method", invokeMethodName)
1111+
// Store target and method as values in the context so they can be picked up by the tracing library
1112+
rw := responsewriter.EnsureResponseWriter(w)
1113+
rw.SetUserValue("id", targetID)
1114+
rw.SetUserValue("method", invokeMethodName)
11131115

1114-
verb := strings.ToUpper(string(reqCtx.Method()))
1116+
verb := strings.ToUpper(r.Method)
11151117
if a.directMessaging == nil {
1116-
msg := NewErrorResponse("ERR_DIRECT_INVOKE", messages.ErrDirectInvokeNotReady)
1117-
fasthttpRespond(reqCtx, fasthttpResponseWithError(nethttp.StatusInternalServerError, msg))
1118+
respondWithError(w, messages.ErrDirectInvokeNotReady)
11181119
return
11191120
}
11201121

@@ -1134,18 +1135,18 @@ func (a *api) onDirectMessage(reqCtx *fasthttp.RequestCtx) {
11341135
}
11351136

11361137
req := invokev1.NewInvokeMethodRequest(invokeMethodName).
1137-
WithHTTPExtension(verb, reqCtx.QueryArgs().String()).
1138-
WithRawDataBytes(reqCtx.Request.Body()).
1139-
WithContentType(string(reqCtx.Request.Header.ContentType())).
1138+
WithHTTPExtension(verb, r.URL.RawQuery).
1139+
WithRawData(r.Body).
1140+
WithContentType(r.Header.Get("content-type")).
11401141
// Save headers to internal metadata
1141-
WithFastHTTPHeaders(&reqCtx.Request.Header)
1142+
WithHTTPHeaders(r.Header)
11421143
if policyDef != nil {
11431144
req.WithReplay(policyDef.HasRetries())
11441145
}
11451146
defer req.Close()
11461147

11471148
policyRunner := resiliency.NewRunnerWithOptions(
1148-
reqCtx, policyDef,
1149+
r.Context(), policyDef,
11491150
resiliency.RunnerOpts[*invokev1.InvokeMethodResponse]{
11501151
Disposer: resiliency.DisposerCloser[*invokev1.InvokeMethodResponse],
11511152
},
@@ -1156,9 +1157,10 @@ func (a *api) onDirectMessage(reqCtx *fasthttp.RequestCtx) {
11561157
if rErr != nil {
11571158
// Allowlist policies that are applied on the callee side can return a Permission Denied error.
11581159
// For everything else, treat it as a gRPC transport error
1160+
apiErr := messages.ErrDirectInvoke.WithFormat(targetID, rErr)
11591161
invokeErr := invokeError{
1160-
statusCode: nethttp.StatusInternalServerError,
1161-
msg: NewErrorResponse("ERR_DIRECT_INVOKE", fmt.Sprintf(messages.ErrDirectInvoke, targetID, rErr)),
1162+
statusCode: apiErr.HTTPCode(),
1163+
msg: apiErr.JSONErrorValue(),
11621164
}
11631165

11641166
if status.Code(rErr) == codes.PermissionDenied {
@@ -1181,7 +1183,7 @@ func (a *api) onDirectMessage(reqCtx *fasthttp.RequestCtx) {
11811183
if rErr != nil {
11821184
return rResp, invokeError{
11831185
statusCode: nethttp.StatusInternalServerError,
1184-
msg: NewErrorResponse("ERR_MALFORMED_RESPONSE", rErr.Error()),
1186+
msg: NewErrorResponse("ERR_MALFORMED_RESPONSE", rErr.Error()).JSONErrorValue(),
11851187
}
11861188
}
11871189
} else {
@@ -1197,67 +1199,70 @@ func (a *api) onDirectMessage(reqCtx *fasthttp.RequestCtx) {
11971199

11981200
// Special case for timeouts/circuit breakers since they won't go through the rest of the logic.
11991201
if errors.Is(err, context.DeadlineExceeded) || breaker.IsErrorPermanent(err) {
1200-
fasthttpRespond(reqCtx, fasthttpResponseWithError(nethttp.StatusInternalServerError, NewErrorResponse("ERR_DIRECT_INVOKE", err.Error())))
1202+
respondWithError(w, messages.ErrDirectInvoke.WithFormat(targetID, err))
12011203
return
12021204
}
12031205

12041206
if resp != nil {
12051207
headers := resp.Headers()
12061208
if len(headers) > 0 {
1207-
invokev1.InternalMetadataToHTTPHeader(reqCtx, headers, reqCtx.Response.Header.Add)
1209+
invokev1.InternalMetadataToHTTPHeader(r.Context(), headers, w.Header().Add)
12081210
}
12091211
}
12101212

12111213
invokeErr := invokeError{}
12121214
if errors.As(err, &invokeErr) {
1213-
fasthttpRespond(reqCtx, fasthttpResponseWithError(invokeErr.statusCode, invokeErr.msg))
1215+
respondWithData(w, invokeErr.statusCode, invokeErr.msg)
12141216
if resp != nil {
12151217
_ = resp.Close()
12161218
}
12171219
return
12181220
}
12191221

12201222
if resp == nil {
1221-
fasthttpRespond(reqCtx, fasthttpResponseWithError(nethttp.StatusInternalServerError, NewErrorResponse("ERR_DIRECT_INVOKE", fmt.Sprintf(messages.ErrDirectInvoke, targetID, "response object is nil"))))
1223+
respondWithError(w, messages.ErrDirectInvoke.WithFormat(targetID, "response object is nil"))
12221224
return
12231225
}
12241226
defer resp.Close()
12251227

12261228
statusCode := int(resp.Status().Code)
12271229

1228-
body, err := resp.RawDataFull()
1230+
if ct := resp.ContentType(); ct != "" {
1231+
w.Header().Set("content-type", ct)
1232+
}
1233+
1234+
w.WriteHeader(statusCode)
1235+
1236+
_, err = io.Copy(w, resp.RawData())
12291237
if err != nil {
1230-
fasthttpRespond(reqCtx, fasthttpResponseWithError(nethttp.StatusInternalServerError, NewErrorResponse("ERR_DIRECT_INVOKE", fmt.Sprintf(messages.ErrDirectInvoke, targetID, err))))
1238+
respondWithError(w, messages.ErrDirectInvoke.WithFormat(targetID, err))
12311239
return
12321240
}
1233-
1234-
reqCtx.Response.Header.SetContentType(resp.ContentType())
1235-
fasthttpRespond(reqCtx, fasthttpResponseWith(statusCode, body))
12361241
}
12371242

12381243
// findTargetIDAndMethod finds ID of the target service and method from the following three places:
12391244
// 1. HTTP header 'dapr-app-id' (path is method)
12401245
// 2. Basic auth header: `http://dapr-app-id:<service-id>@localhost:3500/<method>`
12411246
// 3. URL parameter: `http://localhost:3500/v1.0/invoke/<app-id>/method/<method>`
1242-
func findTargetIDAndMethod(path string, peekHeader func(string) []byte) (targetID string, method string) {
1243-
if appID := peekHeader(daprAppID); len(appID) != 0 {
1244-
return string(appID), strings.TrimPrefix(path, "/")
1247+
func findTargetIDAndMethod(reqPath string, headers nethttp.Header) (targetID string, method string) {
1248+
if appID := headers.Get(daprAppID); appID != "" {
1249+
return appID, strings.TrimPrefix(path.Clean(reqPath), "/")
12451250
}
12461251

1247-
if auth := string(peekHeader(fasthttp.HeaderAuthorization)); strings.HasPrefix(auth, "Basic ") {
1252+
if auth := headers.Get(fasthttp.HeaderAuthorization); strings.HasPrefix(auth, "Basic ") {
12481253
if s, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(auth, "Basic ")); err == nil {
12491254
pair := strings.Split(string(s), ":")
12501255
if len(pair) == 2 && pair[0] == daprAppID {
1251-
return pair[1], strings.TrimPrefix(path, "/")
1256+
return pair[1], strings.TrimPrefix(path.Clean(reqPath), "/")
12521257
}
12531258
}
12541259
}
12551260

12561261
// If we're here, the handler was probably invoked with /v1.0/invoke/ (or the invocation is invalid, missing the app id provided as header or Basic auth)
12571262
// However, we are not relying on wildcardParam because the URL may have been sanitized to remove `//``, so `http://` would have been turned into `http:/`
12581263
// First, check to make sure that the path has the prefix
1259-
if idx := pathHasPrefix(path, apiVersionV1, "invoke"); idx > 0 {
1260-
path = path[idx:]
1264+
if idx := pathHasPrefix(reqPath, apiVersionV1, "invoke"); idx > 0 {
1265+
reqPath = reqPath[idx:]
12611266

12621267
// Scan to find app ID and method
12631268
// Matches `<appid>/method/<method>`.
@@ -1266,9 +1271,9 @@ func findTargetIDAndMethod(path string, peekHeader func(string) []byte) (targetI
12661271
// - `http://example.com/method/mymethod`
12671272
// - `https://example.com/method/mymethod`
12681273
// - `http%3A%2F%2Fexample.com/method/mymethod`
1269-
if idx = strings.Index(path, "/method/"); idx > 0 {
1270-
targetID = path[:idx]
1271-
method = path[(idx + len("/method/")):]
1274+
if idx = strings.Index(reqPath, "/method/"); idx > 0 {
1275+
targetID = reqPath[:idx]
1276+
method = reqPath[(idx + len("/method/")):]
12721277
if t, _ := url.QueryUnescape(targetID); t != "" {
12731278
targetID = t
12741279
}

‎pkg/http/api_test.go

+21-17
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ import (
6565
"github.com/dapr/dapr/pkg/messages"
6666
invokev1 "github.com/dapr/dapr/pkg/messaging/v1"
6767
httpMiddleware "github.com/dapr/dapr/pkg/middleware/http"
68+
commonv1 "github.com/dapr/dapr/pkg/proto/common/v1"
6869
"github.com/dapr/dapr/pkg/resiliency"
6970
"github.com/dapr/dapr/pkg/runtime/compstore"
7071
runtimePubsub "github.com/dapr/dapr/pkg/runtime/pubsub"
@@ -1091,7 +1092,16 @@ func TestV1DirectMessagingEndpoints(t *testing.T) {
10911092
mock.MatchedBy(func(b string) bool {
10921093
return b == "http://123.45.67.89:3000"
10931094
}),
1094-
mock.AnythingOfType("*v1.InvokeMethodRequest"),
1095+
mock.MatchedBy(func(req *invokev1.InvokeMethodRequest) bool {
1096+
msg := req.Message()
1097+
if msg.Method != "fakeMethod" {
1098+
return false
1099+
}
1100+
if msg.HttpExtension.Verb != commonv1.HTTPExtension_POST {
1101+
return false
1102+
}
1103+
return true
1104+
}),
10951105
).
10961106
Return(fakeDirectMessageResponse, nil).
10971107
Once()
@@ -1101,12 +1111,12 @@ func TestV1DirectMessagingEndpoints(t *testing.T) {
11011111

11021112
// assert
11031113
mockDirectMessaging.AssertNumberOfCalls(t, "Invoke", 1)
1104-
assert.Equal(t, 200, resp.StatusCode)
1114+
assert.Equal(t, gohttp.StatusOK, resp.StatusCode)
11051115
assert.Equal(t, "fakeDirectMessageResponse", string(resp.RawBody))
11061116
})
11071117

11081118
t.Run("Invoke direct messaging without querystring - 201 Created", func(t *testing.T) {
1109-
fakeDirectMessageResponse := getFakeDirectMessageResponseWithStatusCode(fasthttp.StatusCreated)
1119+
fakeDirectMessageResponse := getFakeDirectMessageResponseWithStatusCode(gohttp.StatusCreated)
11101120
defer fakeDirectMessageResponse.Close()
11111121

11121122
apiPath := "v1.0/invoke/fakeAppID/method/fakeMethod"
@@ -1364,7 +1374,7 @@ func TestV1DirectMessagingEndpoints(t *testing.T) {
13641374
// assert
13651375
mockDirectMessaging.AssertNumberOfCalls(t, "Invoke", 1)
13661376
assert.Equal(t, 500, resp.StatusCode)
1367-
assert.True(t, strings.HasPrefix(string(resp.RawBody), "{\"errorCode\":\"ERR_MALFORMED_RESPONSE\",\"message\":\""))
1377+
assert.Truef(t, strings.HasPrefix(string(resp.RawBody), "{\"errorCode\":\"ERR_MALFORMED_RESPONSE\",\"message\":\""), "code not found in response: %v", string(resp.RawBody))
13681378
})
13691379

13701380
t.Run("Invoke direct messaging with malformed status response for external invocation", func(t *testing.T) {
@@ -5610,33 +5620,27 @@ func TestFindTargetIDAndMethod(t *testing.T) {
56105620
tests := []struct {
56115621
name string
56125622
path string
5613-
headers map[string]string
5623+
headers gohttp.Header
56145624
wantTargetID string
56155625
wantMethod string
56165626
}{
5617-
{name: "dapr-app-id header", path: "/foo/bar", headers: map[string]string{"dapr-app-id": "myapp"}, wantTargetID: "myapp", wantMethod: "foo/bar"},
5618-
{name: "basic auth", path: "/foo/bar", headers: map[string]string{"Authorization": "Basic ZGFwci1hcHAtaWQ6YXV0aA=="}, wantTargetID: "auth", wantMethod: "foo/bar"},
5619-
{name: "dapr-app-id header has priority over basic auth", path: "/foo/bar", headers: map[string]string{"dapr-app-id": "myapp", "Authorization": "Basic ZGFwci1hcHAtaWQ6YXV0aA=="}, wantTargetID: "myapp", wantMethod: "foo/bar"},
5627+
{name: "dapr-app-id header", path: "/foo/bar", headers: gohttp.Header{"Dapr-App-Id": []string{"myapp"}}, wantTargetID: "myapp", wantMethod: "foo/bar"},
5628+
{name: "basic auth", path: "/foo/bar", headers: gohttp.Header{"Authorization": []string{"Basic ZGFwci1hcHAtaWQ6YXV0aA=="}}, wantTargetID: "auth", wantMethod: "foo/bar"},
5629+
{name: "dapr-app-id header has priority over basic auth", path: "/foo/bar", headers: gohttp.Header{"Dapr-App-Id": []string{"myapp"}, "Authorization": []string{"Basic ZGFwci1hcHAtaWQ6YXV0aA=="}}, wantTargetID: "myapp", wantMethod: "foo/bar"},
56205630
{name: "path with internal target", path: "/v1.0/invoke/myapp/method/foo", wantTargetID: "myapp", wantMethod: "foo"},
5621-
{name: "basic auth has priority over path", path: "/v1.0/invoke/myapp/method/foo", headers: map[string]string{"Authorization": "Basic ZGFwci1hcHAtaWQ6YXV0aA=="}, wantTargetID: "auth", wantMethod: "v1.0/invoke/myapp/method/foo"},
5631+
{name: "basic auth has priority over path", path: "/v1.0/invoke/myapp/method/foo", headers: gohttp.Header{"Authorization": []string{"Basic ZGFwci1hcHAtaWQ6YXV0aA=="}}, wantTargetID: "auth", wantMethod: "v1.0/invoke/myapp/method/foo"},
56225632
{name: "path with '/' method", path: "/v1.0/invoke/myapp/method/", wantTargetID: "myapp", wantMethod: ""},
56235633
{name: "path with missing method", path: "/v1.0/invoke/myapp/method", wantTargetID: "", wantMethod: ""},
56245634
{name: "path with http target unescaped", path: "/v1.0/invoke/http://example.com/method/foo", wantTargetID: "http://example.com", wantMethod: "foo"},
56255635
{name: "path with https target unescaped", path: "/v1.0/invoke/https://example.com/method/foo", wantTargetID: "https://example.com", wantMethod: "foo"},
56265636
{name: "path with http target escaped", path: "/v1.0/invoke/http%3A%2F%2Fexample.com/method/foo", wantTargetID: "http://example.com", wantMethod: "foo"},
56275637
{name: "path with https target escaped", path: "/v1.0/invoke/https%3A%2F%2Fexample.com/method/foo", wantTargetID: "https://example.com", wantMethod: "foo"},
56285638
{name: "path with https target partly escaped", path: "/v1.0/invoke/https%3A/%2Fexample.com/method/foo", wantTargetID: "https://example.com", wantMethod: "foo"},
5639+
{name: "extra slashes are removed", path: "///foo//bar", headers: gohttp.Header{"Dapr-App-Id": []string{"myapp"}}, wantTargetID: "myapp", wantMethod: "foo/bar"},
56295640
}
56305641
for _, tt := range tests {
56315642
t.Run(tt.name, func(t *testing.T) {
5632-
peekHeader := func(k string) []byte {
5633-
if len(tt.headers) == 0 {
5634-
return nil
5635-
}
5636-
return []byte(tt.headers[k])
5637-
}
5638-
5639-
gotTargetID, gotMethod := findTargetIDAndMethod(tt.path, peekHeader)
5643+
gotTargetID, gotMethod := findTargetIDAndMethod(tt.path, tt.headers)
56405644
if gotTargetID != tt.wantTargetID {
56415645
t.Errorf("findTargetIDAndMethod() gotTargetID = %v, want %v", gotTargetID, tt.wantTargetID)
56425646
}

‎pkg/messages/predefined.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,6 @@ const (
7272
ErrActorStateGet = "error getting actor state: %s"
7373
ErrActorStateTransactionSave = "error saving actor transaction state: %s"
7474

75-
// DirectMessaging.
76-
ErrDirectInvoke = "fail to invoke, id: %s, err: %s"
77-
ErrDirectInvokeNoAppID = "failed getting app id either from the URL path or the header dapr-app-id"
78-
ErrDirectInvokeNotReady = "invoke API is not ready"
79-
8075
// Configuration.
8176
ErrConfigurationStoresNotConfigured = "configuration stores not configured"
8277
ErrConfigurationStoreNotFound = "configuration store %s not found"
@@ -94,6 +89,11 @@ var (
9489
ErrBodyRead = APIError{"failed to read request body: %v", "ERR_BODY_READ", http.StatusBadRequest, grpcCodes.InvalidArgument}
9590
ErrMalformedRequest = APIError{"failed deserializing HTTP body: %v", "ERR_MALFORMED_REQUEST", http.StatusBadRequest, grpcCodes.InvalidArgument}
9691

92+
// DirectMessaging.
93+
ErrDirectInvoke = APIError{"fail to invoke, id: %s, err: %v", "ERR_DIRECT_INVOKE", http.StatusInternalServerError, grpcCodes.Internal}
94+
ErrDirectInvokeNoAppID = APIError{"failed getting app id either from the URL path or the header dapr-app-id", "ERR_DIRECT_INVOKE", http.StatusNotFound, grpcCodes.NotFound}
95+
ErrDirectInvokeNotReady = APIError{"invoke API is not ready", "ERR_DIRECT_INVOKE", http.StatusInternalServerError, grpcCodes.Internal}
96+
9797
// Healthz.
9898
ErrHealthNotReady = APIError{"dapr is not ready", "ERR_HEALTH_NOT_READY", http.StatusInternalServerError, grpcCodes.Internal}
9999
ErrOutboundHealthNotReady = APIError{"dapr outbound is not ready", "ERR_OUTBOUND_HEALTH_NOT_READY", http.StatusInternalServerError, grpcCodes.Internal}

‎pkg/messaging/v1/invoke_method_request_test.go

+7-6
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ package v1
1717
import (
1818
"bytes"
1919
"encoding/json"
20+
"errors"
2021
"fmt"
2122
"io"
2223
"net/http"
@@ -344,7 +345,7 @@ func TestRequestProto(t *testing.T) {
344345
ir, err := InternalInvokeRequest(&pb)
345346
assert.NoError(t, err)
346347
defer ir.Close()
347-
ir.data = io.NopCloser(strings.NewReader("test"))
348+
ir.data = newReaderCloser(strings.NewReader("test"))
348349
req2 := ir.Proto()
349350

350351
assert.Equal(t, "application/json", req2.GetMessage().ContentType)
@@ -391,7 +392,7 @@ func TestRequestProtoWithData(t *testing.T) {
391392
ir, err := InternalInvokeRequest(&pb)
392393
assert.NoError(t, err)
393394
defer ir.Close()
394-
ir.data = io.NopCloser(strings.NewReader("test"))
395+
ir.data = newReaderCloser(strings.NewReader("test"))
395396
req2, err := ir.ProtoWithData()
396397
assert.NoError(t, err)
397398

@@ -499,7 +500,7 @@ func TestRequestReplayable(t *testing.T) {
499500
const message = "Nel mezzo del cammin di nostra vita mi ritrovai per una selva oscura, che' la diritta via era smarrita."
500501
newReplayable := func() *InvokeMethodRequest {
501502
return NewInvokeMethodRequest("test_method").
502-
WithRawDataString(message).
503+
WithRawData(newReaderCloser(strings.NewReader(message))).
503504
WithReplay(true)
504505
}
505506

@@ -519,7 +520,7 @@ func TestRequestReplayable(t *testing.T) {
519520
buf := make([]byte, 9)
520521
n, err := io.ReadFull(req.data, buf)
521522
assert.Equal(t, 0, n)
522-
assert.ErrorIs(t, err, io.EOF)
523+
assert.Truef(t, errors.Is(err, io.EOF) || errors.Is(err, http.ErrBodyReadAfterClose), "unexpected error: %v", err)
523524
})
524525

525526
t.Run("replay buffer is full", func(t *testing.T) {
@@ -551,7 +552,7 @@ func TestRequestReplayable(t *testing.T) {
551552
buf := make([]byte, 9)
552553
n, err := io.ReadFull(req.data, buf)
553554
assert.Equal(t, 0, n)
554-
assert.ErrorIs(t, err, io.EOF)
555+
assert.Truef(t, errors.Is(err, io.EOF) || errors.Is(err, http.ErrBodyReadAfterClose), "unexpected error: %v", err)
555556
})
556557

557558
t.Run("replay buffer is full", func(t *testing.T) {
@@ -652,7 +653,7 @@ func TestRequestReplayable(t *testing.T) {
652653
buf := make([]byte, 9)
653654
n, err := io.ReadFull(req.data, buf)
654655
assert.Equal(t, 0, n)
655-
assert.ErrorIs(t, err, io.EOF)
656+
assert.Truef(t, errors.Is(err, io.EOF) || errors.Is(err, http.ErrBodyReadAfterClose), "unexpected error: %v", err)
656657
})
657658

658659
t.Run("replay buffer is full", func(t *testing.T) {

‎pkg/messaging/v1/replayable_request_test.go

+32-4
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ package v1
1616
import (
1717
"bytes"
1818
"crypto/rand"
19+
"errors"
1920
"io"
21+
"net/http"
2022
"testing"
2123

2224
"github.com/stretchr/testify/assert"
@@ -31,7 +33,7 @@ func TestReplayableRequest(t *testing.T) {
3133

3234
newReplayable := func() *replayableRequest {
3335
rr := &replayableRequest{}
34-
rr.WithRawData(bytes.NewReader(message))
36+
rr.WithRawData(newReaderCloser(bytes.NewReader(message)))
3537
rr.SetReplay(true)
3638
return rr
3739
}
@@ -50,7 +52,7 @@ func TestReplayableRequest(t *testing.T) {
5052
buf := make([]byte, 9)
5153
n, err := io.ReadFull(rr.data, buf)
5254
assert.Equal(t, 0, n)
53-
assert.ErrorIs(t, err, io.EOF)
55+
assert.Truef(t, errors.Is(err, io.EOF) || errors.Is(err, http.ErrBodyReadAfterClose), "unexpected error: %v", err)
5456
})
5557

5658
t.Run("replay buffer is full", func(t *testing.T) {
@@ -82,7 +84,7 @@ func TestReplayableRequest(t *testing.T) {
8284
buf := make([]byte, 9)
8385
n, err := io.ReadFull(rr.data, buf)
8486
assert.Equal(t, 0, n)
85-
assert.ErrorIs(t, err, io.EOF)
87+
assert.Truef(t, errors.Is(err, io.EOF) || errors.Is(err, http.ErrBodyReadAfterClose), "unexpected error: %v", err)
8688
})
8789

8890
t.Run("replay buffer is full", func(t *testing.T) {
@@ -191,7 +193,7 @@ func TestReplayableRequest(t *testing.T) {
191193
buf := make([]byte, partial)
192194
n, err := io.ReadFull(rr.data, buf)
193195
assert.Equal(t, 0, n)
194-
assert.ErrorIs(t, err, io.EOF)
196+
assert.Truef(t, errors.Is(err, io.EOF) || errors.Is(err, http.ErrBodyReadAfterClose), "unexpected error: %v", err)
195197
})
196198

197199
t.Run("replay buffer is full", func(t *testing.T) {
@@ -215,3 +217,29 @@ func TestReplayableRequest(t *testing.T) {
215217
})
216218
})
217219
}
220+
221+
// readerCloser is a io.Reader that can be closed. Once the stream is closed, reading from it returns an error.
222+
type readerCloser struct {
223+
r io.Reader
224+
closed bool
225+
}
226+
227+
func newReaderCloser(r io.Reader) *readerCloser {
228+
return &readerCloser{
229+
r: r,
230+
closed: false,
231+
}
232+
}
233+
234+
func (b *readerCloser) Read(p []byte) (n int, err error) {
235+
if b.closed {
236+
// Use http.ErrBodyReadAfterClose which is the error returned by http.Response.Body
237+
return 0, http.ErrBodyReadAfterClose
238+
}
239+
return b.r.Read(p)
240+
}
241+
242+
func (b *readerCloser) Close() error {
243+
b.closed = true
244+
return nil
245+
}

‎tests/integration/suite/daprd/serviceinvocation/http/basic.go

+1-4
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,7 @@ func (b *basic) Run(t *testing.T, ctx context.Context) {
130130
{url: fmt.Sprintf("http://localhost:%d/v1.0/invoke/%s/method/foo", b.daprd2.HTTPPort(), b.daprd1.AppID())},
131131
{url: fmt.Sprintf("http://localhost:%d/v1.0////invoke/%s/method/foo", b.daprd2.HTTPPort(), b.daprd1.AppID())},
132132
{url: fmt.Sprintf("http://localhost:%d/v1.0//invoke//%s/method//foo", b.daprd1.HTTPPort(), b.daprd2.AppID())},
133-
// We cannot use `///foo` here because the test app uses the standard Go mux which responds with a 301 status code if the invocation is for `///foo`
134-
// This makes Dapr retry with a GET request in all cases, as per specs
135-
// See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/301
136-
{url: fmt.Sprintf("http://localhost:%d/foo", b.daprd1.HTTPPort()), headers: map[string]string{
133+
{url: fmt.Sprintf("http://localhost:%d///foo", b.daprd1.HTTPPort()), headers: map[string]string{
137134
"foo": "bar",
138135
"dapr-app-id": b.daprd2.AppID(),
139136
}},

‎utils/streams/multireadercloser.go

+12-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@ limitations under the License.
1313

1414
package streams
1515

16-
import "io"
16+
import (
17+
"errors"
18+
"io"
19+
"net/http"
20+
)
1721

1822
// NewMultiReaderCloser returns a stream that is like io.MultiReader but that can be closed.
1923
// When the returned stream is closed, it closes the readable streams too, if they implement io.Closer.
@@ -41,7 +45,13 @@ func (mr *MultiReaderCloser) Read(p []byte) (n int, err error) {
4145
for len(mr.readers) > 0 {
4246
r := mr.readers[0]
4347
n, err = r.Read(p)
44-
if err == io.EOF {
48+
49+
// When reading from a http.Response Body, we may get ErrBodyReadAfterClose if we already read it all
50+
// We consider that the same as io.EOF
51+
if errors.Is(err, http.ErrBodyReadAfterClose) {
52+
err = io.EOF
53+
mr.readers = mr.readers[1:]
54+
} else if err == io.EOF {
4555
if rc, ok := r.(io.Closer); ok {
4656
_ = rc.Close()
4757
}

0 commit comments

Comments
 (0)
Please sign in to comment.