Skip to content

Commit f6462cb

Browse files
Merge pull request #86 from stefanDeveloper/add-scaler
Add scaler transform in detector
2 parents a417234 + ce08746 commit f6462cb

File tree

6 files changed

+71
-32
lines changed

6 files changed

+71
-32
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ The full list of configuration parameters is available at the [documentation](ht
162162
| `pipeline.data_inspection.inspector.models` | List of models to use for data inspection (e.g., anomaly detection). | Array of model definitions (e.g., `{"model": "ZScoreDetector", "module": "streamad.model", "model_args": {"is_global": false}}`)|
163163
| `pipeline.data_inspection.inspector.anomaly_threshold` | Threshold for classifying an observation as an anomaly. | `0.01` |
164164
| `pipeline.data_analysis.detector.model` | Model to use for data analysis (e.g., DGA detection). | `rf` (Random Forest) option: `XGBoost` |
165-
| `pipeline.data_analysis.detector.checksum` | Checksum for the model file to ensure integrity. | `ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06` |
165+
| `pipeline.data_analysis.detector.checksum` | Checksum for the model file to ensure integrity. | `021af76b2385ddbc76f6e3ad10feb0bb081f9cf05cff2e52333e31040bbf36cc` |
166166
| `pipeline.data_analysis.detector.base_url` | Base URL for downloading the model if not present locally. | `https://heibox.uni-heidelberg.de/d/0d5cbcbe16cd46a58021/` |
167167

168168
<p align="right">(<a href="#readme-top">back to top</a>)</p>

config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ pipeline:
5959
data_analysis:
6060
detector:
6161
model: rf # XGBoost
62-
checksum: ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06
62+
checksum: 021af76b2385ddbc76f6e3ad10feb0bb081f9cf05cff2e52333e31040bbf36cc
6363
base_url: https://heibox.uni-heidelberg.de/d/0d5cbcbe16cd46a58021/
6464
threshold: 0.5
6565

requirements/requirements.inspector.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
PyYAML~=6.0.1
22
confluent-kafka~=2.4.0
33
colorlog~=6.8.2
4+
scipy~=1.12.0
45
streamad~=0.3.1
56
numpy~=1.26.4
67
marshmallow_dataclass~=8.7.1

requirements/requirements.train.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
numpy
22
xgboost
33
scikit-learn~=1.5.2
4-
scipy
4+
scipy~=1.12.0
55
torch
66
pyarrow
77
polars

src/detector/detector.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,15 @@ def __init__(self) -> None:
5656
self.begin_timestamp = None
5757
self.end_timestamp = None
5858
self.model_path = os.path.join(
59-
tempfile.gettempdir(), f"{MODEL}_{CHECKSUM}.pickle"
59+
tempfile.gettempdir(), f"{MODEL}_{CHECKSUM}_model.pickle"
60+
)
61+
self.scaler_path = os.path.join(
62+
tempfile.gettempdir(), f"{MODEL}_{CHECKSUM}_scaler.pickle"
6063
)
6164

6265
self.kafka_consume_handler = ExactlyOnceKafkaConsumeHandler(CONSUME_TOPIC)
6366

64-
self.model = self._get_model()
67+
self.model, self.scaler = self._get_model()
6568

6669
# databases
6770
self.suspicious_batch_timestamps = ClickHouseKafkaSender(
@@ -159,14 +162,28 @@ def _get_model(self):
159162
logger.info(f"Get model: {MODEL} with checksum {CHECKSUM}")
160163
if not os.path.isfile(self.model_path):
161164
response = requests.get(
162-
f"{MODEL_BASE_URL}/files/?p=%2F{MODEL}_{CHECKSUM}.pickle&dl=1"
165+
f"{MODEL_BASE_URL}/files/?p=%2F{MODEL}/{CHECKSUM}/{MODEL}.pickle&dl=1"
166+
)
167+
logger.info(
168+
f"{MODEL_BASE_URL}/files/?p=%2F{MODEL}/{CHECKSUM}/{MODEL}.pickle&dl=1"
163169
)
164-
logger.info(f"{MODEL_BASE_URL}/files/?p=%2F{MODEL}_{CHECKSUM}.pickle&dl=1")
165170
response.raise_for_status()
166171

167172
with open(self.model_path, "wb") as f:
168173
f.write(response.content)
169174

175+
if not os.path.isfile(self.scaler_path):
176+
response = requests.get(
177+
f"{MODEL_BASE_URL}/files/?p=%2F{MODEL}/{CHECKSUM}/scaler.pickle&dl=1"
178+
)
179+
logger.info(
180+
f"{MODEL_BASE_URL}/files/?p=%2F{MODEL}/{CHECKSUM}/scaler.pickle&dl=1"
181+
)
182+
response.raise_for_status()
183+
184+
with open(self.scaler_path, "wb") as f:
185+
f.write(response.content)
186+
170187
# Check file sha256
171188
local_checksum = self._sha256sum(self.model_path)
172189

@@ -181,7 +198,10 @@ def _get_model(self):
181198
with open(self.model_path, "rb") as input_file:
182199
clf = pickle.load(input_file)
183200

184-
return clf
201+
with open(self.scaler_path, "rb") as input_file:
202+
scaler = pickle.load(input_file)
203+
204+
return clf, scaler
185205

186206
def clear_data(self):
187207
"""Clears the data in the internal data structures."""
@@ -237,12 +257,6 @@ def calculate_counts(level: str) -> np.ndarray:
237257
for level, level_value in levels.items()
238258
}
239259

240-
logger.debug("Get frequency standard deviation, median, variance, and mean.")
241-
freq_std = np.std(freq)
242-
freq_var = np.var(freq)
243-
freq_median = np.median(freq)
244-
freq_mean = np.mean(freq)
245-
246260
logger.debug(
247261
"Get standard deviation, median, variance, and mean for full, alpha, special, and numeric count."
248262
)
@@ -268,16 +282,9 @@ def calculate_entropy(s: str) -> float:
268282

269283
# Final feature aggregation as a NumPy array
270284
basic_features = np.array([label_length, label_max, label_average])
271-
freq_features = np.array([freq_std, freq_var, freq_median, freq_mean])
272285

273286
# Flatten counts and stats for each level into arrays
274287
level_features = np.hstack([counts[level] for level in levels.keys()])
275-
stats_features = np.array(
276-
[stats[f"{level}_std"] for level in levels.keys()]
277-
+ [stats[f"{level}_var"] for level in levels.keys()]
278-
+ [stats[f"{level}_median"] for level in levels.keys()]
279-
+ [stats[f"{level}_mean"] for level in levels.keys()]
280-
)
281288

282289
# Entropy features
283290
entropy_features = np.array([entropy[level] for level in levels.keys()])
@@ -287,9 +294,9 @@ def calculate_entropy(s: str) -> float:
287294
[
288295
basic_features,
289296
freq,
290-
freq_features,
297+
# freq_features,
291298
level_features,
292-
stats_features,
299+
# stats_features,
293300
entropy_features,
294301
]
295302
)
@@ -304,7 +311,7 @@ def detect(self) -> None: # pragma: no cover
304311
for message in self.messages:
305312
# TODO predict all messages
306313
y_pred = self.model.predict_proba(
307-
self._get_features(message["domain_name"])
314+
self.scaler.transform(self._get_features(message["domain_name"]))
308315
)
309316
logger.info(f"Prediction: {y_pred}")
310317
if np.argmax(y_pred, axis=1) == 1 and y_pred[0][1] > THRESHOLD:

tests/detector/test_detector.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,17 @@
1010
from src.base.data_classes.batch import Batch
1111
from src.detector.detector import Detector, WrongChecksum
1212

13+
DEFAULT_DATA = {
14+
"client_ip": "192.168.0.167",
15+
"dns_ip": "10.10.0.10",
16+
"response_ip": "252.79.173.222",
17+
"timestamp": "",
18+
"status": "NXDOMAIN",
19+
"domain_name": "IF356gEnJHPdRxnkDId4RDUSgtqxx9I+pZ5n1V53MdghOGQncZWAQgAPRx3kswi.750jnH6iSqmiAAeyDUMX0W6SHGpVsVsKSX8ZkKYDs0GFh/9qU5N9cwl00XSD8ID.NNhBdHZIb7nc0hDQXFPlABDLbRwkJS38LZ8RMX4yUmR2Mb6YqTTJBn+nUcB9P+v.jBQdwdS53XV9W2p1BHjh.16.f.1.6037.tunnel.example.org",
20+
"record_type": "A",
21+
"size": "100b",
22+
}
23+
1324

1425
class TestSha256Sum(unittest.TestCase):
1526
@patch("src.detector.detector.ExactlyOnceKafkaConsumeHandler")
@@ -45,7 +56,7 @@ def setUp(self):
4556

4657
@patch(
4758
"src.detector.detector.CHECKSUM",
48-
"ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06",
59+
"021af76b2385ddbc76f6e3ad10feb0bb081f9cf05cff2e52333e31040bbf36cc",
4960
)
5061
@patch("src.detector.detector.MODEL", "rf")
5162
@patch(
@@ -70,7 +81,7 @@ def setUp(self):
7081

7182
@patch(
7283
"src.detector.detector.CHECKSUM",
73-
"ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06",
84+
"021af76b2385ddbc76f6e3ad10feb0bb081f9cf05cff2e52333e31040bbf36cc",
7485
)
7586
@patch("src.detector.detector.MODEL", "rf")
7687
@patch(
@@ -246,7 +257,7 @@ def setUp(self):
246257

247258
@patch(
248259
"src.detector.detector.CHECKSUM",
249-
"ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06",
260+
"021af76b2385ddbc76f6e3ad10feb0bb081f9cf05cff2e52333e31040bbf36cc",
250261
)
251262
@patch("src.detector.detector.MODEL", "rf")
252263
@patch(
@@ -265,13 +276,13 @@ def test_save_warning(self, mock_clickhouse, mock_kafka_consume_handler):
265276
"request": "google.de",
266277
"probability": 0.8765,
267278
"model": "rf",
268-
"sha256": "ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06",
279+
"sha256": "021af76b2385ddbc76f6e3ad10feb0bb081f9cf05cff2e52333e31040bbf36cc",
269280
},
270281
{
271282
"request": "request.de",
272283
"probability": 0.12388,
273284
"model": "rf",
274-
"sha256": "ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06",
285+
"sha256": "021af76b2385ddbc76f6e3ad10feb0bb081f9cf05cff2e52333e31040bbf36cc",
275286
},
276287
]
277288
sut.messages = [{"logline_id": "test_id"}]
@@ -285,7 +296,27 @@ def test_save_warning(self, mock_clickhouse, mock_kafka_consume_handler):
285296

286297
@patch(
287298
"src.detector.detector.CHECKSUM",
288-
"ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06",
299+
"021af76b2385ddbc76f6e3ad10feb0bb081f9cf05cff2e52333e31040bbf36cc",
300+
)
301+
@patch("src.detector.detector.MODEL", "rf")
302+
@patch(
303+
"src.detector.detector.MODEL_BASE_URL",
304+
"https://heibox.uni-heidelberg.de/d/0d5cbcbe16cd46a58021/",
305+
)
306+
@patch("src.detector.detector.ExactlyOnceKafkaConsumeHandler")
307+
@patch("src.detector.detector.ClickHouseKafkaSender")
308+
def test_prediction(self, mock_clickhouse, mock_kafka_consume_handler):
309+
mock_kafka_consume_handler_instance = MagicMock()
310+
mock_kafka_consume_handler.return_value = mock_kafka_consume_handler_instance
311+
312+
sut = Detector()
313+
sut.messages = [DEFAULT_DATA]
314+
sut.detect()
315+
self.assertNotEqual([], sut.warnings)
316+
317+
@patch(
318+
"src.detector.detector.CHECKSUM",
319+
"021af76b2385ddbc76f6e3ad10feb0bb081f9cf05cff2e52333e31040bbf36cc",
289320
)
290321
@patch("src.detector.detector.MODEL", "rf")
291322
@patch(
@@ -309,7 +340,7 @@ def test_save_empty_warning(self, mock_clickhouse, mock_kafka_consume_handler):
309340

310341
@patch(
311342
"src.detector.detector.CHECKSUM",
312-
"ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06",
343+
"021af76b2385ddbc76f6e3ad10feb0bb081f9cf05cff2e52333e31040bbf36cc",
313344
)
314345
@patch("src.detector.detector.MODEL", "rf")
315346
@patch(
@@ -328,7 +359,7 @@ def test_save_warning_error(self, mock_clickhouse, mock_kafka_consume_handler):
328359
"request": "request.de",
329360
"probability": "INVALID",
330361
"model": "rf",
331-
"sha256": "ba1f718179191348fe2abd51644d76191d42a5d967c6844feb3371b6f798bf06",
362+
"sha256": "021af76b2385ddbc76f6e3ad10feb0bb081f9cf05cff2e52333e31040bbf36cc",
332363
}
333364
]
334365
with self.assertRaises(Exception):

0 commit comments

Comments
 (0)