Skip to content

Latest commit

 

History

History
50 lines (40 loc) · 1.97 KB

2023-06-28.md

File metadata and controls

50 lines (40 loc) · 1.97 KB

2023-06-28

To facilitate the grid searching of optimal hyper-parameters, early stop is now available in SSLRec for methods that trained by basic trainer.

If you need to use early stop, you only need to add the patience keyword under the train keyword in the yaml file to define the endurance number n. When the n verification results are lower than the existing best results, the algorithm will stop training and extract the best model parameters for testing on the test set.

Here is an example:

train:
  epoch: 3000
  batch_size: 4096
  save_model: false
  loss: pairwise
  log_loss: false
  test_step: 3
  patience: 3

If there is no patience keyword, the trainer will train the model with fixed number of epoch as defined.

Also, note that in the data_handler file, add the correct validation and test sets. Otherwise, the algorithm will loosely use the test set as the validation set.

Here is an example in data_hander_general_cf.py:

def load_data(self):
    trn_mat = self._load_one_mat(self.trn_file)
    tst_mat = self._load_one_mat(self.tst_file)
    val_mat = self._load_one_mat(self.val_file)

    self.trn_mat = trn_mat
    configs['data']['user_num'], configs['data']['item_num'] = trn_mat.shape
    self.torch_adj = self._make_torch_adj(trn_mat)

    if configs['train']['loss'] == 'pairwise':
        trn_data = PairwiseTrnData(trn_mat)
    elif configs['train']['loss'] == 'pairwise_with_epoch_flag':
        trn_data = PairwiseWEpochFlagTrnData(trn_mat)
    val_data = AllRankTstData(val_mat, trn_mat)
    tst_data = AllRankTstData(tst_mat, trn_mat)
    
    self.valid_dataloader = data.DataLoader(val_data, batch_size=configs['test']['batch_size'], shuffle=False, num_workers=0)
    self.test_dataloader = data.DataLoader(tst_data, batch_size=configs['test']['batch_size'], shuffle=False, num_workers=0)
    self.train_dataloader = data.DataLoader(trn_data, batch_size=configs['train']['batch_size'], shuffle=True, num_workers=0)