Skip to content

Commit

Permalink
Add AUC metrics and improve calibration reporting (#149)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock authored Dec 6, 2024
1 parent 0b880f9 commit 6c4d05b
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 5 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "FSRS-Optimizer"
version = "5.4.0"
version = "5.4.1"
readme = "README.md"
dependencies = [
"matplotlib>=3.7.0",
Expand Down
6 changes: 6 additions & 0 deletions src/fsrs_optimizer/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,12 @@ def remembered_fallback_prompt(key: str, pretty: str = None):
print(f"Loss before training: {loss_before:.4f}")
print(f"Loss after training: {loss_after:.4f}")
metrics, figures = optimizer.calibration_graph()
for partition in metrics:
print(f"Last rating = {partition}")
for metric in metrics[partition]:
print(f"{metric}: {metrics[partition][metric]:.4f}")
print()

metrics["Log loss"] = loss_after
if save_graphs:
for i, f in enumerate(figures):
Expand Down
33 changes: 29 additions & 4 deletions src/fsrs_optimizer/fsrs_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
mean_absolute_error,
mean_absolute_percentage_error,
r2_score,
roc_auc_score,
)
from scipy.optimize import minimize # type: ignore
from itertools import accumulate
Expand Down Expand Up @@ -1561,10 +1562,17 @@ def calibration_graph(self, dataset=None, verbose=True):
rmse = rmse_matrix(dataset)
if verbose:
tqdm.write(f"RMSE(bins): {rmse:.4f}")
metrics_all = {}
metrics = plot_brier(
dataset["p"], dataset["y"], bins=20, ax=fig1.add_subplot(111)
)
metrics["rmse"] = rmse
metrics["RMSE(bins)"] = rmse
metrics["AUC"] = (
roc_auc_score(y_true=dataset["y"], y_score=dataset["p"])
if len(dataset["y"].unique()) == 2
else np.nan
)
metrics_all["all"] = metrics
fig2 = plt.figure(figsize=(16, 12))
for last_rating in (1, 2, 3, 4):
calibration_data = dataset[dataset["last_rating"] == last_rating]
Expand All @@ -1574,13 +1582,23 @@ def calibration_graph(self, dataset=None, verbose=True):
if verbose:
tqdm.write(f"\nLast rating: {last_rating}")
tqdm.write(f"RMSE(bins): {rmse:.4f}")
plot_brier(
metrics = plot_brier(
calibration_data["p"],
calibration_data["y"],
bins=20,
ax=fig2.add_subplot(2, 2, int(last_rating)),
title=f"Last rating: {last_rating}",
)
metrics["RMSE(bins)"] = rmse
metrics["AUC"] = (
roc_auc_score(
y_true=calibration_data["y"],
y_score=calibration_data["p"],
)
if len(calibration_data["y"].unique()) == 2
else np.nan
)
metrics_all[last_rating] = metrics

fig3 = plt.figure()
self.calibration_helper(
Expand Down Expand Up @@ -1611,7 +1629,7 @@ def calibration_graph(self, dataset=None, verbose=True):
False,
fig5.add_subplot(111),
)
return metrics, (fig1, fig2, fig3, fig4, fig5)
return metrics_all, (fig1, fig2, fig3, fig4, fig5)

def calibration_helper(self, calibration_data, key, bin_func, semilogx, ax1):
ax2 = ax1.twinx()
Expand Down Expand Up @@ -1925,7 +1943,14 @@ def plot_brier(predictions, real, bins=20, ax=None, title=None):
ax2.legend(loc="lower center")
if title:
ax.set_title(title)
metrics = {"R-squared": r2, "MAE": mae, "ICI": ici, "E50": e_50, "E90": e_90, "EMax": e_max}
metrics = {
"R-squared": r2,
"MAE": mae,
"ICI": ici,
"E50": e_50,
"E90": e_90,
"EMax": e_max,
}
return metrics


Expand Down

0 comments on commit 6c4d05b

Please sign in to comment.