Skip to content

Commit 3d7d14e

Browse files
Marina-SakaiHeyJavaBean
authored andcommitted
fix(generic): fix codec to be updated even if there is an idl update (cloudwego#1666)
1 parent 3ecb2ae commit 3d7d14e

15 files changed

+167
-92
lines changed

client/genericclient/generic_stream_service.go

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ func StreamingServiceInfo(g generic.Generic) *serviceinfo.ServiceInfo {
2626
}
2727

2828
func newClientStreamingServiceInfo(g generic.Generic) *serviceinfo.ServiceInfo {
29-
readerWriter := g.MessageReaderWriter()
30-
if readerWriter == nil {
29+
if g.PayloadCodec() != nil {
3130
// TODO: support grpc binary generic
3231
panic("binary generic streaming is not supported")
3332
}
@@ -37,12 +36,12 @@ func newClientStreamingServiceInfo(g generic.Generic) *serviceinfo.ServiceInfo {
3736
nil,
3837
func() interface{} {
3938
args := &generic.Args{}
40-
args.SetCodec(readerWriter)
39+
args.SetCodec(g.MessageReaderWriter())
4140
return args
4241
},
4342
func() interface{} {
4443
result := &generic.Result{}
45-
result.SetCodec(readerWriter)
44+
result.SetCodec(g.MessageReaderWriter())
4645
return result
4746
},
4847
false,
@@ -52,12 +51,12 @@ func newClientStreamingServiceInfo(g generic.Generic) *serviceinfo.ServiceInfo {
5251
nil,
5352
func() interface{} {
5453
args := &generic.Args{}
55-
args.SetCodec(readerWriter)
54+
args.SetCodec(g.MessageReaderWriter())
5655
return args
5756
},
5857
func() interface{} {
5958
result := &generic.Result{}
60-
result.SetCodec(readerWriter)
59+
result.SetCodec(g.MessageReaderWriter())
6160
return result
6261
},
6362
false,
@@ -67,12 +66,12 @@ func newClientStreamingServiceInfo(g generic.Generic) *serviceinfo.ServiceInfo {
6766
nil,
6867
func() interface{} {
6968
args := &generic.Args{}
70-
args.SetCodec(readerWriter)
69+
args.SetCodec(g.MessageReaderWriter())
7170
return args
7271
},
7372
func() interface{} {
7473
result := &generic.Result{}
75-
result.SetCodec(readerWriter)
74+
result.SetCodec(g.MessageReaderWriter())
7675
return result
7776
},
7877
false,
@@ -82,12 +81,12 @@ func newClientStreamingServiceInfo(g generic.Generic) *serviceinfo.ServiceInfo {
8281
nil,
8382
func() interface{} {
8483
args := &generic.Args{}
85-
args.SetCodec(readerWriter)
84+
args.SetCodec(g.MessageReaderWriter())
8685
return args
8786
},
8887
func() interface{} {
8988
result := &generic.Result{}
90-
result.SetCodec(readerWriter)
89+
result.SetCodec(g.MessageReaderWriter())
9190
return result
9291
},
9392
false,

pkg/generic/generic.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ func SetBinaryWithBase64(g Generic, enable bool) error {
142142
c.codec.convOpts.NoBase64Binary = !enable
143143
c.codec.convOptsWithThriftBase.NoBase64Binary = !enable
144144
}
145+
return c.codec.updateMessageReaderWriter()
145146
case *jsonThriftGeneric:
146147
if c.codec == nil {
147148
return fmt.Errorf("empty codec for %#v", c)
@@ -152,15 +153,16 @@ func SetBinaryWithBase64(g Generic, enable bool) error {
152153
c.codec.convOptsWithThriftBase.NoBase64Binary = !enable
153154
c.codec.convOptsWithException.NoBase64Binary = !enable
154155
}
156+
return c.codec.updateMessageReaderWriter()
155157
case *mapThriftGeneric:
156158
if c.codec == nil {
157159
return fmt.Errorf("empty codec for %#v", c)
158160
}
159161
c.codec.binaryWithBase64 = enable
162+
return c.codec.updateMessageReaderWriter()
160163
default:
161164
return fmt.Errorf("Base64Binary is unavailable for %#v", g)
162165
}
163-
return nil
164166
}
165167

166168
// SetBinaryWithByteSlice enable/disable returning []byte for binary field.
@@ -171,10 +173,10 @@ func SetBinaryWithByteSlice(g Generic, enable bool) error {
171173
return fmt.Errorf("empty codec for %#v", c)
172174
}
173175
c.codec.binaryWithByteSlice = enable
176+
return c.codec.updateMessageReaderWriter()
174177
default:
175178
return fmt.Errorf("returning []byte for binary fields is unavailable for %#v", g)
176179
}
177-
return nil
178180
}
179181

180182
// SetFieldsForEmptyStructMode is a enum for EnableSetFieldsForEmptyStruct()
@@ -202,10 +204,10 @@ func EnableSetFieldsForEmptyStruct(g Generic, mode SetFieldsForEmptyStructMode)
202204
return fmt.Errorf("empty codec for %#v", c)
203205
}
204206
c.codec.setFieldsForEmptyStruct = uint8(mode)
207+
return c.codec.updateMessageReaderWriter()
205208
default:
206209
return fmt.Errorf("SetFieldsForEmptyStruct only supports map-generic at present")
207210
}
208-
return nil
209211
}
210212

211213
var thriftCodec = thrift.NewThriftCodec()

pkg/generic/generic_service.go

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ type Service interface {
3232
// ServiceInfoWithGeneric create a generic ServiceInfo
3333
func ServiceInfoWithGeneric(g Generic) *serviceinfo.ServiceInfo {
3434
isCombinedServices := getIsCombinedServices(g)
35-
return newServiceInfo(g.PayloadCodecType(), g.MessageReaderWriter(), g.IDLServiceName(), isCombinedServices)
35+
return newServiceInfo(g, isCombinedServices)
3636
}
3737

3838
func getIsCombinedServices(g Generic) bool {
@@ -44,16 +44,16 @@ func getIsCombinedServices(g Generic) bool {
4444
return false
4545
}
4646

47-
func newServiceInfo(pcType serviceinfo.PayloadCodec, messageReaderWriter interface{}, serviceName string, isCombinedServices bool) *serviceinfo.ServiceInfo {
47+
func newServiceInfo(g Generic, isCombinedServices bool) *serviceinfo.ServiceInfo {
4848
handlerType := (*Service)(nil)
4949

50-
methods, svcName := GetMethodInfo(messageReaderWriter, serviceName)
50+
methods, svcName := getMethodInfo(g, g.IDLServiceName())
5151

5252
svcInfo := &serviceinfo.ServiceInfo{
5353
ServiceName: svcName,
5454
HandlerType: handlerType,
5555
Methods: methods,
56-
PayloadCodec: pcType,
56+
PayloadCodec: g.PayloadCodecType(),
5757
Extra: make(map[string]interface{}),
5858
}
5959
svcInfo.Extra["generic"] = true
@@ -63,7 +63,37 @@ func newServiceInfo(pcType serviceinfo.PayloadCodec, messageReaderWriter interfa
6363
return svcInfo
6464
}
6565

66-
// GetMethodInfo is only used in kitex, please DON'T USE IT. This method may be removed in the future
66+
func getMethodInfo(g Generic, serviceName string) (methods map[string]serviceinfo.MethodInfo, svcName string) {
67+
if g.PayloadCodec() != nil {
68+
// note: binary generic cannot be used with multi-service feature
69+
svcName = serviceinfo.GenericService
70+
methods = map[string]serviceinfo.MethodInfo{
71+
serviceinfo.GenericMethod: serviceinfo.NewMethodInfo(callHandler, newGenericServiceCallArgs, newGenericServiceCallResult, false),
72+
}
73+
} else {
74+
svcName = serviceName
75+
methods = map[string]serviceinfo.MethodInfo{
76+
serviceinfo.GenericMethod: serviceinfo.NewMethodInfo(
77+
callHandler,
78+
func() interface{} {
79+
args := &Args{}
80+
args.SetCodec(g.MessageReaderWriter())
81+
return args
82+
},
83+
func() interface{} {
84+
result := &Result{}
85+
result.SetCodec(g.MessageReaderWriter())
86+
return result
87+
},
88+
false,
89+
),
90+
}
91+
}
92+
return
93+
}
94+
95+
// GetMethodInfo is only used in kitex, please DON'T USE IT.
96+
// DEPRECATED: this method is no longer used. This method will be removed in the future
6797
func GetMethodInfo(messageReaderWriter interface{}, serviceName string) (methods map[string]serviceinfo.MethodInfo, svcName string) {
6898
if messageReaderWriter == nil {
6999
// note: binary generic cannot be used with multi-service feature

pkg/generic/httppbthrift_codec.go

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package generic
1919
import (
2020
"context"
2121
"errors"
22+
"fmt"
2223
"io"
2324
"net/http"
2425
"strings"
@@ -37,12 +38,13 @@ import (
3738
var _ Closer = &httpPbThriftCodec{}
3839

3940
type httpPbThriftCodec struct {
40-
svcDsc atomic.Value // *idl
41-
pbSvcDsc atomic.Value // *pbIdl
42-
provider DescriptorProvider
43-
pbProvider PbDescriptorProvider
44-
svcName string
45-
extra map[string]string
41+
svcDsc atomic.Value // *idl
42+
pbSvcDsc atomic.Value // *pbIdl
43+
provider DescriptorProvider
44+
pbProvider PbDescriptorProvider
45+
svcName string
46+
extra map[string]string
47+
readerWriter atomic.Value // *thrift.HTTPPbReaderWriter
4648
}
4749

4850
func newHTTPPbThriftCodec(p DescriptorProvider, pbp PbDescriptorProvider) *httpPbThriftCodec {
@@ -57,6 +59,7 @@ func newHTTPPbThriftCodec(p DescriptorProvider, pbp PbDescriptorProvider) *httpP
5759
c.setCombinedServices(svc.IsCombinedServices)
5860
c.svcDsc.Store(svc)
5961
c.pbSvcDsc.Store(pbSvc)
62+
c.readerWriter.Store(thrift.NewHTTPPbReaderWriter(svc, pbSvc))
6063
go c.update()
6164
return c
6265
}
@@ -77,6 +80,7 @@ func (c *httpPbThriftCodec) update() {
7780
c.setCombinedServices(svc.IsCombinedServices)
7881
c.svcDsc.Store(svc)
7982
c.pbSvcDsc.Store(pbSvc)
83+
c.readerWriter.Store(thrift.NewHTTPPbReaderWriter(svc, pbSvc))
8084
}
8185
}
8286

@@ -105,16 +109,12 @@ func (c *httpPbThriftCodec) getMethod(req interface{}) (*Method, error) {
105109
}
106110

107111
func (c *httpPbThriftCodec) getMessageReaderWriter() interface{} {
108-
svcDsc, ok := c.svcDsc.Load().(*descriptor.ServiceDescriptor)
109-
if !ok {
110-
return errors.New("get parser ServiceDescriptor failed")
111-
}
112-
pbSvcDsc, ok := c.pbSvcDsc.Load().(*desc.ServiceDescriptor)
113-
if !ok {
114-
return errors.New("get parser PbServiceDescriptor failed")
112+
v := c.readerWriter.Load()
113+
if rw, ok := v.(*thrift.HTTPPbReaderWriter); !ok {
114+
panic(fmt.Sprintf("get readerWriter failed: expected *thrift.HTTPPbReaderWriter, got %T", v))
115+
} else {
116+
return rw
115117
}
116-
117-
return thrift.NewHTTPPbReaderWriter(svcDsc, pbSvcDsc)
118118
}
119119

120120
func (c *httpPbThriftCodec) Name() string {

pkg/generic/httppbthrift_codec_test.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,6 @@ func TestHTTPPbThriftCodec(t *testing.T) {
7171
test.Assert(t, htc.extra[CombineServiceKey] == "false")
7272

7373
rw := htc.getMessageReaderWriter()
74-
_, ok := rw.(thrift.MessageWriter)
75-
test.Assert(t, ok)
76-
_, ok = rw.(thrift.MessageReader)
74+
_, ok := rw.(*thrift.HTTPPbReaderWriter)
7775
test.Assert(t, ok)
7876
}

pkg/generic/httpthrift_codec.go

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package generic
1919
import (
2020
"context"
2121
"errors"
22+
"fmt"
2223
"io"
2324
"net/http"
2425
"sync/atomic"
@@ -50,6 +51,7 @@ type httpThriftCodec struct {
5051
useRawBodyForHTTPResp bool
5152
svcName string
5253
extra map[string]string
54+
readerWriter atomic.Value // *thrift.HTTPReaderWriter
5355
}
5456

5557
func newHTTPThriftCodec(p DescriptorProvider, opts *Options) *httpThriftCodec {
@@ -73,6 +75,7 @@ func newHTTPThriftCodec(p DescriptorProvider, opts *Options) *httpThriftCodec {
7375
}
7476
c.setCombinedServices(svc.IsCombinedServices)
7577
c.svcDsc.Store(svc)
78+
c.configureMessageReaderWriter(svc)
7679
go c.update()
7780
return c
7881
}
@@ -86,9 +89,26 @@ func (c *httpThriftCodec) update() {
8689
c.svcName = svc.Name
8790
c.setCombinedServices(svc.IsCombinedServices)
8891
c.svcDsc.Store(svc)
92+
c.configureMessageReaderWriter(svc)
8993
}
9094
}
9195

96+
func (c *httpThriftCodec) updateMessageReaderWriter() (err error) {
97+
svc, ok := c.svcDsc.Load().(*descriptor.ServiceDescriptor)
98+
if !ok {
99+
return errors.New("get parser ServiceDescriptor failed")
100+
}
101+
c.configureMessageReaderWriter(svc)
102+
return nil
103+
}
104+
105+
func (c *httpThriftCodec) configureMessageReaderWriter(svc *descriptor.ServiceDescriptor) {
106+
rw := thrift.NewHTTPReaderWriter(svc)
107+
c.configureHTTPRequestWriter(rw.WriteHTTPRequest)
108+
c.configureHTTPResponseReader(rw.ReadHTTPResponse)
109+
c.readerWriter.Store(rw)
110+
}
111+
92112
func (c *httpThriftCodec) setCombinedServices(isCombinedServices bool) {
93113
if isCombinedServices {
94114
c.extra[CombineServiceKey] = "true"
@@ -98,14 +118,12 @@ func (c *httpThriftCodec) setCombinedServices(isCombinedServices bool) {
98118
}
99119

100120
func (c *httpThriftCodec) getMessageReaderWriter() interface{} {
101-
svcDsc, ok := c.svcDsc.Load().(*descriptor.ServiceDescriptor)
102-
if !ok {
103-
return errors.New("get parser ServiceDescriptor failed")
121+
v := c.readerWriter.Load()
122+
if rw, ok := v.(*thrift.HTTPReaderWriter); !ok {
123+
panic(fmt.Sprintf("get readerWriter failed: expected *thrift.HTTPReaderWriter, got %T", v))
124+
} else {
125+
return rw
104126
}
105-
rw := thrift.NewHTTPReaderWriter(svcDsc)
106-
c.configureHTTPRequestWriter(rw.WriteHTTPRequest)
107-
c.configureHTTPResponseReader(rw.ReadHTTPResponse)
108-
return rw
109127
}
110128

111129
func (c *httpThriftCodec) configureHTTPRequestWriter(writer *thrift.WriteHTTPRequest) {

pkg/generic/httpthrift_codec_test.go

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,7 @@ func TestHttpThriftCodec(t *testing.T) {
7272
test.Assert(t, !ok)
7373

7474
rw = htc.getMessageReaderWriter()
75-
_, ok = rw.(thrift.MessageWriter)
76-
test.Assert(t, ok)
77-
_, ok = rw.(thrift.MessageReader)
75+
_, ok = rw.(*thrift.HTTPReaderWriter)
7876
test.Assert(t, ok)
7977
}
8078

@@ -105,9 +103,7 @@ func TestHttpThriftCodecWithDynamicGo(t *testing.T) {
105103
test.Assert(t, htc.extra[CombineServiceKey] == "false")
106104

107105
rw := htc.getMessageReaderWriter()
108-
_, ok := rw.(thrift.MessageWriter)
109-
test.Assert(t, ok)
110-
_, ok = rw.(thrift.MessageReader)
106+
_, ok := rw.(*thrift.HTTPReaderWriter)
111107
test.Assert(t, ok)
112108
}
113109

0 commit comments

Comments
 (0)