Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ The full list of configuration parameters is available at the [documentation](ht
| `pipeline.data_inspection.inspector.models` | List of models to use for data inspection (e.g., anomaly detection). | Array of model definitions (e.g., `{"model": "ZScoreDetector", "module": "streamad.model", "model_args": {"is_global": false}}`)|
| `pipeline.data_inspection.inspector.anomaly_threshold` | Threshold for classifying an observation as an anomaly. | `0.01` |
| `pipeline.data_analysis.detector.model` | Model to use for data analysis (e.g., DGA detection). | `rf` (Random Forest) option: `XGBoost` |
| `pipeline.data_analysis.detector.checksum` | Checksum for the model file to ensure integrity. | `ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06` |
| `pipeline.data_analysis.detector.checksum` | Checksum for the model file to ensure integrity. | `021af76b2385ddbc76f6e3ad10feb0bb081f9cf05cff2e52333e31040bbf36cc` |
| `pipeline.data_analysis.detector.base_url` | Base URL for downloading the model if not present locally. | `https://heibox.uni-heidelberg.de/d/0d5cbcbe16cd46a58021/` |

<p align="right">(<a href="#readme-top">back to top</a>)</p>
Expand Down
2 changes: 1 addition & 1 deletion config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ pipeline:
data_analysis:
detector:
model: rf # XGBoost
checksum: ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06
checksum: 021af76b2385ddbc76f6e3ad10feb0bb081f9cf05cff2e52333e31040bbf36cc
base_url: https://heibox.uni-heidelberg.de/d/0d5cbcbe16cd46a58021/
threshold: 0.5

Expand Down
1 change: 1 addition & 0 deletions requirements/requirements.inspector.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
PyYAML~=6.0.1
confluent-kafka~=2.4.0
colorlog~=6.8.2
scipy~=1.12.0
streamad~=0.3.1
numpy~=1.26.4
marshmallow_dataclass~=8.7.1
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.train.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
numpy
xgboost
scikit-learn~=1.5.2
scipy
scipy~=1.12.0
torch
pyarrow
polars
Expand Down
49 changes: 28 additions & 21 deletions src/detector/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,15 @@ def __init__(self) -> None:
self.begin_timestamp = None
self.end_timestamp = None
self.model_path = os.path.join(
tempfile.gettempdir(), f"{MODEL}_{CHECKSUM}.pickle"
tempfile.gettempdir(), f"{MODEL}_{CHECKSUM}_model.pickle"
)
self.scaler_path = os.path.join(
tempfile.gettempdir(), f"{MODEL}_{CHECKSUM}_scaler.pickle"
)

self.kafka_consume_handler = ExactlyOnceKafkaConsumeHandler(CONSUME_TOPIC)

self.model = self._get_model()
self.model, self.scaler = self._get_model()

# databases
self.suspicious_batch_timestamps = ClickHouseKafkaSender(
Expand Down Expand Up @@ -159,14 +162,28 @@ def _get_model(self):
logger.info(f"Get model: {MODEL} with checksum {CHECKSUM}")
if not os.path.isfile(self.model_path):
response = requests.get(
f"{MODEL_BASE_URL}/files/?p=%2F{MODEL}_{CHECKSUM}.pickle&dl=1"
f"{MODEL_BASE_URL}/files/?p=%2F{MODEL}/{CHECKSUM}/{MODEL}.pickle&dl=1"
)
logger.info(
f"{MODEL_BASE_URL}/files/?p=%2F{MODEL}/{CHECKSUM}/{MODEL}.pickle&dl=1"
)
logger.info(f"{MODEL_BASE_URL}/files/?p=%2F{MODEL}_{CHECKSUM}.pickle&dl=1")
response.raise_for_status()

with open(self.model_path, "wb") as f:
f.write(response.content)

if not os.path.isfile(self.scaler_path):
response = requests.get(
f"{MODEL_BASE_URL}/files/?p=%2F{MODEL}/{CHECKSUM}/scaler.pickle&dl=1"
)
logger.info(
f"{MODEL_BASE_URL}/files/?p=%2F{MODEL}/{CHECKSUM}/scaler.pickle&dl=1"
)
response.raise_for_status()

with open(self.scaler_path, "wb") as f:
f.write(response.content)

# Check file sha256
local_checksum = self._sha256sum(self.model_path)

Expand All @@ -181,7 +198,10 @@ def _get_model(self):
with open(self.model_path, "rb") as input_file:
clf = pickle.load(input_file)

return clf
with open(self.scaler_path, "rb") as input_file:
scaler = pickle.load(input_file)

return clf, scaler

def clear_data(self):
"""Clears the data in the internal data structures."""
Expand Down Expand Up @@ -237,12 +257,6 @@ def calculate_counts(level: str) -> np.ndarray:
for level, level_value in levels.items()
}

logger.debug("Get frequency standard deviation, median, variance, and mean.")
freq_std = np.std(freq)
freq_var = np.var(freq)
freq_median = np.median(freq)
freq_mean = np.mean(freq)

logger.debug(
"Get standard deviation, median, variance, and mean for full, alpha, special, and numeric count."
)
Expand All @@ -268,16 +282,9 @@ def calculate_entropy(s: str) -> float:

# Final feature aggregation as a NumPy array
basic_features = np.array([label_length, label_max, label_average])
freq_features = np.array([freq_std, freq_var, freq_median, freq_mean])

# Flatten counts and stats for each level into arrays
level_features = np.hstack([counts[level] for level in levels.keys()])
stats_features = np.array(
[stats[f"{level}_std"] for level in levels.keys()]
+ [stats[f"{level}_var"] for level in levels.keys()]
+ [stats[f"{level}_median"] for level in levels.keys()]
+ [stats[f"{level}_mean"] for level in levels.keys()]
)

# Entropy features
entropy_features = np.array([entropy[level] for level in levels.keys()])
Expand All @@ -287,9 +294,9 @@ def calculate_entropy(s: str) -> float:
[
basic_features,
freq,
freq_features,
# freq_features,
level_features,
stats_features,
# stats_features,
entropy_features,
]
)
Expand All @@ -304,7 +311,7 @@ def detect(self) -> None: # pragma: no cover
for message in self.messages:
# TODO predict all messages
y_pred = self.model.predict_proba(
self._get_features(message["domain_name"])
self.scaler.transform(self._get_features(message["domain_name"]))
)
logger.info(f"Prediction: {y_pred}")
if np.argmax(y_pred, axis=1) == 1 and y_pred[0][1] > THRESHOLD:
Expand Down
47 changes: 39 additions & 8 deletions tests/detector/test_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,17 @@
from src.base.data_classes.batch import Batch
from src.detector.detector import Detector, WrongChecksum

DEFAULT_DATA = {
"client_ip": "192.168.0.167",
"dns_ip": "10.10.0.10",
"response_ip": "252.79.173.222",
"timestamp": "",
"status": "NXDOMAIN",
"domain_name": "IF356gEnJHPdRxnkDId4RDUSgtqxx9I+pZ5n1V53MdghOGQncZWAQgAPRx3kswi.750jnH6iSqmiAAeyDUMX0W6SHGpVsVsKSX8ZkKYDs0GFh/9qU5N9cwl00XSD8ID.NNhBdHZIb7nc0hDQXFPlABDLbRwkJS38LZ8RMX4yUmR2Mb6YqTTJBn+nUcB9P+v.jBQdwdS53XV9W2p1BHjh.16.f.1.6037.tunnel.example.org",
"record_type": "A",
"size": "100b",
}


class TestSha256Sum(unittest.TestCase):
@patch("src.detector.detector.ExactlyOnceKafkaConsumeHandler")
Expand Down Expand Up @@ -45,7 +56,7 @@ def setUp(self):

@patch(
"src.detector.detector.CHECKSUM",
"ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06",
"021af76b2385ddbc76f6e3ad10feb0bb081f9cf05cff2e52333e31040bbf36cc",
)
@patch("src.detector.detector.MODEL", "rf")
@patch(
Expand All @@ -70,7 +81,7 @@ def setUp(self):

@patch(
"src.detector.detector.CHECKSUM",
"ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06",
"021af76b2385ddbc76f6e3ad10feb0bb081f9cf05cff2e52333e31040bbf36cc",
)
@patch("src.detector.detector.MODEL", "rf")
@patch(
Expand Down Expand Up @@ -246,7 +257,7 @@ def setUp(self):

@patch(
"src.detector.detector.CHECKSUM",
"ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06",
"021af76b2385ddbc76f6e3ad10feb0bb081f9cf05cff2e52333e31040bbf36cc",
)
@patch("src.detector.detector.MODEL", "rf")
@patch(
Expand All @@ -265,13 +276,13 @@ def test_save_warning(self, mock_clickhouse, mock_kafka_consume_handler):
"request": "google.de",
"probability": 0.8765,
"model": "rf",
"sha256": "ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06",
"sha256": "021af76b2385ddbc76f6e3ad10feb0bb081f9cf05cff2e52333e31040bbf36cc",
},
{
"request": "request.de",
"probability": 0.12388,
"model": "rf",
"sha256": "ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06",
"sha256": "021af76b2385ddbc76f6e3ad10feb0bb081f9cf05cff2e52333e31040bbf36cc",
},
]
sut.messages = [{"logline_id": "test_id"}]
Expand All @@ -285,7 +296,27 @@ def test_save_warning(self, mock_clickhouse, mock_kafka_consume_handler):

@patch(
"src.detector.detector.CHECKSUM",
"ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06",
"021af76b2385ddbc76f6e3ad10feb0bb081f9cf05cff2e52333e31040bbf36cc",
)
@patch("src.detector.detector.MODEL", "rf")
@patch(
"src.detector.detector.MODEL_BASE_URL",
"https://heibox.uni-heidelberg.de/d/0d5cbcbe16cd46a58021/",
)
@patch("src.detector.detector.ExactlyOnceKafkaConsumeHandler")
@patch("src.detector.detector.ClickHouseKafkaSender")
def test_prediction(self, mock_clickhouse, mock_kafka_consume_handler):
mock_kafka_consume_handler_instance = MagicMock()
mock_kafka_consume_handler.return_value = mock_kafka_consume_handler_instance

sut = Detector()
sut.messages = [DEFAULT_DATA]
sut.detect()
self.assertNotEqual([], sut.warnings)

@patch(
"src.detector.detector.CHECKSUM",
"021af76b2385ddbc76f6e3ad10feb0bb081f9cf05cff2e52333e31040bbf36cc",
)
@patch("src.detector.detector.MODEL", "rf")
@patch(
Expand All @@ -309,7 +340,7 @@ def test_save_empty_warning(self, mock_clickhouse, mock_kafka_consume_handler):

@patch(
"src.detector.detector.CHECKSUM",
"ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06",
"021af76b2385ddbc76f6e3ad10feb0bb081f9cf05cff2e52333e31040bbf36cc",
)
@patch("src.detector.detector.MODEL", "rf")
@patch(
Expand All @@ -328,7 +359,7 @@ def test_save_warning_error(self, mock_clickhouse, mock_kafka_consume_handler):
"request": "request.de",
"probability": "INVALID",
"model": "rf",
"sha256": "ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06",
"sha256": "021af76b2385ddbc76f6e3ad10feb0bb081f9cf05cff2e52333e31040bbf36cc",
}
]
with self.assertRaises(Exception):
Expand Down
Loading