11package wstest_test
22
33import (
4- "context"
54 "fmt"
6- "net"
75 "net/http"
8- "strings"
96 "testing"
107 "time"
118
129 "github.com/gorilla/websocket"
13-
1410 "github.com/posener/wstest"
11+ "github.com/stretchr/testify/assert"
12+ "github.com/stretchr/testify/require"
1513)
1614
1715// TestClient demonstrate the usage of wstest package
1816func TestClient (t * testing.T ) {
1917 t .Parallel ()
2018 var (
21- s = & handler {Upgraded : make (chan struct {})}
22- d = wstest .NewDialer (s , t .Log )
19+ s = & handler {Upgraded : make (chan struct {})}
20+ d = wstest .NewDialer (s )
21+ done = make (chan struct {})
2322 )
2423
2524 c , resp , err := d .Dial ("ws://example.org/ws" , nil )
26- if err != nil {
27- t .Fatalf ("Failed connecting to s: %s" , err )
28- }
25+ require .Nil (t , err )
2926
3027 <- s .Upgraded
3128
32- if got , want := resp .StatusCode , http .StatusSwitchingProtocols ; got != want {
33- t .Errorf ("resp.StatusCode = %q, want %q" , got , want )
34- }
29+ assert .Equal (t , http .StatusSwitchingProtocols , resp .StatusCode )
3530
3631 for i := 0 ; i < 3 ; i ++ {
3732 msg := fmt .Sprintf ("hello, world! %d" , i )
3833
39- err := c .WriteMessage (websocket .TextMessage , []byte (msg ))
40- if err != nil {
41- t .Fatal (err )
42- }
34+ go func () {
35+ err := c .WriteMessage (websocket .TextMessage , []byte (msg ))
36+ require .Nil (t , err )
37+ done <- struct {}{}
38+ }()
4339
4440 mT , m , err := s .ReadMessage ()
45- if err != nil {
46- t .Fatal (err )
47- }
41+ require .Nil (t , err )
4842
49- if want , got := msg , string (m ); want != got {
50- t .Errorf ("dialer got %q, want %q" , got , want )
51- }
52- if want , got := websocket .TextMessage , mT ; want != got {
53- t .Errorf ("message type = %q, want %q" , got , want )
54- }
43+ assert .Equal (t , msg , string (m ))
44+ assert .Equal (t , websocket .TextMessage , mT )
45+ <- done
5546
56- s .WriteMessage (websocket .TextMessage , []byte (msg ))
57- if err != nil {
58- t .Fatal (err )
59- }
47+ go func () {
48+ err := s .WriteMessage (websocket .TextMessage , []byte (msg ))
49+ require .Nil (t , err )
50+ done <- struct {}{}
51+ }()
6052
6153 mT , m , err = c .ReadMessage ()
62- if err != nil {
63- t .Fatal (err )
64- }
54+ require .Nil (t , err )
6555
66- if want , got := msg , string (m ); want != got {
67- t .Errorf ("client got %q, want %q" , got , want )
68- }
69- if want , got := websocket .TextMessage , mT ; want != got {
70- t .Errorf ("message type = %q , want %q" , got , want )
71- }
56+ assert .Equal (t , msg , string (m ))
57+ assert .Equal (t , websocket .TextMessage , mT )
58+ <- done
7259 }
7360
7461 err = c .Close ()
75- if err != nil {
76- t .Fatal (err )
77- }
62+ require .Nil (t , err )
7863
7964 err = s .Close ()
80- if err != nil {
81- t .Fatal (err )
82- }
65+ require .Nil (t , err )
8366}
8467
8568// TestConcurrent tests concurrent reads and writes from a connection
8669func TestConcurrent (t * testing.T ) {
8770 t .Parallel ()
8871 var (
8972 s = & handler {Upgraded : make (chan struct {})}
90- d = wstest .NewDialer (s , nil )
73+ d = wstest .NewDialer (s )
9174 count = 20
9275 )
9376
9477 c , _ , err := d .Dial ("ws://example.org/ws" , nil )
95- if err != nil {
96- t .Fatalf ("Failed connecting to s: %s" , err )
97- }
78+ require .Nil (t , err )
9879
9980 <- s .Upgraded
10081
10182 for _ , pair := range []struct { src , dst * websocket.Conn }{{s .Conn , c }, {c , s .Conn }} {
10283 go func () {
10384 for i := 0 ; i < count ; i ++ {
104- pair .src .WriteJSON (i )
85+ err := pair .src .WriteJSON (i )
86+ require .Nil (t , err )
10587 }
10688 }()
10789
10890 received := make ([]bool , count )
10991
11092 for i := 0 ; i < count ; i ++ {
11193 var j int
112- pair .dst .ReadJSON (& j )
94+ err := pair .dst .ReadJSON (& j )
95+ require .Nil (t , err )
11396
11497 received [j ] = true
11598 }
11699
117- missing := []int {}
100+ var missing []int
118101
119102 for i := range received {
120103 if ! received [i ] {
121104 missing = append (missing , i )
122105 }
123106 }
124- if len (missing ) > 0 {
125- t .Errorf ("%q -> %q: Did not received: %q" , pair .src .LocalAddr (), pair .dst .LocalAddr (), missing )
126- }
107+ assert .Equal (t , 0 , len (missing ), "%q -> %q: Did not received: %q" , pair .src .LocalAddr (), pair .dst .LocalAddr (), missing )
127108 }
128109
129110 err = c .Close ()
130- if err != nil {
131- t .Fatal (err )
132- }
111+ require .Nil (t , err )
133112
134113 err = s .Close ()
135- if err != nil {
136- t .Fatal (err )
137- }
114+ require .Nil (t , err )
138115}
139116
140117func TestBadAddress (t * testing.T ) {
@@ -157,38 +134,30 @@ func TestBadAddress(t *testing.T) {
157134 for _ , tt := range tests {
158135 t .Run (tt .url , func (t * testing.T ) {
159136 s := & handler {Upgraded : make (chan struct {})}
160- d := wstest .NewDialer (s , nil )
137+ d := wstest .NewDialer (s )
161138 c , resp , err := d .Dial (tt .url , nil )
162- if c != nil {
163- t .Errorf ("d = %T, want nil" , c )
164- }
165- if err == nil {
166- t .Error ("opError is nil" )
167- }
139+ assert .Nil (t , c )
140+ assert .NotNil (t , err )
168141 if tt .code != 0 {
169- if got , want := resp .StatusCode , tt .code ; got != want {
170- t .Errorf ("resp.StatusCode = %q, want %q" , got , want )
171- }
142+ assert .Equal (t , tt .code , resp .StatusCode )
172143 }
173144
174145 err = s .Close ()
175- if err != nil {
176- t .Fatal (err )
177- }
146+ require .Nil (t , err )
178147 })
179148 }
180149}
181150
151+ const deadlineExceeded = "deadline exceeded"
152+
182153// TestConnectDeadline tests connection deadlines
183154func TestDeadlines (t * testing.T ) {
184155 t .Parallel ()
185156 h := & handler {Upgraded : make (chan struct {})}
186- d := wstest .NewDialer (h , nil )
157+ d := wstest .NewDialer (h )
187158
188159 c , _ , err := d .Dial ("ws://example.org/ws" , nil )
189- if err != nil {
190- t .Fatalf ("Failed connecting to h: %q" , err )
191- }
160+ require .Nil (t , err )
192161
193162 <- h .Upgraded
194163
@@ -198,27 +167,20 @@ func TestDeadlines(t *testing.T) {
198167
199168 // set the deadline to now, and test for timeout
200169 pair .dst .SetReadDeadline (time .Now ())
201- err = pair .dst .ReadJSON (i )
202- if got , want := err .Error (), context .DeadlineExceeded .Error (); ! strings .Contains (got , want ) {
203- t .Errorf ("err = %q, not conains %q" , got , want )
204- }
205- err = pair .dst .ReadJSON (i )
206- if got , want := err .Error (), context .DeadlineExceeded .Error (); ! strings .Contains (got , want ) {
207- t .Errorf ("err = %q, not conains %q" , got , want )
208- }
170+ err = pair .dst .ReadJSON (& i )
171+ assert .Contains (t , err .Error (), deadlineExceeded )
209172
210- pair .src .WriteJSON (1 )
211- err = pair .dst .ReadJSON (i )
212- if got , want := err .Error (), context .DeadlineExceeded .Error (); ! strings .Contains (got , want ) {
213- t .Errorf ("err = %q, not conains %q" , got , want )
214- }
173+ err = pair .dst .ReadJSON (& i )
174+ assert .Contains (t , err .Error (), deadlineExceeded )
175+
176+ go pair .src .WriteJSON (1 )
177+ err = pair .dst .ReadJSON (& i )
178+ assert .Contains (t , err .Error (), deadlineExceeded )
215179
216180 // even after updating the deadline, should get an error
217181 pair .dst .SetReadDeadline (time .Now ().Add (time .Second ))
218- err = pair .dst .ReadJSON (i )
219- if got , want := err .Error (), context .DeadlineExceeded .Error (); ! strings .Contains (got , want ) {
220- t .Errorf ("err = %q, not conains %q" , got , want )
221- }
182+ err = pair .dst .ReadJSON (& i )
183+ assert .Contains (t , err .Error (), deadlineExceeded )
222184 }
223185}
224186
@@ -229,42 +191,35 @@ func TestConnectDeadline(t *testing.T) {
229191 tests := []struct {
230192 path string
231193 timeout time.Duration
232- err error
194+ wantErr bool
233195 }{
234196 {
235- "/ws/delay" ,
236- time .Millisecond ,
237- context . DeadlineExceeded ,
197+ path : "/ws/delay" ,
198+ timeout : time .Millisecond ,
199+ wantErr : true ,
238200 },
239201 {
240- "/ws" ,
241- time .Second ,
242- nil ,
202+ path : "/ws" ,
203+ timeout : time .Second ,
243204 },
244205 }
245206
246207 for _ , tt := range tests {
247208 t .Run (fmt .Sprintf ("%s/%s" , tt .path , tt .timeout ), func (t * testing.T ) {
248209 s := & handler {Upgraded : make (chan struct {})}
249- d := wstest .NewDialer (s , nil )
210+ d := wstest .NewDialer (s )
250211 d .HandshakeTimeout = tt .timeout
251212 _ , _ , err := d .Dial ("ws://example.org" + tt .path , nil )
252- if tt .err == nil {
253- if err != nil {
254- t .Errorf ("err = %q, want nil" , err )
255- }
256- } else {
257- if got , want := err .(* net.OpError ).Err , tt .err ; got != want {
258- t .Errorf ("err = %q, want %q" , got , want )
259- }
213+ if tt .wantErr {
214+ assert .NotNil (t , err )
215+ return
260216 }
261217
262- if tt .err == nil {
263- select {
264- case <- s .Upgraded :
265- case <- time .After (time .Second ):
266- t .Fatal ("connection was not upgraded after 1s" )
267- }
218+ assert .Nil (t , err )
219+ select {
220+ case <- s .Upgraded :
221+ case <- time .After (time .Second ):
222+ t .Fatal ("connection was not upgraded after 1s" )
268223 }
269224 })
270225 }
@@ -294,12 +249,12 @@ func (s *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
294249}
295250
296251func (s * handler ) connect (w http.ResponseWriter , r * http.Request ) {
252+ defer close (s .Upgraded )
297253 var err error
298254 s .Conn , err = s .upgrader .Upgrade (w , r , nil )
299255 if err != nil {
300256 return
301257 }
302- close (s .Upgraded )
303258}
304259
305260func (s * handler ) Close () error {
0 commit comments