From a2c5f89177678b00d0587c54056a2d041c30f5d5 Mon Sep 17 00:00:00 2001 From: resoluteCoder Date: Tue, 19 Sep 2023 17:25:10 -0500 Subject: [PATCH] Add unit tests for controlsvc --- pkg/controlsvc/connect.go | 2 +- pkg/controlsvc/controlsvc.go | 34 +++--- pkg/controlsvc/controlsvc_test.go | 169 ++++++++++++------------------ pkg/controlsvc/interfaces.go | 4 +- pkg/workceptor/controlsvc.go | 2 +- 5 files changed, 83 insertions(+), 128 deletions(-) diff --git a/pkg/controlsvc/connect.go b/pkg/controlsvc/connect.go index 13f599c1e..db3a9912f 100644 --- a/pkg/controlsvc/connect.go +++ b/pkg/controlsvc/connect.go @@ -83,7 +83,7 @@ func (c *connectCommand) ControlFunc(_ context.Context, nc NetceptorForControlCo if err != nil { return nil, err } - err = cfo.BridgeConn("Connecting\n", rc, "connected service", nc.GetLogger()) + err = cfo.BridgeConn("Connecting\n", rc, "connected service", nc.GetLogger(), &Util{}) if err != nil { return nil, err } diff --git a/pkg/controlsvc/controlsvc.go b/pkg/controlsvc/controlsvc.go index 18e591bbc..a3814eca8 100644 --- a/pkg/controlsvc/controlsvc.go +++ b/pkg/controlsvc/controlsvc.go @@ -81,16 +81,13 @@ func (t *Tls) NewListener(inner net.Listener, config *tls.Config) net.Listener { // SockControl implements the ControlFuncOperations interface that is passed back to control functions. type SockControl struct { - conn net.Conn - utils Utiler - io Copier + conn net.Conn } -func NewSockControl(conn net.Conn, utils Utiler, copier Copier) *SockControl { +// func NewSockControl(conn net.Conn, utils Utiler, copier Copier) *SockControl { +func NewSockControl(conn net.Conn) *SockControl { return &SockControl{ - conn: conn, - utils: utils, - io: copier, + conn: conn, } } @@ -110,21 +107,21 @@ func (s *SockControl) WriteMessage(message string) error { } // BridgeConn bridges the socket to another socket. -func (s *SockControl) BridgeConn(message string, bc io.ReadWriteCloser, bcName string, logger *logger.ReceptorLogger) error { +func (s *SockControl) BridgeConn(message string, bc io.ReadWriteCloser, bcName string, logger *logger.ReceptorLogger, utils Utiler) error { if err := s.WriteMessage(message); err != nil { return err } - s.utils.BridgeConns(s.conn, "control service", bc, bcName, logger) + utils.BridgeConns(s.conn, "control service", bc, bcName, logger) return nil } // ReadFromConn copies from the socket to an io.Writer, until EOF. -func (s *SockControl) ReadFromConn(message string, out io.Writer) error { +func (s *SockControl) ReadFromConn(message string, out io.Writer, io Copier) error { if err := s.WriteMessage(message); err != nil { return err } - if _, err := s.io.Copy(out, s.conn); err != nil { + if _, err := io.Copy(out, s.conn); err != nil { return err } @@ -155,10 +152,9 @@ type Server struct { nc NetceptorForControlsvc controlFuncLock sync.RWMutex controlTypes map[string]ControlCommandType - // new stuff - serverUtils Utiler - serverNet Neter - serverTls Tlser + serverUtils Utiler + serverNet Neter + serverTls Tlser } // New returns a new instance of a control service. @@ -318,7 +314,7 @@ func (s *Server) RunControlSession(conn net.Conn) { } s.controlFuncLock.RUnlock() if ct != nil { - cfo := NewSockControl(conn, &Util{}, &SocketConnIO{}) + cfo := NewSockControl(conn) var cfr map[string]interface{} var cc ControlCommand @@ -469,12 +465,10 @@ func (s *Server) RunControlSvc(ctx context.Context, service string, tlscfg *tls. } }() for _, listener := range []net.Listener{uli, tli, li} { - if reflect.ValueOf(listener).IsNil() { + if listener == nil || reflect.ValueOf(listener).IsNil() { continue } - if listener != nil { - go s.ConnectionListener(ctx, listener) - } + go s.ConnectionListener(ctx, listener) } return nil diff --git a/pkg/controlsvc/controlsvc_test.go b/pkg/controlsvc/controlsvc_test.go index 81df8b787..106d41c60 100644 --- a/pkg/controlsvc/controlsvc_test.go +++ b/pkg/controlsvc/controlsvc_test.go @@ -19,7 +19,7 @@ const ( writeToConnError = "write to conn write message err" ) -func printExpectedError(t *testing.T, err error) { +func printErrorMessage(t *testing.T, err error) { t.Errorf("expected error %s", err) } @@ -34,10 +34,6 @@ func TestConnectionListener(t *testing.T) { expectedError bool expectedCalls func(context.CancelFunc) }{ - { - name: "return from context error", - expectedError: true, - }, { name: "error accepting connection", expectedError: false, @@ -56,19 +52,11 @@ func TestConnectionListener(t *testing.T) { ctx, ctxCancel := context.WithCancel(context.Background()) defer ctxCancel() - if testCase.expectedCalls != nil { - testCase.expectedCalls(ctxCancel) - } + testCase.expectedCalls(ctxCancel) s := controlsvc.New(false, mockNetceptor) - - if testCase.expectedError { - ctxCancel() - } - s.ConnectionListener(ctx, mockListener) }) } - } func TestSetupConnection(t *testing.T) { @@ -79,12 +67,10 @@ func TestSetupConnection(t *testing.T) { setupConnectionTestCases := []struct { name string - expectedError bool expectedCalls func() }{ { - name: "log error - setting timeout", - expectedError: true, + name: "log error - setting timeout", expectedCalls: func() { mockConn.EXPECT().SetDeadline(gomock.Any()).Return(errors.New("terminated")) mockNetceptor.EXPECT().GetLogger().Return(logger) @@ -92,8 +78,7 @@ func TestSetupConnection(t *testing.T) { }, }, { - name: "log error - tls handshake", - expectedError: true, + name: "log error - tls handshake", expectedCalls: func() { mockConn.EXPECT().SetDeadline(gomock.Any()).Return(nil) mockNetceptor.EXPECT().GetLogger().Return(logger) @@ -163,6 +148,9 @@ func TestRunControlSvc(t *testing.T) { { name: "no listeners error", expectedError: "no listeners specified", + expectedCalls: func() { + // empty func for testing + }, listeners: map[string]string{ "service": "", "unixSocket": "", @@ -173,16 +161,14 @@ func TestRunControlSvc(t *testing.T) { for _, testCase := range runControlSvcTestCases { t.Run(testCase.name, func(t *testing.T) { - if testCase.expectedCalls != nil { - testCase.expectedCalls() - } + testCase.expectedCalls() s := controlsvc.New(false, mockNetceptor) s.SetServerUtils(mockUnix) s.SetServerNet(mockNet) err := s.RunControlSvc(context.Background(), testCase.listeners["service"], &tls.Config{}, testCase.listeners["unixSocket"], os.FileMode(0o600), testCase.listeners["tcpListen"], &tls.Config{}) - if err == nil || err.Error() != testCase.expectedError { + if err.Error() != testCase.expectedError { t.Errorf("expected error %s, got %v", testCase.expectedError, err) } }) @@ -194,9 +180,7 @@ func TestSockControlRemoteAddr(t *testing.T) { mockCon := mock_controlsvc.NewMockConn(ctrl) mockAddr := mock_controlsvc.NewMockAddr(ctrl) - mockUtil := mock_controlsvc.NewMockUtiler(ctrl) - mockCopier := mock_controlsvc.NewMockCopier(ctrl) - sockControl := controlsvc.NewSockControl(mockCon, mockUtil, mockCopier) + sockControl := controlsvc.NewSockControl(mockCon) localhost := "127.0.0.1" @@ -212,41 +196,51 @@ func TestSockControlRemoteAddr(t *testing.T) { func TestSockControlWriteMessage(t *testing.T) { ctrl := gomock.NewController(t) mockCon := mock_controlsvc.NewMockConn(ctrl) - mockUtil := mock_controlsvc.NewMockUtiler(ctrl) - mockCopier := mock_controlsvc.NewMockCopier(ctrl) - - sockControl := controlsvc.NewSockControl(mockCon, mockUtil, mockCopier) + sockControl := controlsvc.NewSockControl(mockCon) writeMessageTestCases := []struct { name string message string + expectedError bool expectedCalls func() }{ { - name: "without message", - message: "", + name: "pass without message", + message: "", + expectedError: false, + expectedCalls: func() { + // empty func for testing + }, }, { - name: "with message", - message: "message", + name: "fail with message", + message: "message", + expectedError: true, expectedCalls: func() { mockCon.EXPECT().Write(gomock.Any()).Return(0, errors.New("cannot write message")) }, }, + { + name: "pass with message", + message: "message", + expectedError: false, + expectedCalls: func() { + mockCon.EXPECT().Write(gomock.Any()).Return(0, nil) + }, + }, } for _, testCase := range writeMessageTestCases { t.Run(testCase.name, func(t *testing.T) { - if testCase.expectedCalls != nil { - testCase.expectedCalls() - } + testCase.expectedCalls() err := sockControl.WriteMessage(testCase.message) - if testCase.message == "" && err != nil { - t.Errorf("should be nil") + if !testCase.expectedError && err != nil { + t.Errorf("write message ran unsuccessfully %s", err) } - if testCase.message != "" && err.Error() != "cannot write message" { - t.Errorf("%s %s", testCase.name, err) + + if testCase.expectedError && err.Error() != "cannot write message" { + printErrorMessage(t, err) } }) } @@ -256,10 +250,9 @@ func TestSockControlBridgeConn(t *testing.T) { ctrl := gomock.NewController(t) mockCon := mock_controlsvc.NewMockConn(ctrl) mockUtil := mock_controlsvc.NewMockUtiler(ctrl) - mockCopier := mock_controlsvc.NewMockCopier(ctrl) logger := logger.NewReceptorLogger("") - sockControl := controlsvc.NewSockControl(mockCon, mockUtil, mockCopier) + sockControl := controlsvc.NewSockControl(mockCon) bridgeConnTestCases := []struct { name string @@ -285,13 +278,13 @@ func TestSockControlBridgeConn(t *testing.T) { for _, testCase := range bridgeConnTestCases { t.Run(testCase.name, func(t *testing.T) { testCase.expectedCalls() - err := sockControl.BridgeConn(testCase.message, mockCon, "test", logger) + err := sockControl.BridgeConn(testCase.message, mockCon, "test", logger, mockUtil) if testCase.message == "" && err != nil { - t.Errorf("should be nil") + t.Errorf("bridge conn ran unsuccessfully") } if testCase.message != "" && err.Error() != "terminated" { - t.Errorf("stuff %v", err) + t.Errorf("write message error for bridge conn %v", err) } }) } @@ -300,10 +293,9 @@ func TestSockControlBridgeConn(t *testing.T) { func TestSockControlReadFromConn(t *testing.T) { ctrl := gomock.NewController(t) mockCon := mock_controlsvc.NewMockConn(ctrl) - mockUtil := mock_controlsvc.NewMockUtiler(ctrl) mockCopier := mock_controlsvc.NewMockCopier(ctrl) - sockControl := controlsvc.NewSockControl(mockCon, mockUtil, mockCopier) + sockControl := controlsvc.NewSockControl(mockCon) bridgeConnTestCases := []struct { name string @@ -344,16 +336,14 @@ func TestSockControlReadFromConn(t *testing.T) { for _, testCase := range bridgeConnTestCases { t.Run(testCase.name, func(t *testing.T) { testCase.expectedCalls() - err := sockControl.ReadFromConn(testCase.message, mockCon) + err := sockControl.ReadFromConn(testCase.message, mockCon, mockCopier) - if testCase.expectedError { - if err == nil && err.Error() != testCase.errorMessage { - printExpectedError(t, err) - } - } else { - if err != nil { - printExpectedError(t, err) - } + if testCase.expectedError && err.Error() != testCase.errorMessage { + printErrorMessage(t, err) + } + + if !testCase.expectedError && err != nil { + printErrorMessage(t, err) } }) } @@ -362,10 +352,7 @@ func TestSockControlReadFromConn(t *testing.T) { func TestSockControlWriteToConn(t *testing.T) { ctrl := gomock.NewController(t) mockCon := mock_controlsvc.NewMockConn(ctrl) - mockUtil := mock_controlsvc.NewMockUtiler(ctrl) - mockCopier := mock_controlsvc.NewMockCopier(ctrl) - - sockControl := controlsvc.NewSockControl(mockCon, mockUtil, mockCopier) + sockControl := controlsvc.NewSockControl(mockCon) bridgeConnTestCases := []struct { name string @@ -415,14 +402,12 @@ func TestSockControlWriteToConn(t *testing.T) { err := sockControl.WriteToConn(testCase.message, c) - if testCase.expectedError { - if err == nil && err.Error() != testCase.errorMessage { - printExpectedError(t, err) - } - } else { - if err != nil { - printExpectedError(t, err) - } + if testCase.expectedError && err.Error() != testCase.errorMessage { + printErrorMessage(t, err) + } + + if !testCase.expectedError && err != nil { + printErrorMessage(t, err) } }) } @@ -431,10 +416,7 @@ func TestSockControlWriteToConn(t *testing.T) { func TestSockControlClose(t *testing.T) { ctrl := gomock.NewController(t) mockCon := mock_controlsvc.NewMockConn(ctrl) - mockUtil := mock_controlsvc.NewMockUtiler(ctrl) - mockCopier := mock_controlsvc.NewMockCopier(ctrl) - - sockControl := controlsvc.NewSockControl(mockCon, mockUtil, mockCopier) + sockControl := controlsvc.NewSockControl(mockCon) errorMessage := "cannot close connection" @@ -442,7 +424,7 @@ func TestSockControlClose(t *testing.T) { err := sockControl.Close() if err == nil && err.Error() != errorMessage { - printExpectedError(t, err) + printErrorMessage(t, err) } } @@ -451,17 +433,15 @@ func TestAddControlFunc(t *testing.T) { mockCtrlCmd := mock_controlsvc.NewMockControlCommandType(ctrl) mockNetceptor := mock_controlsvc.NewMockNetceptorForControlsvc(ctrl) controlFuncTestsCases := []struct { - name string - input string - expectedError bool - errorMessage string - testCase func(msg string, err error) + name string + input string + errorMessage string + testCase func(msg string, err error) }{ { - name: "ping command", - input: "ping", - expectedError: true, - errorMessage: "control function named ping already exists", + name: "ping command", + input: "ping", + errorMessage: "control function named ping already exists", testCase: func(msg string, err error) { if msg != err.Error() { t.Errorf("expected error: %s, received: %s", msg, err) @@ -469,9 +449,8 @@ func TestAddControlFunc(t *testing.T) { }, }, { - name: "obliterate command", - input: "obliterate", - expectedError: false, + name: "obliterate command", + input: "obliterate", testCase: func(msg string, err error) { if err != nil { t.Errorf("error should be nil. received %s", err) @@ -505,17 +484,12 @@ func TestRunControlSession(t *testing.T) { runControlSessionTestCases := []struct { name string - message string - input chan []byte expectedCalls func() - expectedError bool - errorMessage string }{ { name: "logger warning - could not close connection", expectedCalls: func() { mockCon.EXPECT().Write(gomock.Any()).Return(0, nil) - // meh mockCon.EXPECT().Read(make([]byte, 1)).Return(0, io.EOF) mockCon.EXPECT().Close().Return(errors.New("test")) mockNetceptor.EXPECT().GetLogger().Return(logger) @@ -527,7 +501,6 @@ func TestRunControlSession(t *testing.T) { mockCon.EXPECT().Write(gomock.Any()).Return(0, errors.New("test")) mockCon.EXPECT().Close() }, - errorMessage: "Could not write in control service: test", }, { name: "logger debug - control service closed", @@ -559,24 +532,18 @@ func TestRunControlSession(t *testing.T) { func TestRunControlSessionTwo(t *testing.T) { ctrl := gomock.NewController(t) - mockCon := mock_controlsvc.NewMockConn(ctrl) mockNetceptor := mock_controlsvc.NewMockNetceptorForControlsvc(ctrl) logger := logger.NewReceptorLogger("") runControlSessionTestCases := []struct { name string - message string - input chan []byte expectedCalls func() - expectedError bool - errorMessage string commandByte []byte }{ { name: "command must be a string", expectedCalls: func() { mockNetceptor.EXPECT().NodeID() - mockCon.EXPECT().Write(gomock.Any()).Return(0, nil).AnyTimes() // don't know why mockNetceptor.EXPECT().GetLogger().Return(logger).Times(4) }, commandByte: []byte("{\"command\": 0}"), @@ -585,7 +552,6 @@ func TestRunControlSessionTwo(t *testing.T) { name: "JSON did not contain a command", expectedCalls: func() { mockNetceptor.EXPECT().NodeID() - mockCon.EXPECT().Write(gomock.Any()).Return(0, nil).AnyTimes() mockNetceptor.EXPECT().GetLogger().Return(logger).Times(4) }, commandByte: []byte("{}"), @@ -594,7 +560,6 @@ func TestRunControlSessionTwo(t *testing.T) { name: "command must be a string", expectedCalls: func() { mockNetceptor.EXPECT().NodeID() - mockCon.EXPECT().Write(gomock.Any()).Return(0, nil).AnyTimes() // don't know why mockNetceptor.EXPECT().GetLogger().Return(logger).Times(4) }, commandByte: []byte("{\"command\": \"echo\"}"), @@ -603,7 +568,6 @@ func TestRunControlSessionTwo(t *testing.T) { name: "tokens", expectedCalls: func() { mockNetceptor.EXPECT().NodeID() - mockCon.EXPECT().Write(gomock.Any()).Return(0, nil).AnyTimes() // don't know why mockNetceptor.EXPECT().GetLogger().Return(logger).Times(4) }, commandByte: []byte("a b"), @@ -612,7 +576,6 @@ func TestRunControlSessionTwo(t *testing.T) { name: "control types - reload", expectedCalls: func() { mockNetceptor.EXPECT().NodeID() - mockCon.EXPECT().Write(gomock.Any()).Return(0, nil).AnyTimes() // don't know why mockNetceptor.EXPECT().GetLogger().Return(logger).Times(6) }, commandByte: []byte("{\"command\": \"reload\"}"), @@ -621,7 +584,6 @@ func TestRunControlSessionTwo(t *testing.T) { name: "control types - no ping target", expectedCalls: func() { mockNetceptor.EXPECT().NodeID() - mockCon.EXPECT().Write(gomock.Any()).Return(0, nil).AnyTimes() // don't know why mockNetceptor.EXPECT().GetLogger().Return(logger).Times(5) }, commandByte: []byte("{\"command\": \"ping\"}"), @@ -637,7 +599,6 @@ func TestRunControlSessionTwo(t *testing.T) { go func() { pipeA.Write(testCase.commandByte) pipeA.Close() - }() go func() { io.ReadAll(pipeA) diff --git a/pkg/controlsvc/interfaces.go b/pkg/controlsvc/interfaces.go index 80539a39e..4f4c1bbde 100644 --- a/pkg/controlsvc/interfaces.go +++ b/pkg/controlsvc/interfaces.go @@ -36,8 +36,8 @@ type ControlCommand interface { // ControlFuncOperations provides callbacks for control services to take actions. type ControlFuncOperations interface { - BridgeConn(message string, bc io.ReadWriteCloser, bcName string, logger *logger.ReceptorLogger) error - ReadFromConn(message string, out io.Writer) error + BridgeConn(message string, bc io.ReadWriteCloser, bcName string, logger *logger.ReceptorLogger, utils Utiler) error + ReadFromConn(message string, out io.Writer, io Copier) error WriteToConn(message string, in chan []byte) error Close() error RemoteAddr() net.Addr diff --git a/pkg/workceptor/controlsvc.go b/pkg/workceptor/controlsvc.go index bfcf5b705..452d4804f 100644 --- a/pkg/workceptor/controlsvc.go +++ b/pkg/workceptor/controlsvc.go @@ -297,7 +297,7 @@ func (c *workceptorCommand) ControlFunc(ctx context.Context, nc controlsvc.Netce return nil, err } worker.UpdateBasicStatus(WorkStatePending, "Waiting for Input Data", 0) - err = cfo.ReadFromConn(fmt.Sprintf("Work unit created with ID %s. Send stdin data and EOF.\n", worker.ID()), stdin) + err = cfo.ReadFromConn(fmt.Sprintf("Work unit created with ID %s. Send stdin data and EOF.\n", worker.ID()), stdin, &controlsvc.SocketConnIO{}) if err != nil { worker.UpdateBasicStatus(WorkStateFailed, fmt.Sprintf("Error reading input data: %s", err), 0)