Skip to content

Commit

Permalink
its over chris, you have thrown the test in mt doom
Browse files Browse the repository at this point in the history
  • Loading branch information
resoluteCoder committed Sep 19, 2023
1 parent 974930f commit 81dcf13
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 98 deletions.
92 changes: 50 additions & 42 deletions pkg/controlsvc/controlsvc.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"io/fs"
"net"
"os"
"reflect"
"runtime"
"strings"
"sync"
Expand Down Expand Up @@ -362,6 +363,50 @@ func (s *Server) RunControlSession(conn net.Conn) {
}
}

func (s *Server) ConnectionListener(ctx context.Context, listener net.Listener) {
for {
if ctx.Err() != nil {
return
}
conn, err := listener.Accept()
if err != nil {
if !strings.HasSuffix(err.Error(), "normal close") {
s.nc.GetLogger().Error("Error accepting connection: %s\n", err)
}

continue
}
go s.SetupConnection(conn)
}
}

func (s *Server) SetupConnection(conn net.Conn) {
defer conn.Close()
tlsConn, ok := conn.(*tls.Conn)
if ok {
// Explicitly run server TLS handshake so we can deal with timeout and errors here
err := conn.SetDeadline(time.Now().Add(10 * time.Second))
if err != nil {
s.nc.GetLogger().Error("Error setting timeout: %s. Closing socket.\n", err)

return
}
err = tlsConn.Handshake()
if err != nil {
s.nc.GetLogger().Error("TLS handshake error: %s. Closing socket.\n", err)

return
}
err = conn.SetDeadline(time.Time{})
if err != nil {
s.nc.GetLogger().Error("Error clearing timeout: %s. Closing socket.\n", err)

return
}
}
s.RunControlSession(conn)
}

// RunControlSvc runs the main accept loop of the control service.
func (s *Server) RunControlSvc(ctx context.Context, service string, tlscfg *tls.Config,
unixSocket string, unixSocketPermissions os.FileMode, tcpListen string, tcptls *tls.Config,
Expand Down Expand Up @@ -395,7 +440,7 @@ func (s *Server) RunControlSvc(ctx context.Context, service string, tlscfg *tls.
} else {
tli = nil
}
var li net.Listener
var li *netceptor.Listener
if service != "" {
li, err = s.nc.ListenAndAdvertise(service, tlscfg, map[string]string{
"type": "Control Service",
Expand Down Expand Up @@ -424,48 +469,11 @@ func (s *Server) RunControlSvc(ctx context.Context, service string, tlscfg *tls.
}
}()
for _, listener := range []net.Listener{uli, tli, li} {
if reflect.ValueOf(listener).IsNil() {
continue
}
if listener != nil {
go func(listener net.Listener) {
for {
if ctx.Err() != nil {
return
}
conn, err := listener.Accept()
if err != nil {
if !strings.HasSuffix(err.Error(), "normal close") {
s.nc.GetLogger().Error("Error accepting connection: %s\n", err)
}

continue
}
go func() {
defer conn.Close()
tlsConn, ok := conn.(*tls.Conn)
if ok {
// Explicitly run server TLS handshake so we can deal with timeout and errors here
err = conn.SetDeadline(time.Now().Add(10 * time.Second))
if err != nil {
s.nc.GetLogger().Error("Error setting timeout: %s. Closing socket.\n", err)

return
}
err = tlsConn.Handshake()
if err != nil {
s.nc.GetLogger().Error("TLS handshake error: %s. Closing socket.\n", err)

return
}
err = conn.SetDeadline(time.Time{})
if err != nil {
s.nc.GetLogger().Error("Error clearing timeout: %s. Closing socket.\n", err)

return
}
}
s.RunControlSession(conn)
}()
}
}(listener)
go s.ConnectionListener(ctx, listener)
}
}

Expand Down
160 changes: 104 additions & 56 deletions pkg/controlsvc/controlsvc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,106 @@ import (
"net"
"os"
"testing"
"time"

"github.com/ansible/receptor/pkg/controlsvc"
"github.com/ansible/receptor/pkg/controlsvc/mock_controlsvc"
"github.com/ansible/receptor/pkg/logger"
"github.com/golang/mock/gomock"
)

func TestConnectionListener(t *testing.T) {
ctrl := gomock.NewController(t)
mockNetceptor := mock_controlsvc.NewMockNetceptorForControlsvc(ctrl)
mockListener := mock_controlsvc.NewMockListener(ctrl)
logger := logger.NewReceptorLogger("")

connectionListenerTestCases := []struct {
name string
expectedError bool
expectedCalls func(context.CancelFunc)
}{
{
name: "return from context error",
expectedError: true,
expectedCalls: func(ctx context.CancelFunc) {},
},
{
name: "error accepting connection",
expectedError: false,
expectedCalls: func(ctxCancel context.CancelFunc) {
mockListener.EXPECT().Accept().DoAndReturn(func() (net.Conn, error) {
ctxCancel()
return nil, errors.New("terminated")
})
mockNetceptor.EXPECT().GetLogger().Return(logger)
},
},
}

for _, testCase := range connectionListenerTestCases {
t.Run(testCase.name, func(t *testing.T) {
ctx, ctxCancel := context.WithCancel(context.Background())
defer ctxCancel()

testCase.expectedCalls(ctxCancel)
s := controlsvc.New(false, mockNetceptor)

if testCase.expectedError {
ctxCancel()
}

s.ConnectionListener(ctx, mockListener)
})
}

Check failure on line 62 in pkg/controlsvc/controlsvc_test.go

View workflow job for this annotation

GitHub Actions / lint-receptor

File is not `gofumpt`-ed (gofumpt)
}

func TestSetupConnection(t *testing.T) {
ctrl := gomock.NewController(t)
mockNetceptor := mock_controlsvc.NewMockNetceptorForControlsvc(ctrl)
mockConn := mock_controlsvc.NewMockConn(ctrl)
logger := logger.NewReceptorLogger("")

setupConnectionTestCases := []struct {
name string
expectedError bool
expectedCalls func()
}{
{
name: "log error - setting timeout",
expectedError: true,
expectedCalls: func() {
mockConn.EXPECT().SetDeadline(gomock.Any()).Return(errors.New("terminated"))
mockNetceptor.EXPECT().GetLogger().Return(logger)
mockConn.EXPECT().Close()
},
},
{
name: "log error - tls handshake",
expectedError: true,
expectedCalls: func() {
mockConn.EXPECT().SetDeadline(gomock.Any()).Return(nil)
mockNetceptor.EXPECT().GetLogger().Return(logger)
mockConn.EXPECT().Close().AnyTimes()
},
},
}

for _, testCase := range setupConnectionTestCases {
t.Run(testCase.name, func(t *testing.T) {
testCase.expectedCalls()
s := controlsvc.New(false, mockNetceptor)
tlsConn := tls.Client(mockConn, &tls.Config{})
s.SetupConnection(tlsConn)
})
}
}

func TestRunControlSvc(t *testing.T) {
ctrl := gomock.NewController(t)
mockNetceptor := mock_controlsvc.NewMockNetceptorForControlsvc(ctrl)
mockUnix := mock_controlsvc.NewMockUtiler(ctrl)
mockNet := mock_controlsvc.NewMockNeter(ctrl)
// mockListener := mock_controlsvc.NewMockListener(ctrl)
// logger := logger.NewReceptorLogger("")

runControlSvcTestCases := []struct {
name string
Expand Down Expand Up @@ -77,19 +162,6 @@ func TestRunControlSvc(t *testing.T) {
"tcpListen": "",
},
},
// {
// name: "idk",
// expectedError: "",
// expectedCalls: func() {
// mockNet.EXPECT().Listen(gomock.Any(), gomock.Any()).Return(mockListener, nil)
// mockNetceptor.EXPECT().GetLogger().Return(logger)
// },
// listeners: map[string]string{
// "service": "",
// "unixSocket": "",
// "tcpListen": "tcp listener",
// },
// },
}

for _, testCase := range runControlSvcTestCases {
Expand All @@ -102,40 +174,12 @@ func TestRunControlSvc(t *testing.T) {
err := s.RunControlSvc(context.Background(), testCase.listeners["service"], &tls.Config{}, testCase.listeners["unixSocket"], os.FileMode(0o600), testCase.listeners["tcpListen"], &tls.Config{})

if err == nil || err.Error() != testCase.expectedError {
t.Errorf("expected error %s, got %s", testCase.expectedError, err.Error())
t.Errorf("expected error %s, got %v", testCase.expectedError, err)
}
})
}
}

func TestRunControlSvcOld(t *testing.T) {
// ctrl := gomock.NewController(t)
// defer ctrl.Finish()

// mock_netceptor := mock_controlsvc.NewMockNetceptorForControlsvc(ctrl)
// s := controlsvc.New(false, mock_netceptor)
// mock_unix := mock_controlsvc.NewMockUtiler(ctrl)
// s.SetServerUtils(mock_unix)

// mock_net_listener := mock_controlsvc.NewMockListener(ctrl)
// mock_unix.EXPECT().UnixSocketListen(gomock.Any(), gomock.Any()).Return(mock_net_listener, nil, nil)

// newCtx, ctxCancel := context.WithTimeout(context.Background(), time.Millisecond*1)
// defer ctxCancel()

// logger := logger.NewReceptorLogger("test")
// mock_net_listener.EXPECT().Accept().Return(nil, errors.New("blargh"))
// // mock_net_listener.EXPECT().Close()
// mock_netceptor.EXPECT().GetLogger().Return(logger)
// err := s.RunControlSvc(newCtx, "", &tls.Config{}, "unixSocket", os.FileMode(0o600), "", &tls.Config{})
// errorString := "Error accepting connection: blargh"
// fmt.Println(err, errorString)
// if err == nil || err.Error() != errorString {
// t.Errorf("expected error: %+v, got: %+v", errorString, err.Error())
// }

}

func TestSockControlRemoteAddr(t *testing.T) {
ctrl := gomock.NewController(t)

Expand Down Expand Up @@ -223,7 +267,7 @@ func TestSockControlBridgeConn(t *testing.T) {
name: "with message and error",
message: "message",
expectedCalls: func() {
mockCon.EXPECT().Write(gomock.Any()).Return(0, errors.New("blargh"))
mockCon.EXPECT().Write(gomock.Any()).Return(0, errors.New("terminated"))
},
},
}
Expand All @@ -236,7 +280,7 @@ func TestSockControlBridgeConn(t *testing.T) {
if testCase.message == "" && err != nil {
t.Errorf("should be nil")
}
if testCase.message != "" && err.Error() != "blargh" {
if testCase.message != "" && err.Error() != "terminated" {
t.Errorf("stuff %v", err)
}
})
Expand Down Expand Up @@ -322,7 +366,7 @@ func TestSockControlWriteToConn(t *testing.T) {
errorMessage string
}{
{
name: "without message and error",
name: "without message and with error",
message: "",
expectedCalls: func() {
mockCon.EXPECT().Write(gomock.Any()).Return(0, errors.New("write to conn chan error"))
Expand All @@ -331,7 +375,7 @@ func TestSockControlWriteToConn(t *testing.T) {
errorMessage: "write to conn chan error",
},
{
name: "with message and error",
name: "with message and with error",
message: "message",
expectedCalls: func() {
mockCon.EXPECT().Write(gomock.Any()).Return(0, errors.New("write to conn write message error"))
Expand All @@ -354,15 +398,11 @@ func TestSockControlWriteToConn(t *testing.T) {
t.Run(testCase.name, func(t *testing.T) {
testCase.expectedCalls()
c := make(chan []byte)
go func() {
go func(c chan []byte) {
c <- []byte{7}
}()
if !testCase.expectedError {
defer close(c)
}(c)

time.AfterFunc(time.Millisecond*100, func() {
close(c)
})
}
err := sockControl.WriteToConn(testCase.message, c)

if testCase.expectedError {
Expand Down Expand Up @@ -487,6 +527,14 @@ func TestRunControlSession(t *testing.T) {
mockCon.EXPECT().Close()
},
},
{
name: "logger warning - could not read in control service",
expectedCalls: func() {
mockCon.EXPECT().Write(gomock.Any()).Return(0, nil)
mockCon.EXPECT().Read(make([]byte, 1)).Return(0, errors.New("terminated"))
mockCon.EXPECT().Close()
},
},
}

for _, testCase := range runControlSessionTestCases {
Expand Down

0 comments on commit 81dcf13

Please sign in to comment.