@@ -501,13 +501,15 @@ def upload_artifact(
501501 self .log_stored_logs_to_mlflow ()
502502
503503 @rank_zero_only
504- def log_matrix (self , matrix : np .ndarray , name : str ) -> None :
504+ def log_matrix (self , matrix : np .ndarray , name : str , step : int ) -> None :
505505 """Logs confusion matrix to the logging service.
506506
507507 @type matrix: np.ndarray
508508 @param matrix: The confusion matrix to log.
509509 @type name: str
510510 @param name: The name used to log the matrix.
511+ @type step: int
512+ @param step: The current step.
511513 """
512514 if self .is_mlflow :
513515 matrix_data = {
@@ -518,7 +520,7 @@ def log_matrix(self, matrix: np.ndarray, name: str) -> None:
518520
519521 if self .is_tensorboard :
520522 matrix_str = np .array2string (matrix , separator = ", " )
521- self .experiment ["tensorboard" ].add_text (name , matrix_str )
523+ self .experiment ["tensorboard" ].add_text (name , matrix_str , step )
522524
523525 if self .is_wandb :
524526 import wandb
@@ -529,7 +531,7 @@ def log_matrix(self, matrix: np.ndarray, name: str) -> None:
529531 )
530532 for i , row in enumerate (matrix ):
531533 table .add_data (i , * row )
532- self .experiment ["wandb" ].log ({f"{ name } _table" : table })
534+ self .experiment ["wandb" ].log ({f"{ name } _table" : table }, step = step )
533535
534536 @rank_zero_only
535537 def log_images (self , imgs : Dict [str , np .ndarray ], step : int ) -> None :
0 commit comments