Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions luxonis_ml/tracker/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,30 @@ def upload_artifact(
) # Stores details for retrying later
self.log_stored_logs_to_mlflow()

@rank_zero_only
def log_matrix(
self, matrix: np.ndarray, name: str = "confusion_matrix"
) -> None:
"""Logs a confusion matrix as a JSON artifact by flattening the
matrix and saving its shape.

@type matrix: np.ndarray
@param matrix: The confusion matrix to log.
@type name: str
@param name: The name of the artifact.
"""
matrix_data = {
"flat_array": matrix.flatten().tolist(),
"shape": matrix.shape,
}
tmp_file_path = f"{name}.json"
try:
with open(tmp_file_path, "w") as tmp_file:
json.dump(matrix_data, tmp_file)
self.upload_artifact(path=tmp_file_path)
finally:
Path(tmp_file_path).unlink(missing_ok=True)

@rank_zero_only
def log_images(self, imgs: Dict[str, np.ndarray], step: int) -> None:
"""Logs multiple images.
Expand Down
Loading