Skip to content

Commit 04b166c

Browse files
author
Eyal Posener
authored
Merge pull request #26 from posener/go1.10
Use go1.10 new features
2 parents a9f917d + 3ecf1c6 commit 04b166c

File tree

16 files changed

+78
-747
lines changed

16 files changed

+78
-747
lines changed

.travis.yml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
language: go
22
sudo: false
33
go:
4-
- 1.7
5-
- 1.8
6-
- tip
4+
- master
75

86
before_install:
97
- go get -u -t ./...
108

119
script:
12-
- ./go.test.sh
10+
- go test -count 20 -v -race -coverprofile=coverage.txt -covermode=atomic ./...
1311

1412
after_success:
1513
- bash <(curl -s https://codecov.io/bash)

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ func TestHandler(t *testing.T) {
3939
- s := httptest.NewServer(h)
4040
- defer s.Close()
4141
- d := websocket.Dialer{}
42-
+ d := wstest.NewDialer(h, nil) // or t.Log instead of nil
42+
+ d := wstest.NewDialer(h)
4343

4444
- c, resp, err := d.Dial("ws://" + s.Listener.Addr().String() + "/ws", nil)
4545
+ c, resp, err := d.Dial("ws://" + "whatever" + "/ws", nil)

dialer.go

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ import (
1212
"net/http/httptest"
1313

1414
"github.com/gorilla/websocket"
15-
"github.com/posener/wstest/pipe"
1615
)
1716

1817
// NewDialer creates a wstest recorder to an http.Handler which accepts websocket upgrades.
@@ -21,11 +20,9 @@ import (
2120
// client running on the current program flow
2221
//
2322
// h is an http.Handler that handles websocket connections.
24-
// debugLog is a function for a log.Println-like function for printing everything that
25-
// is passed over the connection. Can be set to nil if no logs are needed.
2623
// It returns a *websocket.Dial struct, which can then be used to dial to the handler.
27-
func NewDialer(h http.Handler, debugLog pipe.Println) *websocket.Dialer {
28-
client, server := pipe.New(debugLog)
24+
func NewDialer(h http.Handler) *websocket.Dialer {
25+
client, server := net.Pipe()
2926
conn := &recorder{server: server}
3027

3128
// run the runServer in a goroutine, so when the Dial send the request to

dialer_test.go

Lines changed: 72 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -1,140 +1,117 @@
11
package wstest_test
22

33
import (
4-
"context"
54
"fmt"
6-
"net"
75
"net/http"
8-
"strings"
96
"testing"
107
"time"
118

129
"github.com/gorilla/websocket"
13-
1410
"github.com/posener/wstest"
11+
"github.com/stretchr/testify/assert"
12+
"github.com/stretchr/testify/require"
1513
)
1614

1715
// TestClient demonstrate the usage of wstest package
1816
func TestClient(t *testing.T) {
1917
t.Parallel()
2018
var (
21-
s = &handler{Upgraded: make(chan struct{})}
22-
d = wstest.NewDialer(s, t.Log)
19+
s = &handler{Upgraded: make(chan struct{})}
20+
d = wstest.NewDialer(s)
21+
done = make(chan struct{})
2322
)
2423

2524
c, resp, err := d.Dial("ws://example.org/ws", nil)
26-
if err != nil {
27-
t.Fatalf("Failed connecting to s: %s", err)
28-
}
25+
require.Nil(t, err)
2926

3027
<-s.Upgraded
3128

32-
if got, want := resp.StatusCode, http.StatusSwitchingProtocols; got != want {
33-
t.Errorf("resp.StatusCode = %q, want %q", got, want)
34-
}
29+
assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
3530

3631
for i := 0; i < 3; i++ {
3732
msg := fmt.Sprintf("hello, world! %d", i)
3833

39-
err := c.WriteMessage(websocket.TextMessage, []byte(msg))
40-
if err != nil {
41-
t.Fatal(err)
42-
}
34+
go func() {
35+
err := c.WriteMessage(websocket.TextMessage, []byte(msg))
36+
require.Nil(t, err)
37+
done <- struct{}{}
38+
}()
4339

4440
mT, m, err := s.ReadMessage()
45-
if err != nil {
46-
t.Fatal(err)
47-
}
41+
require.Nil(t, err)
4842

49-
if want, got := msg, string(m); want != got {
50-
t.Errorf("dialer got %q, want %q", got, want)
51-
}
52-
if want, got := websocket.TextMessage, mT; want != got {
53-
t.Errorf("message type = %q, want %q", got, want)
54-
}
43+
assert.Equal(t, msg, string(m))
44+
assert.Equal(t, websocket.TextMessage, mT)
45+
<-done
5546

56-
s.WriteMessage(websocket.TextMessage, []byte(msg))
57-
if err != nil {
58-
t.Fatal(err)
59-
}
47+
go func() {
48+
err := s.WriteMessage(websocket.TextMessage, []byte(msg))
49+
require.Nil(t, err)
50+
done <- struct{}{}
51+
}()
6052

6153
mT, m, err = c.ReadMessage()
62-
if err != nil {
63-
t.Fatal(err)
64-
}
54+
require.Nil(t, err)
6555

66-
if want, got := msg, string(m); want != got {
67-
t.Errorf("client got %q, want %q", got, want)
68-
}
69-
if want, got := websocket.TextMessage, mT; want != got {
70-
t.Errorf("message type = %q , want %q", got, want)
71-
}
56+
assert.Equal(t, msg, string(m))
57+
assert.Equal(t, websocket.TextMessage, mT)
58+
<-done
7259
}
7360

7461
err = c.Close()
75-
if err != nil {
76-
t.Fatal(err)
77-
}
62+
require.Nil(t, err)
7863

7964
err = s.Close()
80-
if err != nil {
81-
t.Fatal(err)
82-
}
65+
require.Nil(t, err)
8366
}
8467

8568
// TestConcurrent tests concurrent reads and writes from a connection
8669
func TestConcurrent(t *testing.T) {
8770
t.Parallel()
8871
var (
8972
s = &handler{Upgraded: make(chan struct{})}
90-
d = wstest.NewDialer(s, nil)
73+
d = wstest.NewDialer(s)
9174
count = 20
9275
)
9376

9477
c, _, err := d.Dial("ws://example.org/ws", nil)
95-
if err != nil {
96-
t.Fatalf("Failed connecting to s: %s", err)
97-
}
78+
require.Nil(t, err)
9879

9980
<-s.Upgraded
10081

10182
for _, pair := range []struct{ src, dst *websocket.Conn }{{s.Conn, c}, {c, s.Conn}} {
10283
go func() {
10384
for i := 0; i < count; i++ {
104-
pair.src.WriteJSON(i)
85+
err := pair.src.WriteJSON(i)
86+
require.Nil(t, err)
10587
}
10688
}()
10789

10890
received := make([]bool, count)
10991

11092
for i := 0; i < count; i++ {
11193
var j int
112-
pair.dst.ReadJSON(&j)
94+
err := pair.dst.ReadJSON(&j)
95+
require.Nil(t, err)
11396

11497
received[j] = true
11598
}
11699

117-
missing := []int{}
100+
var missing []int
118101

119102
for i := range received {
120103
if !received[i] {
121104
missing = append(missing, i)
122105
}
123106
}
124-
if len(missing) > 0 {
125-
t.Errorf("%q -> %q: Did not received: %q", pair.src.LocalAddr(), pair.dst.LocalAddr(), missing)
126-
}
107+
assert.Equal(t, 0, len(missing), "%q -> %q: Did not received: %q", pair.src.LocalAddr(), pair.dst.LocalAddr(), missing)
127108
}
128109

129110
err = c.Close()
130-
if err != nil {
131-
t.Fatal(err)
132-
}
111+
require.Nil(t, err)
133112

134113
err = s.Close()
135-
if err != nil {
136-
t.Fatal(err)
137-
}
114+
require.Nil(t, err)
138115
}
139116

140117
func TestBadAddress(t *testing.T) {
@@ -157,38 +134,30 @@ func TestBadAddress(t *testing.T) {
157134
for _, tt := range tests {
158135
t.Run(tt.url, func(t *testing.T) {
159136
s := &handler{Upgraded: make(chan struct{})}
160-
d := wstest.NewDialer(s, nil)
137+
d := wstest.NewDialer(s)
161138
c, resp, err := d.Dial(tt.url, nil)
162-
if c != nil {
163-
t.Errorf("d = %T, want nil", c)
164-
}
165-
if err == nil {
166-
t.Error("opError is nil")
167-
}
139+
assert.Nil(t, c)
140+
assert.NotNil(t, err)
168141
if tt.code != 0 {
169-
if got, want := resp.StatusCode, tt.code; got != want {
170-
t.Errorf("resp.StatusCode = %q, want %q", got, want)
171-
}
142+
assert.Equal(t, tt.code, resp.StatusCode)
172143
}
173144

174145
err = s.Close()
175-
if err != nil {
176-
t.Fatal(err)
177-
}
146+
require.Nil(t, err)
178147
})
179148
}
180149
}
181150

151+
const deadlineExceeded = "deadline exceeded"
152+
182153
// TestConnectDeadline tests connection deadlines
183154
func TestDeadlines(t *testing.T) {
184155
t.Parallel()
185156
h := &handler{Upgraded: make(chan struct{})}
186-
d := wstest.NewDialer(h, nil)
157+
d := wstest.NewDialer(h)
187158

188159
c, _, err := d.Dial("ws://example.org/ws", nil)
189-
if err != nil {
190-
t.Fatalf("Failed connecting to h: %q", err)
191-
}
160+
require.Nil(t, err)
192161

193162
<-h.Upgraded
194163

@@ -198,27 +167,20 @@ func TestDeadlines(t *testing.T) {
198167

199168
// set the deadline to now, and test for timeout
200169
pair.dst.SetReadDeadline(time.Now())
201-
err = pair.dst.ReadJSON(i)
202-
if got, want := err.Error(), context.DeadlineExceeded.Error(); !strings.Contains(got, want) {
203-
t.Errorf("err = %q, not conains %q", got, want)
204-
}
205-
err = pair.dst.ReadJSON(i)
206-
if got, want := err.Error(), context.DeadlineExceeded.Error(); !strings.Contains(got, want) {
207-
t.Errorf("err = %q, not conains %q", got, want)
208-
}
170+
err = pair.dst.ReadJSON(&i)
171+
assert.Contains(t, err.Error(), deadlineExceeded)
209172

210-
pair.src.WriteJSON(1)
211-
err = pair.dst.ReadJSON(i)
212-
if got, want := err.Error(), context.DeadlineExceeded.Error(); !strings.Contains(got, want) {
213-
t.Errorf("err = %q, not conains %q", got, want)
214-
}
173+
err = pair.dst.ReadJSON(&i)
174+
assert.Contains(t, err.Error(), deadlineExceeded)
175+
176+
go pair.src.WriteJSON(1)
177+
err = pair.dst.ReadJSON(&i)
178+
assert.Contains(t, err.Error(), deadlineExceeded)
215179

216180
// even after updating the deadline, should get an error
217181
pair.dst.SetReadDeadline(time.Now().Add(time.Second))
218-
err = pair.dst.ReadJSON(i)
219-
if got, want := err.Error(), context.DeadlineExceeded.Error(); !strings.Contains(got, want) {
220-
t.Errorf("err = %q, not conains %q", got, want)
221-
}
182+
err = pair.dst.ReadJSON(&i)
183+
assert.Contains(t, err.Error(), deadlineExceeded)
222184
}
223185
}
224186

@@ -229,42 +191,35 @@ func TestConnectDeadline(t *testing.T) {
229191
tests := []struct {
230192
path string
231193
timeout time.Duration
232-
err error
194+
wantErr bool
233195
}{
234196
{
235-
"/ws/delay",
236-
time.Millisecond,
237-
context.DeadlineExceeded,
197+
path: "/ws/delay",
198+
timeout: time.Millisecond,
199+
wantErr: true,
238200
},
239201
{
240-
"/ws",
241-
time.Second,
242-
nil,
202+
path: "/ws",
203+
timeout: time.Second,
243204
},
244205
}
245206

246207
for _, tt := range tests {
247208
t.Run(fmt.Sprintf("%s/%s", tt.path, tt.timeout), func(t *testing.T) {
248209
s := &handler{Upgraded: make(chan struct{})}
249-
d := wstest.NewDialer(s, nil)
210+
d := wstest.NewDialer(s)
250211
d.HandshakeTimeout = tt.timeout
251212
_, _, err := d.Dial("ws://example.org"+tt.path, nil)
252-
if tt.err == nil {
253-
if err != nil {
254-
t.Errorf("err = %q, want nil", err)
255-
}
256-
} else {
257-
if got, want := err.(*net.OpError).Err, tt.err; got != want {
258-
t.Errorf("err = %q, want %q", got, want)
259-
}
213+
if tt.wantErr {
214+
assert.NotNil(t, err)
215+
return
260216
}
261217

262-
if tt.err == nil {
263-
select {
264-
case <-s.Upgraded:
265-
case <-time.After(time.Second):
266-
t.Fatal("connection was not upgraded after 1s")
267-
}
218+
assert.Nil(t, err)
219+
select {
220+
case <-s.Upgraded:
221+
case <-time.After(time.Second):
222+
t.Fatal("connection was not upgraded after 1s")
268223
}
269224
})
270225
}
@@ -294,12 +249,12 @@ func (s *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
294249
}
295250

296251
func (s *handler) connect(w http.ResponseWriter, r *http.Request) {
252+
defer close(s.Upgraded)
297253
var err error
298254
s.Conn, err = s.upgrader.Upgrade(w, r, nil)
299255
if err != nil {
300256
return
301257
}
302-
close(s.Upgraded)
303258
}
304259

305260
func (s *handler) Close() error {

example_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ func ExampleNewDialer() {
1919
// it uses the gorilla's websocket.Dial function, over a fake net.Conn struct.
2020
// it runs the handler's ServeHTTP function in a goroutine, so the handler can
2121
// communicate with a client running on the current program flow
22-
d = wstest.NewDialer(s, nil)
22+
d = wstest.NewDialer(s)
2323

2424
resp string
2525
)

0 commit comments

Comments
 (0)