Skip to content

Commit 81dcf13

Browse files
committed
its over chris, you have thrown the test in mt doom
1 parent 974930f commit 81dcf13

File tree

2 files changed

+154
-98
lines changed

2 files changed

+154
-98
lines changed

pkg/controlsvc/controlsvc.go

Lines changed: 50 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"io/fs"
1313
"net"
1414
"os"
15+
"reflect"
1516
"runtime"
1617
"strings"
1718
"sync"
@@ -362,6 +363,50 @@ func (s *Server) RunControlSession(conn net.Conn) {
362363
}
363364
}
364365

366+
func (s *Server) ConnectionListener(ctx context.Context, listener net.Listener) {
367+
for {
368+
if ctx.Err() != nil {
369+
return
370+
}
371+
conn, err := listener.Accept()
372+
if err != nil {
373+
if !strings.HasSuffix(err.Error(), "normal close") {
374+
s.nc.GetLogger().Error("Error accepting connection: %s\n", err)
375+
}
376+
377+
continue
378+
}
379+
go s.SetupConnection(conn)
380+
}
381+
}
382+
383+
func (s *Server) SetupConnection(conn net.Conn) {
384+
defer conn.Close()
385+
tlsConn, ok := conn.(*tls.Conn)
386+
if ok {
387+
// Explicitly run server TLS handshake so we can deal with timeout and errors here
388+
err := conn.SetDeadline(time.Now().Add(10 * time.Second))
389+
if err != nil {
390+
s.nc.GetLogger().Error("Error setting timeout: %s. Closing socket.\n", err)
391+
392+
return
393+
}
394+
err = tlsConn.Handshake()
395+
if err != nil {
396+
s.nc.GetLogger().Error("TLS handshake error: %s. Closing socket.\n", err)
397+
398+
return
399+
}
400+
err = conn.SetDeadline(time.Time{})
401+
if err != nil {
402+
s.nc.GetLogger().Error("Error clearing timeout: %s. Closing socket.\n", err)
403+
404+
return
405+
}
406+
}
407+
s.RunControlSession(conn)
408+
}
409+
365410
// RunControlSvc runs the main accept loop of the control service.
366411
func (s *Server) RunControlSvc(ctx context.Context, service string, tlscfg *tls.Config,
367412
unixSocket string, unixSocketPermissions os.FileMode, tcpListen string, tcptls *tls.Config,
@@ -395,7 +440,7 @@ func (s *Server) RunControlSvc(ctx context.Context, service string, tlscfg *tls.
395440
} else {
396441
tli = nil
397442
}
398-
var li net.Listener
443+
var li *netceptor.Listener
399444
if service != "" {
400445
li, err = s.nc.ListenAndAdvertise(service, tlscfg, map[string]string{
401446
"type": "Control Service",
@@ -424,48 +469,11 @@ func (s *Server) RunControlSvc(ctx context.Context, service string, tlscfg *tls.
424469
}
425470
}()
426471
for _, listener := range []net.Listener{uli, tli, li} {
472+
if reflect.ValueOf(listener).IsNil() {
473+
continue
474+
}
427475
if listener != nil {
428-
go func(listener net.Listener) {
429-
for {
430-
if ctx.Err() != nil {
431-
return
432-
}
433-
conn, err := listener.Accept()
434-
if err != nil {
435-
if !strings.HasSuffix(err.Error(), "normal close") {
436-
s.nc.GetLogger().Error("Error accepting connection: %s\n", err)
437-
}
438-
439-
continue
440-
}
441-
go func() {
442-
defer conn.Close()
443-
tlsConn, ok := conn.(*tls.Conn)
444-
if ok {
445-
// Explicitly run server TLS handshake so we can deal with timeout and errors here
446-
err = conn.SetDeadline(time.Now().Add(10 * time.Second))
447-
if err != nil {
448-
s.nc.GetLogger().Error("Error setting timeout: %s. Closing socket.\n", err)
449-
450-
return
451-
}
452-
err = tlsConn.Handshake()
453-
if err != nil {
454-
s.nc.GetLogger().Error("TLS handshake error: %s. Closing socket.\n", err)
455-
456-
return
457-
}
458-
err = conn.SetDeadline(time.Time{})
459-
if err != nil {
460-
s.nc.GetLogger().Error("Error clearing timeout: %s. Closing socket.\n", err)
461-
462-
return
463-
}
464-
}
465-
s.RunControlSession(conn)
466-
}()
467-
}
468-
}(listener)
476+
go s.ConnectionListener(ctx, listener)
469477
}
470478
}
471479

pkg/controlsvc/controlsvc_test.go

Lines changed: 104 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,106 @@ import (
88
"net"
99
"os"
1010
"testing"
11-
"time"
1211

1312
"github.com/ansible/receptor/pkg/controlsvc"
1413
"github.com/ansible/receptor/pkg/controlsvc/mock_controlsvc"
1514
"github.com/ansible/receptor/pkg/logger"
1615
"github.com/golang/mock/gomock"
1716
)
1817

18+
func TestConnectionListener(t *testing.T) {
19+
ctrl := gomock.NewController(t)
20+
mockNetceptor := mock_controlsvc.NewMockNetceptorForControlsvc(ctrl)
21+
mockListener := mock_controlsvc.NewMockListener(ctrl)
22+
logger := logger.NewReceptorLogger("")
23+
24+
connectionListenerTestCases := []struct {
25+
name string
26+
expectedError bool
27+
expectedCalls func(context.CancelFunc)
28+
}{
29+
{
30+
name: "return from context error",
31+
expectedError: true,
32+
expectedCalls: func(ctx context.CancelFunc) {},
33+
},
34+
{
35+
name: "error accepting connection",
36+
expectedError: false,
37+
expectedCalls: func(ctxCancel context.CancelFunc) {
38+
mockListener.EXPECT().Accept().DoAndReturn(func() (net.Conn, error) {
39+
ctxCancel()
40+
return nil, errors.New("terminated")
41+
})
42+
mockNetceptor.EXPECT().GetLogger().Return(logger)
43+
},
44+
},
45+
}
46+
47+
for _, testCase := range connectionListenerTestCases {
48+
t.Run(testCase.name, func(t *testing.T) {
49+
ctx, ctxCancel := context.WithCancel(context.Background())
50+
defer ctxCancel()
51+
52+
testCase.expectedCalls(ctxCancel)
53+
s := controlsvc.New(false, mockNetceptor)
54+
55+
if testCase.expectedError {
56+
ctxCancel()
57+
}
58+
59+
s.ConnectionListener(ctx, mockListener)
60+
})
61+
}
62+
63+
}
64+
65+
func TestSetupConnection(t *testing.T) {
66+
ctrl := gomock.NewController(t)
67+
mockNetceptor := mock_controlsvc.NewMockNetceptorForControlsvc(ctrl)
68+
mockConn := mock_controlsvc.NewMockConn(ctrl)
69+
logger := logger.NewReceptorLogger("")
70+
71+
setupConnectionTestCases := []struct {
72+
name string
73+
expectedError bool
74+
expectedCalls func()
75+
}{
76+
{
77+
name: "log error - setting timeout",
78+
expectedError: true,
79+
expectedCalls: func() {
80+
mockConn.EXPECT().SetDeadline(gomock.Any()).Return(errors.New("terminated"))
81+
mockNetceptor.EXPECT().GetLogger().Return(logger)
82+
mockConn.EXPECT().Close()
83+
},
84+
},
85+
{
86+
name: "log error - tls handshake",
87+
expectedError: true,
88+
expectedCalls: func() {
89+
mockConn.EXPECT().SetDeadline(gomock.Any()).Return(nil)
90+
mockNetceptor.EXPECT().GetLogger().Return(logger)
91+
mockConn.EXPECT().Close().AnyTimes()
92+
},
93+
},
94+
}
95+
96+
for _, testCase := range setupConnectionTestCases {
97+
t.Run(testCase.name, func(t *testing.T) {
98+
testCase.expectedCalls()
99+
s := controlsvc.New(false, mockNetceptor)
100+
tlsConn := tls.Client(mockConn, &tls.Config{})
101+
s.SetupConnection(tlsConn)
102+
})
103+
}
104+
}
105+
19106
func TestRunControlSvc(t *testing.T) {
20107
ctrl := gomock.NewController(t)
21108
mockNetceptor := mock_controlsvc.NewMockNetceptorForControlsvc(ctrl)
22109
mockUnix := mock_controlsvc.NewMockUtiler(ctrl)
23110
mockNet := mock_controlsvc.NewMockNeter(ctrl)
24-
// mockListener := mock_controlsvc.NewMockListener(ctrl)
25-
// logger := logger.NewReceptorLogger("")
26111

27112
runControlSvcTestCases := []struct {
28113
name string
@@ -77,19 +162,6 @@ func TestRunControlSvc(t *testing.T) {
77162
"tcpListen": "",
78163
},
79164
},
80-
// {
81-
// name: "idk",
82-
// expectedError: "",
83-
// expectedCalls: func() {
84-
// mockNet.EXPECT().Listen(gomock.Any(), gomock.Any()).Return(mockListener, nil)
85-
// mockNetceptor.EXPECT().GetLogger().Return(logger)
86-
// },
87-
// listeners: map[string]string{
88-
// "service": "",
89-
// "unixSocket": "",
90-
// "tcpListen": "tcp listener",
91-
// },
92-
// },
93165
}
94166

95167
for _, testCase := range runControlSvcTestCases {
@@ -102,40 +174,12 @@ func TestRunControlSvc(t *testing.T) {
102174
err := s.RunControlSvc(context.Background(), testCase.listeners["service"], &tls.Config{}, testCase.listeners["unixSocket"], os.FileMode(0o600), testCase.listeners["tcpListen"], &tls.Config{})
103175

104176
if err == nil || err.Error() != testCase.expectedError {
105-
t.Errorf("expected error %s, got %s", testCase.expectedError, err.Error())
177+
t.Errorf("expected error %s, got %v", testCase.expectedError, err)
106178
}
107179
})
108180
}
109181
}
110182

111-
func TestRunControlSvcOld(t *testing.T) {
112-
// ctrl := gomock.NewController(t)
113-
// defer ctrl.Finish()
114-
115-
// mock_netceptor := mock_controlsvc.NewMockNetceptorForControlsvc(ctrl)
116-
// s := controlsvc.New(false, mock_netceptor)
117-
// mock_unix := mock_controlsvc.NewMockUtiler(ctrl)
118-
// s.SetServerUtils(mock_unix)
119-
120-
// mock_net_listener := mock_controlsvc.NewMockListener(ctrl)
121-
// mock_unix.EXPECT().UnixSocketListen(gomock.Any(), gomock.Any()).Return(mock_net_listener, nil, nil)
122-
123-
// newCtx, ctxCancel := context.WithTimeout(context.Background(), time.Millisecond*1)
124-
// defer ctxCancel()
125-
126-
// logger := logger.NewReceptorLogger("test")
127-
// mock_net_listener.EXPECT().Accept().Return(nil, errors.New("blargh"))
128-
// // mock_net_listener.EXPECT().Close()
129-
// mock_netceptor.EXPECT().GetLogger().Return(logger)
130-
// err := s.RunControlSvc(newCtx, "", &tls.Config{}, "unixSocket", os.FileMode(0o600), "", &tls.Config{})
131-
// errorString := "Error accepting connection: blargh"
132-
// fmt.Println(err, errorString)
133-
// if err == nil || err.Error() != errorString {
134-
// t.Errorf("expected error: %+v, got: %+v", errorString, err.Error())
135-
// }
136-
137-
}
138-
139183
func TestSockControlRemoteAddr(t *testing.T) {
140184
ctrl := gomock.NewController(t)
141185

@@ -223,7 +267,7 @@ func TestSockControlBridgeConn(t *testing.T) {
223267
name: "with message and error",
224268
message: "message",
225269
expectedCalls: func() {
226-
mockCon.EXPECT().Write(gomock.Any()).Return(0, errors.New("blargh"))
270+
mockCon.EXPECT().Write(gomock.Any()).Return(0, errors.New("terminated"))
227271
},
228272
},
229273
}
@@ -236,7 +280,7 @@ func TestSockControlBridgeConn(t *testing.T) {
236280
if testCase.message == "" && err != nil {
237281
t.Errorf("should be nil")
238282
}
239-
if testCase.message != "" && err.Error() != "blargh" {
283+
if testCase.message != "" && err.Error() != "terminated" {
240284
t.Errorf("stuff %v", err)
241285
}
242286
})
@@ -322,7 +366,7 @@ func TestSockControlWriteToConn(t *testing.T) {
322366
errorMessage string
323367
}{
324368
{
325-
name: "without message and error",
369+
name: "without message and with error",
326370
message: "",
327371
expectedCalls: func() {
328372
mockCon.EXPECT().Write(gomock.Any()).Return(0, errors.New("write to conn chan error"))
@@ -331,7 +375,7 @@ func TestSockControlWriteToConn(t *testing.T) {
331375
errorMessage: "write to conn chan error",
332376
},
333377
{
334-
name: "with message and error",
378+
name: "with message and with error",
335379
message: "message",
336380
expectedCalls: func() {
337381
mockCon.EXPECT().Write(gomock.Any()).Return(0, errors.New("write to conn write message error"))
@@ -354,15 +398,11 @@ func TestSockControlWriteToConn(t *testing.T) {
354398
t.Run(testCase.name, func(t *testing.T) {
355399
testCase.expectedCalls()
356400
c := make(chan []byte)
357-
go func() {
401+
go func(c chan []byte) {
358402
c <- []byte{7}
359-
}()
360-
if !testCase.expectedError {
403+
defer close(c)
404+
}(c)
361405

362-
time.AfterFunc(time.Millisecond*100, func() {
363-
close(c)
364-
})
365-
}
366406
err := sockControl.WriteToConn(testCase.message, c)
367407

368408
if testCase.expectedError {
@@ -487,6 +527,14 @@ func TestRunControlSession(t *testing.T) {
487527
mockCon.EXPECT().Close()
488528
},
489529
},
530+
{
531+
name: "logger warning - could not read in control service",
532+
expectedCalls: func() {
533+
mockCon.EXPECT().Write(gomock.Any()).Return(0, nil)
534+
mockCon.EXPECT().Read(make([]byte, 1)).Return(0, errors.New("terminated"))
535+
mockCon.EXPECT().Close()
536+
},
537+
},
490538
}
491539

492540
for _, testCase := range runControlSessionTestCases {

0 commit comments

Comments
 (0)