diff --git a/README.md b/README.md index 755ab440..e65d9e75 100644 --- a/README.md +++ b/README.md @@ -162,7 +162,7 @@ The full list of configuration parameters is available at the [documentation](ht | `pipeline.data_inspection.inspector.models` | List of models to use for data inspection (e.g., anomaly detection). | Array of model definitions (e.g., `{"model": "ZScoreDetector", "module": "streamad.model", "model_args": {"is_global": false}}`)| | `pipeline.data_inspection.inspector.anomaly_threshold` | Threshold for classifying an observation as an anomaly. | `0.01` | | `pipeline.data_analysis.detector.model` | Model to use for data analysis (e.g., DGA detection). | `rf` (Random Forest) option: `XGBoost` | -| `pipeline.data_analysis.detector.checksum` | Checksum for the model file to ensure integrity. | `ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06` | +| `pipeline.data_analysis.detector.checksum` | Checksum for the model file to ensure integrity. | `021af76b2385ddbc76f6e3ad10feb0bb081f9cf05cff2e52333e31040bbf36cc` | | `pipeline.data_analysis.detector.base_url` | Base URL for downloading the model if not present locally. | `https://heibox.uni-heidelberg.de/d/0d5cbcbe16cd46a58021/` |
diff --git a/config.yaml b/config.yaml index f3e06afb..87f5ee5a 100644 --- a/config.yaml +++ b/config.yaml @@ -59,7 +59,7 @@ pipeline: data_analysis: detector: model: rf # XGBoost - checksum: ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06 + checksum: 021af76b2385ddbc76f6e3ad10feb0bb081f9cf05cff2e52333e31040bbf36cc base_url: https://heibox.uni-heidelberg.de/d/0d5cbcbe16cd46a58021/ threshold: 0.5 diff --git a/requirements/requirements.inspector.txt b/requirements/requirements.inspector.txt index b1c97b21..371b8d74 100644 --- a/requirements/requirements.inspector.txt +++ b/requirements/requirements.inspector.txt @@ -1,6 +1,7 @@ PyYAML~=6.0.1 confluent-kafka~=2.4.0 colorlog~=6.8.2 +scipy~=1.12.0 streamad~=0.3.1 numpy~=1.26.4 marshmallow_dataclass~=8.7.1 diff --git a/requirements/requirements.train.txt b/requirements/requirements.train.txt index 2018a86e..550eda81 100644 --- a/requirements/requirements.train.txt +++ b/requirements/requirements.train.txt @@ -1,7 +1,7 @@ numpy xgboost scikit-learn~=1.5.2 -scipy +scipy~=1.12.0 torch pyarrow polars diff --git a/src/detector/detector.py b/src/detector/detector.py index 84393bfc..3ff324fe 100644 --- a/src/detector/detector.py +++ b/src/detector/detector.py @@ -56,12 +56,15 @@ def __init__(self) -> None: self.begin_timestamp = None self.end_timestamp = None self.model_path = os.path.join( - tempfile.gettempdir(), f"{MODEL}_{CHECKSUM}.pickle" + tempfile.gettempdir(), f"{MODEL}_{CHECKSUM}_model.pickle" + ) + self.scaler_path = os.path.join( + tempfile.gettempdir(), f"{MODEL}_{CHECKSUM}_scaler.pickle" ) self.kafka_consume_handler = ExactlyOnceKafkaConsumeHandler(CONSUME_TOPIC) - self.model = self._get_model() + self.model, self.scaler = self._get_model() # databases self.suspicious_batch_timestamps = ClickHouseKafkaSender( @@ -159,14 +162,28 @@ def _get_model(self): logger.info(f"Get model: {MODEL} with checksum {CHECKSUM}") if not os.path.isfile(self.model_path): response = requests.get( - f"{MODEL_BASE_URL}/files/?p=%2F{MODEL}_{CHECKSUM}.pickle&dl=1" + f"{MODEL_BASE_URL}/files/?p=%2F{MODEL}/{CHECKSUM}/{MODEL}.pickle&dl=1" + ) + logger.info( + f"{MODEL_BASE_URL}/files/?p=%2F{MODEL}/{CHECKSUM}/{MODEL}.pickle&dl=1" ) - logger.info(f"{MODEL_BASE_URL}/files/?p=%2F{MODEL}_{CHECKSUM}.pickle&dl=1") response.raise_for_status() with open(self.model_path, "wb") as f: f.write(response.content) + if not os.path.isfile(self.scaler_path): + response = requests.get( + f"{MODEL_BASE_URL}/files/?p=%2F{MODEL}/{CHECKSUM}/scaler.pickle&dl=1" + ) + logger.info( + f"{MODEL_BASE_URL}/files/?p=%2F{MODEL}/{CHECKSUM}/scaler.pickle&dl=1" + ) + response.raise_for_status() + + with open(self.scaler_path, "wb") as f: + f.write(response.content) + # Check file sha256 local_checksum = self._sha256sum(self.model_path) @@ -181,7 +198,10 @@ def _get_model(self): with open(self.model_path, "rb") as input_file: clf = pickle.load(input_file) - return clf + with open(self.scaler_path, "rb") as input_file: + scaler = pickle.load(input_file) + + return clf, scaler def clear_data(self): """Clears the data in the internal data structures.""" @@ -237,12 +257,6 @@ def calculate_counts(level: str) -> np.ndarray: for level, level_value in levels.items() } - logger.debug("Get frequency standard deviation, median, variance, and mean.") - freq_std = np.std(freq) - freq_var = np.var(freq) - freq_median = np.median(freq) - freq_mean = np.mean(freq) - logger.debug( "Get standard deviation, median, variance, and mean for full, alpha, special, and numeric count." ) @@ -268,16 +282,9 @@ def calculate_entropy(s: str) -> float: # Final feature aggregation as a NumPy array basic_features = np.array([label_length, label_max, label_average]) - freq_features = np.array([freq_std, freq_var, freq_median, freq_mean]) # Flatten counts and stats for each level into arrays level_features = np.hstack([counts[level] for level in levels.keys()]) - stats_features = np.array( - [stats[f"{level}_std"] for level in levels.keys()] - + [stats[f"{level}_var"] for level in levels.keys()] - + [stats[f"{level}_median"] for level in levels.keys()] - + [stats[f"{level}_mean"] for level in levels.keys()] - ) # Entropy features entropy_features = np.array([entropy[level] for level in levels.keys()]) @@ -287,9 +294,9 @@ def calculate_entropy(s: str) -> float: [ basic_features, freq, - freq_features, + # freq_features, level_features, - stats_features, + # stats_features, entropy_features, ] ) @@ -304,7 +311,7 @@ def detect(self) -> None: # pragma: no cover for message in self.messages: # TODO predict all messages y_pred = self.model.predict_proba( - self._get_features(message["domain_name"]) + self.scaler.transform(self._get_features(message["domain_name"])) ) logger.info(f"Prediction: {y_pred}") if np.argmax(y_pred, axis=1) == 1 and y_pred[0][1] > THRESHOLD: diff --git a/tests/detector/test_detector.py b/tests/detector/test_detector.py index 57963f35..c240ccb7 100644 --- a/tests/detector/test_detector.py +++ b/tests/detector/test_detector.py @@ -10,6 +10,17 @@ from src.base.data_classes.batch import Batch from src.detector.detector import Detector, WrongChecksum +DEFAULT_DATA = { + "client_ip": "192.168.0.167", + "dns_ip": "10.10.0.10", + "response_ip": "252.79.173.222", + "timestamp": "", + "status": "NXDOMAIN", + "domain_name": "IF356gEnJHPdRxnkDId4RDUSgtqxx9I+pZ5n1V53MdghOGQncZWAQgAPRx3kswi.750jnH6iSqmiAAeyDUMX0W6SHGpVsVsKSX8ZkKYDs0GFh/9qU5N9cwl00XSD8ID.NNhBdHZIb7nc0hDQXFPlABDLbRwkJS38LZ8RMX4yUmR2Mb6YqTTJBn+nUcB9P+v.jBQdwdS53XV9W2p1BHjh.16.f.1.6037.tunnel.example.org", + "record_type": "A", + "size": "100b", +} + class TestSha256Sum(unittest.TestCase): @patch("src.detector.detector.ExactlyOnceKafkaConsumeHandler") @@ -45,7 +56,7 @@ def setUp(self): @patch( "src.detector.detector.CHECKSUM", - "ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06", + "021af76b2385ddbc76f6e3ad10feb0bb081f9cf05cff2e52333e31040bbf36cc", ) @patch("src.detector.detector.MODEL", "rf") @patch( @@ -70,7 +81,7 @@ def setUp(self): @patch( "src.detector.detector.CHECKSUM", - "ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06", + "021af76b2385ddbc76f6e3ad10feb0bb081f9cf05cff2e52333e31040bbf36cc", ) @patch("src.detector.detector.MODEL", "rf") @patch( @@ -246,7 +257,7 @@ def setUp(self): @patch( "src.detector.detector.CHECKSUM", - "ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06", + "021af76b2385ddbc76f6e3ad10feb0bb081f9cf05cff2e52333e31040bbf36cc", ) @patch("src.detector.detector.MODEL", "rf") @patch( @@ -265,13 +276,13 @@ def test_save_warning(self, mock_clickhouse, mock_kafka_consume_handler): "request": "google.de", "probability": 0.8765, "model": "rf", - "sha256": "ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06", + "sha256": "021af76b2385ddbc76f6e3ad10feb0bb081f9cf05cff2e52333e31040bbf36cc", }, { "request": "request.de", "probability": 0.12388, "model": "rf", - "sha256": "ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06", + "sha256": "021af76b2385ddbc76f6e3ad10feb0bb081f9cf05cff2e52333e31040bbf36cc", }, ] sut.messages = [{"logline_id": "test_id"}] @@ -285,7 +296,27 @@ def test_save_warning(self, mock_clickhouse, mock_kafka_consume_handler): @patch( "src.detector.detector.CHECKSUM", - "ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06", + "021af76b2385ddbc76f6e3ad10feb0bb081f9cf05cff2e52333e31040bbf36cc", + ) + @patch("src.detector.detector.MODEL", "rf") + @patch( + "src.detector.detector.MODEL_BASE_URL", + "https://heibox.uni-heidelberg.de/d/0d5cbcbe16cd46a58021/", + ) + @patch("src.detector.detector.ExactlyOnceKafkaConsumeHandler") + @patch("src.detector.detector.ClickHouseKafkaSender") + def test_prediction(self, mock_clickhouse, mock_kafka_consume_handler): + mock_kafka_consume_handler_instance = MagicMock() + mock_kafka_consume_handler.return_value = mock_kafka_consume_handler_instance + + sut = Detector() + sut.messages = [DEFAULT_DATA] + sut.detect() + self.assertNotEqual([], sut.warnings) + + @patch( + "src.detector.detector.CHECKSUM", + "021af76b2385ddbc76f6e3ad10feb0bb081f9cf05cff2e52333e31040bbf36cc", ) @patch("src.detector.detector.MODEL", "rf") @patch( @@ -309,7 +340,7 @@ def test_save_empty_warning(self, mock_clickhouse, mock_kafka_consume_handler): @patch( "src.detector.detector.CHECKSUM", - "ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06", + "021af76b2385ddbc76f6e3ad10feb0bb081f9cf05cff2e52333e31040bbf36cc", ) @patch("src.detector.detector.MODEL", "rf") @patch( @@ -328,7 +359,7 @@ def test_save_warning_error(self, mock_clickhouse, mock_kafka_consume_handler): "request": "request.de", "probability": "INVALID", "model": "rf", - "sha256": "ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06", + "sha256": "021af76b2385ddbc76f6e3ad10feb0bb081f9cf05cff2e52333e31040bbf36cc", } ] with self.assertRaises(Exception):