To facilitate the grid searching of optimal hyper-parameters,
early stop is now available in SSLRec for methods that trained by basic trainer
.
If you need to use early stop, you only need to add the patience
keyword under the train
keyword
in the yaml
file to define the endurance number n.
When the n verification results are lower than the existing best results,
the algorithm will stop training and extract the best model parameters for testing on the test set.
Here is an example:
train:
epoch: 3000
batch_size: 4096
save_model: false
loss: pairwise
log_loss: false
test_step: 3
patience: 3
If there is no patience
keyword, the trainer will train the model with fixed number of epoch as defined.
Also, note that in the data_handler
file, add the correct validation and test sets.
Otherwise, the algorithm will loosely use the test set as the validation set.
Here is an example in data_hander_general_cf.py
:
def load_data(self):
trn_mat = self._load_one_mat(self.trn_file)
tst_mat = self._load_one_mat(self.tst_file)
val_mat = self._load_one_mat(self.val_file)
self.trn_mat = trn_mat
configs['data']['user_num'], configs['data']['item_num'] = trn_mat.shape
self.torch_adj = self._make_torch_adj(trn_mat)
if configs['train']['loss'] == 'pairwise':
trn_data = PairwiseTrnData(trn_mat)
elif configs['train']['loss'] == 'pairwise_with_epoch_flag':
trn_data = PairwiseWEpochFlagTrnData(trn_mat)
val_data = AllRankTstData(val_mat, trn_mat)
tst_data = AllRankTstData(tst_mat, trn_mat)
self.valid_dataloader = data.DataLoader(val_data, batch_size=configs['test']['batch_size'], shuffle=False, num_workers=0)
self.test_dataloader = data.DataLoader(tst_data, batch_size=configs['test']['batch_size'], shuffle=False, num_workers=0)
self.train_dataloader = data.DataLoader(trn_data, batch_size=configs['train']['batch_size'], shuffle=True, num_workers=0)