Skip to content

Commit

Permalink
Remove stratified group k fold (#101)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock authored Apr 2, 2024
1 parent 11db4ec commit e81789e
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 53 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "FSRS-Optimizer"
version = "4.27.2"
version = "4.27.3"
readme = "README.md"
dependencies = [
"matplotlib>=3.7.0",
Expand Down
82 changes: 30 additions & 52 deletions src/fsrs_optimizer/fsrs_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torch import Tensor
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from sklearn.model_selection import StratifiedGroupKFold, TimeSeriesSplit
from sklearn.model_selection import TimeSeriesSplit
from sklearn.metrics import root_mean_squared_error, mean_absolute_error, r2_score
from scipy.optimize import minimize
from itertools import accumulate
Expand Down Expand Up @@ -977,7 +977,6 @@ def train(
self,
lr: float = 4e-2,
n_epoch: int = 5,
n_splits: int = 5,
batch_size: int = 512,
verbose: bool = True,
split_by_time: bool = False,
Expand All @@ -993,59 +992,38 @@ def train(

w = []
plots = []
if n_splits > 1:
if split_by_time:
tscv = TimeSeriesSplit(n_splits=n_splits)
self.dataset.sort_values(by=["review_time"], inplace=True)
for i, (train_index, test_index) in enumerate(tscv.split(self.dataset)):
if verbose:
tqdm.write(f"TRAIN: {len(train_index)} TEST: {len(test_index)}")
train_set = self.dataset.iloc[train_index].copy()
test_set = self.dataset.iloc[test_index].copy()
trainer = Trainer(
train_set,
test_set,
self.init_w,
n_epoch=n_epoch,
lr=lr,
batch_size=batch_size,
)
w.append(trainer.train(verbose=verbose))
self.w = w[-1]
self.evaluate()
metrics, figures = self.calibration_graph(
self.dataset.iloc[test_index]
)
for j, f in enumerate(figures):
f.savefig(f"graph_{j}_test_{i}.png")
plt.close(f)
if verbose:
print(metrics)
plots.append(trainer.plot())
else:
sgkf = StratifiedGroupKFold(n_splits=n_splits)
for train_index, test_index in sgkf.split(
self.dataset, self.dataset["i"], self.dataset["group"]
):
if verbose:
tqdm.write(f"TRAIN: {len(train_index)} TEST: {len(test_index)}")
train_set = self.dataset.iloc[train_index].copy()
test_set = self.dataset.iloc[test_index].copy()
trainer = Trainer(
train_set,
test_set,
self.init_w,
n_epoch=n_epoch,
lr=lr,
batch_size=batch_size,
)
w.append(trainer.train(verbose=verbose))
if verbose:
plots.append(trainer.plot())
if split_by_time:
tscv = TimeSeriesSplit(n_splits=5)
self.dataset.sort_values(by=["review_time"], inplace=True)
for i, (train_index, test_index) in enumerate(tscv.split(self.dataset)):
if verbose:
tqdm.write(f"TRAIN: {len(train_index)} TEST: {len(test_index)}")
train_set = self.dataset.iloc[train_index].copy()
test_set = self.dataset.iloc[test_index].copy()
trainer = Trainer(
train_set,
test_set,
self.init_w,
n_epoch=n_epoch,
lr=lr,
batch_size=batch_size,
)
w.append(trainer.train(verbose=verbose))
self.w = w[-1]
self.evaluate()
metrics, figures = self.calibration_graph(
self.dataset.iloc[test_index]
)
for j, f in enumerate(figures):
f.savefig(f"graph_{j}_test_{i}.png")
plt.close(f)
if verbose:
print(metrics)
plots.append(trainer.plot())
else:
trainer = Trainer(
self.dataset,
self.dataset,
None,
self.init_w,
n_epoch=n_epoch,
lr=lr,
Expand Down

0 comments on commit e81789e

Please sign in to comment.