From 6c4d05b86b1eb65e706526f193d842d2a319880d Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Fri, 6 Dec 2024 10:56:22 +0800 Subject: [PATCH] Add AUC metrics and improve calibration reporting (#149) --- pyproject.toml | 2 +- src/fsrs_optimizer/__main__.py | 6 +++++ src/fsrs_optimizer/fsrs_optimizer.py | 33 ++++++++++++++++++++++++---- 3 files changed, 36 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e081f23..834d2c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "FSRS-Optimizer" -version = "5.4.0" +version = "5.4.1" readme = "README.md" dependencies = [ "matplotlib>=3.7.0", diff --git a/src/fsrs_optimizer/__main__.py b/src/fsrs_optimizer/__main__.py index bcf0eff..998bdb9 100644 --- a/src/fsrs_optimizer/__main__.py +++ b/src/fsrs_optimizer/__main__.py @@ -146,6 +146,12 @@ def remembered_fallback_prompt(key: str, pretty: str = None): print(f"Loss before training: {loss_before:.4f}") print(f"Loss after training: {loss_after:.4f}") metrics, figures = optimizer.calibration_graph() + for partition in metrics: + print(f"Last rating = {partition}") + for metric in metrics[partition]: + print(f"{metric}: {metrics[partition][metric]:.4f}") + print() + metrics["Log loss"] = loss_after if save_graphs: for i, f in enumerate(figures): diff --git a/src/fsrs_optimizer/fsrs_optimizer.py b/src/fsrs_optimizer/fsrs_optimizer.py index 1dfc8a8..652a780 100644 --- a/src/fsrs_optimizer/fsrs_optimizer.py +++ b/src/fsrs_optimizer/fsrs_optimizer.py @@ -23,6 +23,7 @@ mean_absolute_error, mean_absolute_percentage_error, r2_score, + roc_auc_score, ) from scipy.optimize import minimize # type: ignore from itertools import accumulate @@ -1561,10 +1562,17 @@ def calibration_graph(self, dataset=None, verbose=True): rmse = rmse_matrix(dataset) if verbose: tqdm.write(f"RMSE(bins): {rmse:.4f}") + metrics_all = {} metrics = plot_brier( dataset["p"], dataset["y"], bins=20, ax=fig1.add_subplot(111) ) - metrics["rmse"] = rmse + metrics["RMSE(bins)"] = rmse + metrics["AUC"] = ( + roc_auc_score(y_true=dataset["y"], y_score=dataset["p"]) + if len(dataset["y"].unique()) == 2 + else np.nan + ) + metrics_all["all"] = metrics fig2 = plt.figure(figsize=(16, 12)) for last_rating in (1, 2, 3, 4): calibration_data = dataset[dataset["last_rating"] == last_rating] @@ -1574,13 +1582,23 @@ def calibration_graph(self, dataset=None, verbose=True): if verbose: tqdm.write(f"\nLast rating: {last_rating}") tqdm.write(f"RMSE(bins): {rmse:.4f}") - plot_brier( + metrics = plot_brier( calibration_data["p"], calibration_data["y"], bins=20, ax=fig2.add_subplot(2, 2, int(last_rating)), title=f"Last rating: {last_rating}", ) + metrics["RMSE(bins)"] = rmse + metrics["AUC"] = ( + roc_auc_score( + y_true=calibration_data["y"], + y_score=calibration_data["p"], + ) + if len(calibration_data["y"].unique()) == 2 + else np.nan + ) + metrics_all[last_rating] = metrics fig3 = plt.figure() self.calibration_helper( @@ -1611,7 +1629,7 @@ def calibration_graph(self, dataset=None, verbose=True): False, fig5.add_subplot(111), ) - return metrics, (fig1, fig2, fig3, fig4, fig5) + return metrics_all, (fig1, fig2, fig3, fig4, fig5) def calibration_helper(self, calibration_data, key, bin_func, semilogx, ax1): ax2 = ax1.twinx() @@ -1925,7 +1943,14 @@ def plot_brier(predictions, real, bins=20, ax=None, title=None): ax2.legend(loc="lower center") if title: ax.set_title(title) - metrics = {"R-squared": r2, "MAE": mae, "ICI": ici, "E50": e_50, "E90": e_90, "EMax": e_max} + metrics = { + "R-squared": r2, + "MAE": mae, + "ICI": ici, + "E50": e_50, + "E90": e_90, + "EMax": e_max, + } return metrics