diff --git a/src/detector/detector.py b/src/detector/detector.py index 687a9cd..3b7de11 100644 --- a/src/detector/detector.py +++ b/src/detector/detector.py @@ -8,6 +8,7 @@ import numpy as np import math import requests +from numpy import median sys.path.append(os.getcwd()) from src.base.utils import setup_config @@ -280,13 +281,12 @@ def detect(self) -> None: # pragma: no cover def send_warning(self) -> None: logger.info("Store alert to file.") if len(self.warnings) > 0: - overall_score = 0 - for warning in self.warnings: - overall_score += warning["probability"] - overall_score = overall_score / len(self.warnings) + overall_score = median( + [warning["probability"] for warning in self.warnings] + ) alert = {"overall_score": overall_score, "result": self.warnings} logger.info(f"Add alert: {alert}") - with open(os.path.join(tempfile.gettempdir(), "warnings.json"), "a") as f: + with open(os.path.join(tempfile.gettempdir(), "warnings.json"), "a+") as f: json.dump(alert, f) f.write("\n") else: diff --git a/tests/test_detector.py b/tests/test_detector.py index f9dbe91..b3fb72b 100644 --- a/tests/test_detector.py +++ b/tests/test_detector.py @@ -1,6 +1,8 @@ +import os +import tempfile import unittest from datetime import datetime, timedelta -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, patch, mock_open from requests import HTTPError @@ -217,6 +219,98 @@ def test_get_data_while_busy(self, mock_kafka_consume_handler, mock_logger): self.assertEqual([{"test": "test_message_2"}], sut.messages) +class TestSendWarning(unittest.TestCase): + def setUp(self): + patcher = patch("src.detector.detector.logger") + self.mock_logger = patcher.start() + self.addCleanup(patcher.stop) + + @patch( + "src.detector.detector.CHECKSUM", + "ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06", + ) + @patch("src.detector.detector.MODEL", "rf") + @patch( + "src.detector.detector.MODEL_BASE_URL", + "https://heibox.uni-heidelberg.de/d/0d5cbcbe16cd46a58021/", + ) + @patch("src.detector.detector.KafkaConsumeHandler") + def test_save_warning(self, mock_kafka_consume_handler): + mock_kafka_consume_handler_instance = MagicMock() + mock_kafka_consume_handler.return_value = mock_kafka_consume_handler_instance + + sut = Detector() + sut.warnings = [ + { + "request": "google.de", + "probability": 0.8765, + "model": "rf", + "sha256": "ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06", + }, + { + "request": "request.de", + "probability": 0.12388, + "model": "rf", + "sha256": "ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06", + }, + ] + open_mock = mock_open() + with patch("src.detector.detector.open", open_mock, create=True): + sut.send_warning() + + open_mock.assert_called_with( + os.path.join(tempfile.gettempdir(), "warnings.json"), "a+" + ) + + @patch( + "src.detector.detector.CHECKSUM", + "ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06", + ) + @patch("src.detector.detector.MODEL", "rf") + @patch( + "src.detector.detector.MODEL_BASE_URL", + "https://heibox.uni-heidelberg.de/d/0d5cbcbe16cd46a58021/", + ) + @patch("src.detector.detector.KafkaConsumeHandler") + def test_save_empty_warning(self, mock_kafka_consume_handler): + mock_kafka_consume_handler_instance = MagicMock() + mock_kafka_consume_handler.return_value = mock_kafka_consume_handler_instance + + sut = Detector() + sut.warnings = [] + open_mock = mock_open() + with patch("src.detector.detector.open", open_mock, create=True): + sut.send_warning() + + open_mock.assert_not_called() + + @patch( + "src.detector.detector.CHECKSUM", + "ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06", + ) + @patch("src.detector.detector.MODEL", "rf") + @patch( + "src.detector.detector.MODEL_BASE_URL", + "https://heibox.uni-heidelberg.de/d/0d5cbcbe16cd46a58021/", + ) + @patch("src.detector.detector.KafkaConsumeHandler") + def test_save_warning_error(self, mock_kafka_consume_handler): + mock_kafka_consume_handler_instance = MagicMock() + mock_kafka_consume_handler.return_value = mock_kafka_consume_handler_instance + + sut = Detector() + sut.warnings = [ + { + "request": "request.de", + "probability": "INVALID", + "model": "rf", + "sha256": "ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06", + } + ] + with self.assertRaises(Exception): + sut.send_warning() + + class TestClearData(unittest.TestCase): def setUp(self): patcher = patch("src.detector.detector.logger")