Skip to content

Commit 963eba0

Browse files
committed
feat(ttstream): implement whole stream timeout
1 parent c0f8a10 commit 963eba0

File tree

3 files changed

+31
-5
lines changed

3 files changed

+31
-5
lines changed

pkg/remote/trans/ttstream/client_provider.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ package ttstream
1818

1919
import (
2020
"context"
21+
"strconv"
22+
"time"
2123

2224
"github.com/bytedance/gopkg/cloud/metainfo"
2325
"github.com/cloudwego/gopkg/protocol/ttheader"
@@ -88,6 +90,12 @@ func (c clientProvider) NewStream(ctx context.Context, ri rpcinfo.RPCInfo) (stre
8890
if strHeader == nil {
8991
strHeader = map[string]string{}
9092
}
93+
// retrieve deadline fron context as the whole stream timeout
94+
if ddl, ok := ctx.Deadline(); ok {
95+
tm := time.Until(ddl)
96+
intHeader[ttheader.RPCTimeout] = strconv.Itoa(int(tm.Milliseconds()))
97+
}
98+
9199
strHeader[ttheader.HeaderIDLServiceName] = invocation.ServiceName()
92100
metainfo.SaveMetaInfoToMap(ctx, strHeader)
93101

pkg/remote/trans/ttstream/server_provider.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,15 @@ import (
2121
"fmt"
2222
"net"
2323
"strconv"
24+
"time"
2425

2526
"github.com/bytedance/gopkg/cloud/metainfo"
2627
"github.com/cloudwego/gopkg/protocol/thrift"
2728
"github.com/cloudwego/gopkg/protocol/ttheader"
2829
"github.com/cloudwego/netpoll"
2930

3031
"github.com/cloudwego/kitex/pkg/kerrors"
32+
"github.com/cloudwego/kitex/pkg/klog"
3133
"github.com/cloudwego/kitex/pkg/remote"
3234
"github.com/cloudwego/kitex/pkg/remote/trans/ttstream/ktx"
3335
"github.com/cloudwego/kitex/pkg/streaming"
@@ -113,6 +115,18 @@ func (s *serverProvider) OnStream(ctx context.Context, conn net.Conn) (context.C
113115
// cancel ctx when OnStreamFinish
114116
ctx, cancelFunc := ktx.WithCancel(ctx)
115117
ctx = context.WithValue(ctx, serverStreamCancelCtxKey{}, cancelFunc)
118+
// process whole stream timeout
119+
var cancel context.CancelFunc
120+
if tmStr, ok := st.meta[ttheader.RPCTimeout]; ok {
121+
tm, err := strconv.Atoi(tmStr)
122+
if err == nil {
123+
ctx, cancel = context.WithTimeout(ctx, time.Duration(tm)*time.Millisecond)
124+
st.cancelFunc = cancel
125+
} else {
126+
klog.CtxErrorf(ctx, "ttstream decode RPCTimeout failed, err: %v", err)
127+
}
128+
}
129+
116130
return ctx, ss, nil
117131
}
118132

pkg/remote/trans/ttstream/stream.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ type stream struct {
102102
recvTimeout time.Duration
103103
metaFrameHandler MetaFrameHandler
104104
closeCallback []func(error)
105+
cancelFunc context.CancelFunc // only valid in server stream
105106
}
106107

107108
func (s *stream) Service() string {
@@ -126,8 +127,7 @@ func (s *stream) TransportProtocol() ktransport.Protocol {
126127
func (s *stream) SendMsg(ctx context.Context, msg any) (err error) {
127128
if state := atomic.LoadInt32(&s.state); state == streamStateHalfCloseLocal || state == streamStateInactive {
128129
if ex := s.clientStreamException.Load(); ex == nil {
129-
// 这个错误设计是有问题的,不能这么玩
130-
// 而且这里可以预定义,不需要每次都生成
130+
// todo: predefine error here
131131
return errIllegalOperation.WithCause(errors.New("stream is closed send"))
132132
} else {
133133
return ex.(error)
@@ -155,9 +155,11 @@ func (s *stream) RecvMsg(ctx context.Context, data any) error {
155155
ctx, cancel = context.WithTimeout(ctx, s.recvTimeout)
156156
defer cancel()
157157
}
158-
// 在这个环境检测 ctx 是否被 cancel
158+
// todo: format the error returned by output
159+
// like gRPC ContextError
159160
payload, err := s.reader.output(ctx)
160161
if err != nil {
162+
s.close(err, true, clientTransport)
161163
return err
162164
}
163165
err = DecodePayload(context.Background(), payload, data)
@@ -195,6 +197,10 @@ func (s *stream) close(exception error, rst bool, kind int32) error {
195197
// stream has been closed
196198
return nil
197199
}
200+
// todo: think about the cancel logic location
201+
if s.cancelFunc != nil {
202+
s.cancelFunc()
203+
}
198204
select {
199205
case s.headerSig <- streamSigInactive:
200206
default:
@@ -204,7 +210,6 @@ func (s *stream) close(exception error, rst bool, kind int32) error {
204210
default:
205211
}
206212
s.reader.close(exception)
207-
// 是否需要区分 kind?是需要判断的,因为 server stream 不需要这个东西
208213
if kind == clientTransport && exception != nil {
209214
s.clientStreamException.Store(exception)
210215
}
@@ -310,7 +315,6 @@ func (s *stream) sendTrailer(exception error) (err error) {
310315
return s.writeFrame(trailerFrameType, nil, wtrailer, payload)
311316
}
312317

313-
// todo: 处理写入错误,需要感知到连接断开这个现象
314318
func (s *stream) sendRstFrame(exception error) (err error) {
315319
klog.Debugf("stream[%d] send rst frame: err=%v", s.sid, exception)
316320

0 commit comments

Comments
 (0)