From 5fb95e13ff20eaa7edc72c9c293569b938f32955 Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Mon, 25 Mar 2024 16:08:43 +0800 Subject: [PATCH] Feat/allow test set=none (#97) --- pyproject.toml | 2 +- src/fsrs_optimizer/fsrs_optimizer.py | 59 +++++++++++++--------------- 2 files changed, 28 insertions(+), 33 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5f30958..990ec98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "FSRS-Optimizer" -version = "4.26.8" +version = "4.27.0" readme = "README.md" dependencies = [ "matplotlib>=3.7.0", diff --git a/src/fsrs_optimizer/fsrs_optimizer.py b/src/fsrs_optimizer/fsrs_optimizer.py index 6052c34..6526825 100644 --- a/src/fsrs_optimizer/fsrs_optimizer.py +++ b/src/fsrs_optimizer/fsrs_optimizer.py @@ -247,7 +247,7 @@ class Trainer: def __init__( self, train_set: pd.DataFrame, - test_set: pd.DataFrame, + test_set: Optional[pd.DataFrame], init_w: List[float], n_epoch: int = 1, lr: float = 1e-2, @@ -267,7 +267,7 @@ def __init__( self.avg_eval_losses = [] self.loss_fn = nn.BCELoss(reduction="none") - def build_dataset(self, train_set: pd.DataFrame, test_set: pd.DataFrame): + def build_dataset(self, train_set: pd.DataFrame, test_set: Optional[pd.DataFrame]): pre_train_set = train_set[train_set["i"] == 2] self.pre_train_set = BatchDataset(pre_train_set, batch_size=self.batch_size) self.pre_train_data_loader = BatchLoader(self.pre_train_set) @@ -279,8 +279,11 @@ def build_dataset(self, train_set: pd.DataFrame, test_set: pd.DataFrame): self.train_set = BatchDataset(train_set, batch_size=self.batch_size) self.train_data_loader = BatchLoader(self.train_set) - self.test_set = BatchDataset(test_set, batch_size=self.batch_size) - self.test_data_loader = BatchLoader(self.test_set) + self.test_set = ( + [] + if test_set is None + else BatchDataset(test_set, batch_size=self.batch_size) + ) def train(self, verbose: bool = True): self.verbose = verbose @@ -333,33 +336,25 @@ def train(self, verbose: bool = True): def eval(self): self.model.eval() with torch.no_grad(): - sequences, delta_ts, labels, seq_lens = ( - self.train_set.x_train, - self.train_set.t_train, - self.train_set.y_train, - self.train_set.seq_len, - ) - real_batch_size = seq_lens.shape[0] - outputs, _ = self.model(sequences.transpose(0, 1)) - stabilities = outputs[seq_lens - 1, torch.arange(real_batch_size), 0] - retentions = power_forgetting_curve(delta_ts, stabilities) - train_loss = self.loss_fn(retentions, labels).mean() - if self.verbose: - tqdm.write(f"train loss: {train_loss:.6f}") - self.avg_train_losses.append(train_loss) - - sequences, delta_ts, labels, seq_lens = ( - self.test_set.x_train, - self.test_set.t_train, - self.test_set.y_train, - self.test_set.seq_len, - ) - real_batch_size = seq_lens.shape[0] - outputs, _ = self.model(sequences.transpose(0, 1)) - stabilities = outputs[seq_lens - 1, torch.arange(real_batch_size), 0] - retentions = power_forgetting_curve(delta_ts, stabilities) - test_loss = self.loss_fn(retentions, labels).mean() - self.avg_eval_losses.append(test_loss) + losses = [] + for dataset in (self.train_set, self.test_set): + if len(dataset) == 0: + losses.append(0) + continue + sequences, delta_ts, labels, seq_lens = ( + dataset.x_train, + dataset.t_train, + dataset.y_train, + dataset.seq_len, + ) + real_batch_size = seq_lens.shape[0] + outputs, _ = self.model(sequences.transpose(0, 1)) + stabilities = outputs[seq_lens - 1, torch.arange(real_batch_size), 0] + retentions = power_forgetting_curve(delta_ts, stabilities) + loss = self.loss_fn(retentions, labels).mean() + losses.append(loss) + self.avg_train_losses.append(losses[0]) + self.avg_eval_losses.append(losses[1]) w = list( map( @@ -369,7 +364,7 @@ def eval(self): ) weighted_loss = ( - train_loss * len(self.train_set) + test_loss * len(self.test_set) + losses[0] * len(self.train_set) + losses[1] * len(self.test_set) ) / (len(self.train_set) + len(self.test_set)) return weighted_loss, w