From e81789e6d51bfbc32ad00851c17cc408b386ac5f Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Tue, 2 Apr 2024 10:29:10 +0800 Subject: [PATCH] Remove stratified group k fold (#101) --- pyproject.toml | 2 +- src/fsrs_optimizer/fsrs_optimizer.py | 82 ++++++++++------------------ 2 files changed, 31 insertions(+), 53 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6af356c..8503653 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "FSRS-Optimizer" -version = "4.27.2" +version = "4.27.3" readme = "README.md" dependencies = [ "matplotlib>=3.7.0", diff --git a/src/fsrs_optimizer/fsrs_optimizer.py b/src/fsrs_optimizer/fsrs_optimizer.py index f42b8b9..ff89cc5 100644 --- a/src/fsrs_optimizer/fsrs_optimizer.py +++ b/src/fsrs_optimizer/fsrs_optimizer.py @@ -16,7 +16,7 @@ from torch import Tensor from torch.utils.data import Dataset from torch.nn.utils.rnn import pad_sequence -from sklearn.model_selection import StratifiedGroupKFold, TimeSeriesSplit +from sklearn.model_selection import TimeSeriesSplit from sklearn.metrics import root_mean_squared_error, mean_absolute_error, r2_score from scipy.optimize import minimize from itertools import accumulate @@ -977,7 +977,6 @@ def train( self, lr: float = 4e-2, n_epoch: int = 5, - n_splits: int = 5, batch_size: int = 512, verbose: bool = True, split_by_time: bool = False, @@ -993,59 +992,38 @@ def train( w = [] plots = [] - if n_splits > 1: - if split_by_time: - tscv = TimeSeriesSplit(n_splits=n_splits) - self.dataset.sort_values(by=["review_time"], inplace=True) - for i, (train_index, test_index) in enumerate(tscv.split(self.dataset)): - if verbose: - tqdm.write(f"TRAIN: {len(train_index)} TEST: {len(test_index)}") - train_set = self.dataset.iloc[train_index].copy() - test_set = self.dataset.iloc[test_index].copy() - trainer = Trainer( - train_set, - test_set, - self.init_w, - n_epoch=n_epoch, - lr=lr, - batch_size=batch_size, - ) - w.append(trainer.train(verbose=verbose)) - self.w = w[-1] - self.evaluate() - metrics, figures = self.calibration_graph( - self.dataset.iloc[test_index] - ) - for j, f in enumerate(figures): - f.savefig(f"graph_{j}_test_{i}.png") - plt.close(f) - if verbose: - print(metrics) - plots.append(trainer.plot()) - else: - sgkf = StratifiedGroupKFold(n_splits=n_splits) - for train_index, test_index in sgkf.split( - self.dataset, self.dataset["i"], self.dataset["group"] - ): - if verbose: - tqdm.write(f"TRAIN: {len(train_index)} TEST: {len(test_index)}") - train_set = self.dataset.iloc[train_index].copy() - test_set = self.dataset.iloc[test_index].copy() - trainer = Trainer( - train_set, - test_set, - self.init_w, - n_epoch=n_epoch, - lr=lr, - batch_size=batch_size, - ) - w.append(trainer.train(verbose=verbose)) - if verbose: - plots.append(trainer.plot()) + if split_by_time: + tscv = TimeSeriesSplit(n_splits=5) + self.dataset.sort_values(by=["review_time"], inplace=True) + for i, (train_index, test_index) in enumerate(tscv.split(self.dataset)): + if verbose: + tqdm.write(f"TRAIN: {len(train_index)} TEST: {len(test_index)}") + train_set = self.dataset.iloc[train_index].copy() + test_set = self.dataset.iloc[test_index].copy() + trainer = Trainer( + train_set, + test_set, + self.init_w, + n_epoch=n_epoch, + lr=lr, + batch_size=batch_size, + ) + w.append(trainer.train(verbose=verbose)) + self.w = w[-1] + self.evaluate() + metrics, figures = self.calibration_graph( + self.dataset.iloc[test_index] + ) + for j, f in enumerate(figures): + f.savefig(f"graph_{j}_test_{i}.png") + plt.close(f) + if verbose: + print(metrics) + plots.append(trainer.plot()) else: trainer = Trainer( self.dataset, - self.dataset, + None, self.init_w, n_epoch=n_epoch, lr=lr,