Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(generic): fix codec to be updated even if there is an idl update #1666

Merged
19 changes: 9 additions & 10 deletions client/genericclient/generic_stream_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ func StreamingServiceInfo(g generic.Generic) *serviceinfo.ServiceInfo {
}

func newClientStreamingServiceInfo(g generic.Generic) *serviceinfo.ServiceInfo {
readerWriter := g.MessageReaderWriter()
if readerWriter == nil {
if g.PayloadCodec() != nil {
// TODO: support grpc binary generic
panic("binary generic streaming is not supported")
}
Expand All @@ -37,12 +36,12 @@ func newClientStreamingServiceInfo(g generic.Generic) *serviceinfo.ServiceInfo {
nil,
func() interface{} {
args := &generic.Args{}
args.SetCodec(readerWriter)
args.SetCodec(g.MessageReaderWriter())
return args
},
func() interface{} {
result := &generic.Result{}
result.SetCodec(readerWriter)
result.SetCodec(g.MessageReaderWriter())
return result
},
false,
Expand All @@ -52,12 +51,12 @@ func newClientStreamingServiceInfo(g generic.Generic) *serviceinfo.ServiceInfo {
nil,
func() interface{} {
args := &generic.Args{}
args.SetCodec(readerWriter)
args.SetCodec(g.MessageReaderWriter())
return args
},
func() interface{} {
result := &generic.Result{}
result.SetCodec(readerWriter)
result.SetCodec(g.MessageReaderWriter())
return result
},
false,
Expand All @@ -67,12 +66,12 @@ func newClientStreamingServiceInfo(g generic.Generic) *serviceinfo.ServiceInfo {
nil,
func() interface{} {
args := &generic.Args{}
args.SetCodec(readerWriter)
args.SetCodec(g.MessageReaderWriter())
return args
},
func() interface{} {
result := &generic.Result{}
result.SetCodec(readerWriter)
result.SetCodec(g.MessageReaderWriter())
return result
},
false,
Expand All @@ -82,12 +81,12 @@ func newClientStreamingServiceInfo(g generic.Generic) *serviceinfo.ServiceInfo {
nil,
func() interface{} {
args := &generic.Args{}
args.SetCodec(readerWriter)
args.SetCodec(g.MessageReaderWriter())
return args
},
func() interface{} {
result := &generic.Result{}
result.SetCodec(readerWriter)
result.SetCodec(g.MessageReaderWriter())
return result
},
false,
Expand Down
8 changes: 5 additions & 3 deletions pkg/generic/generic.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ func SetBinaryWithBase64(g Generic, enable bool) error {
c.codec.convOpts.NoBase64Binary = !enable
c.codec.convOptsWithThriftBase.NoBase64Binary = !enable
}
return c.codec.updateMessageReaderWriter()
case *jsonThriftGeneric:
if c.codec == nil {
return fmt.Errorf("empty codec for %#v", c)
Expand All @@ -152,15 +153,16 @@ func SetBinaryWithBase64(g Generic, enable bool) error {
c.codec.convOptsWithThriftBase.NoBase64Binary = !enable
c.codec.convOptsWithException.NoBase64Binary = !enable
}
return c.codec.updateMessageReaderWriter()
case *mapThriftGeneric:
if c.codec == nil {
return fmt.Errorf("empty codec for %#v", c)
}
c.codec.binaryWithBase64 = enable
return c.codec.updateMessageReaderWriter()
default:
return fmt.Errorf("Base64Binary is unavailable for %#v", g)
}
return nil
}

// SetBinaryWithByteSlice enable/disable returning []byte for binary field.
Expand All @@ -171,10 +173,10 @@ func SetBinaryWithByteSlice(g Generic, enable bool) error {
return fmt.Errorf("empty codec for %#v", c)
}
c.codec.binaryWithByteSlice = enable
return c.codec.updateMessageReaderWriter()
default:
return fmt.Errorf("returning []byte for binary fields is unavailable for %#v", g)
}
return nil
}

// SetFieldsForEmptyStructMode is a enum for EnableSetFieldsForEmptyStruct()
Expand Down Expand Up @@ -202,10 +204,10 @@ func EnableSetFieldsForEmptyStruct(g Generic, mode SetFieldsForEmptyStructMode)
return fmt.Errorf("empty codec for %#v", c)
}
c.codec.setFieldsForEmptyStruct = uint8(mode)
return c.codec.updateMessageReaderWriter()
default:
return fmt.Errorf("SetFieldsForEmptyStruct only supports map-generic at present")
}
return nil
}

var thriftCodec = thrift.NewThriftCodec()
Expand Down
40 changes: 35 additions & 5 deletions pkg/generic/generic_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type Service interface {
// ServiceInfoWithGeneric create a generic ServiceInfo
func ServiceInfoWithGeneric(g Generic) *serviceinfo.ServiceInfo {
isCombinedServices := getIsCombinedServices(g)
return newServiceInfo(g.PayloadCodecType(), g.MessageReaderWriter(), g.IDLServiceName(), isCombinedServices)
return newServiceInfo(g, isCombinedServices)
}

func getIsCombinedServices(g Generic) bool {
Expand All @@ -44,16 +44,16 @@ func getIsCombinedServices(g Generic) bool {
return false
}

func newServiceInfo(pcType serviceinfo.PayloadCodec, messageReaderWriter interface{}, serviceName string, isCombinedServices bool) *serviceinfo.ServiceInfo {
func newServiceInfo(g Generic, isCombinedServices bool) *serviceinfo.ServiceInfo {
handlerType := (*Service)(nil)

methods, svcName := GetMethodInfo(messageReaderWriter, serviceName)
methods, svcName := getMethodInfo(g, g.IDLServiceName())

svcInfo := &serviceinfo.ServiceInfo{
ServiceName: svcName,
HandlerType: handlerType,
Methods: methods,
PayloadCodec: pcType,
PayloadCodec: g.PayloadCodecType(),
Extra: make(map[string]interface{}),
}
svcInfo.Extra["generic"] = true
Expand All @@ -63,7 +63,37 @@ func newServiceInfo(pcType serviceinfo.PayloadCodec, messageReaderWriter interfa
return svcInfo
}

// GetMethodInfo is only used in kitex, please DON'T USE IT. This method may be removed in the future
func getMethodInfo(g Generic, serviceName string) (methods map[string]serviceinfo.MethodInfo, svcName string) {
if g.PayloadCodec() != nil {
// note: binary generic cannot be used with multi-service feature
svcName = serviceinfo.GenericService
methods = map[string]serviceinfo.MethodInfo{
serviceinfo.GenericMethod: serviceinfo.NewMethodInfo(callHandler, newGenericServiceCallArgs, newGenericServiceCallResult, false),
}
} else {
svcName = serviceName
methods = map[string]serviceinfo.MethodInfo{
serviceinfo.GenericMethod: serviceinfo.NewMethodInfo(
callHandler,
func() interface{} {
args := &Args{}
args.SetCodec(g.MessageReaderWriter())
return args
},
func() interface{} {
result := &Result{}
result.SetCodec(g.MessageReaderWriter())
return result
},
false,
),
}
}
return
}

// GetMethodInfo is only used in kitex, please DON'T USE IT.
// DEPRECATED: this method is no longer used. This method will be removed in the future
func GetMethodInfo(messageReaderWriter interface{}, serviceName string) (methods map[string]serviceinfo.MethodInfo, svcName string) {
if messageReaderWriter == nil {
// note: binary generic cannot be used with multi-service feature
Expand Down
30 changes: 15 additions & 15 deletions pkg/generic/httppbthrift_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package generic
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"strings"
Expand All @@ -37,12 +38,13 @@ import (
var _ Closer = &httpPbThriftCodec{}

type httpPbThriftCodec struct {
svcDsc atomic.Value // *idl
pbSvcDsc atomic.Value // *pbIdl
provider DescriptorProvider
pbProvider PbDescriptorProvider
svcName string
extra map[string]string
svcDsc atomic.Value // *idl
pbSvcDsc atomic.Value // *pbIdl
provider DescriptorProvider
pbProvider PbDescriptorProvider
svcName string
extra map[string]string
readerWriter atomic.Value // *thrift.HTTPPbReaderWriter
}

func newHTTPPbThriftCodec(p DescriptorProvider, pbp PbDescriptorProvider) *httpPbThriftCodec {
Expand All @@ -57,6 +59,7 @@ func newHTTPPbThriftCodec(p DescriptorProvider, pbp PbDescriptorProvider) *httpP
c.setCombinedServices(svc.IsCombinedServices)
c.svcDsc.Store(svc)
c.pbSvcDsc.Store(pbSvc)
c.readerWriter.Store(thrift.NewHTTPPbReaderWriter(svc, pbSvc))
go c.update()
return c
}
Expand All @@ -77,6 +80,7 @@ func (c *httpPbThriftCodec) update() {
c.setCombinedServices(svc.IsCombinedServices)
c.svcDsc.Store(svc)
c.pbSvcDsc.Store(pbSvc)
c.readerWriter.Store(thrift.NewHTTPPbReaderWriter(svc, pbSvc))
}
}

Expand Down Expand Up @@ -105,16 +109,12 @@ func (c *httpPbThriftCodec) getMethod(req interface{}) (*Method, error) {
}

func (c *httpPbThriftCodec) getMessageReaderWriter() interface{} {
svcDsc, ok := c.svcDsc.Load().(*descriptor.ServiceDescriptor)
if !ok {
return errors.New("get parser ServiceDescriptor failed")
}
pbSvcDsc, ok := c.pbSvcDsc.Load().(*desc.ServiceDescriptor)
if !ok {
return errors.New("get parser PbServiceDescriptor failed")
v := c.readerWriter.Load()
if rw, ok := v.(*thrift.HTTPPbReaderWriter); !ok {
panic(fmt.Sprintf("get readerWriter failed: expected *thrift.HTTPPbReaderWriter, got %T", v))
} else {
return rw
}

return thrift.NewHTTPPbReaderWriter(svcDsc, pbSvcDsc)
}

func (c *httpPbThriftCodec) Name() string {
Expand Down
4 changes: 1 addition & 3 deletions pkg/generic/httppbthrift_codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,6 @@ func TestHTTPPbThriftCodec(t *testing.T) {
test.Assert(t, htc.extra[CombineServiceKey] == "false")

rw := htc.getMessageReaderWriter()
_, ok := rw.(thrift.MessageWriter)
test.Assert(t, ok)
_, ok = rw.(thrift.MessageReader)
_, ok := rw.(*thrift.HTTPPbReaderWriter)
test.Assert(t, ok)
}
32 changes: 25 additions & 7 deletions pkg/generic/httpthrift_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package generic
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"sync/atomic"
Expand Down Expand Up @@ -50,6 +51,7 @@ type httpThriftCodec struct {
useRawBodyForHTTPResp bool
svcName string
extra map[string]string
readerWriter atomic.Value // *thrift.HTTPReaderWriter
}

func newHTTPThriftCodec(p DescriptorProvider, opts *Options) *httpThriftCodec {
Expand All @@ -73,6 +75,7 @@ func newHTTPThriftCodec(p DescriptorProvider, opts *Options) *httpThriftCodec {
}
c.setCombinedServices(svc.IsCombinedServices)
c.svcDsc.Store(svc)
c.configureMessageReaderWriter(svc)
go c.update()
return c
}
Expand All @@ -86,9 +89,26 @@ func (c *httpThriftCodec) update() {
c.svcName = svc.Name
c.setCombinedServices(svc.IsCombinedServices)
c.svcDsc.Store(svc)
c.configureMessageReaderWriter(svc)
}
}

func (c *httpThriftCodec) updateMessageReaderWriter() (err error) {
svc, ok := c.svcDsc.Load().(*descriptor.ServiceDescriptor)
if !ok {
return errors.New("get parser ServiceDescriptor failed")
}
c.configureMessageReaderWriter(svc)
return nil
}

func (c *httpThriftCodec) configureMessageReaderWriter(svc *descriptor.ServiceDescriptor) {
rw := thrift.NewHTTPReaderWriter(svc)
c.configureHTTPRequestWriter(rw.WriteHTTPRequest)
c.configureHTTPResponseReader(rw.ReadHTTPResponse)
c.readerWriter.Store(rw)
}

func (c *httpThriftCodec) setCombinedServices(isCombinedServices bool) {
if isCombinedServices {
c.extra[CombineServiceKey] = "true"
Expand All @@ -98,14 +118,12 @@ func (c *httpThriftCodec) setCombinedServices(isCombinedServices bool) {
}

func (c *httpThriftCodec) getMessageReaderWriter() interface{} {
svcDsc, ok := c.svcDsc.Load().(*descriptor.ServiceDescriptor)
if !ok {
return errors.New("get parser ServiceDescriptor failed")
v := c.readerWriter.Load()
if rw, ok := v.(*thrift.HTTPReaderWriter); !ok {
panic(fmt.Sprintf("get readerWriter failed: expected *thrift.HTTPReaderWriter, got %T", v))
} else {
return rw
}
rw := thrift.NewHTTPReaderWriter(svcDsc)
c.configureHTTPRequestWriter(rw.WriteHTTPRequest)
c.configureHTTPResponseReader(rw.ReadHTTPResponse)
return rw
}

func (c *httpThriftCodec) configureHTTPRequestWriter(writer *thrift.WriteHTTPRequest) {
Expand Down
8 changes: 2 additions & 6 deletions pkg/generic/httpthrift_codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,7 @@ func TestHttpThriftCodec(t *testing.T) {
test.Assert(t, !ok)

rw = htc.getMessageReaderWriter()
_, ok = rw.(thrift.MessageWriter)
test.Assert(t, ok)
_, ok = rw.(thrift.MessageReader)
_, ok = rw.(*thrift.HTTPReaderWriter)
test.Assert(t, ok)
}

Expand Down Expand Up @@ -105,9 +103,7 @@ func TestHttpThriftCodecWithDynamicGo(t *testing.T) {
test.Assert(t, htc.extra[CombineServiceKey] == "false")

rw := htc.getMessageReaderWriter()
_, ok := rw.(thrift.MessageWriter)
test.Assert(t, ok)
_, ok = rw.(thrift.MessageReader)
_, ok := rw.(*thrift.HTTPReaderWriter)
test.Assert(t, ok)
}

Expand Down
Loading