16
16
from torch import Tensor
17
17
from torch .utils .data import Dataset
18
18
from torch .nn .utils .rnn import pad_sequence
19
- from sklearn .model_selection import StratifiedGroupKFold , TimeSeriesSplit
19
+ from sklearn .model_selection import TimeSeriesSplit
20
20
from sklearn .metrics import root_mean_squared_error , mean_absolute_error , r2_score
21
21
from scipy .optimize import minimize
22
22
from itertools import accumulate
@@ -977,7 +977,6 @@ def train(
977
977
self ,
978
978
lr : float = 4e-2 ,
979
979
n_epoch : int = 5 ,
980
- n_splits : int = 5 ,
981
980
batch_size : int = 512 ,
982
981
verbose : bool = True ,
983
982
split_by_time : bool = False ,
@@ -993,59 +992,38 @@ def train(
993
992
994
993
w = []
995
994
plots = []
996
- if n_splits > 1 :
997
- if split_by_time :
998
- tscv = TimeSeriesSplit (n_splits = n_splits )
999
- self .dataset .sort_values (by = ["review_time" ], inplace = True )
1000
- for i , (train_index , test_index ) in enumerate (tscv .split (self .dataset )):
1001
- if verbose :
1002
- tqdm .write (f"TRAIN: { len (train_index )} TEST: { len (test_index )} " )
1003
- train_set = self .dataset .iloc [train_index ].copy ()
1004
- test_set = self .dataset .iloc [test_index ].copy ()
1005
- trainer = Trainer (
1006
- train_set ,
1007
- test_set ,
1008
- self .init_w ,
1009
- n_epoch = n_epoch ,
1010
- lr = lr ,
1011
- batch_size = batch_size ,
1012
- )
1013
- w .append (trainer .train (verbose = verbose ))
1014
- self .w = w [- 1 ]
1015
- self .evaluate ()
1016
- metrics , figures = self .calibration_graph (
1017
- self .dataset .iloc [test_index ]
1018
- )
1019
- for j , f in enumerate (figures ):
1020
- f .savefig (f"graph_{ j } _test_{ i } .png" )
1021
- plt .close (f )
1022
- if verbose :
1023
- print (metrics )
1024
- plots .append (trainer .plot ())
1025
- else :
1026
- sgkf = StratifiedGroupKFold (n_splits = n_splits )
1027
- for train_index , test_index in sgkf .split (
1028
- self .dataset , self .dataset ["i" ], self .dataset ["group" ]
1029
- ):
1030
- if verbose :
1031
- tqdm .write (f"TRAIN: { len (train_index )} TEST: { len (test_index )} " )
1032
- train_set = self .dataset .iloc [train_index ].copy ()
1033
- test_set = self .dataset .iloc [test_index ].copy ()
1034
- trainer = Trainer (
1035
- train_set ,
1036
- test_set ,
1037
- self .init_w ,
1038
- n_epoch = n_epoch ,
1039
- lr = lr ,
1040
- batch_size = batch_size ,
1041
- )
1042
- w .append (trainer .train (verbose = verbose ))
1043
- if verbose :
1044
- plots .append (trainer .plot ())
995
+ if split_by_time :
996
+ tscv = TimeSeriesSplit (n_splits = 5 )
997
+ self .dataset .sort_values (by = ["review_time" ], inplace = True )
998
+ for i , (train_index , test_index ) in enumerate (tscv .split (self .dataset )):
999
+ if verbose :
1000
+ tqdm .write (f"TRAIN: { len (train_index )} TEST: { len (test_index )} " )
1001
+ train_set = self .dataset .iloc [train_index ].copy ()
1002
+ test_set = self .dataset .iloc [test_index ].copy ()
1003
+ trainer = Trainer (
1004
+ train_set ,
1005
+ test_set ,
1006
+ self .init_w ,
1007
+ n_epoch = n_epoch ,
1008
+ lr = lr ,
1009
+ batch_size = batch_size ,
1010
+ )
1011
+ w .append (trainer .train (verbose = verbose ))
1012
+ self .w = w [- 1 ]
1013
+ self .evaluate ()
1014
+ metrics , figures = self .calibration_graph (
1015
+ self .dataset .iloc [test_index ]
1016
+ )
1017
+ for j , f in enumerate (figures ):
1018
+ f .savefig (f"graph_{ j } _test_{ i } .png" )
1019
+ plt .close (f )
1020
+ if verbose :
1021
+ print (metrics )
1022
+ plots .append (trainer .plot ())
1045
1023
else :
1046
1024
trainer = Trainer (
1047
1025
self .dataset ,
1048
- self . dataset ,
1026
+ None ,
1049
1027
self .init_w ,
1050
1028
n_epoch = n_epoch ,
1051
1029
lr = lr ,
0 commit comments