Skip to content

Commit 2097002

Browse files
authored
fix: Improve server stop handling with graceful shutdowns (go-kratos#3525)
* Improve server stop handling with graceful shutdowns - Remove default stop timeout - Add context handling for server stop - Implement graceful stop for gRPC server - Enhance HTTP server shutdown logic - Use shared context for server operations * Remove unnecessary error logging in gRPC server test
1 parent 54f8e11 commit 2097002

File tree

6 files changed

+297
-16
lines changed

6 files changed

+297
-16
lines changed

app.go

+8-4
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ func New(opts ...Option) *App {
4141
ctx: context.Background(),
4242
sigs: []os.Signal{syscall.SIGTERM, syscall.SIGQUIT, syscall.SIGINT},
4343
registrarTimeout: 10 * time.Second,
44-
stopTimeout: 10 * time.Second,
4544
}
4645
if id, err := uuid.NewUUID(); err == nil {
4746
o.id = id.String()
@@ -98,18 +97,23 @@ func (a *App) Run() error {
9897
return err
9998
}
10099
}
100+
octx := NewContext(a.opts.ctx, a)
101101
for _, srv := range a.opts.servers {
102102
server := srv
103103
eg.Go(func() error {
104104
<-ctx.Done() // wait for stop signal
105-
stopCtx, cancel := context.WithTimeout(NewContext(a.opts.ctx, a), a.opts.stopTimeout)
106-
defer cancel()
105+
stopCtx := octx
106+
if a.opts.stopTimeout > 0 {
107+
var cancel context.CancelFunc
108+
stopCtx, cancel = context.WithTimeout(stopCtx, a.opts.stopTimeout)
109+
defer cancel()
110+
}
107111
return server.Stop(stopCtx)
108112
})
109113
wg.Add(1)
110114
eg.Go(func() error {
111115
wg.Done() // here is to ensure server start has begun running before register, so defer is not needed
112-
return server.Start(NewContext(a.opts.ctx, a))
116+
return server.Start(octx)
113117
})
114118
}
115119
wg.Wait()

transport/grpc/server.go

+15-3
Original file line numberDiff line numberDiff line change
@@ -223,13 +223,25 @@ func (s *Server) Start(ctx context.Context) error {
223223
}
224224

225225
// Stop stop the gRPC server.
226-
func (s *Server) Stop(_ context.Context) error {
226+
func (s *Server) Stop(ctx context.Context) error {
227227
if s.adminClean != nil {
228228
s.adminClean()
229229
}
230230
s.health.Shutdown()
231-
s.GracefulStop()
232-
log.Info("[gRPC] server stopping")
231+
232+
done := make(chan struct{})
233+
go func() {
234+
defer close(done)
235+
log.Info("[gRPC] server stopping")
236+
s.Server.GracefulStop()
237+
}()
238+
239+
select {
240+
case <-done:
241+
case <-ctx.Done():
242+
log.Warn("[gRPC] server couldn't stop gracefully in time, doing force stop")
243+
s.Server.Stop()
244+
}
233245
return nil
234246
}
235247

transport/grpc/server_test.go

+127
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
package grpc
22

33
import (
4+
"bytes"
45
"context"
56
"crypto/tls"
67
"fmt"
78
"net"
89
"net/url"
910
"reflect"
1011
"strings"
12+
"sync"
1113
"testing"
1214
"time"
1315

@@ -17,6 +19,7 @@ import (
1719
"github.com/go-kratos/kratos/v2/errors"
1820
"github.com/go-kratos/kratos/v2/internal/matcher"
1921
pb "github.com/go-kratos/kratos/v2/internal/testdata/helloworld"
22+
"github.com/go-kratos/kratos/v2/log"
2023
"github.com/go-kratos/kratos/v2/middleware"
2124
"github.com/go-kratos/kratos/v2/transport"
2225
)
@@ -371,3 +374,127 @@ func TestListener(t *testing.T) {
371374
t.Errorf("expect not empty")
372375
}
373376
}
377+
378+
func TestStop(t *testing.T) {
379+
timeoutCtx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
380+
defer cancel()
381+
382+
tests := []struct {
383+
name string
384+
ctx context.Context
385+
cancel context.CancelFunc
386+
wantForceStop bool
387+
}{
388+
{
389+
name: "normal",
390+
ctx: context.Background(),
391+
cancel: func() {},
392+
wantForceStop: false,
393+
},
394+
{
395+
name: "timeout",
396+
ctx: timeoutCtx,
397+
cancel: cancel,
398+
wantForceStop: true,
399+
},
400+
}
401+
402+
for _, tt := range tests {
403+
t.Run(tt.name, func(t *testing.T) {
404+
l, err := net.Listen("tcp", ":0")
405+
if err != nil {
406+
t.Fatal(err)
407+
}
408+
defer l.Close()
409+
410+
old := log.GetLogger()
411+
defer log.SetLogger(old)
412+
413+
// Create a logger to capture logs
414+
var logs safeBytesBuffer
415+
log.SetLogger(log.NewStdLogger(&logs))
416+
417+
s := NewServer(Listener(l))
418+
pb.RegisterGreeterServer(s, &server{})
419+
420+
go func() {
421+
err := s.Start(context.Background()) //nolint
422+
if err != nil {
423+
log.Fatal(err)
424+
}
425+
}()
426+
427+
time.Sleep(100 * time.Millisecond)
428+
429+
conn, err := DialInsecure(
430+
context.Background(),
431+
WithEndpoint(l.Addr().String()),
432+
WithOptions(grpc.WithBlock()),
433+
)
434+
if err != nil {
435+
t.Fatal(err)
436+
}
437+
defer conn.Close()
438+
439+
go func() {
440+
client := pb.NewGreeterClient(conn)
441+
if tt.wantForceStop {
442+
// Simulate a long-running request
443+
s, err := client.SayHelloStream(context.Background()) //nolint
444+
if err != nil {
445+
log.Fatal(err)
446+
}
447+
// Keep the stream open
448+
for {
449+
// Intentionally do not send messages, only receive messages
450+
_, err := s.Recv()
451+
if err != nil {
452+
break
453+
}
454+
}
455+
} else {
456+
_, err := client.SayHello(context.Background(), &pb.HelloRequest{Name: "test"}) //nolint
457+
if err != nil {
458+
log.Error(err)
459+
}
460+
}
461+
}()
462+
463+
time.Sleep(100 * time.Millisecond)
464+
465+
err = s.Stop(tt.ctx)
466+
if err != nil {
467+
t.Errorf("Expected no error, got %v", err)
468+
return
469+
}
470+
471+
// Check if the stop was forced or graceful
472+
if tt.wantForceStop {
473+
if !strings.Contains(logs.String(), "force stop") {
474+
t.Errorf("Expected force stop\n%s", logs.String())
475+
}
476+
} else {
477+
if strings.Contains(logs.String(), "force stop") {
478+
t.Errorf("Expected graceful stop\n%s", logs.String())
479+
}
480+
}
481+
})
482+
}
483+
}
484+
485+
type safeBytesBuffer struct {
486+
mu sync.Mutex
487+
buf bytes.Buffer
488+
}
489+
490+
func (b *safeBytesBuffer) Write(p []byte) (n int, err error) {
491+
b.mu.Lock()
492+
defer b.mu.Unlock()
493+
return b.buf.Write(p)
494+
}
495+
496+
func (b *safeBytesBuffer) String() string {
497+
b.mu.Lock()
498+
defer b.mu.Unlock()
499+
return b.buf.String()
500+
}

transport/http/resolver_test.go

+36-8
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"reflect"
88
"strconv"
9+
"sync"
910
"testing"
1011
"time"
1112

@@ -86,6 +87,8 @@ type mockWatch struct {
8687

8788
nextErr bool
8889
stopErr bool
90+
91+
lock sync.Mutex
8992
}
9093

9194
func (m *mockWatch) Next() ([]*registry.ServiceInstance, error) {
@@ -94,6 +97,8 @@ func (m *mockWatch) Next() ([]*registry.ServiceInstance, error) {
9497
return nil, m.ctx.Err()
9598
default:
9699
}
100+
m.lock.Lock()
101+
defer m.lock.Unlock()
97102
if m.nextErr {
98103
return nil, errors.New("mock test error")
99104
}
@@ -115,6 +120,8 @@ func (m *mockWatch) Next() ([]*registry.ServiceInstance, error) {
115120
}
116121

117122
func (m *mockWatch) Stop() error {
123+
m.lock.Lock()
124+
defer m.lock.Unlock()
118125
if m.stopErr {
119126
return errors.New("mock test error")
120127
}
@@ -130,46 +137,67 @@ func TestResolver(t *testing.T) {
130137
return
131138
}
132139

140+
cancelCtx, cancel := context.WithCancel(context.Background())
141+
defer cancel()
142+
133143
// 异步 无需报错
134-
_, err = newResolver(context.Background(), &mockDiscoveries{true, false, false}, ta, &mockRebalancer{}, false, false, 25)
144+
r, err := newResolver(cancelCtx, &mockDiscoveries{true, false, false}, ta, &mockRebalancer{}, false, false, 25)
135145
if err != nil {
136146
t.Errorf("expect %v, got %v", nil, err)
137147
}
148+
if r != nil {
149+
_ = r.Close()
150+
}
138151

139152
// 同步 一切正常运行
140-
_, err = newResolver(context.Background(), &mockDiscoveries{false, false, false}, ta, &mockRebalancer{}, true, true, 25)
153+
r, err = newResolver(cancelCtx, &mockDiscoveries{false, false, false}, ta, &mockRebalancer{}, true, true, 25)
141154
if err != nil {
142155
t.Errorf("expect %v, got %v", nil, err)
143156
}
157+
if r != nil {
158+
_ = r.Close()
159+
}
144160

145161
// 同步 但是 next 出错 以及 stop 出错
146-
_, err = newResolver(context.Background(), &mockDiscoveries{false, true, true}, ta, &mockRebalancer{}, true, true, 25)
162+
r, err = newResolver(cancelCtx, &mockDiscoveries{false, true, true}, ta, &mockRebalancer{}, true, true, 25)
147163
if err == nil {
148164
t.Errorf("expect err, got nil")
149165
}
166+
if r != nil {
167+
_ = r.Close()
168+
}
150169

151170
// 同步 service name watch 失败
152-
_, err = newResolver(context.Background(), &mockDiscoveries{false, true, true}, &Target{
171+
r, err = newResolver(cancelCtx, &mockDiscoveries{false, true, true}, &Target{
153172
Scheme: "discovery",
154173
Endpoint: errServiceName,
155174
}, &mockRebalancer{}, true, true, 25)
156175
if err == nil {
157176
t.Errorf("expect err, got nil")
158177
}
178+
if r != nil {
179+
_ = r.Close()
180+
}
159181

160-
cancelCtx, cancel := context.WithCancel(context.Background())
161182
cancel()
162183

163184
// 此处应该打印出来 context.Canceled
164-
r, err := newResolver(cancelCtx, &mockDiscoveries{false, false, false}, ta, &mockRebalancer{}, false, false, 25)
185+
r, err = newResolver(cancelCtx, &mockDiscoveries{false, false, false}, ta, &mockRebalancer{}, false, false, 25)
165186
if err != nil {
166187
t.Errorf("expect %v, got %v", nil, err)
167188
}
168-
_ = r.Close()
189+
if r != nil {
190+
_ = r.Close()
191+
}
169192

170193
// 同步 但是服务取消,此时需要报错
171-
_, err = newResolver(cancelCtx, &mockDiscoveries{false, false, true}, ta, &mockRebalancer{}, true, true, 25)
194+
r, err = newResolver(cancelCtx, &mockDiscoveries{false, false, true}, ta, &mockRebalancer{}, true, true, 25)
172195
if err == nil {
173196
t.Errorf("expect ctx cancel err, got nil")
174197
}
198+
if r != nil {
199+
_ = r.Close()
200+
}
201+
202+
time.Sleep(100 * time.Millisecond)
175203
}

transport/http/server.go

+8-1
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,14 @@ func (s *Server) Start(ctx context.Context) error {
343343
// Stop stop the HTTP server.
344344
func (s *Server) Stop(ctx context.Context) error {
345345
log.Info("[HTTP] server stopping")
346-
return s.Shutdown(ctx)
346+
err := s.Shutdown(ctx)
347+
if err != nil {
348+
if ctx.Err() != nil {
349+
log.Warn("[HTTP] server couldn't stop gracefully in time, doing force stop")
350+
err = s.Server.Close()
351+
}
352+
}
353+
return err
347354
}
348355

349356
func (s *Server) listenAndEndpoint() error {

0 commit comments

Comments
 (0)