diff --git a/config.yaml b/config.yaml index a145b62..0bd64b7 100644 --- a/config.yaml +++ b/config.yaml @@ -59,8 +59,8 @@ pipeline: data_analysis: detector: - model: xg # XGBoost - checksum: 21d1f40c9e186a08e9d2b400cea607f4163b39d187a9f9eca3da502b21cf3b9b + model: rf # XGBoost + checksum: ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06 base_url: https://heibox.uni-heidelberg.de/d/0d5cbcbe16cd46a58021/ threshold: 0.5 diff --git a/docker/docker-compose.external.yml b/docker/docker-compose.external.yml new file mode 100644 index 0000000..fb577b2 --- /dev/null +++ b/docker/docker-compose.external.yml @@ -0,0 +1,146 @@ +services: + + logcollector: + build: + context: .. + dockerfile: docker/dockerfiles/Dockerfile.logcollector + network: host + restart: "unless-stopped" + depends_on: + logserver: + condition: service_started + networks: + heidgaf: + ipv4_address: 172.27.0.7 + volumes: + - ./config.yaml:/usr/src/app/config.yaml + memswap_limit: 768m + deploy: + resources: + limits: + cpus: '2' + memory: 512m + reservations: + cpus: '1' + memory: 256m + + logserver: + build: + context: .. + dockerfile: docker/dockerfiles/Dockerfile.logserver + network: host + restart: "unless-stopped" + ports: + - 9998:9998 + networks: + heidgaf: + ipv4_address: 172.27.0.8 + memswap_limit: 768m + deploy: + resources: + limits: + cpus: '2' + memory: 512m + reservations: + cpus: '1' + memory: 256m + volumes: + - "${MOUNT_PATH:?MOUNT_PATH not set}:/opt/file.txt" + - ./config.yaml:/usr/src/app/config.yaml + + + inspector: + build: + context: .. + dockerfile: docker/dockerfiles/Dockerfile.inspector + network: host + restart: "unless-stopped" + depends_on: + logserver: + condition: service_started + prefilter: + condition: service_started + logcollector: + condition: service_started + networks: + heidgaf: + ipv4_address: 172.27.0.6 + volumes: + - ./config.yaml:/usr/src/app/config.yaml + deploy: + mode: "replicated" + replicas: 1 + resources: + limits: + cpus: '2' + memory: 512m + reservations: + cpus: '1' + memory: 256m + + prefilter: + build: + context: .. + dockerfile: docker/dockerfiles/Dockerfile.prefilter + network: host + restart: "unless-stopped" + depends_on: + logcollector: + condition: service_started + logserver: + condition: service_started + networks: + heidgaf: + ipv4_address: 172.27.0.9 + volumes: + - ./config.yaml:/usr/src/app/config.yaml + deploy: + mode: "replicated" + replicas: 1 + resources: + limits: + cpus: '2' + memory: 512m + reservations: + cpus: '1' + memory: 256m + + detector: + build: + context: .. + dockerfile: docker/dockerfiles/Dockerfile.detector + network: host + restart: "unless-stopped" + depends_on: + logcollector: + condition: service_started + logserver: + condition: service_started + networks: + heidgaf: + ipv4_address: 172.27.0.10 + volumes: + - ./config.yaml:/usr/src/app/config.yaml + deploy: + mode: "replicated" + replicas: 1 + resources: + limits: + cpus: '2' + memory: 512m + reservations: + cpus: '1' + memory: 256m + devices: + - driver: nvidia + count: 1 # alternatively, use `count: all` for all GPUs + capabilities: [gpu] + +networks: + heidgaf: + driver: bridge + ipam: + driver: default + config: + - subnet: 172.27.0.0/16 + gateway: 172.27.0.1 diff --git a/docker/docker-compose.kafka.yml b/docker/docker-compose.kafka.yml index b76a621..8ec9c8a 100644 --- a/docker/docker-compose.kafka.yml +++ b/docker/docker-compose.kafka.yml @@ -2,6 +2,7 @@ services: zookeeper: image: confluentinc/cp-zookeeper:7.3.2 container_name: zookeeper + restart: "unless-stopped" networks: heidgaf: ipv4_address: 172.27.0.2 diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 598070f..1ddd730 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -126,13 +126,40 @@ services: cpus: '1' memory: 256m - # detector: - # build: - # context: ./dockerfiles - # dockerfile: Dockerfile.detector - # deploy: - # mode: "replicated" - # replicas: 6 + detector: + build: + context: .. + dockerfile: docker/dockerfiles/Dockerfile.detector + network: host + restart: "unless-stopped" + depends_on: + kafka1: + condition: service_healthy + kafka2: + condition: service_healthy + kafka3: + condition: service_healthy + logcollector: + condition: service_started + logserver: + condition: service_started + networks: + heidgaf: + ipv4_address: 172.27.0.10 + deploy: + mode: "replicated" + replicas: 1 + resources: + limits: + cpus: '2' + memory: 512m + reservations: + cpus: '1' + memory: 256m + devices: + - driver: nvidia + count: 1 # alternatively, use `count: all` for all GPUs + capabilities: [gpu] networks: heidgaf: diff --git a/docker/dockerfiles/Dockerfile.detector b/docker/dockerfiles/Dockerfile.detector index df3e4fe..c0cbc51 100644 --- a/docker/dockerfiles/Dockerfile.detector +++ b/docker/dockerfiles/Dockerfile.detector @@ -1,12 +1,16 @@ -FROM python:3 +FROM python:3.11-slim-bookworm + +ENV PYTHONDONTWRITEBYTECODE=1 WORKDIR /usr/src/app COPY requirements/requirements.detector.txt ./ -RUN pip --disable-pip-version-check install --no-cache-dir --no-compile -r requirements.detector.txt +RUN pip --disable-pip-version-check install --no-cache-dir --no-compile -r requirements.detector.txt COPY src/base ./src/base COPY src/detector ./src/detector COPY config.yaml . +RUN rm -rf /root/.cache + CMD [ "python", "src/detector/detector.py"] diff --git a/docker/dockerfiles/Dockerfile.logcollector b/docker/dockerfiles/Dockerfile.logcollector index aa21abf..aaac14e 100644 --- a/docker/dockerfiles/Dockerfile.logcollector +++ b/docker/dockerfiles/Dockerfile.logcollector @@ -1,4 +1,4 @@ -FROM python:3-slim-bookworm +FROM python:3.11-slim-bookworm ENV PYTHONDONTWRITEBYTECODE=1 diff --git a/docker/dockerfiles/Dockerfile.logserver b/docker/dockerfiles/Dockerfile.logserver index 10f9212..e85df13 100644 --- a/docker/dockerfiles/Dockerfile.logserver +++ b/docker/dockerfiles/Dockerfile.logserver @@ -1,4 +1,4 @@ -FROM python:3-slim-bookworm +FROM python:3.11-slim-bookworm ENV PYTHONDONTWRITEBYTECODE=1 diff --git a/docker/dockerfiles/Dockerfile.prefilter b/docker/dockerfiles/Dockerfile.prefilter index 4bdb9e0..bb4646c 100644 --- a/docker/dockerfiles/Dockerfile.prefilter +++ b/docker/dockerfiles/Dockerfile.prefilter @@ -1,4 +1,4 @@ -FROM python:3-slim-bookworm +FROM python:3.11-slim-bookworm ENV PYTHONDONTWRITEBYTECODE=1 diff --git a/docs/pipeline.rst b/docs/pipeline.rst index 406828f..bde168e 100644 --- a/docs/pipeline.rst +++ b/docs/pipeline.rst @@ -375,7 +375,7 @@ Overview The `Inspector` stage is responsible to run time-series based anomaly detection on prefiltered batches. This stage is essentiell to reduce the load on the `Detection` stage. -Otherwise, resource complexity would increase disproportionately. +Otherwise, resource complexity increases disproportionately. Main Class ---------- @@ -393,20 +393,17 @@ The :class:`Inspector` loads the StreamAD model to perform anomaly detection. It consumes batches on the topic ``inspect``, usually produced by the ``Prefilter``. For a new batch, it derives the timestamps ``begin_timestamp`` and ``end_timestamp``. Based on time type (e.g. ``s``, ``ms``) and time range (e.g. ``5``) the sliding non-overlapping window is created. -For univariate time-series, it counts the number of occurances, whereas for multivariate, it considers the packet size. :cite:`schuppen_fanci_2018` +For univariate time-series, it counts the number of occurances, whereas for multivariate, it considers the number of occurances and packet size. :cite:`schuppen_fanci_2018` -.. note:: TODO Add mathematical explanation. - -:math:`y = x` - -An anomaly is noted when it is greater than a ``score_threshold``. In addition, we support a relative anomaly threshold. -So, if the anomaly threshold is ``0.01``, it sends anomalies for further detection, if the amount of anomlies divided by the total amount of requests in the batch is greater. +An anomaly is noted when it is greater than a ``score_threshold``. +In addition, we support a relative anomaly threshold. +So, if the anomaly threshold is ``0.01``, it sends anomalies for further detection, if the amount of anomlies divided by the total amount of requests in the batch is greater than ``0.01``. Configuration ------------- All StreamAD models are supported. This includes univariate, multivariate, and ensemble methods. -In case special arguments are desired for your environment, the ``model_args`` as a dictionary can be passed for each model. +In case special arguments are desired for your environment, the ``model_args`` as a dictionary ``dict`` can be passed for each model. Univariate models in `streamad.model`: @@ -417,7 +414,7 @@ Univariate models in `streamad.model`: - :class:`OCSVMDetector` Multivariate models in `streamad.model`: -Currently, we rely on the packet size for multivariate processing. +Currently, we rely on the packet size and number occurances for multivariate processing. - :class:`xStreamDetector` - :class:`RShashDetector` @@ -439,11 +436,13 @@ Stage 5: Detection Overview -------- -The `Detector` resembles the heart of heiDGAF. It runs pre-trained machine learning models to get a probability outcome of DNS requests. +The `Detector` resembles the heart of heiDGAF. It runs pre-trained machine learning models to get a probability outcome for the DNS requests. The pre-trained models are under the EUPL-1.2 license online available. In total, we rely on the following data sets for the pre-trained models we offer: - `CIC-Bell-DNS-2021 `_ +- `DGTA-BENCH - Domain Generation and Tunneling Algorithms for Benchmark `_ +- `DGArchive `_ Main Class ---------- @@ -456,7 +455,7 @@ Usage The :class:`Detector` consumes anomalous batches of requests. It calculates a probability score for each request, and at last, an overall score of the batch. -Such alerts are log to ``/tmp/warnings.json``. +Alerts are log to ``/tmp/warnings.json``. Configuration ------------- diff --git a/docs/training.rst b/docs/training.rst index c272d42..c81a26d 100644 --- a/docs/training.rst +++ b/docs/training.rst @@ -3,3 +3,17 @@ Training Overview ======== + +In total, we support ``RandomForest``, and ``XGBoost``. +The :class:`DetectorTraining` resembles the main function to fit any model. +After initialisation, + +It supports various data sets: + +- ``all``: Includes all available data sets +- ``cic``: Train on the CICBellDNS2021 data set +- ``dgta``: Train on the DTGA Benchmarking data set +- ``dgarchive``: Train on the DGArchive data set + +For hyperparameter optimisation we use ``optuna``. +It offers GPU support to get the best parameters. diff --git a/docs/usage.rst b/docs/usage.rst index cb8239a..8f03a7d 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -17,6 +17,14 @@ If you want to use heiDGAF, just use the provided ``docker-compose.yml`` to quic $ docker compose -f docker/docker-compose.yml up +Run container individually: + + +.. code-block:: console + + $ docker compose -f docker/docker-compose.kafka.yml up + $ docker run ... + Installation ------------ diff --git a/requirements/requirements.detector.txt b/requirements/requirements.detector.txt index 919bb40..7487f70 100644 --- a/requirements/requirements.detector.txt +++ b/requirements/requirements.detector.txt @@ -1,6 +1,8 @@ -joblib xgboost -marshmallow_dataclass~=8.7.1 +scikit-learn~=1.5.2 requests -confluent-kafka~=2.4.0 colorlog~=6.8.2 +PyYAML~=6.0.1 +colorlog~=6.8.2 +confluent-kafka~=2.4.0 +marshmallow_dataclass~=8.7.1 diff --git a/requirements/requirements.train.txt b/requirements/requirements.train.txt index 40ef6a2..9082dbf 100644 --- a/requirements/requirements.train.txt +++ b/requirements/requirements.train.txt @@ -1,5 +1,7 @@ numpy -polars -torch xgboost -scikit-learn +scikit-learn~=1.5.2 +scipy +torch +pyarrow +polars diff --git a/src/base/__init__.py b/src/base/__init__.py index 20738ea..d44faf7 100644 --- a/src/base/__init__.py +++ b/src/base/__init__.py @@ -1,6 +1,5 @@ -from typing import Optional, List, Dict +from typing import List from dataclasses import dataclass, field -import marshmallow_dataclass import marshmallow.validate import datetime diff --git a/src/base/kafka_handler.py b/src/base/kafka_handler.py index fb61692..fdbeacc 100644 --- a/src/base/kafka_handler.py +++ b/src/base/kafka_handler.py @@ -339,6 +339,9 @@ def consume_and_return_json_data(self) -> tuple[None | str, dict]: logger.error("Unknown data format.") raise ValueError + def _is_dicts(self, obj): + return isinstance(obj, list) and all(isinstance(item, dict) for item in obj) + def consume_and_return_object(self) -> tuple[None | str, Batch]: """ Calls the :meth:`consume()` method and waits for it to return data. Loads the data and converts it to a Batch @@ -359,7 +362,7 @@ def consume_and_return_object(self) -> tuple[None | str, Batch]: logger.debug("No data returned.") return None, {} except KafkaMessageFetchException as e: - logger.debug(e) + logger.warning(e) raise except KeyboardInterrupt: raise @@ -368,9 +371,16 @@ def consume_and_return_object(self) -> tuple[None | str, Batch]: json_from_message = json.loads(value) logger.debug(f"{json_from_message=}") - # TODO: Fix literal evaluation on data... eval_data: dict = ast.literal_eval(value) - eval_data["data"] = [ast.literal_eval(item) for item in eval_data.get("data")] + logger.debug("Check if data is a list of dicts") + + if self._is_dicts(eval_data.get("data")): + eval_data["data"] = eval_data.get("data") + else: + eval_data["data"] = [ + ast.literal_eval(item) for item in eval_data.get("data") + ] + eval_data: Batch = self.batch_schema.load(eval_data) if isinstance(eval_data, Batch): diff --git a/src/detector/detector.py b/src/detector/detector.py index 46b75e7..3b7de11 100644 --- a/src/detector/detector.py +++ b/src/detector/detector.py @@ -1,11 +1,14 @@ import hashlib import json import os +import pickle import sys import tempfile -import joblib +import numpy as np +import math import requests +from numpy import median sys.path.append(os.getcwd()) from src.base.utils import setup_config @@ -40,12 +43,13 @@ class Detector: """ def __init__(self) -> None: - self.topic = "detector" self.messages = [] self.warnings = [] self.begin_timestamp = None self.end_timestamp = None - self.model_path = os.path.join(tempfile.gettempdir(), f"{MODEL}_{CHECKSUM}.pkl") + self.model_path = os.path.join( + tempfile.gettempdir(), f"{MODEL}_{CHECKSUM}.pickle" + ) logger.debug(f"Initializing Detector...") logger.debug(f"Calling KafkaConsumeHandler(topic='Detector')...") @@ -114,12 +118,12 @@ def _get_model(self) -> None: Downloads model from server. If model already exists, it returns the current model. In addition, it checks the sha256 sum in case a model has been updated. """ - + logger.info(f"Get model: {MODEL} with checksum {CHECKSUM}") if not os.path.isfile(self.model_path): response = requests.get( - f"{MODEL_BASE_URL}/files/?p=%2F{MODEL}_{CHECKSUM}.pkl&dl=1" + f"{MODEL_BASE_URL}/files/?p=%2F{MODEL}_{CHECKSUM}.pickle&dl=1" ) - logger.info(f"{MODEL_BASE_URL}/files/?p=%2F{MODEL}_{CHECKSUM}.pkl&dl=1") + logger.info(f"{MODEL_BASE_URL}/files/?p=%2F{MODEL}_{CHECKSUM}.pickle&dl=1") response.raise_for_status() with open(self.model_path, "wb") as f: @@ -136,7 +140,10 @@ def _get_model(self) -> None: f"Checksum {CHECKSUM} SHA256 is not equal with new checksum {local_checksum}!" ) - return joblib.load(self.model_path) + with open(self.model_path, "rb") as input_file: + clf = pickle.load(input_file) + + return clf def clear_data(self): """Clears the data in the internal data structures.""" @@ -145,16 +152,145 @@ def clear_data(self): self.end_timestamp = None self.warnings = [] + def _get_features(self, query: str): + """Transform a dataset with new features using numpy. + + Args: + query (str): A string to process. + + Returns: + dict: Preprocessed data with computed features. + """ + + # Splitting by dots to calculate label length and max length + label_parts = query.split(".") + label_length = len(label_parts) + label_max = max(len(part) for part in label_parts) + label_average = len(query.strip(".")) + + logger.debug("Get letter frequency") + alc = "abcdefghijklmnopqrstuvwxyz" + freq = np.array( + [query.lower().count(i) / len(query) if len(query) > 0 else 0 for i in alc] + ) + + logger.debug("Get full, alpha, special, and numeric count.") + + def calculate_counts(level: str) -> np.ndarray: + if len(level) == 0: + return np.array([0, 0, 0, 0]) + + full_count = len(level) + alpha_count = sum(c.isalpha() for c in level) / full_count + numeric_count = sum(c.isdigit() for c in level) / full_count + special_count = ( + sum(not c.isalnum() and not c.isspace() for c in level) / full_count + ) + + return np.array([full_count, alpha_count, numeric_count, special_count]) + + levels = { + "fqdn": query, + "thirdleveldomain": label_parts[0] if len(label_parts) > 2 else "", + "secondleveldomain": label_parts[1] if len(label_parts) > 1 else "", + } + counts = { + level: calculate_counts(level_value) + for level, level_value in levels.items() + } + + logger.debug("Get frequency standard deviation, median, variance, and mean.") + freq_std = np.std(freq) + freq_var = np.var(freq) + freq_median = np.median(freq) + freq_mean = np.mean(freq) + + logger.debug( + "Get standard deviation, median, variance, and mean for full, alpha, special, and numeric count." + ) + stats = {} + for level, count_array in counts.items(): + stats[f"{level}_std"] = np.std(count_array) + stats[f"{level}_var"] = np.var(count_array) + stats[f"{level}_median"] = np.median(count_array) + stats[f"{level}_mean"] = np.mean(count_array) + + logger.debug("Start entropy calculation") + + def calculate_entropy(s: str) -> float: + if len(s) == 0: + return 0 + probabilities = [float(s.count(c)) / len(s) for c in dict.fromkeys(list(s))] + entropy = -sum(p * math.log(p, 2) for p in probabilities) + return entropy + + entropy = {level: calculate_entropy(value) for level, value in levels.items()} + + logger.debug("Finished entropy calculation") + + # Final feature aggregation as a NumPy array + basic_features = np.array([label_length, label_max, label_average]) + freq_features = np.array([freq_std, freq_var, freq_median, freq_mean]) + + # Flatten counts and stats for each level into arrays + level_features = np.hstack([counts[level] for level in levels.keys()]) + stats_features = np.array( + [stats[f"{level}_std"] for level in levels.keys()] + + [stats[f"{level}_var"] for level in levels.keys()] + + [stats[f"{level}_median"] for level in levels.keys()] + + [stats[f"{level}_mean"] for level in levels.keys()] + ) + + # Entropy features + entropy_features = np.array([entropy[level] for level in levels.keys()]) + + # Concatenate all features into a single numpy array + all_features = np.concatenate( + [ + basic_features, + freq, + freq_features, + level_features, + stats_features, + entropy_features, + ] + ) + + logger.debug("Finished data transformation") + + return all_features.reshape(1, -1) + def detect(self) -> None: # pragma: no cover + logger.info("Start detecting malicious requests.") for message in self.messages: - y_pred = self.model.predict(message["host_domain_name"]) - if y_pred > THRESHOLD: - self.warnings.append(message) + # TODO predict all messages + y_pred = self.model.predict_proba( + self._get_features(message["domain_name"]) + ) + logger.info(f"Prediction: {y_pred}") + if np.argmax(y_pred, axis=1) == 1 and y_pred[0][1] > THRESHOLD: + logger.info("Append malicious request to warning.") + warning = { + "request": message, + "probability": float(y_pred[0][1]), + "model": MODEL, + "sha256": CHECKSUM, + } + self.warnings.append(warning) def send_warning(self) -> None: - with open(os.path.join(tempfile.gettempdir(), "warnings.json"), "a") as f: - json.dump(self.messages, f) - f.write("\n") + logger.info("Store alert to file.") + if len(self.warnings) > 0: + overall_score = median( + [warning["probability"] for warning in self.warnings] + ) + alert = {"overall_score": overall_score, "result": self.warnings} + logger.info(f"Add alert: {alert}") + with open(os.path.join(tempfile.gettempdir(), "warnings.json"), "a+") as f: + json.dump(alert, f) + f.write("\n") + else: + logger.info("No warning produced.") def main(one_iteration: bool = False): # pragma: no cover diff --git a/src/inspector/inspector.py b/src/inspector/inspector.py index fa2b1d3..a989743 100644 --- a/src/inspector/inspector.py +++ b/src/inspector/inspector.py @@ -64,7 +64,6 @@ class Inspector: def __init__(self) -> None: self.key = None - self.topic = "Collector" self.begin_timestamp = None self.end_timestamp = None self.messages = [] @@ -74,7 +73,7 @@ def __init__(self) -> None: logger.debug(f"Calling KafkaConsumeHandler(topic='Inspect')...") self.kafka_consume_handler = KafkaConsumeHandler(topic="Inspect") logger.debug(f"Calling KafkaProduceHandler(transactional_id='Inspect')...") - self.kafka_produce_handler = KafkaProduceHandler(transactional_id="Inspect") + self.kafka_produce_handler = KafkaProduceHandler(transactional_id="inspect") logger.debug(f"Initialized Inspector.") def get_and_fill_data(self) -> None: @@ -266,7 +265,7 @@ def _count_errors(self, messages: list, begin_timestamp, end_timestamp): def inspect(self): """Runs anomaly detection on given StreamAD Model on either univariate, multivariate data, or as an ensemble.""" - if len(MODELS) == 0: + if MODELS == None or len(MODELS) == 0: logger.warning("No model ist set!") raise NotImplementedError(f"No model is set!") match MODE: @@ -350,11 +349,14 @@ def _inspect_ensemble(self, models: str): module = importlib.import_module(model["module"]) module_model = getattr(module, model["model"]) self.model.append(module_model(**model["model_args"])) + for x in stream.iter_item(): + scores = [] # Fit all models in ensemble for models in self.model: - models.fit_score(x) - score = ensemble.ensemble([model for model in self.model]) + scores.append(models.fit_score(x)) + # TODO Calibrators are missing + score = ensemble.ensemble(scores) if score != None: self.anomalies.append(score) else: @@ -401,16 +403,26 @@ def send_data(self): if total_anomalies / len(self.X) > ANOMALY_THRESHOLD: logger.debug("Sending data to KafkaProduceHandler...") logger.info("Sending anomalies to detector for further analysation.") - data_to_send = { - "begin_timestamp": self.begin_timestamp.strftime(TIMESTAMP_FORMAT), - "end_timestamp": self.end_timestamp.strftime(TIMESTAMP_FORMAT), - "data": self.messages, - } - self.kafka_produce_handler.send( - topic=self.topic, - data=json.dumps(data_to_send), - key=self.key, - ) + buckets = {} + for message in self.messages: + if message["client_ip"] in buckets.keys(): + buckets[message["client_ip"]].append(message) + else: + buckets[message["client_ip"]] = [] + buckets.get(message["client_ip"]).append(message) + for key, value in buckets.items(): + logger.info(f"Sending anomalies to detector for {key}.") + logger.info(f"Sending anomalies to detector for {value}.") + data_to_send = { + "begin_timestamp": self.begin_timestamp.strftime(TIMESTAMP_FORMAT), + "end_timestamp": self.end_timestamp.strftime(TIMESTAMP_FORMAT), + "data": value, + } + self.kafka_produce_handler.send( + topic="Detector", + data=json.dumps(data_to_send), + key=key, + ) def main(one_iteration: bool = False): diff --git a/src/mock/generator.py b/src/mock/generator.py index 6e93f77..f14e361 100644 --- a/src/mock/generator.py +++ b/src/mock/generator.py @@ -2,21 +2,43 @@ import socket import sys import time +import polars as pl +import numpy as np + sys.path.append(os.getcwd()) -from src.base.log_config import setup_logging from src.mock.log_generator import generate_dns_log_line from src.base.log_config import get_logger +from src.train.dataset import Dataset, DatasetLoader logger = get_logger() if __name__ == "__main__": + data_base_path: str = "./data" + datasets = DatasetLoader(base_path=data_base_path, max_rows=10000) + dataset = Dataset( + data_path="", + data=pl.concat( + [ + datasets.dgta_dataset.data, + datasets.cic_dataset.data, + datasets.bambenek_dataset.data, + datasets.dga_dataset.data, + datasets.dgarchive_dataset.data, + ] + ), + max_rows=100, + ) + data = dataset.data + print(data) + np.random.seed(None) with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as client_socket: client_socket.connect((str("127.0.0.1"), 9998)) while True: for i in range(0, 10): - logline = generate_dns_log_line() + random_domain = data.sample(n=1) + logline = generate_dns_log_line(random_domain["query"].item()) client_socket.send(logline.encode("utf-8")) logger.info(f"Sent logline: {logline}") time.sleep(0.1) diff --git a/src/mock/log_generator.py b/src/mock/log_generator.py index d9e56ed..4543053 100644 --- a/src/mock/log_generator.py +++ b/src/mock/log_generator.py @@ -19,7 +19,7 @@ def random_ipv6(): RECORD_TYPES = ["AAAA", "A", "PR", "CNAME"] -def generate_dns_log_line(): +def generate_dns_log_line(domain: str): timestamp = ( datetime.datetime.now() + datetime.timedelta(0, 0, random.randint(0, 900)) ).strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z" @@ -30,7 +30,7 @@ def generate_dns_log_line(): response = IP[random.randint(0, 1)]() size = f"{random.randint(50, 255)}b" - return f"{timestamp} {status} {client_ip} {server_ip} random-ip.de {record_type} {response} {size}" + return f"{timestamp} {status} {client_ip} {server_ip} {domain} {record_type} {response} {size}" if __name__ == "__main__": diff --git a/src/train/dataset.py b/src/train/dataset.py index 5b3d05e..40f8a0d 100644 --- a/src/train/dataset.py +++ b/src/train/dataset.py @@ -1,15 +1,28 @@ -import logging +import sys +import os from dataclasses import dataclass -from typing import Any, Callable, List +from typing import Callable, List -import numpy as np import polars as pl import sklearn.model_selection -from fe_polars.encoding.one_hot_encoding import OneHotEncoder from torch.utils.data.dataset import Dataset +sys.path.append(os.getcwd()) +from src.base.log_config import get_logger + +logger = get_logger("train.feature") + def preprocess(x: pl.DataFrame): + """Preprocesses a `pl.DataFrame` into a basic data set for later transformation. + + Args: + x (pl.DataFrame): Data sets for preprocessing + + Returns: + pl.DataFrame: Preprocessed data set + """ + logger.debug("Start preprocessing data.") x = x.with_columns( [ (pl.col("query").str.split(".").alias("labels")), @@ -22,6 +35,7 @@ def preprocess(x: pl.DataFrame): ] ) + logger.debug("Start preprocessing FQDN.") x = x.with_columns( [ # FQDN @@ -29,6 +43,9 @@ def preprocess(x: pl.DataFrame): ] ) + x = x.filter(pl.col("labels").list.len().ne(1)) + + logger.debug("Start preprocessing Second-level domain.") x = x.with_columns( [ # Second-level domain @@ -41,6 +58,7 @@ def preprocess(x: pl.DataFrame): ] ) + logger.debug("Start preprocessing Third-level domain.") x = x.with_columns( [ # Third-level domain @@ -56,6 +74,7 @@ def preprocess(x: pl.DataFrame): ), ] ) + logger.debug("Start preprocessing class.") x = x.with_columns( [ ( @@ -66,103 +85,314 @@ def preprocess(x: pl.DataFrame): ) ] ) - + logger.debug("End preprocessing data.") return x -def cast_dga(data_path: str): +def cast_dga(data_path: str, max_rows: int) -> pl.DataFrame: + """Cast dga data set. + + Args: + data_path (str): Data path to data set + max_rows (int): Maximum rows. + + Returns: + pl.DataFrame: Loaded pl.DataFrame. + """ + logger.info(f"Start casting data set {data_path}.") df = pl.read_csv(data_path) df = df.rename({"Domain": "query"}) df = df.drop(["DGA_family", "Type"]) df = df.with_columns([pl.lit("malicious").alias("class")]) df = preprocess(df) - return df + + df_legit = df.filter(pl.col("class").eq(0))[:max_rows] + df_malicious = df.filter(pl.col("class").eq(1))[:max_rows] + + logger.info(f"Data loaded with shape {df.shape}") + return pl.concat([df_legit, df_malicious]) -def cast_bambenek(data_path: str): +def cast_bambenek(data_path: str, max_rows: int) -> pl.DataFrame: + """Cast Bambenek data set. + + Args: + data_path (str): Data path to data set + max_rows (int): Maximum rows. + + Returns: + pl.DataFrame: Loaded pl.DataFrame. + """ + logger.info(f"Start casting data set {data_path}.") df = pl.read_csv(data_path) df = df.rename({"Domain": "query"}) df = df.drop(["DGA_family", "Type"]) df = df.with_columns([pl.lit("malicious").alias("class")]) df = preprocess(df) - return df + df_legit = df.filter(pl.col("class").eq(0))[:max_rows] + df_malicious = df.filter(pl.col("class").eq(1))[:max_rows] + + logger.info(f"Data loaded with shape {df.shape}") + return pl.concat([df_legit, df_malicious]) + + +def cast_cic(data_path: List[str], max_rows: int) -> pl.DataFrame: + """Cast CIC data set. -def cast_cic(data_path: List[str]): + Args: + data_path (str): Data path to data set + max_rows (int): Maximum rows. + + Returns: + pl.DataFrame: Loaded pl.DataFrame. + """ dataframes = [] for data in data_path: + logger.info(f"Start casting data set {data}.") y = data.split("_")[-1].split(".")[0] - df = pl.read_csv(data, has_header=False) + df = pl.read_csv( + data, has_header=False, n_rows=max_rows if max_rows > 0 else None + ) if y == "benign": df = df.with_columns([pl.lit("legit").alias("class")]) else: df = df.with_columns([pl.lit(y).alias("class")]) df = df.rename({"column_1": "query"}) df = preprocess(df) + + logger.info(f"Data loaded with shape {df.shape}") + dataframes.append(df) + + return pl.concat(dataframes) + + +def cast_dgarchive(data_path: List[str], max_rows: int) -> pl.DataFrame: + """Cast DGArchive data set. + + Args: + data_path (str): Data path to data set + max_rows (int): Maximum rows. + + Returns: + pl.DataFrame: Loaded pl.DataFrame. + """ + dataframes = [] + for data in data_path: + logger.info(f"Start casting data set {data}.") + df = pl.read_csv( + data, + has_header=False, + separator=",", + n_rows=max_rows if max_rows > 0 else None, + ) + df = df.rename({"column_1": "query"}) + df = df.select("query") + df = df.with_columns([pl.lit("1").alias("class")]) + df = preprocess(df) + logger.info(f"Data loaded with shape {df.shape}") dataframes.append(df) return pl.concat(dataframes) -def cast_dgta(data_path: str) -> pl.DataFrame: +def cast_dgta(data_path: str, max_rows: int) -> pl.DataFrame: + """Cast DGTA data set. + + Args: + data_path (str): Data path to data set + max_rows (int): Maximum rows. + + Returns: + pl.DataFrame: Loaded pl.DataFrame. + """ + def __custom_decode(data): - retL = [None] * len(data) - for i, datum in enumerate(data): - retL[i] = str(datum.decode("latin-1").encode("utf-8").decode("utf-8")) + """Custom decode function. + + Args: + data (str): Str to decode. + + Returns: + str: Decoded str. + """ + return str(data.decode("latin-1").encode("utf-8").decode("utf-8")) - return pl.Series(retL) + logger.info(f"Start casting data set {data_path}.") df = pl.read_parquet(data_path) df = df.rename({"domain": "query"}) # Drop unnecessary column df = df.drop("__index_level_0__") - df = df.with_columns([pl.col("query").map(__custom_decode)]) + df = df.with_columns( + pl.col("query").map_elements(__custom_decode, return_dtype=pl.Utf8) + ) df = preprocess(df) - return df + df_legit = df.filter(pl.col("class").eq(0))[:max_rows] + df_malicious = df.filter(pl.col("class").eq(1))[:max_rows] + + logger.info(f"Data loaded with shape {df.shape}") + return pl.concat([df_legit, df_malicious]) class DatasetLoader: - def __init__(self) -> None: + """DatasetLoader for Training.""" + + def __init__(self, base_path: str = "", max_rows: int = -1) -> None: + """Initialise data sets. + + Args: + base_path (str, optional): Base path to data set folder. Defaults to "". + max_rows (int, optional): Maximum rows to consider. Defaults to -1. + """ + logger.info("Initialise DatasetLoader") + self.base_path = base_path + self.max_rows = max_rows + logger.info("Finished initialisation.") + + @property + def dgta_dataset(self) -> Dataset: self.dgta_data = Dataset( - data_path="/home/smachmeier/projects/heiDGA/data/dgta/dgta-benchmark.parquet", + data_path=f"{self.base_path}/dgta/dgta-benchmark.parquet", cast_dataset=cast_dgta, + max_rows=self.max_rows, ) + return self.dgta_data + @property + def dga_dataset(self) -> Dataset: self.dga_data = Dataset( - data_path="/home/smachmeier/projects/heiDGA/data/360_dga_domain.csv", + data_path=f"{self.base_path}/360_dga_domain.csv", cast_dataset=cast_dga, + max_rows=self.max_rows, ) + return self.dga_data + @property + def bambenek_dataset(self) -> Dataset: self.bambenek_data = Dataset( - data_path="/home/smachmeier/projects/heiDGA/data/bambenek_dga_domain.csv", + data_path=f"{self.base_path}/bambenek_dga_domain.csv", cast_dataset=cast_bambenek, + max_rows=self.max_rows, ) + return self.bambenek_data + @property + def cic_dataset(self) -> Dataset: self.cic_data = Dataset( data_path=[ - "/home/smachmeier/projects/heiDGA/example/CICBellDNS2021_CSV_benign.csv", - "/home/smachmeier/projects/heiDGA/example/CICBellDNS2021_CSV_malware.csv", - "/home/smachmeier/projects/heiDGA/example/CICBellDNS2021_CSV_phishing.csv", - "/home/smachmeier/projects/heiDGA/example/CICBellDNS2021_CSV_spam.csv", + f"{self.base_path}/cic/CICBellDNS2021_CSV_benign.csv", + f"{self.base_path}/cic/CICBellDNS2021_CSV_malware.csv", + f"{self.base_path}/cic/CICBellDNS2021_CSV_phishing.csv", + f"{self.base_path}/cic/CICBellDNS2021_CSV_spam.csv", ], cast_dataset=cast_cic, + max_rows=self.max_rows, ) + return self.cic_data @property - def dgta_dataset(self) -> Dataset: - return self.dgta_data - - @property - def dga_dataset(self) -> Dataset: - return self.dga_data - - @property - def bambenek_dataset(self) -> Dataset: - return self.bambenek_data - - @property - def cic_dataset(self) -> Dataset: - return self.cic_data + def dgarchive_dataset(self) -> Dataset: + self.dgarchive_data = Dataset( + data_path=[ + f"{self.base_path}/dgarchive/bamital_dga.csv", + f"{self.base_path}/dgarchive/banjori_dga.csv", + f"{self.base_path}/dgarchive/bedep_dga.csv", + f"{self.base_path}/dgarchive/beebone_dga.csv", + f"{self.base_path}/dgarchive/blackhole_dga.csv", + f"{self.base_path}/dgarchive/bobax_dga.csv", + f"{self.base_path}/dgarchive/ccleaner_dga.csv", + f"{self.base_path}/dgarchive/chinad_dga.csv", + f"{self.base_path}/dgarchive/chir_dga.csv", + f"{self.base_path}/dgarchive/conficker_dga.csv", + f"{self.base_path}/dgarchive/corebot_dga.csv", + f"{self.base_path}/dgarchive/cryptolocker_dga.csv", + f"{self.base_path}/dgarchive/darkshell_dga.csv", + f"{self.base_path}/dgarchive/diamondfox_dga.csv", + f"{self.base_path}/dgarchive/dircrypt_dga.csv", + f"{self.base_path}/dgarchive/dmsniff_dga.csv", + f"{self.base_path}/dgarchive/dnsbenchmark_dga.csv", + f"{self.base_path}/dgarchive/dnschanger_dga.csv", + f"{self.base_path}/dgarchive/downloader_dga.csv", + f"{self.base_path}/dgarchive/dyre_dga.csv", + f"{self.base_path}/dgarchive/ebury_dga.csv", + f"{self.base_path}/dgarchive/ekforward_dga.csv", + f"{self.base_path}/dgarchive/emotet_dga.csv", + f"{self.base_path}/dgarchive/feodo_dga.csv", + f"{self.base_path}/dgarchive/fobber_dga.csv", + f"{self.base_path}/dgarchive/gameover_dga.csv", + f"{self.base_path}/dgarchive/gameover_p2p.csv", + f"{self.base_path}/dgarchive/gozi_dga.csv", + f"{self.base_path}/dgarchive/goznym_dga.csv", + f"{self.base_path}/dgarchive/gspy_dga.csv", + f"{self.base_path}/dgarchive/hesperbot_dga.csv", + f"{self.base_path}/dgarchive/infy_dga.csv", + f"{self.base_path}/dgarchive/locky_dga.csv", + f"{self.base_path}/dgarchive/madmax_dga.csv", + f"{self.base_path}/dgarchive/makloader_dga.csv", + f"{self.base_path}/dgarchive/matsnu_dga.csv", + f"{self.base_path}/dgarchive/mirai_dga.csv", + f"{self.base_path}/dgarchive/modpack_dga.csv", + f"{self.base_path}/dgarchive/monerominer_dga.csv", + f"{self.base_path}/dgarchive/murofet_dga.csv", + f"{self.base_path}/dgarchive/murofetweekly_dga.csv", + f"{self.base_path}/dgarchive/mydoom_dga.csv", + f"{self.base_path}/dgarchive/necurs_dga.csv", + f"{self.base_path}/dgarchive/nymaim2_dga.csv", + f"{self.base_path}/dgarchive/nymaim_dga.csv", + f"{self.base_path}/dgarchive/oderoor_dga.csv", + f"{self.base_path}/dgarchive/omexo_dga.csv", + f"{self.base_path}/dgarchive/padcrypt_dga.csv", + f"{self.base_path}/dgarchive/pandabanker_dga.csv", + f"{self.base_path}/dgarchive/pitou_dga.csv", + f"{self.base_path}/dgarchive/proslikefan_dga.csv", + f"{self.base_path}/dgarchive/pushdo_dga.csv", + f"{self.base_path}/dgarchive/pushdotid_dga.csv", + f"{self.base_path}/dgarchive/pykspa2_dga.csv", + f"{self.base_path}/dgarchive/pykspa2s_dga.csv", + f"{self.base_path}/dgarchive/pykspa_dga.csv", + f"{self.base_path}/dgarchive/qadars_dga.csv", + f"{self.base_path}/dgarchive/qakbot_dga.csv", + f"{self.base_path}/dgarchive/qhost_dga.csv", + f"{self.base_path}/dgarchive/qsnatch_dga.csv", + f"{self.base_path}/dgarchive/ramdo_dga.csv", + f"{self.base_path}/dgarchive/ramnit_dga.csv", + f"{self.base_path}/dgarchive/ranbyus_dga.csv", + f"{self.base_path}/dgarchive/randomloader_dga.csv", + f"{self.base_path}/dgarchive/redyms_dga.csv", + f"{self.base_path}/dgarchive/rovnix_dga.csv", + f"{self.base_path}/dgarchive/shifu_dga.csv", + f"{self.base_path}/dgarchive/simda_dga.csv", + f"{self.base_path}/dgarchive/sisron_dga.csv", + f"{self.base_path}/dgarchive/sphinx_dga.csv", + f"{self.base_path}/dgarchive/suppobox_dga.csv", + f"{self.base_path}/dgarchive/sutra_dga.csv", + f"{self.base_path}/dgarchive/symmi_dga.csv", + f"{self.base_path}/dgarchive/szribi_dga.csv", + f"{self.base_path}/dgarchive/tempedreve_dga.csv", + f"{self.base_path}/dgarchive/tempedrevetdd_dga.csv", + f"{self.base_path}/dgarchive/tinba_dga.csv", + f"{self.base_path}/dgarchive/tinynuke_dga.csv", + f"{self.base_path}/dgarchive/tofsee_dga.csv", + f"{self.base_path}/dgarchive/torpig_dga.csv", + f"{self.base_path}/dgarchive/tsifiri_dga.csv", + f"{self.base_path}/dgarchive/ud2_dga.csv", + f"{self.base_path}/dgarchive/ud3_dga.csv", + f"{self.base_path}/dgarchive/ud4_dga.csv", + f"{self.base_path}/dgarchive/urlzone_dga.csv", + f"{self.base_path}/dgarchive/vawtrak_dga.csv", + f"{self.base_path}/dgarchive/vidro_dga.csv", + f"{self.base_path}/dgarchive/vidrotid_dga.csv", + f"{self.base_path}/dgarchive/virut_dga.csv", + f"{self.base_path}/dgarchive/volatilecedar_dga.csv", + f"{self.base_path}/dgarchive/wd_dga.csv", + f"{self.base_path}/dgarchive/xshellghost_dga.csv", + f"{self.base_path}/dgarchive/xxhex_dga.csv", + ], + cast_dataset=cast_dgarchive, + max_rows=self.max_rows, + ) + return self.dgarchive_data @dataclass @@ -170,7 +400,11 @@ class Dataset: """Dataset class.""" def __init__( - self, data_path: Any, data: pl.DataFrame = None, cast_dataset: Callable = None + self, + data_path: List[str], + data: pl.DataFrame = None, + cast_dataset: Callable = None, + max_rows: int = -1, ) -> None: """Initializes data. @@ -185,15 +419,19 @@ def __init__( Raises: NotImplementedError: _description_ """ - if cast_dataset != None: - self.data = cast_dataset(data_path) + if cast_dataset != None and data_path != "": + logger.info("Cast function provided, load data set.") + self.data = cast_dataset(data_path, max_rows) elif data_path != "": + logger.info("Data path provided, load data set.") self.data = pl.read_csv(data_path) elif not data is None: + logger.info("Data set provided, load data set.") self.data = data else: + logger.error("No data given!") raise NotImplementedError("No data given") - self.label_encoder = OneHotEncoder(features_to_encode=["class"]) + self.X_train, self.X_val, self.X_test, self.Y_train, self.Y_val, self.Y_test = ( self.__train_test_val_split() ) @@ -208,7 +446,14 @@ def __len__(self) -> int: def __train_test_val_split( self, train_frac: float = 0.8, random_state: int = None - ) -> tuple[list, list, list, list, list, list]: + ) -> tuple[ + pl.DataFrame, + pl.DataFrame, + pl.DataFrame, + pl.DataFrame, + pl.DataFrame, + pl.DataFrame, + ]: """Splits data set in train, test, and validation set Args: @@ -219,6 +464,8 @@ def __train_test_val_split( tuple[list, list, list, list, list, list]: X_train, X_val, X_test, Y_train, Y_val, Y_test """ + logger.info("Create train, validation, and test split.") + self.data = self.data.filter(pl.col("query").str.len_chars() > 0) self.data = self.data.unique(subset="query") diff --git a/src/train/feature.py b/src/train/feature.py index eb4856d..4bea63a 100644 --- a/src/train/feature.py +++ b/src/train/feature.py @@ -1,31 +1,38 @@ -import logging +import sys +import os import math from string import ascii_lowercase as alc from typing import List import polars as pl +sys.path.append(os.getcwd()) +from src.base.log_config import get_logger -class Preprocessor: +logger = get_logger("train.feature") + + +class Processor: + """Processor for data set. Extracts features from data space.""" def __init__(self, features_to_drop: List): """Init. Args: - feature_to_drop (list): list of feature to drop + feature_to_drop (list): List of feature to drop """ self.features_to_drop = features_to_drop def transform(self, x: pl.DataFrame) -> pl.DataFrame: - """Transform our dataset with new features + """Transform our dataset with new features. Args: - x (pl.DataFrame): dataframe with our features + x (pl.DataFrame): pl.DataFrame with our features. Returns: - pl.DataFrame: preprocessed dataframe + pl.DataFrame: Preprocessed dataframe. """ - logging.debug("Start data transformation") + logger.debug("Start data transformation") x = x.with_columns( [ (pl.col("query").str.split(".").list.len().alias("label_length")), @@ -44,7 +51,7 @@ def transform(self, x: pl.DataFrame) -> pl.DataFrame: ), ] ) - # Get letter frequency + logger.debug("Get letter frequency") for i in alc: x = x.with_columns( [ @@ -56,96 +63,50 @@ def transform(self, x: pl.DataFrame) -> pl.DataFrame: ).alias(f"freq_{i}"), ] ) + logger.debug("Get full, alpha, special, and numeric count.") + for level in ["thirdleveldomain", "secondleveldomain", "fqdn"]: + x = x.with_columns( + [ + ( + pl.when(pl.col(level).str.len_chars().eq(0)) + .then(pl.lit(0)) + .otherwise( + pl.col(level) + .str.len_chars() + .truediv(pl.col(level).str.len_chars()) + ) + ).alias(f"{level}_full_count"), + ( + pl.when(pl.col(level).str.len_chars().eq(0)) + .then(pl.lit(0)) + .otherwise( + pl.col(level) + .str.count_matches(r"[a-zA-Z]") + .truediv(pl.col(level).str.len_chars()) + ) + ).alias(f"{level}_alpha_count"), + ( + pl.when(pl.col(level).str.len_chars().eq(0)) + .then(pl.lit(0)) + .otherwise( + pl.col(level) + .str.count_matches(r"[0-9]") + .truediv(pl.col(level).str.len_chars()) + ) + ).alias(f"{level}_numeric_count"), + ( + pl.when(pl.col(level).str.len_chars().eq(0)) + .then(pl.lit(0)) + .otherwise( + pl.col(level) + .str.count_matches(r"[^\w\s]") + .truediv(pl.col(level).str.len_chars()) + ) + ).alias(f"{level}_special_count"), + ] + ) - x = x.with_columns( - [ - # FQDN - (pl.col("query").str.len_chars().alias("fqdn_full_count")), - ( - pl.col("query") - .str.count_matches(r"[a-zA-Z]") - .truediv(pl.col("query").str.len_chars()) - ).alias("fqdn_alpha_count"), - ( - pl.col("query") - .str.count_matches(r"[0-9]") - .truediv(pl.col("query").str.len_chars()) - ).alias("fqdn_numeric_count"), - ( - pl.col("query") - .str.count_matches(r"[^\w\s]") - .truediv(pl.col("query").str.len_chars()) - ).alias("fqdn_special_count"), - ] - ) - - x = x.with_columns( - [ - ( - pl.col("secondleveldomain") - .str.len_chars() - .truediv(pl.col("secondleveldomain").str.len_chars()) - .alias("secondleveldomain_full_count") - ), - ( - pl.col("secondleveldomain") - .str.count_matches(r"[a-zA-Z]") - .truediv(pl.col("secondleveldomain").str.len_chars()) - ).alias("secondleveldomain_alpha_count"), - ( - pl.col("secondleveldomain") - .str.count_matches(r"[0-9]") - .truediv(pl.col("secondleveldomain").str.len_chars()) - ).alias("secondleveldomain_numeric_count"), - ( - pl.col("secondleveldomain") - .str.count_matches(r"[^\w\s]") - .truediv(pl.col("secondleveldomain").str.len_chars()) - ).alias("secondleveldomain_special_count"), - ] - ) - - x = x.with_columns( - [ - ( - pl.when(pl.col("thirdleveldomain").str.len_chars().eq(0)) - .then(pl.lit(0)) - .otherwise( - pl.col("thirdleveldomain") - .str.len_chars() - .truediv(pl.col("thirdleveldomain").str.len_chars()) - ) - ).alias("thirdleveldomain_full_count"), - ( - pl.when(pl.col("thirdleveldomain").str.len_chars().eq(0)) - .then(pl.lit(0)) - .otherwise( - pl.col("thirdleveldomain") - .str.count_matches(r"[a-zA-Z]") - .truediv(pl.col("thirdleveldomain").str.len_chars()) - ) - ).alias("thirdleveldomain_alpha_count"), - ( - pl.when(pl.col("thirdleveldomain").str.len_chars().eq(0)) - .then(pl.lit(0)) - .otherwise( - pl.col("thirdleveldomain") - .str.count_matches(r"[0-9]") - .truediv(pl.col("thirdleveldomain").str.len_chars()) - ) - ).alias("thirdleveldomain_numeric_count"), - ( - pl.when(pl.col("thirdleveldomain").str.len_chars().eq(0)) - .then(pl.lit(0)) - .otherwise( - pl.col("thirdleveldomain") - .str.count_matches(r"[^\w\s]") - .truediv(pl.col("thirdleveldomain").str.len_chars()) - ) - ).alias("thirdleveldomain_special_count"), - ] - ) - + logger.debug("Get frequency standard deviation, median, variance, and mean.") x = x.with_columns( [ ( @@ -171,6 +132,9 @@ def transform(self, x: pl.DataFrame) -> pl.DataFrame: ] ) + logger.debug( + "Get standard deviation, median, variance, and mean for full, alpha, special, and numeric count." + ) for level in ["thirdleveldomain", "secondleveldomain", "fqdn"]: x = x.with_columns( [ @@ -225,7 +189,7 @@ def transform(self, x: pl.DataFrame) -> pl.DataFrame: ] ) - logging.debug("Start entropy calculation") + logger.debug("Start entropy calculation") for ent in ["fqdn", "thirdleveldomain", "secondleveldomain"]: x = x.with_columns( [ @@ -234,7 +198,8 @@ def transform(self, x: pl.DataFrame) -> pl.DataFrame: lambda x: [ float(str(x).count(c)) / len(str(x)) for c in dict.fromkeys(list(str(x))) - ] + ], + return_dtype=pl.List(pl.Float64), ) ).alias("prob"), ] @@ -253,13 +218,14 @@ def transform(self, x: pl.DataFrame) -> pl.DataFrame: ] ) x = x.drop("prob") - logging.debug("Finished entropy calculation") + logger.debug("Finished entropy calculation") - # Fill NaN + logger.debug("Fill NaN.") x = x.fill_nan(0) - # Drop features not useful anymore + + logger.debug("Drop features that are not useful.") x = x.drop(self.features_to_drop) - logging.debug("Finished data transformation") + logger.debug("Finished data transformation") return x diff --git a/src/train/model.py b/src/train/model.py index a6c2e34..ff44452 100644 --- a/src/train/model.py +++ b/src/train/model.py @@ -1,91 +1,350 @@ +from abc import ABCMeta, abstractmethod +import hashlib +import pickle import sys import os -import logging -from time import time -from typing import Any +import tempfile -import numpy as np -import polars as pl +from sklearn.metrics import make_scorer +import xgboost as xgb +import optuna import torch -from fe_polars.encoding.target_encoding import TargetEncoder -from fe_polars.imputing.base_imputing import Imputer +import numpy as np from sklearn.ensemble import RandomForestClassifier -from sklearn.model_selection import RandomizedSearchCV -from xgboost import XGBClassifier, XGBRFClassifier +from sklearn.model_selection import cross_val_score + sys.path.append(os.getcwd()) -from src.train.feature import Preprocessor +from src.train.feature import Processor +from src.base.log_config import get_logger +from src.train.dataset import Dataset + +logger = get_logger("train.model") + +SEED = 108 +N_FOLDS = 5 +CV_RESULT_DIR = "./results" class Pipeline: """Pipeline for training models.""" def __init__( - self, - preprocessor: Preprocessor, - mean_imputer: Imputer, - target_encoder: TargetEncoder, - clf: Any, + self, processor: Processor, model: str, dataset: Dataset, model_output_path: str ): """Initializes preprocessors, encoder, and model. Args: - preprocessor (Preprocessor): Preprocessor to transform input data into features. + processor (processor): Processor to transform input data into features. mean_imputer (Imputer): Mean imputer to handle null values. target_encoder (TargetEncoder): Target encoder for non-numeric values. clf (torch.nn.Modul): torch.nn.Modul for training. """ - self.preprocessor = preprocessor - self.mean_imputer = mean_imputer - self.target_encoder = target_encoder - self.clf = clf + self.processor = processor + self.dataset = dataset + self.model_output_path = model_output_path + logger.info("Start data set transformation.") + x_train = self.processor.transform(x=self.dataset.X_train) + logger.info(f"End data set transformation with shape {x_train.shape}.") + + self.x_train = x_train.to_numpy() + self.y_train = self.dataset.Y_train.to_numpy().ravel() + match model: + case "rf": + self.model = RandomForestModel( + processor=self.processor, x_train=self.x_train, y_train=self.y_train + ) + case "xg": + self.model = XGBoostModel( + processor=self.processor, x_train=self.x_train, y_train=self.y_train + ) + case _: + raise NotImplementedError(f"Model not implemented!") + + def fit(self): + """Fits models to training data. + + Args: + x_train (np.array): X data. + y_train (np.array): Y labels. + """ + if not os.path.exists(CV_RESULT_DIR): + os.mkdir(CV_RESULT_DIR) + + study = optuna.create_study(direction="maximize") + study.optimize(self.model.objective, n_trials=20, timeout=600) + + logger.info(f"Number of finished trials: {len(study.trials)}") + logger.info("Best trial:") + trial = study.best_trial + + logger.info(f" Value: {trial.value}") + logger.info(f" Params: ") + for key, value in trial.params.items(): + logger.info(f" {key}: {value}") + + self.model.train(trial, self.model_output_path) + + def predict(self, x): + """Predicts given X. + + Args: + x (np.array): X data + Returns: + np.array: Model output. + """ + return self.model.predict(x) + + +class Model(metaclass=ABCMeta): + def __init__( + self, processor: Processor, x_train: np.ndarray, y_train: np.ndarray + ) -> None: + self.processor = processor + self.x_train = x_train + self.y_train = y_train # setting device on GPU if available, else CPU self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - logging.info(f"Using device: {self.device}") + logger.info(f"Using device: {self.device}") if torch.cuda.is_available(): - logging.debug("GPU detected") - logging.debug(f"\t{torch.cuda.get_device_name(0)}") + logger.info("GPU detected") + logger.info(f"\t{torch.cuda.get_device_name(0)}") if self.device.type == "cuda": - logging.debug("Memory Usage:") - logging.debug( + logger.info("Memory Usage:") + logger.info( f"\tAllocated: {round(torch.cuda.memory_allocated(0)/1024**3,1)} GB" ) - logging.debug( + logger.info( f"\tCached: {round(torch.cuda.memory_reserved(0)/1024**3,1)} GB" ) + self.device = "gpu" + else: + self.device = "cpu" - def fit(self, x_train: pl.DataFrame, y_train: pl.DataFrame): - """Fits models to training data. + def sha256sum(self, file_path: str) -> str: + """Return a SHA265 sum check to validate the model. Args: - x_train (np.array): X data. - y_train (np.array): Y labels. + file_path (str): File path of model. + + Returns: + str: SHA256 sum """ - x_train = self.preprocessor.transform(x=x_train) - x_train = self.target_encoder.fit_transform(x=x_train, y=y_train) - x_train = self.mean_imputer.fit_transform(x=x_train) - - clf = RandomizedSearchCV( - self.clf["model"], - self.clf["search"], - random_state=42, - n_iter=100, - cv=3, - verbose=2, - n_jobs=-1, + h = hashlib.sha256() + + with open(file_path, "rb") as file: + while True: + # Reading is buffered, so we can read smaller chunks. + chunk = file.read(h.block_size) + if not chunk: + break + h.update(chunk) + + return h.hexdigest() + + @abstractmethod + def objective(self, trial): + pass + + @abstractmethod + def predict(self, x): + pass + + @abstractmethod + def train(self, trial, output_path): + pass + + +class XGBoostModel(Model): + def __init__( + self, processor: Processor, x_train: np.ndarray, y_train: np.ndarray + ) -> None: + super().__init__(processor, x_train, y_train) + + def fdr_metric(self, preds: np.ndarray, dtrain: xgb.DMatrix) -> tuple[str, float]: + # Get the true labels + labels = dtrain.get_label() + + # Threshold predictions to get binary outcomes (assuming binary classification with 0.5 threshold) + preds_binary = (preds > 0.5).astype(int) + + # Calculate False Positives (FP) and True Positives (TP) + FP = np.sum((preds_binary == 1) & (labels == 0)) + TP = np.sum((preds_binary == 1) & (labels == 1)) + + # Avoid division by zero + if FP + TP == 0: + fdr = 0.0 + else: + fdr = FP / (FP + TP) + + # Return the result in the format (name, value) + return ( + "fdr", + 1 - fdr, + ) # -1 is essentiell since XGBoost wants a scoring value (higher is better). However, FDR represents a loss function. + + def objective(self, trial): + dtrain = xgb.DMatrix(self.x_train, label=self.y_train) + + param = { + "verbosity": 0, + "objective": "binary:logistic", + "eval_metric": "auc", + "device": self.device, + "booster": trial.suggest_categorical( + "booster", ["gbtree", "gblinear", "dart"] + ), + "lambda": trial.suggest_float("lambda", 1e-8, 1.0, log=True), + "alpha": trial.suggest_float("alpha", 1e-8, 1.0, log=True), + # sampling ratio for training data. + "subsample": trial.suggest_float("subsample", 0.2, 1.0), + # sampling according to each tree. + "colsample_bytree": trial.suggest_float("colsample_bytree", 0.2, 1.0), + } + + if param["booster"] == "gbtree" or param["booster"] == "dart": + param["max_depth"] = trial.suggest_int("max_depth", 1, 9) + # minimum child weight, larger the term more conservative the tree. + param["min_child_weight"] = trial.suggest_int("min_child_weight", 2, 10) + param["eta"] = trial.suggest_float("eta", 1e-8, 1.0, log=True) + param["gamma"] = trial.suggest_float("gamma", 1e-8, 1.0, log=True) + param["grow_policy"] = trial.suggest_categorical( + "grow_policy", ["depthwise", "lossguide"] + ) + + if param["booster"] == "dart": + param["sample_type"] = trial.suggest_categorical( + "sample_type", ["uniform", "weighted"] + ) + param["normalize_type"] = trial.suggest_categorical( + "normalize_type", ["tree", "forest"] + ) + param["rate_drop"] = trial.suggest_float("rate_drop", 1e-8, 1.0, log=True) + param["skip_drop"] = trial.suggest_float("skip_drop", 1e-8, 1.0, log=True) + + xgb_cv_results = xgb.cv( + params=param, + dtrain=dtrain, + num_boost_round=10000, + nfold=N_FOLDS, + stratified=True, + early_stopping_rounds=100, + seed=SEED, + verbose_eval=False, + custom_metric=self.fdr_metric, + ) + + # Set n_estimators as a trial attribute; Accessible via study.trials_dataframe(). + trial.set_user_attr("n_estimators", len(xgb_cv_results)) + + # Save cross-validation results. + filepath = os.path.join(CV_RESULT_DIR, "{}.csv".format(trial.number)) + xgb_cv_results.to_csv(filepath, index=False) + + # Extract the best score. + best_fdr = xgb_cv_results["test-auc-mean"].values[-1] + return best_fdr + + def predict(self, x): + """Predicts given X. + + Args: + x (np.array): X data + + Returns: + np.array: Model output. + """ + x = self.processor.transform(x=x) + # dtest = xgb.DMatrix(x.to_numpy()) + return self.clf.predict(x) + + def train(self, trial, output_path): + logger.info("Number of estimators: {}".format(trial.user_attrs["n_estimators"])) + + # dtrain = xgb.DMatrix(self.x_train, label=self.y_train) + + params = { + "verbosity": 0, + "objective": "binary:logistic", + "eval_metric": "auc", + "device": self.device, + } + + self.clf = xgb.XGBClassifier( + n_estimators=trial.user_attrs["n_estimators"], **trial.params, **params ) + self.clf.fit(self.x_train, self.y_train) + + logger.info("Save trained model to a file.") + with open( + os.path.join(tempfile.gettempdir(), f"xg_{trial.number}.pickle"), "wb" + ) as fout: + pickle.dump(self.clf, fout) - start = time() - model = clf.fit(x_train.to_numpy(), y_train.to_numpy().ravel()) - logging.info( - f"GridSearchCV took {time() - start:.2f} seconds for {len(clf.cv_results_['params']):d} candidate parameter settings." + sha256sum = self.sha256sum( + os.path.join(tempfile.gettempdir(), f"xg_{trial.number}.pickle") ) - logging.info(model.best_params_) + with open(os.path.join(output_path, f"xg_{sha256sum}.pickle"), "wb") as fout: + pickle.dump(self.clf, fout) + + +class RandomForestModel(Model): + def __init__( + self, processor: Processor, x_train: np.ndarray, y_train: np.ndarray + ) -> None: + super().__init__(processor, x_train, y_train) + + # Define the custom FDR metric + def fdr_metric(self, y_true: np.ndarray, y_pred: np.ndarray): + # False Positives (FP): cases where the model predicted 1 but the actual label is 0 + FP = np.sum((y_pred == 1) & (y_true == 0)) + + # True Positives (TP): cases where the model correctly predicted 1 + TP = np.sum((y_pred == 1) & (y_true == 1)) + + # Compute FDR, avoiding division by zero + if FP + TP == 0: + fdr = 0.0 + else: + fdr = FP / (FP + TP) - self.clf = model + return fdr + + def objective(self, trial): + # Define hyperparameters to optimize + n_estimators = trial.suggest_int("n_estimators", 50, 300) + max_depth = trial.suggest_int("max_depth", 2, 20) + min_samples_split = trial.suggest_int("min_samples_split", 2, 20) + min_samples_leaf = trial.suggest_int("min_samples_leaf", 1, 20) + max_features = trial.suggest_categorical("max_features", ["sqrt", "log2", None]) + + # Create model with suggested hyperparameters + classifier_obj = RandomForestClassifier( + n_estimators=n_estimators, + max_depth=max_depth, + min_samples_split=min_samples_split, + min_samples_leaf=min_samples_leaf, + max_features=max_features, + random_state=SEED, + ) + + # Create a scorer using make_scorer, setting greater_is_better to False since lower FDR is better + fdr_scorer = make_scorer(self.fdr_metric, greater_is_better=False) + + score = cross_val_score( + classifier_obj, + self.x_train, + self.y_train, + n_jobs=-1, + cv=N_FOLDS, + scoring=fdr_scorer, + ) + fdr = score.mean() + return fdr def predict(self, x): """Predicts given X. @@ -96,62 +355,21 @@ def predict(self, x): Returns: np.array: Model output. """ - x = self.preprocessor.transform(x=x) - x = self.target_encoder.transform(x=x) - x = self.mean_imputer.transform(x=x) - return self.clf.predict(X=x.to_numpy()) - - -xgboost_params = { - "max_leaves": 2**8, - "alpha": 0.9, - "scale_pos_weight": 0.5, - "objective": "binary:logistic", - "tree_method": "hist", - "device": "cuda", -} - -xgboost_rf_params = { - "colsample_bynode": 0.8, - "learning_rate": 1, - "max_depth": 5, - "num_parallel_tree": 100, - "objective": "binary:logistic", - "subsample": 0.8, - "tree_method": "hist", - "booster": "gbtree", - "device": "cuda", -} - -xgboost_model = { - "model": XGBClassifier(**xgboost_params), - "search": { - "eta": list(np.linspace(0.1, 0.6, 6)), - "gamma": [int(x) for x in np.linspace(0, 10, 10)], - "learning_rate": [0.03, 0.01, 0.003, 0.001], - "min_child_weight": [1, 3, 5, 7, 10], - "subsample": [0.6, 0.8, 1.0, 1.2, 1.4], - "colsample_bytree": [0.6, 0.8, 1.0, 1.2, 1.4], - "max_depth": [3, 4, 5, 6, 7, 8, 9, 10, 12, 14], - "reg_lambda": np.array([0.4, 0.6, 0.8, 1, 1.2, 1.4]), - }, -} -xgboost_rf_model = { - "model": XGBRFClassifier(**xgboost_rf_params), - "search": { - "max_depth": [3, 6, 9], - "eta": list(np.linspace(0.1, 0.6, 6)), - "gamma": [int(x) for x in np.linspace(0, 10, 10)], - }, -} -random_forest_model = { - "model": RandomForestClassifier(), - "search": { - "n_estimators": [int(x) for x in np.linspace(start=200, stop=1000, num=10)], - "max_features": [42], - "max_depth": [int(x) for x in np.linspace(10, 110, num=11)], # .append(None), - "min_samples_split": [2, 5, 10], - "min_samples_leaf": [1, 2, 4], - "bootstrap": [True, False], - }, -} + x = self.processor.transform(x=x) + return self.clf.predict(x) + + def train(self, trial, output_path): + self.clf = RandomForestClassifier(**trial.params) + self.clf.fit(self.x_train, self.y_train) + + logger.info("Save trained model to a file.") + with open( + os.path.join(tempfile.gettempdir(), f"rf_{trial.number}.pickle"), "wb" + ) as fout: + pickle.dump(self.clf, fout) + + sha256sum = self.sha256sum( + os.path.join(tempfile.gettempdir(), f"rf_{trial.number}.pickle") + ) + with open(os.path.join(output_path, f"rf_{sha256sum}.pickle"), "wb") as fout: + pickle.dump(self.clf, fout) diff --git a/src/train/train.py b/src/train/train.py index 30dabf1..193b0bc 100644 --- a/src/train/train.py +++ b/src/train/train.py @@ -1,52 +1,58 @@ +import argparse import sys import os -import logging from enum import Enum, unique -import joblib +import click import numpy as np import polars as pl import torch -from fe_polars.encoding.target_encoding import TargetEncoder -from fe_polars.imputing.base_imputing import Imputer from sklearn.metrics import classification_report sys.path.append(os.getcwd()) -from src.train.dataset import Dataset, DatasetLoader -from src.train.feature import Preprocessor +from src.train.dataset import Dataset, DatasetLoader, Dataset +from src.train.feature import Processor from src.train.model import ( Pipeline, - random_forest_model, - xgboost_model, - xgboost_rf_model, ) +from src.base.log_config import get_logger + +logger = get_logger("train.train") @unique -class Dataset(str, Enum): +class DatasetEnum(str, Enum): ALL = "all" CIC = "cic" DGTA = "dgta" + DGARCHIVE = "dgarchive" @unique -class Model(str, Enum): +class ModelEnum(str, Enum): RANDOM_FOREST_CLASSIFIER = "rf" XG_BOOST_CLASSIFIER = "xg" - XG_BOOST_RANDOM_FOREST_CLASSIFIER = "xg-rf" -class DNSAnalyzerTraining: +class DetectorTraining: def __init__( - self, model: Model.RANDOM_FOREST_CLASSIFIER, dataset: Dataset = Dataset.ALL + self, + model: ModelEnum.RANDOM_FOREST_CLASSIFIER, + model_output_path: str = "./", + dataset: DatasetEnum = DatasetEnum.ALL, + data_base_path: str = "./data", + max_rows: int = -1, ) -> None: """Trainer class to fit models on data sets. Args: model (torch.nn.Module): Fit model. - dataset (heidgaf.datasets.Dataset): Data set for training. + dataset (src.train.datasets.Dataset): Data set for training. + data_base_path(src.train.train.DatasetEnum): """ - self.datasets = DatasetLoader() + logger.info("Get DatasetLoader.") + self.datasets = DatasetLoader(base_path=data_base_path, max_rows=max_rows) + self.model_output_path = model_output_path match dataset: case "all": self.dataset = Dataset( @@ -57,25 +63,20 @@ def __init__( self.datasets.cic_dataset.data, self.datasets.bambenek_dataset.data, self.datasets.dga_dataset.data, + self.datasets.dgarchive_dataset.data, ] ), + max_rows=max_rows, ) case "cic": - self.dataset = self.datasets.cic_dataset.data + self.dataset = self.datasets.cic_dataset case "dgta": - self.dataset = self.datasets.dgta_dataset.data + self.dataset = self.datasets.dgta_dataset + case "dgarchive": + self.dataset = self.datasets.dgarchive_data case _: raise NotImplementedError(f"Dataset not implemented!") - - match model: - case "rf": - self.model = random_forest_model - case "xg": - self.model = xgboost_model - case "xg-rf": - self.model = xgboost_rf_model - case _: - raise NotImplementedError(f"Model not implemented!") + self.model = model def train(self, seed=42, output_path: str = "model.pkl"): """Starts training of the model. Checks prior if GPU is available. @@ -87,11 +88,10 @@ def train(self, seed=42, output_path: str = "model.pkl"): np.random.seed(seed) torch.manual_seed(seed) - logging.info(f"Loading data sets") - # Training model + logger.info(f"Set up Pipeline.") model_pipeline = Pipeline( - preprocessor=Preprocessor( + processor=Processor( features_to_drop=[ "query", "labels", @@ -101,29 +101,68 @@ def train(self, seed=42, output_path: str = "model.pkl"): "tld", ] ), - mean_imputer=Imputer(features_to_impute=[], strategy="mean"), - target_encoder=TargetEncoder(smoothing=100, features_to_encode=[]), - clf=self.model, - ) - processor = Preprocessor( - features_to_drop=[ - "query", - "labels", - "thirdleveldomain", - "secondleveldomain", - "fqdn", - "tld", - ] + model=self.model, + dataset=self.dataset, + model_output_path=self.model_output_path, ) - data = processor.transform(self.dataset.data) - data.write_csv("full_data.csv") - model_pipeline.fit(x_train=self.dataset.X_train, y_train=self.dataset.Y_train) + logger.info("Fit model.") + model_pipeline.fit() + logger.info("Validate test set") y_pred = model_pipeline.predict(self.dataset.X_test) - logging.info(classification_report(self.dataset.Y_test, y_pred, labels=[0, 1])) + y_pred = [round(value) for value in y_pred] + logger.info(classification_report(self.dataset.Y_test, y_pred, labels=[0, 1])) + logger.info("Test validation test.") y_pred = model_pipeline.predict(self.dataset.X_val) - logging.info(classification_report(self.dataset.Y_val, y_pred, labels=[0, 1])) + y_pred = [round(value) for value in y_pred] + logger.info(classification_report(self.dataset.Y_val, y_pred, labels=[0, 1])) + - joblib.dump(model_pipeline.clf, output_path) +@click.command() +@click.option( + "-m", + "--model", + type=click.Choice(["xg", "rf"]), + help="Model to train, choose between XGBoost and RandomForest classifier", +) +@click.option( + "-ds", + "--dataset", + default="all", + type=click.Choice(["all", "dgarchive", "cic", "dgta"]), + help="Data set to train model, choose between all available datasets, DGArchive, CIC and DGTA.", +) +@click.option( + "-ds_path", + "--dataset_path", + type=click.Path(exists=True), + help="Dataset path, follow folder structure.", +) +@click.option( + "-ds_max_rows", + "--dataset_max_rows", + default=-1, + type=int, + help="Maximum rows to load from each dataset.", +) +@click.option( + "-m_output_path", + "--model_output_path", + type=click.Path(exists=True), + help="Model output path. Stores model with {{MODEL}}_{{SHA256}}.pickle.", +) +def main(model, dataset, dataset_path, dataset_max_rows, model_output_path): + trainer = DetectorTraining( + model=model, + dataset=dataset, + data_base_path=dataset_path, + max_rows=dataset_max_rows, + model_output_path=model_output_path, + ) + trainer.train() + + +if __name__ == "__main__": # pragma: no cover + main() diff --git a/tests/test_detector.py b/tests/test_detector.py index e56dda3..b3fb72b 100644 --- a/tests/test_detector.py +++ b/tests/test_detector.py @@ -1,6 +1,8 @@ +import os +import tempfile import unittest from datetime import datetime, timedelta -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, patch, mock_open from requests import HTTPError @@ -30,6 +32,30 @@ def test_sha256_not_existing_file(self, mock_kafka_consume_handler): sut._sha256sum("not_existing") +class TestFeatures(unittest.TestCase): + def setUp(self): + patcher = patch("src.detector.detector.logger") + self.mock_logger = patcher.start() + self.addCleanup(patcher.stop) + + @patch( + "src.detector.detector.CHECKSUM", + "ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06", + ) + @patch("src.detector.detector.MODEL", "rf") + @patch( + "src.detector.detector.MODEL_BASE_URL", + "https://heibox.uni-heidelberg.de/d/0d5cbcbe16cd46a58021/", + ) + @patch("src.detector.detector.KafkaConsumeHandler") + def test_get_model(self, mock_kafka_consume_handler): + mock_kafka_consume_handler_instance = MagicMock() + mock_kafka_consume_handler.return_value = mock_kafka_consume_handler_instance + + sut = Detector() + sut._get_features("google.de") + + class TestGetModel(unittest.TestCase): def setUp(self): patcher = patch("src.detector.detector.logger") @@ -38,9 +64,9 @@ def setUp(self): @patch( "src.detector.detector.CHECKSUM", - "21d1f40c9e186a08e9d2b400cea607f4163b39d187a9f9eca3da502b21cf3b9b", + "ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06", ) - @patch("src.detector.detector.MODEL", "xg") + @patch("src.detector.detector.MODEL", "rf") @patch( "src.detector.detector.MODEL_BASE_URL", "https://heibox.uni-heidelberg.de/d/0d5cbcbe16cd46a58021/", @@ -56,7 +82,7 @@ def test_get_model(self, mock_kafka_consume_handler): "src.detector.detector.CHECKSUM", "WRONG", ) - @patch("src.detector.detector.MODEL", "xg") + @patch("src.detector.detector.MODEL", "rf") @patch( "src.detector.detector.MODEL_BASE_URL", "https://heibox.uni-heidelberg.de/d/0d5cbcbe16cd46a58021/", @@ -71,7 +97,7 @@ def test_get_model_not_existing(self, mock_kafka_consume_handler): @patch( "src.detector.detector.CHECKSUM", - "21d1f40c9e186a08e9d2b400cea607f4163b39d187a9f9eca3da502b21cf3b9b", + "04970cd6fe0be5369248d24541c7b8faf69718706019f80280a0a687884f35fb", ) @patch("src.detector.detector.MODEL", "WRONG") @patch( @@ -193,6 +219,98 @@ def test_get_data_while_busy(self, mock_kafka_consume_handler, mock_logger): self.assertEqual([{"test": "test_message_2"}], sut.messages) +class TestSendWarning(unittest.TestCase): + def setUp(self): + patcher = patch("src.detector.detector.logger") + self.mock_logger = patcher.start() + self.addCleanup(patcher.stop) + + @patch( + "src.detector.detector.CHECKSUM", + "ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06", + ) + @patch("src.detector.detector.MODEL", "rf") + @patch( + "src.detector.detector.MODEL_BASE_URL", + "https://heibox.uni-heidelberg.de/d/0d5cbcbe16cd46a58021/", + ) + @patch("src.detector.detector.KafkaConsumeHandler") + def test_save_warning(self, mock_kafka_consume_handler): + mock_kafka_consume_handler_instance = MagicMock() + mock_kafka_consume_handler.return_value = mock_kafka_consume_handler_instance + + sut = Detector() + sut.warnings = [ + { + "request": "google.de", + "probability": 0.8765, + "model": "rf", + "sha256": "ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06", + }, + { + "request": "request.de", + "probability": 0.12388, + "model": "rf", + "sha256": "ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06", + }, + ] + open_mock = mock_open() + with patch("src.detector.detector.open", open_mock, create=True): + sut.send_warning() + + open_mock.assert_called_with( + os.path.join(tempfile.gettempdir(), "warnings.json"), "a+" + ) + + @patch( + "src.detector.detector.CHECKSUM", + "ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06", + ) + @patch("src.detector.detector.MODEL", "rf") + @patch( + "src.detector.detector.MODEL_BASE_URL", + "https://heibox.uni-heidelberg.de/d/0d5cbcbe16cd46a58021/", + ) + @patch("src.detector.detector.KafkaConsumeHandler") + def test_save_empty_warning(self, mock_kafka_consume_handler): + mock_kafka_consume_handler_instance = MagicMock() + mock_kafka_consume_handler.return_value = mock_kafka_consume_handler_instance + + sut = Detector() + sut.warnings = [] + open_mock = mock_open() + with patch("src.detector.detector.open", open_mock, create=True): + sut.send_warning() + + open_mock.assert_not_called() + + @patch( + "src.detector.detector.CHECKSUM", + "ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06", + ) + @patch("src.detector.detector.MODEL", "rf") + @patch( + "src.detector.detector.MODEL_BASE_URL", + "https://heibox.uni-heidelberg.de/d/0d5cbcbe16cd46a58021/", + ) + @patch("src.detector.detector.KafkaConsumeHandler") + def test_save_warning_error(self, mock_kafka_consume_handler): + mock_kafka_consume_handler_instance = MagicMock() + mock_kafka_consume_handler.return_value = mock_kafka_consume_handler_instance + + sut = Detector() + sut.warnings = [ + { + "request": "request.de", + "probability": "INVALID", + "model": "rf", + "sha256": "ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06", + } + ] + with self.assertRaises(Exception): + sut.send_warning() + + class TestClearData(unittest.TestCase): def setUp(self): patcher = patch("src.detector.detector.logger") diff --git a/tests/test_inspector.py b/tests/test_inspector.py index f559f53..37d49a7 100644 --- a/tests/test_inspector.py +++ b/tests/test_inspector.py @@ -1,10 +1,37 @@ import unittest from datetime import datetime, timedelta from unittest.mock import MagicMock, patch +import numpy as np +import json +from streamad.model import ZScoreDetector, RShashDetector from src.base import Batch from src.inspector.inspector import Inspector, main +DEFAULT_DATA = { + "client_ip": "192.168.0.167", + "dns_ip": "10.10.0.10", + "response_ip": "252.79.173.222", + "timestamp": "", + "status": "NXDOMAIN", + "host_domain_name": "24sata.info", + "record_type": "A", + "size": "100b", +} + +TIMESTAMP_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ" + + +def get_batch(data): + begin = datetime.now() + end = begin + timedelta(0, 3) + test_batch = Batch( + begin_timestamp=begin, + end_timestamp=end, + data=data if data != None else [], + ) + return test_batch + class TestInit(unittest.TestCase): @patch("src.inspector.inspector.KafkaProduceHandler") @@ -29,16 +56,11 @@ class TestGetData(unittest.TestCase): def test_get_data_without_return_data( self, mock_kafka_consume_handler, mock_produce_handler, mock_logger ): - test_batch = Batch( - begin_timestamp=datetime.now(), - end_timestamp=datetime.now() + timedelta(0, 3), - data=[], - ) mock_kafka_consume_handler_instance = MagicMock() mock_kafka_consume_handler.return_value = mock_kafka_consume_handler_instance mock_kafka_consume_handler_instance.consume_and_return_object.return_value = ( None, - test_batch, + get_batch(None), ) mock_produce_handler_instance = MagicMock() mock_produce_handler.return_value = mock_produce_handler_instance @@ -54,13 +76,7 @@ def test_get_data_without_return_data( def test_get_data_with_return_data( self, mock_kafka_consume_handler, mock_produce_handler, mock_logger ): - begin = datetime.now() - end = begin + timedelta(0, 3) - test_batch = Batch( - begin_timestamp=begin, - end_timestamp=end, - data=[{"test": "test_message_1"}, {"test": "test_message_2"}], - ) + test_batch = get_batch([{"test": "test_message_1"}, {"test": "test_message_2"}]) mock_kafka_consume_handler_instance = MagicMock() mock_kafka_consume_handler.return_value = mock_kafka_consume_handler_instance mock_kafka_consume_handler_instance.consume_and_return_object.return_value = ( @@ -74,12 +90,37 @@ def test_get_data_with_return_data( sut.messages = [] sut.get_and_fill_data() - self.assertEqual(begin, sut.begin_timestamp) - self.assertEqual(end, sut.end_timestamp) + self.assertEqual(test_batch.begin_timestamp, sut.begin_timestamp) + self.assertEqual(test_batch.end_timestamp, sut.end_timestamp) self.assertEqual( [{"test": "test_message_1"}, {"test": "test_message_2"}], sut.messages ) + @patch("src.inspector.inspector.logger") + @patch("src.inspector.inspector.KafkaProduceHandler") + @patch("src.inspector.inspector.KafkaConsumeHandler") + def test_get_data_with_no_return_data( + self, mock_kafka_consume_handler, mock_produce_handler, mock_logger + ): + begin = None + end = None + mock_kafka_consume_handler_instance = MagicMock() + mock_kafka_consume_handler.return_value = mock_kafka_consume_handler_instance + mock_kafka_consume_handler_instance.consume_and_return_object.return_value = ( + None, + None, + ) + mock_produce_handler_instance = MagicMock() + mock_produce_handler.return_value = mock_produce_handler_instance + + sut = Inspector() + sut.messages = [] + sut.get_and_fill_data() + + self.assertEqual(begin, sut.begin_timestamp) + self.assertEqual(end, sut.end_timestamp) + self.assertEqual([], sut.messages) + @patch("src.inspector.inspector.KafkaProduceHandler") @patch("src.inspector.inspector.KafkaConsumeHandler") def test_get_data_while_busy( @@ -166,22 +207,16 @@ def test_count_errors(self, mock_kafka_consume_handler, mock_produce_handler): mock_produce_handler.return_value = mock_produce_handler_instance sut = Inspector() - begin_timestamp = "2024-07-02T12:52:45.000Z" - end_timestamp = "2024-07-02T12:52:55.000Z" - messages = [ - { - "client_ip": "192.168.0.167", - "dns_ip": "10.10.0.10", - "response_ip": "252.79.173.222", - "timestamp": "2024-07-02T12:52:50.988Z", - "status": "NXDOMAIN", - "host_domain_name": "24sata.info", - "record_type": "A", - "size": "111b", - }, - ] - self.assertIsNotNone( - sut._count_errors(messages, begin_timestamp, end_timestamp) + begin_timestamp = datetime.now() + end_timestamp = datetime.now() + timedelta(0, 0, 2) + data = DEFAULT_DATA + data["timestamp"] = datetime.strftime( + begin_timestamp + timedelta(0, 0, 1), TIMESTAMP_FORMAT + ) + messages = [data] + np.testing.assert_array_equal( + np.asarray([[1.0], [0.0]]), + sut._count_errors(messages, begin_timestamp, end_timestamp), ) @patch("src.inspector.inspector.KafkaProduceHandler") @@ -195,22 +230,16 @@ def test_mean_packet_size(self, mock_kafka_consume_handler, mock_produce_handler mock_produce_handler.return_value = mock_produce_handler_instance sut = Inspector() - begin_timestamp = "2024-07-02T12:52:45.000Z" - end_timestamp = "2024-07-02T12:52:55.000Z" - messages = [ - { - "client_ip": "192.168.0.167", - "dns_ip": "10.10.0.10", - "response_ip": "252.79.173.222", - "timestamp": "2024-07-02T12:52:50.988Z", - "status": "NXDOMAIN", - "host_domain_name": "24sata.info", - "record_type": "A", - "size": "111b", - }, - ] - self.assertIsNotNone( - sut._mean_packet_size(messages, begin_timestamp, end_timestamp) + begin_timestamp = datetime.now() + end_timestamp = datetime.now() + timedelta(0, 0, 2) + data = DEFAULT_DATA + data["timestamp"] = datetime.strftime( + begin_timestamp + timedelta(0, 0, 1), TIMESTAMP_FORMAT + ) + messages = [data] + np.testing.assert_array_equal( + np.asarray([[100], [0.0]]), + sut._mean_packet_size(messages, begin_timestamp, end_timestamp), ) @patch("src.inspector.inspector.KafkaProduceHandler") @@ -224,11 +253,15 @@ def test_count_errors_empty_messages( mock_produce_handler.return_value = mock_produce_handler_instance sut = Inspector() - begin_timestamp = "2024-07-02T12:52:45.000Z" - end_timestamp = "2024-07-02T12:52:55.000Z" - messages = [] - self.assertIsNotNone( - sut._count_errors(messages, begin_timestamp, end_timestamp) + begin_timestamp = datetime.now() + end_timestamp = datetime.now() + timedelta(0, 0, 2) + data = DEFAULT_DATA + data["timestamp"] = datetime.strftime( + begin_timestamp + timedelta(0, 0, 1), TIMESTAMP_FORMAT + ) + np.testing.assert_array_equal( + np.asarray([[0.0], [0.0]]), + sut._count_errors([], begin_timestamp, end_timestamp), ) @patch("src.inspector.inspector.KafkaProduceHandler") @@ -242,15 +275,65 @@ def test_mean_packet_size_empty_messages( mock_produce_handler.return_value = mock_produce_handler_instance sut = Inspector() - begin_timestamp = "2024-07-02T12:52:45.000Z" - end_timestamp = "2024-07-02T12:52:55.000Z" - messages = [] - self.assertIsNotNone( - sut._mean_packet_size(messages, begin_timestamp, end_timestamp) + begin_timestamp = datetime.now() + end_timestamp = begin_timestamp + timedelta(0, 0, 2) + data = DEFAULT_DATA + data["timestamp"] = datetime.strftime( + begin_timestamp + timedelta(0, 0, 1), TIMESTAMP_FORMAT + ) + np.testing.assert_array_equal( + np.asarray([[0.0], [0.0]]), + sut._mean_packet_size([], begin_timestamp, end_timestamp), ) class TestInspectFunction(unittest.TestCase): + @patch("src.inspector.inspector.logger") + @patch("src.inspector.inspector.KafkaProduceHandler") + @patch("src.inspector.inspector.KafkaConsumeHandler") + @patch( + "src.inspector.inspector.MODELS", + None, + ) + def test_inspect_none_models( + self, mock_kafka_consume_handler, mock_produce_handler, mock_logger + ): + mock_kafka_consume_handler_instance = MagicMock() + mock_kafka_consume_handler.return_value = mock_kafka_consume_handler_instance + mock_kafka_consume_handler_instance.consume_and_return_object.return_value = ( + "test", + get_batch(None), + ) + mock_produce_handler_instance = MagicMock() + mock_produce_handler.return_value = mock_produce_handler_instance + + sut = Inspector() + with self.assertRaises(NotImplementedError): + sut.inspect() + + @patch("src.inspector.inspector.logger") + @patch("src.inspector.inspector.KafkaProduceHandler") + @patch("src.inspector.inspector.KafkaConsumeHandler") + @patch( + "src.inspector.inspector.MODELS", + "", + ) + def test_inspect_empy_models( + self, mock_kafka_consume_handler, mock_produce_handler, mock_logger + ): + mock_kafka_consume_handler_instance = MagicMock() + mock_kafka_consume_handler.return_value = mock_kafka_consume_handler_instance + mock_kafka_consume_handler_instance.consume_and_return_object.return_value = ( + "test", + get_batch(None), + ) + mock_produce_handler_instance = MagicMock() + mock_produce_handler.return_value = mock_produce_handler_instance + + sut = Inspector() + with self.assertRaises(NotImplementedError): + sut.inspect() + @patch("src.inspector.inspector.logger") @patch("src.inspector.inspector.KafkaProduceHandler") @patch("src.inspector.inspector.KafkaConsumeHandler") @@ -258,14 +341,96 @@ class TestInspectFunction(unittest.TestCase): "src.inspector.inspector.MODELS", [{"model": "ZScoreDetector", "module": "streamad.model", "model_args": {}}], ) + @patch("src.inspector.inspector.TIME_TYPE", "ms") + @patch("src.inspector.inspector.TIME_RANGE", 1) + def test_inspect_univariate( + self, mock_kafka_consume_handler, mock_produce_handler, mock_logger + ): + test_batch = get_batch(None) + test_batch.begin_timestamp = datetime.now() + test_batch.end_timestamp = datetime.now() + timedelta(0, 0, 2) + data = DEFAULT_DATA + data["timestamp"] = datetime.strftime( + test_batch.begin_timestamp + timedelta(0, 0, 1), TIMESTAMP_FORMAT + ) + test_batch.data = [data] + mock_kafka_consume_handler_instance = MagicMock() + mock_kafka_consume_handler.return_value = mock_kafka_consume_handler_instance + mock_kafka_consume_handler_instance.consume_and_return_object.return_value = ( + "test", + test_batch, + ) + mock_produce_handler_instance = MagicMock() + mock_produce_handler.return_value = mock_produce_handler_instance + + sut = Inspector() + sut.get_and_fill_data() + sut.inspect() + self.assertEqual([0, 0], sut.anomalies) + + @patch("src.inspector.inspector.logger") + @patch("src.inspector.inspector.KafkaProduceHandler") + @patch("src.inspector.inspector.KafkaConsumeHandler") + @patch( + "src.inspector.inspector.MODELS", + [ + { + "model": "ZScoreDetector", + "module": "streamad.model", + "model_args": {"window_len": 10}, + } + ], + ) + @patch("src.inspector.inspector.TIME_TYPE", "ms") + @patch("src.inspector.inspector.TIME_RANGE", 1) def test_inspect_univariate( self, mock_kafka_consume_handler, mock_produce_handler, mock_logger ): - test_batch = Batch( - begin_timestamp=datetime.now(), - end_timestamp=datetime.now() + timedelta(0, 3), - data=[], + test_batch = get_batch(None) + test_batch.begin_timestamp = datetime.now() + test_batch.end_timestamp = datetime.now() + timedelta(0, 0, 2) + data = DEFAULT_DATA + data["timestamp"] = datetime.strftime( + test_batch.begin_timestamp + timedelta(0, 0, 1), TIMESTAMP_FORMAT + ) + test_batch.data = [data] + mock_kafka_consume_handler_instance = MagicMock() + mock_kafka_consume_handler.return_value = mock_kafka_consume_handler_instance + mock_kafka_consume_handler_instance.consume_and_return_object.return_value = ( + "test", + test_batch, + ) + mock_produce_handler_instance = MagicMock() + mock_produce_handler.return_value = mock_produce_handler_instance + + sut = Inspector() + sut.get_and_fill_data() + sut.inspect() + self.assertNotEqual([None, None], sut.anomalies) + + @patch("src.inspector.inspector.logger") + @patch("src.inspector.inspector.KafkaProduceHandler") + @patch("src.inspector.inspector.KafkaConsumeHandler") + @patch( + "src.inspector.inspector.MODELS", + [ + {"model": "ZScoreDetector", "module": "streamad.model", "model_args": {}}, + {"model": "KNNDetector", "module": "streamad.model", "model_args": {}}, + ], + ) + @patch("src.inspector.inspector.TIME_TYPE", "ms") + @patch("src.inspector.inspector.TIME_RANGE", 1) + def test_inspect_univariate_two_models( + self, mock_kafka_consume_handler, mock_produce_handler, mock_logger + ): + test_batch = get_batch(None) + test_batch.begin_timestamp = datetime.now() + test_batch.end_timestamp = datetime.now() + timedelta(0, 0, 2) + data = DEFAULT_DATA + data["timestamp"] = datetime.strftime( + test_batch.begin_timestamp + timedelta(0, 0, 1), TIMESTAMP_FORMAT ) + test_batch.data = [data] mock_kafka_consume_handler_instance = MagicMock() mock_kafka_consume_handler.return_value = mock_kafka_consume_handler_instance mock_kafka_consume_handler_instance.consume_and_return_object.return_value = ( @@ -278,7 +443,8 @@ def test_inspect_univariate( sut = Inspector() sut.get_and_fill_data() sut.inspect() - self.assertIsNotNone(sut.anomalies) + self.assertEqual([0, 0], sut.anomalies) + self.assertTrue(isinstance(sut.model, ZScoreDetector)) @patch("src.inspector.inspector.logger") @patch("src.inspector.inspector.KafkaProduceHandler") @@ -291,11 +457,14 @@ def test_inspect_univariate( def test_inspect_multivariate( self, mock_kafka_consume_handler, mock_produce_handler, mock_logger ): - test_batch = Batch( - begin_timestamp=datetime.now(), - end_timestamp=datetime.now() + timedelta(0, 3), - data=[], + test_batch = get_batch(None) + test_batch.begin_timestamp = datetime.now() + test_batch.end_timestamp = datetime.now() + timedelta(0, 0, 2) + data = DEFAULT_DATA + data["timestamp"] = datetime.strftime( + test_batch.begin_timestamp + timedelta(0, 0, 1), TIMESTAMP_FORMAT ) + test_batch.data = [data] mock_kafka_consume_handler_instance = MagicMock() mock_kafka_consume_handler.return_value = mock_kafka_consume_handler_instance mock_kafka_consume_handler_instance.consume_and_return_object.return_value = ( @@ -308,7 +477,223 @@ def test_inspect_multivariate( sut = Inspector() sut.get_and_fill_data() sut.inspect() - self.assertIsNotNone(sut.anomalies) + self.assertEqual([0, 0], sut.anomalies) + + @patch("src.inspector.inspector.logger") + @patch("src.inspector.inspector.KafkaProduceHandler") + @patch("src.inspector.inspector.KafkaConsumeHandler") + @patch( + "src.inspector.inspector.MODELS", + [ + { + "model": "RShashDetector", + "module": "streamad.model", + "model_args": {"window_len": 10}, + } + ], + ) + @patch("src.inspector.inspector.MODE", "multivariate") + def test_inspect_multivariate_window_len( + self, mock_kafka_consume_handler, mock_produce_handler, mock_logger + ): + test_batch = get_batch(None) + test_batch.begin_timestamp = datetime.now() + test_batch.end_timestamp = datetime.now() + timedelta(0, 0, 2) + data = DEFAULT_DATA + data["timestamp"] = datetime.strftime( + test_batch.begin_timestamp + timedelta(0, 0, 1), TIMESTAMP_FORMAT + ) + test_batch.data = [data] + mock_kafka_consume_handler_instance = MagicMock() + mock_kafka_consume_handler.return_value = mock_kafka_consume_handler_instance + mock_kafka_consume_handler_instance.consume_and_return_object.return_value = ( + "test", + test_batch, + ) + mock_produce_handler_instance = MagicMock() + mock_produce_handler.return_value = mock_produce_handler_instance + + sut = Inspector() + sut.get_and_fill_data() + sut.inspect() + self.assertNotEqual([None, None], sut.anomalies) + + @patch("src.inspector.inspector.logger") + @patch("src.inspector.inspector.KafkaProduceHandler") + @patch("src.inspector.inspector.KafkaConsumeHandler") + @patch( + "src.inspector.inspector.MODELS", + [ + {"model": "RShashDetector", "module": "streamad.model", "model_args": {}}, + {"model": "xStreamDetector", "module": "streamad.model", "model_args": {}}, + ], + ) + @patch("src.inspector.inspector.MODE", "multivariate") + def test_inspect_multivariate_two_models( + self, mock_kafka_consume_handler, mock_produce_handler, mock_logger + ): + test_batch = get_batch(None) + test_batch.begin_timestamp = datetime.now() + test_batch.end_timestamp = datetime.now() + timedelta(0, 0, 2) + data = DEFAULT_DATA + data["timestamp"] = datetime.strftime( + test_batch.begin_timestamp + timedelta(0, 0, 1), TIMESTAMP_FORMAT + ) + test_batch.data = [data] + mock_kafka_consume_handler_instance = MagicMock() + mock_kafka_consume_handler.return_value = mock_kafka_consume_handler_instance + mock_kafka_consume_handler_instance.consume_and_return_object.return_value = ( + "test", + test_batch, + ) + mock_produce_handler_instance = MagicMock() + mock_produce_handler.return_value = mock_produce_handler_instance + + sut = Inspector() + sut.get_and_fill_data() + sut.inspect() + self.assertEqual([0, 0], sut.anomalies) + self.assertTrue(isinstance(sut.model, RShashDetector)) + + @patch("src.inspector.inspector.logger") + @patch("src.inspector.inspector.KafkaProduceHandler") + @patch("src.inspector.inspector.KafkaConsumeHandler") + @patch( + "src.inspector.inspector.MODELS", + [ + {"model": "KNNDetector", "module": "streamad.model", "model_args": {}}, + {"model": "SpotDetector", "module": "streamad.model", "model_args": {}}, + ], + ) + @patch( + "src.inspector.inspector.ENSEMBLE", + { + "model": "WeightEnsemble", + "module": "streamad.process", + "model_args": {"ensemble_weights": [0.6, 0.4]}, + }, + ) + @patch("src.inspector.inspector.MODE", "ensemble") + def test_inspect_ensemble( + self, mock_kafka_consume_handler, mock_produce_handler, mock_logger + ): + test_batch = get_batch(None) + test_batch.begin_timestamp = datetime.now() + test_batch.end_timestamp = datetime.now() + timedelta(0, 0, 2) + data = DEFAULT_DATA + data["timestamp"] = datetime.strftime( + test_batch.begin_timestamp + timedelta(0, 0, 1), TIMESTAMP_FORMAT + ) + test_batch.data = [data] + mock_kafka_consume_handler_instance = MagicMock() + mock_kafka_consume_handler.return_value = mock_kafka_consume_handler_instance + mock_kafka_consume_handler_instance.consume_and_return_object.return_value = ( + "test", + test_batch, + ) + mock_produce_handler_instance = MagicMock() + mock_produce_handler.return_value = mock_produce_handler_instance + + sut = Inspector() + sut.get_and_fill_data() + sut.inspect() + self.assertEqual([0, 0], sut.anomalies) + + @patch("src.inspector.inspector.logger") + @patch("src.inspector.inspector.KafkaProduceHandler") + @patch("src.inspector.inspector.KafkaConsumeHandler") + @patch( + "src.inspector.inspector.MODELS", + [ + { + "model": "KNNDetector", + "module": "streamad.model", + "model_args": {"window_len": 10}, + }, + { + "model": "SpotDetector", + "module": "streamad.model", + "model_args": {"window_len": 10}, + }, + ], + ) + @patch( + "src.inspector.inspector.ENSEMBLE", + { + "model": "WeightEnsemble", + "module": "streamad.process", + "model_args": {"ensemble_weights": [0.6, 0.4]}, + }, + ) + @patch("src.inspector.inspector.MODE", "ensemble") + def test_inspect_ensemble_window_len( + self, mock_kafka_consume_handler, mock_produce_handler, mock_logger + ): + test_batch = get_batch(None) + test_batch.begin_timestamp = datetime.now() + test_batch.end_timestamp = datetime.now() + timedelta(0, 0, 2) + data = DEFAULT_DATA + data["timestamp"] = datetime.strftime( + test_batch.begin_timestamp + timedelta(0, 0, 1), TIMESTAMP_FORMAT + ) + test_batch.data = [data] + mock_kafka_consume_handler_instance = MagicMock() + mock_kafka_consume_handler.return_value = mock_kafka_consume_handler_instance + mock_kafka_consume_handler_instance.consume_and_return_object.return_value = ( + "test", + test_batch, + ) + mock_produce_handler_instance = MagicMock() + mock_produce_handler.return_value = mock_produce_handler_instance + + sut = Inspector() + sut.get_and_fill_data() + sut.inspect() + self.assertNotEqual([None, None], sut.anomalies) + + @patch("src.inspector.inspector.logger") + @patch("src.inspector.inspector.KafkaProduceHandler") + @patch("src.inspector.inspector.KafkaConsumeHandler") + @patch( + "src.inspector.inspector.MODELS", + [ + {"model": "RShashDetector", "module": "streamad.model", "model_args": {}}, + {"model": "SpotDetector", "module": "streamad.model", "model_args": {}}, + ], + ) + @patch( + "src.inspector.inspector.ENSEMBLE", + { + "model": "WeightEnsemble", + "module": "streamad.process", + "model_args": {"ensemble_weights": [0.6, 0.4]}, + }, + ) + @patch("src.inspector.inspector.MODE", "ensemble") + def test_inspect_ensemble_invalid( + self, mock_kafka_consume_handler, mock_produce_handler, mock_logger + ): + test_batch = get_batch(None) + test_batch.begin_timestamp = datetime.now() + test_batch.end_timestamp = datetime.now() + timedelta(0, 0, 2) + data = DEFAULT_DATA + data["timestamp"] = datetime.strftime( + test_batch.begin_timestamp + timedelta(0, 0, 1), TIMESTAMP_FORMAT + ) + test_batch.data = [data] + mock_kafka_consume_handler_instance = MagicMock() + mock_kafka_consume_handler.return_value = mock_kafka_consume_handler_instance + mock_kafka_consume_handler_instance.consume_and_return_object.return_value = ( + "test", + test_batch, + ) + mock_produce_handler_instance = MagicMock() + mock_produce_handler.return_value = mock_produce_handler_instance + + sut = Inspector() + sut.get_and_fill_data() + with self.assertRaises(NotImplementedError): + sut.inspect() @patch("src.inspector.inspector.logger") @patch("src.inspector.inspector.KafkaProduceHandler") @@ -383,6 +768,42 @@ def test_invalid_mode(self, mock_kafka_consume_handler, mock_produce_handler): sut.inspect() +class TestSend(unittest.TestCase): + @patch("src.inspector.inspector.KafkaProduceHandler") + @patch("src.inspector.inspector.KafkaConsumeHandler") + @patch("src.inspector.inspector.SCORE_THRESHOLD", 0.1) + @patch("src.inspector.inspector.ANOMALY_THRESHOLD", 0.01) + def test_send(self, mock_kafka_consume_handler, mock_produce_handler): + mock_kafka_consume_handler_instance = MagicMock() + mock_kafka_consume_handler.return_value = mock_kafka_consume_handler_instance + mock_produce_handler_instance = MagicMock() + mock_produce_handler.return_value = mock_produce_handler_instance + + sut = Inspector() + sut.anomalies = [0.9, 0.9] + sut.X = np.array([[0.0], [0.0]]) + sut.begin_timestamp = datetime.now() + sut.end_timestamp = datetime.now() + timedelta(0, 0, 2) + data = DEFAULT_DATA + data["timestamp"] = datetime.strftime( + sut.begin_timestamp + timedelta(0, 0, 1), TIMESTAMP_FORMAT + ) + sut.messages = [data] + sut.send_data() + + mock_produce_handler_instance.send.assert_called_once_with( + topic="Detector", + data=json.dumps( + { + "begin_timestamp": sut.begin_timestamp.strftime(TIMESTAMP_FORMAT), + "end_timestamp": sut.end_timestamp.strftime(TIMESTAMP_FORMAT), + "data": [data], + } + ), + key="192.168.0.167", + ) + + class TestMainFunction(unittest.TestCase): @patch("src.inspector.inspector.logger") @patch("src.inspector.inspector.Inspector") diff --git a/tests/test_kafka_consume_handler.py b/tests/test_kafka_consume_handler.py index c4d69a6..a94ddcc 100644 --- a/tests/test_kafka_consume_handler.py +++ b/tests/test_kafka_consume_handler.py @@ -157,5 +157,33 @@ def test_del_with_existing_consumer(self, mock_consumer): mock_consumer_instance.close.assert_not_called() +class TestDict(unittest.TestCase): + @patch("src.base.kafka_handler.CONSUMER_GROUP_ID", "test_group_id") + @patch( + "src.base.kafka_handler.KAFKA_BROKERS", + [ + { + "hostname": "127.0.0.1", + "port": 9999, + }, + { + "hostname": "127.0.0.2", + "port": 9998, + }, + { + "hostname": "127.0.0.3", + "port": 9997, + }, + ], + ) + @patch("src.base.kafka_handler.Consumer") + def test_dict(self, mock_consumer): + mock_consumer_instance = MagicMock() + mock_consumer.return_value = mock_consumer_instance + + sut = KafkaConsumeHandler(topic="test_topic") + self.assertTrue(sut._is_dicts([{}, {}])) + + if __name__ == "__main__": unittest.main()