Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/base/clickhouse_kafka_sender.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@ def __init__(self, table_name: str):
)()

def insert(self, data: dict):
"""Produces the insert operation to Kafka."""
"""
Produces the insert operation to Kafka.

Args:
data (dict): content to write into the Kafka queue
"""
self.kafka_producer.produce(
topic=f"clickhouse_{self.table_name}",
data=self.data_schema.dumps(data),
Expand Down
4 changes: 4 additions & 0 deletions src/base/data_classes/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@

@dataclass
class Batch:
"""
Class definition of a batch, used to divide the log input into smaller amounts
"""

batch_id: uuid.UUID = field(
metadata={"marshmallow_field": marshmallow.fields.UUID()}
)
Expand Down
10 changes: 9 additions & 1 deletion src/base/kafka_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,15 @@ def consume_as_json(self) -> tuple[None | str, dict]:
except Exception:
raise ValueError("Unknown data format")

def _all_topics_created(self, topics):
def _all_topics_created(self, topics) -> bool:
"""
Checks whether each topic in a list of topics was created. If not, retries for a set amount of times

Args:
topics (list): List of topics to check
Returns:
bool
"""
number_of_retries_left = 30
all_topics_created = False
while not all_topics_created: # try for 15 seconds
Expand Down
2 changes: 2 additions & 0 deletions src/detector/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ def calculate_entropy(s: str) -> float:
return all_features.reshape(1, -1)

def detect(self) -> None: # pragma: no cover
"""Method to detect malicious requests in the network flows"""
logger.info("Start detecting malicious requests.")
for message in self.messages:
# TODO predict all messages
Expand All @@ -317,6 +318,7 @@ def detect(self) -> None: # pragma: no cover
self.warnings.append(warning)

def send_warning(self) -> None:
"""Dispatch warnings saved to the object's warning list"""
logger.info("Store alert.")
if len(self.warnings) > 0:
overall_score = median(
Expand Down
34 changes: 26 additions & 8 deletions src/inspector/inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,10 +336,20 @@ def inspect(self):
raise NotImplementedError(f"Mode {MODE} is not supported!")

def _inspect_multivariate(self, model: str):
"""
Method to inspect multivariate data for anomalies using a StreamAD Model
Errors are count in the time window and fit model to retrieve scores.

Args:
model (str): Model name (should be capable of handling multivariate data)

"""
logger.debug(f"Load Model: {model['model']} from {model['module']}.")
if not model["model"] in VALID_MULTIVARIATE_MODELS:
logger.error(f"Model {model} is not a valid univariate model.")
raise NotImplementedError(f"Model {model} is not a valid univariate model.")
logger.error(f"Model {model} is not a valid multivariate model.")
raise NotImplementedError(
f"Model {model} is not a valid multivariate model."
)

module = importlib.import_module(model["module"])
module_model = getattr(module, model["model"])
Expand Down Expand Up @@ -367,11 +377,19 @@ def _inspect_multivariate(self, model: str):
self.anomalies.append(0)

def _inspect_ensemble(self, models: str):
"""
Method to inspect data for anomalies using ensembles of two StreamAD models
Errors are count in the time window and fit model to retrieve scores.

Args:
model (str): Model name (should be a valid ensemble modle)

"""
logger.debug(f"Load Model: {ENSEMBLE['model']} from {ENSEMBLE['module']}.")
if not ENSEMBLE["model"] in VALID_ENSEMBLE_MODELS:
logger.error(f"Model {ENSEMBLE} is not a valid univariate model.")
logger.error(f"Model {ENSEMBLE} is not a valid ensemble model.")
raise NotImplementedError(
f"Model {ENSEMBLE} is not a valid univariate model."
f"Model {ENSEMBLE} is not a valid ensemble model."
)

module = importlib.import_module(ENSEMBLE["module"])
Expand All @@ -389,9 +407,9 @@ def _inspect_ensemble(self, models: str):
for model in models:
logger.debug(f"Load Model: {model['model']} from {model['module']}.")
if not model["model"] in VALID_UNIVARIATE_MODELS:
logger.error(f"Model {models} is not a valid univariate model.")
logger.error(f"Model {models} is not a valid ensemble model.")
raise NotImplementedError(
f"Model {models} is not a valid univariate model."
f"Model {models} is not a valid ensemble model."
)

module = importlib.import_module(model["module"])
Expand All @@ -415,8 +433,7 @@ def _inspect_univariate(self, model: str):
Errors are count in the time window and fit model to retrieve scores.

Args:
model (BaseDetector): StreamAD model.
model_args (dict): Arguments passed to the StreamAD model.
model (str): StreamAD model name.
"""

logger.debug(f"Load Model: {model['model']} from {model['module']}.")
Expand Down Expand Up @@ -445,6 +462,7 @@ def _inspect_univariate(self, model: str):
self.anomalies.append(0)

def send_data(self):
"""Pass the anomalous data for the detector unit for further processing"""
total_anomalies = np.count_nonzero(
np.greater_equal(np.array(self.anomalies), SCORE_THRESHOLD)
)
Expand Down
20 changes: 20 additions & 0 deletions src/logcollector/batch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,12 @@ def add_message(self, key: str, message: str) -> None:
self._reset_timer()

def _send_all_batches(self, reset_timer: bool = True) -> None:
"""
Dispatch all batches for the Kafka queue

Args:
reset_timer (bool): whether or not the timer should be reset
"""
number_of_keys = 0
total_number_of_batch_messages = self.batch.get_message_count_for_batch()
total_number_of_buffer_messages = self.batch.get_message_count_for_buffer()
Expand Down Expand Up @@ -438,6 +444,12 @@ def _send_all_batches(self, reset_timer: bool = True) -> None:
)

def _send_batch_for_key(self, key: str) -> None:
"""
Send one batch based on the key

Args:
key (str): Key to identify the batch
"""
try:
data = self.batch.complete_batch(key)
except ValueError as e:
Expand All @@ -447,6 +459,13 @@ def _send_batch_for_key(self, key: str) -> None:
self._send_data_packet(key, data)

def _send_data_packet(self, key: str, data: dict) -> None:
"""
Sends a packet of a batch to the defined Kafka topic

Args:
key (str): key to identify the batch
data (dict): the batch data to send
"""
batch_schema = marshmallow_dataclass.class_schema(Batch)()

self.kafka_produce_handler.produce(
Expand All @@ -456,6 +475,7 @@ def _send_data_packet(self, key: str, data: dict) -> None:
)

def _reset_timer(self) -> None:
"""Restarts the internal timer of the object"""
if self.timer:
self.timer.cancel()

Expand Down
23 changes: 21 additions & 2 deletions src/monitoring/clickhouse_batch_sender.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ class Table:
columns: dict[str, type]

def verify(self, data: dict[str, Any]):
"""
Verify if the data has the correct columns and types.

Args:
data (dict): The values for each cell
"""
if len(data) != len(self.columns):
raise ValueError(
f"Wrong number of fields in data: Expected {len(self.columns)}, got {len(data)}"
Expand Down Expand Up @@ -182,7 +188,14 @@ def __del__(self):
self.insert_all()

def add(self, table_name: str, data: dict[str, Any]):
"""Adds the data to the batch for the table. Verifies the fields first."""
"""
Adds the data to the batch for the table. Verifies the fields first.

Args:
table_name (str): Name of the table to add data to
data (dict): The values for each cell in the table

"""
self.tables.get(table_name).verify(data)
self.batch.get(table_name).append(list(data.values()))

Expand All @@ -192,7 +205,12 @@ def add(self, table_name: str, data: dict[str, Any]):
self._start_timer()

def insert(self, table_name: str):
"""Inserts the batch for the given table."""
"""
Inserts the batch for the given table.

Args:
table_name (str): Name of the table to insert data to
"""
if self.batch[table_name]:
with self.lock:
self._client.insert(
Expand All @@ -216,6 +234,7 @@ def insert_all(self):
self.timer = None

def _start_timer(self):
"""Set the timer for batch processing of data insertion"""
if self.timer:
self.timer.cancel()

Expand Down
52 changes: 52 additions & 0 deletions src/train/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,16 @@ def __init__(
super().__init__(processor, x_train, y_train)

def fdr_metric(self, preds: np.ndarray, dtrain: xgb.DMatrix) -> tuple[str, float]:
"""
Custom FDR metric to evaluate model performance based on False Discovery Rate.

Args:
preds (np.ndarray): The predicted values.
dtrain (xgb.DMatrix): The training data matrix.

Returns:
tuple: A tuple containing the metric name ("fdr") and its value.
"""
# Get the true labels
labels = dtrain.get_label()

Expand All @@ -188,6 +198,15 @@ def fdr_metric(self, preds: np.ndarray, dtrain: xgb.DMatrix) -> tuple[str, float
) # -1 is essentiell since XGBoost wants a scoring value (higher is better). However, FDR represents a loss function.

def objective(self, trial):
"""
Optimizes the XGBoost model hyperparameters using cross-validation.

Args:
trial: A trial object from the optimization framework (e.g., Optuna).

Returns:
float: The best FDR value after cross-validation.
"""
dtrain = xgb.DMatrix(self.x_train, label=self.y_train)

param = {
Expand Down Expand Up @@ -263,6 +282,13 @@ def predict(self, x):
return self.clf.predict(x)

def train(self, trial, output_path):
"""
Trains the XGBoost model and saves the trained model to a file.

Args:
trial: A trial object from the optimization framework.
output_path (str): The directory path to save the trained model.
"""
logger.info("Number of estimators: {}".format(trial.user_attrs["n_estimators"]))

# dtrain = xgb.DMatrix(self.x_train, label=self.y_train)
Expand Down Expand Up @@ -300,6 +326,16 @@ def __init__(

# Define the custom FDR metric
def fdr_metric(self, y_true: np.ndarray, y_pred: np.ndarray):
"""
Custom FDR metric to evaluate the performance of the Random Forest model.

Args:
y_true (np.ndarray): The true labels.
y_pred (np.ndarray): The predicted labels.

Returns:
float: The False Discovery Rate (FDR).
"""
# False Positives (FP): cases where the model predicted 1 but the actual label is 0
FP = np.sum((y_pred == 1) & (y_true == 0))

Expand All @@ -315,6 +351,15 @@ def fdr_metric(self, y_true: np.ndarray, y_pred: np.ndarray):
return fdr

def objective(self, trial):
"""
Optimizes the Random Forest model hyperparameters using cross-validation.

Args:
trial: A trial object from the optimization framework (e.g., Optuna).

Returns:
float: The best FDR value after cross-validation.
"""
# Define hyperparameters to optimize
n_estimators = trial.suggest_int("n_estimators", 50, 300)
max_depth = trial.suggest_int("max_depth", 2, 20)
Expand Down Expand Up @@ -359,6 +404,13 @@ def predict(self, x):
return self.clf.predict(x)

def train(self, trial, output_path):
"""
Trains the Random Forest model and saves the trained model to a file.

Args:
trial: A trial object from the optimization framework.
output_path (str): The directory path to save the trained model.
"""
self.clf = RandomForestClassifier(**trial.params)
self.clf.fit(self.x_train, self.y_train)

Expand Down
Loading