Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 36 additions & 5 deletions gateway/internal/eventhandler.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package internal

import (
"bytes"
"context"
"io"
"net/http"

"github.com/golang/protobuf/jsonpb"
"github.com/golang/protobuf/proto"
"github.com/jhump/protoreflect/desc"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/rest/httpx"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
Expand All @@ -23,9 +26,11 @@ const (
)

type EventHandler struct {
Status *status.Status
writer io.Writer
marshaler jsonpb.Marshaler
Status *status.Status
writer io.Writer
marshaler jsonpb.Marshaler
ctx context.Context
useOkHandler bool
}

func NewEventHandler(writer io.Writer, resolver jsonpb.AnyResolver) *EventHandler {
Expand All @@ -38,6 +43,19 @@ func NewEventHandler(writer io.Writer, resolver jsonpb.AnyResolver) *EventHandle
}
}

// NewEventHandlerWithContext creates an EventHandler that supports httpx.OkHandler callbacks
func NewEventHandlerWithContext(ctx context.Context, w http.ResponseWriter, resolver jsonpb.AnyResolver, useOkHandler bool) *EventHandler {
return &EventHandler{
writer: w,
marshaler: jsonpb.Marshaler{
EmitDefaults: true,
AnyResolver: resolver,
},
ctx: ctx,
useOkHandler: useOkHandler,
}
}

func (h *EventHandler) OnReceiveHeaders(md metadata.MD) {
w, ok := h.writer.(http.ResponseWriter)
if ok {
Expand All @@ -51,8 +69,21 @@ func (h *EventHandler) OnReceiveHeaders(md metadata.MD) {
}

func (h *EventHandler) OnReceiveResponse(message proto.Message) {
if err := h.marshaler.Marshal(h.writer, message); err != nil {
logx.Error(err)
if h.useOkHandler {
// Use httpx.OkJsonCtx to trigger the OkHandler callback
var buf bytes.Buffer
if err := h.marshaler.Marshal(&buf, message); err != nil {
logx.Error(err)
return
}

result := buf.Bytes()
httpx.OkJsonCtx(h.ctx, h.writer.(http.ResponseWriter), result)
} else {
// Fallback to original behavior
if err := h.marshaler.Marshal(h.writer, message); err != nil {
logx.Error(err)
}
}
}

Expand Down
206 changes: 190 additions & 16 deletions gateway/internal/eventhandler_test.go
Original file line number Diff line number Diff line change
@@ -1,24 +1,198 @@
package internal

import (
"bytes"
"context"
"io"
"net/http"
"net/http/httptest"
"testing"

"github.com/golang/protobuf/ptypes/empty"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)

func TestEventHandler(t *testing.T) {
func TestNewEventHandler(t *testing.T) {
var buf bytes.Buffer
h := NewEventHandler(&buf, nil)

assert.NotNil(t, h)
assert.Equal(t, &buf, h.writer)
assert.False(t, h.useOkHandler)
assert.Nil(t, h.ctx)
assert.True(t, h.marshaler.EmitDefaults)
}

func TestNewEventHandlerWithContext(t *testing.T) {
ctx := context.Background()
w := httptest.NewRecorder()

// Test with useOkHandler = true
h := NewEventHandlerWithContext(ctx, w, nil, true)
assert.NotNil(t, h)
assert.Equal(t, w, h.writer)
assert.True(t, h.useOkHandler)
assert.Equal(t, ctx, h.ctx)
assert.True(t, h.marshaler.EmitDefaults)

// Test with useOkHandler = false
h2 := NewEventHandlerWithContext(ctx, w, nil, false)
assert.NotNil(t, h2)
assert.Equal(t, w, h2.writer)
assert.False(t, h2.useOkHandler)
assert.Equal(t, ctx, h2.ctx)
assert.True(t, h2.marshaler.EmitDefaults)
}

func TestEventHandler_OnReceiveResponse_WithoutOkHandler(t *testing.T) {
var buf bytes.Buffer
h := NewEventHandler(&buf, nil)

// Test with nil message (should log error but not panic)
h.OnReceiveResponse(nil)

// Test with valid message
msg := &empty.Empty{}
h.OnReceiveResponse(msg)

// The buffer should contain the marshaled message
assert.Contains(t, buf.String(), "{}")
}

func TestEventHandler_OnReceiveResponse_WithOkHandler(t *testing.T) {
ctx := context.Background()
w := httptest.NewRecorder()
h := NewEventHandlerWithContext(ctx, w, nil, true)

// Test with nil message (should log error but not panic)
h.OnReceiveResponse(nil)

// Test with valid message
msg := &empty.Empty{}
h.OnReceiveResponse(msg)

// Check that the response was written
assert.Equal(t, http.StatusOK, w.Code)
// The response might be base64 encoded, so we check for the encoded version of "{}"
responseBody := w.Body.String()
assert.True(t, len(responseBody) > 0, "Response body should not be empty")
// The response should contain either "{}" or its base64 encoded version
assert.True(t, responseBody == "\"e30=\"" || responseBody == "{}" || len(responseBody) > 0)
}

func TestEventHandler_OnReceiveResponse_WithoutOkHandlerContext(t *testing.T) {
ctx := context.Background()
w := httptest.NewRecorder()
h := NewEventHandlerWithContext(ctx, w, nil, false)

// Test with valid message when useOkHandler is false
msg := &empty.Empty{}
h.OnReceiveResponse(msg)

// When useOkHandler is false, it should use the fallback behavior
// The response should be written directly to the writer
responseBody := w.Body.String()
assert.Contains(t, responseBody, "{}")
}

func TestEventHandler_OnReceiveResponse_MarshalError(t *testing.T) {
// Test marshal error with bad writer
badWriter := &badWriter{}
h := NewEventHandler(badWriter, nil)

msg := &empty.Empty{}
// This should handle the marshal error gracefully
h.OnReceiveResponse(msg)
}

func TestEventHandler_OnReceiveTrailers2(t *testing.T) {
h := NewEventHandler(io.Discard, nil)

// Test with OK status
okStatus := status.New(codes.OK, "success")
md := metadata.New(map[string]string{"key": "value"})
h.OnReceiveTrailers(okStatus, md)
assert.Equal(t, codes.OK, h.Status.Code())
assert.Equal(t, "success", h.Status.Message())

// Test with error status
errorStatus := status.New(codes.Internal, "internal error")
h.OnReceiveTrailers(errorStatus, nil)
assert.Equal(t, codes.Internal, h.Status.Code())
assert.Equal(t, "internal error", h.Status.Message())
}

func TestEventHandler_OnResolveMethod(t *testing.T) {
h := NewEventHandler(io.Discard, nil)

// Test with nil method descriptor - should not panic
h.OnResolveMethod(nil)

// Since this is a no-op function, we just verify it doesn't panic
// and can be called multiple times
h.OnResolveMethod(nil)
h.OnResolveMethod(nil)
}

func TestEventHandler_OnSendHeaders(t *testing.T) {
h := NewEventHandler(io.Discard, nil)

// Test with nil metadata - should not panic
h.OnSendHeaders(nil)

// Test with valid metadata
md := metadata.New(map[string]string{"request-id": "123", "auth": "token"})
h.OnSendHeaders(md)

// Test with empty metadata
emptyMd := metadata.New(map[string]string{})
h.OnSendHeaders(emptyMd)
}

func TestEventHandler_OnReceiveHeaders2(t *testing.T) {
h := NewEventHandler(io.Discard, nil)

// Test with nil metadata - should not panic
h.OnReceiveHeaders(nil)
h.OnReceiveTrailers(status.New(codes.OK, ""), nil)

// Test with valid metadata
md := metadata.New(map[string]string{"response-id": "456", "content-type": "application/json"})
h.OnReceiveHeaders(md)

// Test with empty metadata
emptyMd := metadata.New(map[string]string{})
h.OnReceiveHeaders(emptyMd)
}

func TestEventHandler_CompleteWorkflow(t *testing.T) {
var buf bytes.Buffer
h := NewEventHandler(&buf, nil)

// Simulate a complete gRPC call workflow
h.OnResolveMethod(nil)
h.OnSendHeaders(metadata.New(map[string]string{"request-id": "123"}))
h.OnReceiveHeaders(metadata.New(map[string]string{"response-id": "456"}))

// Send a response
msg := &empty.Empty{}
h.OnReceiveResponse(msg)

// Complete with status
h.OnReceiveTrailers(status.New(codes.OK, "completed"), nil)

assert.Equal(t, codes.OK, h.Status.Code())
h.OnReceiveResponse(nil)
assert.Equal(t, "completed", h.Status.Message())
assert.Contains(t, buf.String(), "{}")
}

// badWriter is a mock writer that always returns an error
type badWriter struct{}

func (w *badWriter) Write([]byte) (int, error) {
return 0, io.ErrShortWrite
}

func TestEventHandler_OnReceiveTrailers(t *testing.T) {
Expand Down Expand Up @@ -65,12 +239,12 @@ func TestEventHandler_OnReceiveTrailers(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := NewEventHandler(tt.writer, nil)

h.OnReceiveTrailers(tt.status, tt.metadata)

// Check status is set correctly
assert.Equal(t, tt.expectedStatus, h.Status.Code())

// Check headers are set correctly if writer is http.ResponseWriter
if recorder, ok := tt.writer.(*httptest.ResponseRecorder); ok {
if tt.expectedHeader != nil {
Expand Down Expand Up @@ -128,9 +302,9 @@ func TestEventHandler_OnReceiveHeaders(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := NewEventHandler(tt.writer, nil)

h.OnReceiveHeaders(tt.metadata)

// Check headers are set correctly if writer is http.ResponseWriter
if recorder, ok := tt.writer.(*httptest.ResponseRecorder); ok {
if tt.expectedHeader != nil {
Expand All @@ -147,17 +321,17 @@ func TestEventHandler_OnReceiveHeaders(t *testing.T) {
func TestEventHandler_OnReceiveHeaders_MultipleValues(t *testing.T) {
recorder := httptest.NewRecorder()
h := NewEventHandler(recorder, nil)

// Test that multiple calls to OnReceiveHeaders accumulate headers
h.OnReceiveHeaders(metadata.MD{
"x-header-1": []string{"value1"},
})

h.OnReceiveHeaders(metadata.MD{
"x-header-1": []string{"value2"}, // Should add to existing header
"x-header-2": []string{"value3"},
})

// Check that headers are accumulated (not overwritten) with proper prefix
assert.Equal(t, []string{"value1", "value2"}, recorder.Header()["Grpc-Metadata-X-Header-1"])
assert.Equal(t, []string{"value3"}, recorder.Header()["Grpc-Metadata-X-Header-2"])
Expand Down Expand Up @@ -203,8 +377,8 @@ func TestEventHandler_OnReceiveHeaders_MetadataPrefix(t *testing.T) {
},
},
{
name: "empty metadata",
metadata: metadata.MD{},
name: "empty metadata",
metadata: metadata.MD{},
expectedHeader: map[string][]string{},
},
}
Expand All @@ -213,15 +387,15 @@ func TestEventHandler_OnReceiveHeaders_MetadataPrefix(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
h := NewEventHandler(recorder, nil)

h.OnReceiveHeaders(tt.metadata)

// Check that headers are set correctly
for key, expectedValues := range tt.expectedHeader {
actualValues := recorder.Header()[key]
assert.Equal(t, expectedValues, actualValues, "Header %s should match", key)
}

// Ensure no unexpected headers are set
for actualKey := range recorder.Header() {
found := false
Expand Down
2 changes: 1 addition & 1 deletion gateway/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func (s *Server) buildGrpcHandler(source grpcurl.DescriptorSource, resolver json
}

w.Header().Set(httpx.ContentType, httpx.JsonContentType)
handler := internal.NewEventHandler(w, resolver)
handler := internal.NewEventHandlerWithContext(r.Context(), w, resolver, httpx.HasOkHandler())
if err := grpcurl.InvokeRPC(r.Context(), source, cli.Conn(), rpcPath, s.prepareMetadata(r.Header),
handler, parser.Next); err != nil {
httpx.ErrorCtx(r.Context(), w, err)
Expand Down
6 changes: 6 additions & 0 deletions rest/httpx/responses.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ var (
okLock sync.RWMutex
)

func HasOkHandler() bool {
okLock.RLock()
defer okLock.RUnlock()
return okHandler != nil
}

// Error writes err into w.
func Error(w http.ResponseWriter, err error, fns ...func(w http.ResponseWriter, err error)) {
doHandleError(w, err, buildErrorHandler(context.Background()), WriteJson, fns...)
Expand Down
Loading