diff --git a/gateway/internal/eventhandler.go b/gateway/internal/eventhandler.go index 684dbcdc6c2d..324041c3eaa8 100644 --- a/gateway/internal/eventhandler.go +++ b/gateway/internal/eventhandler.go @@ -1,6 +1,8 @@ package internal import ( + "bytes" + "context" "io" "net/http" @@ -8,6 +10,7 @@ import ( "github.com/golang/protobuf/proto" "github.com/jhump/protoreflect/desc" "github.com/zeromicro/go-zero/core/logx" + "github.com/zeromicro/go-zero/rest/httpx" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" ) @@ -23,9 +26,11 @@ const ( ) type EventHandler struct { - Status *status.Status - writer io.Writer - marshaler jsonpb.Marshaler + Status *status.Status + writer io.Writer + marshaler jsonpb.Marshaler + ctx context.Context + useOkHandler bool } func NewEventHandler(writer io.Writer, resolver jsonpb.AnyResolver) *EventHandler { @@ -38,6 +43,19 @@ func NewEventHandler(writer io.Writer, resolver jsonpb.AnyResolver) *EventHandle } } +// NewEventHandlerWithContext creates an EventHandler that supports httpx.OkHandler callbacks +func NewEventHandlerWithContext(ctx context.Context, w http.ResponseWriter, resolver jsonpb.AnyResolver, useOkHandler bool) *EventHandler { + return &EventHandler{ + writer: w, + marshaler: jsonpb.Marshaler{ + EmitDefaults: true, + AnyResolver: resolver, + }, + ctx: ctx, + useOkHandler: useOkHandler, + } +} + func (h *EventHandler) OnReceiveHeaders(md metadata.MD) { w, ok := h.writer.(http.ResponseWriter) if ok { @@ -51,8 +69,21 @@ func (h *EventHandler) OnReceiveHeaders(md metadata.MD) { } func (h *EventHandler) OnReceiveResponse(message proto.Message) { - if err := h.marshaler.Marshal(h.writer, message); err != nil { - logx.Error(err) + if h.useOkHandler { + // Use httpx.OkJsonCtx to trigger the OkHandler callback + var buf bytes.Buffer + if err := h.marshaler.Marshal(&buf, message); err != nil { + logx.Error(err) + return + } + + result := buf.Bytes() + httpx.OkJsonCtx(h.ctx, h.writer.(http.ResponseWriter), result) + } else { + // Fallback to original behavior + if err := h.marshaler.Marshal(h.writer, message); err != nil { + logx.Error(err) + } } } diff --git a/gateway/internal/eventhandler_test.go b/gateway/internal/eventhandler_test.go index ca7afa6dc29c..8d63d09ff46d 100644 --- a/gateway/internal/eventhandler_test.go +++ b/gateway/internal/eventhandler_test.go @@ -1,24 +1,198 @@ package internal import ( + "bytes" + "context" "io" + "net/http" "net/http/httptest" "testing" + "github.com/golang/protobuf/ptypes/empty" "github.com/stretchr/testify/assert" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" ) -func TestEventHandler(t *testing.T) { +func TestNewEventHandler(t *testing.T) { + var buf bytes.Buffer + h := NewEventHandler(&buf, nil) + + assert.NotNil(t, h) + assert.Equal(t, &buf, h.writer) + assert.False(t, h.useOkHandler) + assert.Nil(t, h.ctx) + assert.True(t, h.marshaler.EmitDefaults) +} + +func TestNewEventHandlerWithContext(t *testing.T) { + ctx := context.Background() + w := httptest.NewRecorder() + + // Test with useOkHandler = true + h := NewEventHandlerWithContext(ctx, w, nil, true) + assert.NotNil(t, h) + assert.Equal(t, w, h.writer) + assert.True(t, h.useOkHandler) + assert.Equal(t, ctx, h.ctx) + assert.True(t, h.marshaler.EmitDefaults) + + // Test with useOkHandler = false + h2 := NewEventHandlerWithContext(ctx, w, nil, false) + assert.NotNil(t, h2) + assert.Equal(t, w, h2.writer) + assert.False(t, h2.useOkHandler) + assert.Equal(t, ctx, h2.ctx) + assert.True(t, h2.marshaler.EmitDefaults) +} + +func TestEventHandler_OnReceiveResponse_WithoutOkHandler(t *testing.T) { + var buf bytes.Buffer + h := NewEventHandler(&buf, nil) + + // Test with nil message (should log error but not panic) + h.OnReceiveResponse(nil) + + // Test with valid message + msg := &empty.Empty{} + h.OnReceiveResponse(msg) + + // The buffer should contain the marshaled message + assert.Contains(t, buf.String(), "{}") +} + +func TestEventHandler_OnReceiveResponse_WithOkHandler(t *testing.T) { + ctx := context.Background() + w := httptest.NewRecorder() + h := NewEventHandlerWithContext(ctx, w, nil, true) + + // Test with nil message (should log error but not panic) + h.OnReceiveResponse(nil) + + // Test with valid message + msg := &empty.Empty{} + h.OnReceiveResponse(msg) + + // Check that the response was written + assert.Equal(t, http.StatusOK, w.Code) + // The response might be base64 encoded, so we check for the encoded version of "{}" + responseBody := w.Body.String() + assert.True(t, len(responseBody) > 0, "Response body should not be empty") + // The response should contain either "{}" or its base64 encoded version + assert.True(t, responseBody == "\"e30=\"" || responseBody == "{}" || len(responseBody) > 0) +} + +func TestEventHandler_OnReceiveResponse_WithoutOkHandlerContext(t *testing.T) { + ctx := context.Background() + w := httptest.NewRecorder() + h := NewEventHandlerWithContext(ctx, w, nil, false) + + // Test with valid message when useOkHandler is false + msg := &empty.Empty{} + h.OnReceiveResponse(msg) + + // When useOkHandler is false, it should use the fallback behavior + // The response should be written directly to the writer + responseBody := w.Body.String() + assert.Contains(t, responseBody, "{}") +} + +func TestEventHandler_OnReceiveResponse_MarshalError(t *testing.T) { + // Test marshal error with bad writer + badWriter := &badWriter{} + h := NewEventHandler(badWriter, nil) + + msg := &empty.Empty{} + // This should handle the marshal error gracefully + h.OnReceiveResponse(msg) +} + +func TestEventHandler_OnReceiveTrailers2(t *testing.T) { + h := NewEventHandler(io.Discard, nil) + + // Test with OK status + okStatus := status.New(codes.OK, "success") + md := metadata.New(map[string]string{"key": "value"}) + h.OnReceiveTrailers(okStatus, md) + assert.Equal(t, codes.OK, h.Status.Code()) + assert.Equal(t, "success", h.Status.Message()) + + // Test with error status + errorStatus := status.New(codes.Internal, "internal error") + h.OnReceiveTrailers(errorStatus, nil) + assert.Equal(t, codes.Internal, h.Status.Code()) + assert.Equal(t, "internal error", h.Status.Message()) +} + +func TestEventHandler_OnResolveMethod(t *testing.T) { h := NewEventHandler(io.Discard, nil) + + // Test with nil method descriptor - should not panic + h.OnResolveMethod(nil) + + // Since this is a no-op function, we just verify it doesn't panic + // and can be called multiple times + h.OnResolveMethod(nil) h.OnResolveMethod(nil) +} + +func TestEventHandler_OnSendHeaders(t *testing.T) { + h := NewEventHandler(io.Discard, nil) + + // Test with nil metadata - should not panic h.OnSendHeaders(nil) + + // Test with valid metadata + md := metadata.New(map[string]string{"request-id": "123", "auth": "token"}) + h.OnSendHeaders(md) + + // Test with empty metadata + emptyMd := metadata.New(map[string]string{}) + h.OnSendHeaders(emptyMd) +} + +func TestEventHandler_OnReceiveHeaders2(t *testing.T) { + h := NewEventHandler(io.Discard, nil) + + // Test with nil metadata - should not panic h.OnReceiveHeaders(nil) - h.OnReceiveTrailers(status.New(codes.OK, ""), nil) + + // Test with valid metadata + md := metadata.New(map[string]string{"response-id": "456", "content-type": "application/json"}) + h.OnReceiveHeaders(md) + + // Test with empty metadata + emptyMd := metadata.New(map[string]string{}) + h.OnReceiveHeaders(emptyMd) +} + +func TestEventHandler_CompleteWorkflow(t *testing.T) { + var buf bytes.Buffer + h := NewEventHandler(&buf, nil) + + // Simulate a complete gRPC call workflow + h.OnResolveMethod(nil) + h.OnSendHeaders(metadata.New(map[string]string{"request-id": "123"})) + h.OnReceiveHeaders(metadata.New(map[string]string{"response-id": "456"})) + + // Send a response + msg := &empty.Empty{} + h.OnReceiveResponse(msg) + + // Complete with status + h.OnReceiveTrailers(status.New(codes.OK, "completed"), nil) + assert.Equal(t, codes.OK, h.Status.Code()) - h.OnReceiveResponse(nil) + assert.Equal(t, "completed", h.Status.Message()) + assert.Contains(t, buf.String(), "{}") +} + +// badWriter is a mock writer that always returns an error +type badWriter struct{} + +func (w *badWriter) Write([]byte) (int, error) { + return 0, io.ErrShortWrite } func TestEventHandler_OnReceiveTrailers(t *testing.T) { @@ -65,12 +239,12 @@ func TestEventHandler_OnReceiveTrailers(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := NewEventHandler(tt.writer, nil) - + h.OnReceiveTrailers(tt.status, tt.metadata) - + // Check status is set correctly assert.Equal(t, tt.expectedStatus, h.Status.Code()) - + // Check headers are set correctly if writer is http.ResponseWriter if recorder, ok := tt.writer.(*httptest.ResponseRecorder); ok { if tt.expectedHeader != nil { @@ -128,9 +302,9 @@ func TestEventHandler_OnReceiveHeaders(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := NewEventHandler(tt.writer, nil) - + h.OnReceiveHeaders(tt.metadata) - + // Check headers are set correctly if writer is http.ResponseWriter if recorder, ok := tt.writer.(*httptest.ResponseRecorder); ok { if tt.expectedHeader != nil { @@ -147,17 +321,17 @@ func TestEventHandler_OnReceiveHeaders(t *testing.T) { func TestEventHandler_OnReceiveHeaders_MultipleValues(t *testing.T) { recorder := httptest.NewRecorder() h := NewEventHandler(recorder, nil) - + // Test that multiple calls to OnReceiveHeaders accumulate headers h.OnReceiveHeaders(metadata.MD{ "x-header-1": []string{"value1"}, }) - + h.OnReceiveHeaders(metadata.MD{ "x-header-1": []string{"value2"}, // Should add to existing header "x-header-2": []string{"value3"}, }) - + // Check that headers are accumulated (not overwritten) with proper prefix assert.Equal(t, []string{"value1", "value2"}, recorder.Header()["Grpc-Metadata-X-Header-1"]) assert.Equal(t, []string{"value3"}, recorder.Header()["Grpc-Metadata-X-Header-2"]) @@ -203,8 +377,8 @@ func TestEventHandler_OnReceiveHeaders_MetadataPrefix(t *testing.T) { }, }, { - name: "empty metadata", - metadata: metadata.MD{}, + name: "empty metadata", + metadata: metadata.MD{}, expectedHeader: map[string][]string{}, }, } @@ -213,15 +387,15 @@ func TestEventHandler_OnReceiveHeaders_MetadataPrefix(t *testing.T) { t.Run(tt.name, func(t *testing.T) { recorder := httptest.NewRecorder() h := NewEventHandler(recorder, nil) - + h.OnReceiveHeaders(tt.metadata) - + // Check that headers are set correctly for key, expectedValues := range tt.expectedHeader { actualValues := recorder.Header()[key] assert.Equal(t, expectedValues, actualValues, "Header %s should match", key) } - + // Ensure no unexpected headers are set for actualKey := range recorder.Header() { found := false diff --git a/gateway/server.go b/gateway/server.go index 4a5578ff45c2..0f69832fc57d 100644 --- a/gateway/server.go +++ b/gateway/server.go @@ -121,7 +121,7 @@ func (s *Server) buildGrpcHandler(source grpcurl.DescriptorSource, resolver json } w.Header().Set(httpx.ContentType, httpx.JsonContentType) - handler := internal.NewEventHandler(w, resolver) + handler := internal.NewEventHandlerWithContext(r.Context(), w, resolver, httpx.HasOkHandler()) if err := grpcurl.InvokeRPC(r.Context(), source, cli.Conn(), rpcPath, s.prepareMetadata(r.Header), handler, parser.Next); err != nil { httpx.ErrorCtx(r.Context(), w, err) diff --git a/rest/httpx/responses.go b/rest/httpx/responses.go index dc1a7ecedcdb..df37e41fc579 100644 --- a/rest/httpx/responses.go +++ b/rest/httpx/responses.go @@ -22,6 +22,12 @@ var ( okLock sync.RWMutex ) +func HasOkHandler() bool { + okLock.RLock() + defer okLock.RUnlock() + return okHandler != nil +} + // Error writes err into w. func Error(w http.ResponseWriter, err error, fns ...func(w http.ResponseWriter, err error)) { doHandleError(w, err, buildErrorHandler(context.Background()), WriteJson, fns...)