Skip to content

Commit e81789e

Browse files
authored
Remove stratified group k fold (#101)
1 parent 11db4ec commit e81789e

File tree

2 files changed

+31
-53
lines changed

2 files changed

+31
-53
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "FSRS-Optimizer"
7-
version = "4.27.2"
7+
version = "4.27.3"
88
readme = "README.md"
99
dependencies = [
1010
"matplotlib>=3.7.0",

src/fsrs_optimizer/fsrs_optimizer.py

Lines changed: 30 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from torch import Tensor
1717
from torch.utils.data import Dataset
1818
from torch.nn.utils.rnn import pad_sequence
19-
from sklearn.model_selection import StratifiedGroupKFold, TimeSeriesSplit
19+
from sklearn.model_selection import TimeSeriesSplit
2020
from sklearn.metrics import root_mean_squared_error, mean_absolute_error, r2_score
2121
from scipy.optimize import minimize
2222
from itertools import accumulate
@@ -977,7 +977,6 @@ def train(
977977
self,
978978
lr: float = 4e-2,
979979
n_epoch: int = 5,
980-
n_splits: int = 5,
981980
batch_size: int = 512,
982981
verbose: bool = True,
983982
split_by_time: bool = False,
@@ -993,59 +992,38 @@ def train(
993992

994993
w = []
995994
plots = []
996-
if n_splits > 1:
997-
if split_by_time:
998-
tscv = TimeSeriesSplit(n_splits=n_splits)
999-
self.dataset.sort_values(by=["review_time"], inplace=True)
1000-
for i, (train_index, test_index) in enumerate(tscv.split(self.dataset)):
1001-
if verbose:
1002-
tqdm.write(f"TRAIN: {len(train_index)} TEST: {len(test_index)}")
1003-
train_set = self.dataset.iloc[train_index].copy()
1004-
test_set = self.dataset.iloc[test_index].copy()
1005-
trainer = Trainer(
1006-
train_set,
1007-
test_set,
1008-
self.init_w,
1009-
n_epoch=n_epoch,
1010-
lr=lr,
1011-
batch_size=batch_size,
1012-
)
1013-
w.append(trainer.train(verbose=verbose))
1014-
self.w = w[-1]
1015-
self.evaluate()
1016-
metrics, figures = self.calibration_graph(
1017-
self.dataset.iloc[test_index]
1018-
)
1019-
for j, f in enumerate(figures):
1020-
f.savefig(f"graph_{j}_test_{i}.png")
1021-
plt.close(f)
1022-
if verbose:
1023-
print(metrics)
1024-
plots.append(trainer.plot())
1025-
else:
1026-
sgkf = StratifiedGroupKFold(n_splits=n_splits)
1027-
for train_index, test_index in sgkf.split(
1028-
self.dataset, self.dataset["i"], self.dataset["group"]
1029-
):
1030-
if verbose:
1031-
tqdm.write(f"TRAIN: {len(train_index)} TEST: {len(test_index)}")
1032-
train_set = self.dataset.iloc[train_index].copy()
1033-
test_set = self.dataset.iloc[test_index].copy()
1034-
trainer = Trainer(
1035-
train_set,
1036-
test_set,
1037-
self.init_w,
1038-
n_epoch=n_epoch,
1039-
lr=lr,
1040-
batch_size=batch_size,
1041-
)
1042-
w.append(trainer.train(verbose=verbose))
1043-
if verbose:
1044-
plots.append(trainer.plot())
995+
if split_by_time:
996+
tscv = TimeSeriesSplit(n_splits=5)
997+
self.dataset.sort_values(by=["review_time"], inplace=True)
998+
for i, (train_index, test_index) in enumerate(tscv.split(self.dataset)):
999+
if verbose:
1000+
tqdm.write(f"TRAIN: {len(train_index)} TEST: {len(test_index)}")
1001+
train_set = self.dataset.iloc[train_index].copy()
1002+
test_set = self.dataset.iloc[test_index].copy()
1003+
trainer = Trainer(
1004+
train_set,
1005+
test_set,
1006+
self.init_w,
1007+
n_epoch=n_epoch,
1008+
lr=lr,
1009+
batch_size=batch_size,
1010+
)
1011+
w.append(trainer.train(verbose=verbose))
1012+
self.w = w[-1]
1013+
self.evaluate()
1014+
metrics, figures = self.calibration_graph(
1015+
self.dataset.iloc[test_index]
1016+
)
1017+
for j, f in enumerate(figures):
1018+
f.savefig(f"graph_{j}_test_{i}.png")
1019+
plt.close(f)
1020+
if verbose:
1021+
print(metrics)
1022+
plots.append(trainer.plot())
10451023
else:
10461024
trainer = Trainer(
10471025
self.dataset,
1048-
self.dataset,
1026+
None,
10491027
self.init_w,
10501028
n_epoch=n_epoch,
10511029
lr=lr,

0 commit comments

Comments
 (0)