diff --git a/pkg/netceptor/ping_test.go b/pkg/netceptor/ping_test.go index 6536d7243..c2636598a 100644 --- a/pkg/netceptor/ping_test.go +++ b/pkg/netceptor/ping_test.go @@ -28,9 +28,20 @@ func setupTest(t *testing.T) (*gomock.Controller, *mock_netceptor.MockNetcForPin return ctrl, mockNetceptor, mockPacketConn, ctx } +func teardownTest(t *testing.T, mockNetceptor *mock_netceptor.MockNetcForPing, mockPacketConn *mock_netceptor.MockPacketConner) { + mockPacketConn.EXPECT().SetHopsToLive(gomock.Any()).Times(0) + mockPacketConn.EXPECT().Close().Times(0) + mockNetceptor.EXPECT().NewAddr(gomock.Any(), gomock.Any()).Times(0) + mockNetceptor.EXPECT().ListenPacket(gomock.Any()).Times(0) + mockPacketConn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).Times(0) + mockNetceptor.EXPECT().Context().Times(0) + mockPacketConn.EXPECT().ReadFrom(gomock.Any()).Times(0) + mockPacketConn.EXPECT().SubscribeUnreachable(gomock.Any()).Times(0) +} + func TestListenSubscribeUnreachableErr(t *testing.T) { ctrl, mockNetceptor, mockPacketConn, ctx := setupTest(t) - defer ctrl.Finish() + // defer ctrl.Finish() defer ctx.Done() mockUnreachableMessage := netceptor.UnreachableMessage{ @@ -61,11 +72,14 @@ func TestListenSubscribeUnreachableErr(t *testing.T) { if subscribeUnreachableError == nil { t.Fatal("SubscribeUnreachable expected to return error but returned nil") } + + teardownTest(t, mockNetceptor, mockPacketConn) + ctrl.Finish() } func TestCreatePing(t *testing.T) { ctrl, mockNetceptor, mockPacketConn, ctx := setupTest(t) - defer ctrl.Finish() + // defer ctrl.Finish() defer ctx.Done() // Set up the mock behaviours @@ -91,11 +105,14 @@ func TestCreatePing(t *testing.T) { if nodeID == expectedNodeID { t.Errorf("expected node ID %s, got %s", expectedNodeID, nodeID) } + + teardownTest(t, mockNetceptor, mockPacketConn) + ctrl.Finish() } func TestListenPacketErr(t *testing.T) { - ctrl, mockNetceptor, _, ctx := setupTest(t) - defer ctrl.Finish() + ctrl, mockNetceptor, mockPacketConn, ctx := setupTest(t) + // defer ctrl.Finish() defer ctx.Done() // Set up the mock behaviours @@ -104,11 +121,14 @@ func TestListenPacketErr(t *testing.T) { if listenPacketError == nil { t.Fatal("ListenPacker expected to return error but returned nil") } + + teardownTest(t, mockNetceptor, mockPacketConn) + ctrl.Finish() } func TestReadFromErr(t *testing.T) { ctrl, mockNetceptor, mockPacketConn, ctx := setupTest(t) - defer ctrl.Finish() + // defer ctrl.Finish() defer ctx.Done() // Set up the mock behaviours @@ -122,11 +142,14 @@ func TestReadFromErr(t *testing.T) { if readFromError == nil { t.Fatal("ReadFrom expected to return error but returned nil") } + + teardownTest(t, mockNetceptor, mockPacketConn) + ctrl.Finish() } func TestWriteToErr(t *testing.T) { ctrl, mockNetceptor, mockPacketConn, ctx := setupTest(t) - defer ctrl.Finish() + // defer ctrl.Finish() defer ctx.Done() // Set up the mock behaviours @@ -141,11 +164,14 @@ func TestWriteToErr(t *testing.T) { if writeToError == nil { t.Fatal("ReadFrom expected to return error but returned nil") } + + teardownTest(t, mockNetceptor, mockPacketConn) + ctrl.Finish() } func TestTimeOutErr(t *testing.T) { ctrl, mockNetceptor, mockPacketConn, ctx := setupTest(t) - defer ctrl.Finish() + // defer ctrl.Finish() defer ctx.Done() // Set up the mock behaviours @@ -161,11 +187,14 @@ func TestTimeOutErr(t *testing.T) { if err.Error() != "timeout" { t.Fatalf("Expected error to be 'timeout' but got %v", err) } + + teardownTest(t, mockNetceptor, mockPacketConn) + ctrl.Finish() } func TestUserCancel(t *testing.T) { ctrl, mockNetceptor, mockPacketConn, ctx := setupTest(t) - defer ctrl.Finish() + // defer ctrl.Finish() defer ctx.Done() // Set up the mock behaviours @@ -174,7 +203,7 @@ func TestUserCancel(t *testing.T) { mockPacketConn.EXPECT().ReadFrom(gomock.Any()).Return(0, nil, nil).Times(1) mockPacketConn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).Return(0, nil) mockNetceptor.EXPECT().Context().DoAndReturn(func() context.Context { - time.Sleep(time.Second * 1) + time.Sleep(time.Second * 2) return context.Background() }).Times(2) @@ -187,29 +216,36 @@ func TestUserCancel(t *testing.T) { if err.Error() != "user cancelled" { t.Fatalf("Expected error to be 'user cancelled' but got %v", err) } + + teardownTest(t, mockNetceptor, mockPacketConn) + ctrl.Finish() } func TestNetceptorShutdown(t *testing.T) { ctrl, mockNetceptor, mockPacketConn, ctx := setupTest(t) - defer ctrl.Finish() + // defer ctrl.Finish() defer ctx.Done() // Set up the mock behaviours mockNetceptor.EXPECT().ListenPacket(gomock.Any()).Return(mockPacketConn, nil) mockPacketConn.EXPECT().SubscribeUnreachable(gomock.Any()).Return(make(chan netceptor.UnreachableNotification)) mockPacketConn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).Return(0, nil) - // mockPacketConn.EXPECT().ReadFrom(gomock.Any()).Return(0, nil, nil).MaxTimes(1) + mockPacketConn.EXPECT().ReadFrom(gomock.Any()).Return(0, nil, nil).MaxTimes(1) mockNetceptor.EXPECT().Context().DoAndReturn(func() context.Context { newCtx, ctxCancel := context.WithCancel(context.Background()) ctxCancel() return newCtx - }).Times(1) + }).Times(2) + time.Sleep(time.Second * 1) _, _, err := netceptor.CreatePing(ctx, mockNetceptor, "target", 1) if err.Error() != "netceptor shutdown" { t.Fatalf("Expected error to be 'netceptor shutdown' but got %v", err) } + + teardownTest(t, mockNetceptor, mockPacketConn) + ctrl.Finish() } func TestCreateTraceroute(t *testing.T) { @@ -217,6 +253,7 @@ func TestCreateTraceroute(t *testing.T) { mockNetceptor := mock_netceptor.NewMockNetcForTraceroute(ctrl) ctx := context.Background() + defer ctx.Done() mockNetceptor.EXPECT().Context().Return(context.Background()) mockNetceptor.EXPECT().MaxForwardingHops().Return(byte(1)) @@ -228,4 +265,27 @@ func TestCreateTraceroute(t *testing.T) { t.Fatalf("Unexpected error %v", res.Err) } } + + ctrl.Finish() +} + +func TestCreateTracerouteError(t *testing.T) { + ctrl := gomock.NewController(t) + + mockNetceptor := mock_netceptor.NewMockNetcForTraceroute(ctrl) + ctx := context.Background() + defer ctx.Done() + + mockNetceptor.EXPECT().Context().Return(context.Background()) + mockNetceptor.EXPECT().MaxForwardingHops().Return(byte(1)) + mockNetceptor.EXPECT().Ping(ctx, "target", byte(0)).Return(time.Since(time.Now()), "target", errors.New("traceroute error")) + + result := netceptor.CreateTraceroute(ctx, mockNetceptor, "target") + for res := range result { + if res.Err.Error() != "traceroute error" { + t.Fatalf("Expected error to be 'traceroute error' but got: %v", res.Err) + } + } + + ctrl.Finish() }