Skip to content

Commit 19aa18b

Browse files
committed
kite: cancel Context when client disconnects
1 parent 827d5a3 commit 19aa18b

File tree

4 files changed

+106
-10
lines changed

4 files changed

+106
-10
lines changed

client.go

+29
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package kite
22

33
import (
4+
"context"
45
"encoding/json"
56
"errors"
67
"fmt"
@@ -126,6 +127,11 @@ type Client struct {
126127
session sockjs.Session
127128
send chan *message
128129

130+
// ctx and cancel keeps track of session lifetime
131+
ctxMu sync.Mutex
132+
ctx context.Context
133+
cancel func()
134+
129135
// muReconnect protects Reconnect
130136
muReconnect sync.Mutex
131137

@@ -206,8 +212,13 @@ func (k *Kite) NewClient(remoteURL string) *Client {
206212
Concurrent: true,
207213
send: make(chan *message),
208214
interrupt: make(chan error, 1),
215+
ctx: context.TODO(),
216+
cancel: func() {},
209217
}
210218

219+
c.OnConnect(c.setContext)
220+
c.OnDisconnect(c.closeContext)
221+
211222
k.OnRegister(c.updateAuth)
212223

213224
return c
@@ -262,6 +273,24 @@ func (c *Client) updateAuth(reg *protocol.RegisterResult) {
262273
}
263274
}
264275

276+
func (c *Client) setContext() {
277+
c.ctxMu.Lock()
278+
c.ctx, c.cancel = context.WithCancel(context.Background())
279+
c.ctxMu.Unlock()
280+
}
281+
282+
func (c *Client) closeContext() {
283+
c.ctxMu.Lock()
284+
c.cancel()
285+
c.ctxMu.Unlock()
286+
}
287+
288+
func (c *Client) context() context.Context {
289+
c.ctxMu.Lock()
290+
defer c.ctxMu.Unlock()
291+
return c.ctx
292+
}
293+
265294
func (c *Client) authCopy() *Auth {
266295
c.authMu.Lock()
267296
defer c.authMu.Unlock()

kite.go

+1
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ func (k *Kite) sockjsHandler(session sockjs.Session) {
283283
go c.sendHub()
284284

285285
k.callOnConnectHandlers(c)
286+
c.callOnConnectHandlers()
286287

287288
// Run after methods are registered and delegate is set
288289
c.readLoop()

kite_test.go

+75-9
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@ package kite
22

33
import (
44
"errors"
5+
"flag"
56
"fmt"
67
"math"
78
"math/rand"
89
"os"
10+
"reflect"
911
"strconv"
1012
"sync"
1113
"testing"
@@ -20,6 +22,8 @@ import (
2022
"github.com/igm/sockjs-go/sockjs"
2123
)
2224

25+
var timeout = flag.Duration("telltime", 4*time.Second, "Timeout for kite calls.")
26+
2327
func init() {
2428
rand.Seed(time.Now().Unix() + int64(os.Getpid()))
2529
}
@@ -32,6 +36,76 @@ func panicRegisterHandler(*protocol.RegisterResult) {
3236
panic("this panic should be ignored")
3337
}
3438

39+
func transportFromEnv() config.Transport {
40+
env := os.Getenv("KITE_TRANSPORT")
41+
tr, ok := config.Transports[env]
42+
if env != "" && !ok {
43+
panic(fmt.Errorf("transport %q doesn't exists", env))
44+
}
45+
return tr
46+
}
47+
48+
func TestContext(t *testing.T) {
49+
flag.Parse()
50+
51+
ch := make(chan int, 4) // checkpoints, to ensure flor of control
52+
53+
k := New("server", "0.0.1")
54+
k.Config.DisableAuthentication = true
55+
k.Config.Port = 3333
56+
k.Config.Transport = transportFromEnv()
57+
k.HandleFunc("longrunning", func(r *Request) (interface{}, error) {
58+
ch <- 2
59+
60+
go func() {
61+
<-r.Context.Done()
62+
ch <- 4
63+
close(ch)
64+
}()
65+
return nil, nil
66+
})
67+
go k.Run()
68+
<-k.ServerReadyNotify()
69+
defer k.Close()
70+
71+
c := New("client", "0.0.1").NewClient("http://127.0.0.1:3333/kite")
72+
if err := c.Dial(); err != nil {
73+
t.Fatalf("Dial()=%s", err)
74+
}
75+
76+
ch <- 1
77+
78+
if _, err := c.TellWithTimeout("longrunning", *timeout); err != nil {
79+
t.Fatalf("TellWithTimeout()=%s", err)
80+
}
81+
82+
ch <- 3
83+
84+
c.Close()
85+
86+
var got []int
87+
timeout := time.After(2 * time.Second)
88+
89+
collect:
90+
for {
91+
select {
92+
case i, ok := <-ch:
93+
if !ok {
94+
break collect
95+
}
96+
got = append(got, i)
97+
case <-timeout:
98+
t.Fatal("timed out collecting checkpoints")
99+
}
100+
}
101+
102+
want := []int{1, 2, 3, 4}
103+
104+
if !reflect.DeepEqual(got, want) {
105+
t.Fatalf("got %v, want %v")
106+
}
107+
}
108+
35109
func TestMultiple(t *testing.T) {
36110
testDuration := time.Second * 10
37111

@@ -44,15 +118,7 @@ func TestMultiple(t *testing.T) {
44118
// ports are starting from 6000 up to 6000 + kiteNumber
45119
port := 6000
46120

47-
var transport config.Transport
48-
if transportName := os.Getenv("KITE_TRANSPORT"); transportName != "" {
49-
tr, ok := config.Transports[transportName]
50-
if !ok {
51-
t.Fatalf("transport '%s' doesn't exists", transportName)
52-
}
53-
54-
transport = tr
55-
}
121+
transport := transportFromEnv()
56122

57123
for i := 0; i < kiteNumber; i++ {
58124
m := New("mathworker"+strconv.Itoa(i), "0.1."+strconv.Itoa(i))

request.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ func (c *Client) newRequest(method string, args *dnode.Partial) (*Request, func(
154154
LocalKite: c.LocalKite,
155155
Client: c,
156156
Auth: options.Auth,
157-
Context: context.TODO(),
157+
Context: c.context(),
158158
}
159159

160160
// Call response callback function, send back our response

0 commit comments

Comments
 (0)