Skip to content

Commit

Permalink
Expt/regularization of parameters (#157)
Browse files Browse the repository at this point in the history
* add regularization of parameters

* add gamma for regularization

* L2 regularization

* apply L2 regularization based on init_w

* update default gamma

* update hyper-parameters & bump version
  • Loading branch information
L-M-Sherlock authored Jan 10, 2025
1 parent df93cfe commit a2bbc6b
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "FSRS-Optimizer"
version = "5.6.5"
version = "5.7.0"
readme = "README.md"
dependencies = [
"matplotlib>=3.7.0",
Expand Down
45 changes: 42 additions & 3 deletions src/fsrs_optimizer/fsrs_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,32 @@
S_MIN = 0.01


DEFAULT_PARAMS_STDDEV_TENSOR = torch.tensor(
[
6.61,
9.52,
17.69,
27.74,
0.55,
0.28,
0.67,
0.12,
0.4,
0.18,
0.34,
0.27,
0.08,
0.14,
0.57,
0.25,
1.03,
0.27,
0.39,
],
dtype=torch.float,
)


class FSRS(nn.Module):
def __init__(self, w: List[float], float_delta_t: bool = False):
super(FSRS, self).__init__()
Expand Down Expand Up @@ -299,9 +325,10 @@ def __init__(
train_set: pd.DataFrame,
test_set: Optional[pd.DataFrame],
init_w: List[float],
n_epoch: int = 1,
lr: float = 1e-2,
batch_size: int = 256,
n_epoch: int = 5,
lr: float = 4e-2,
gamma: float = 2,
batch_size: int = 512,
max_seq_len: int = 64,
float_delta_t: bool = False,
enable_short_term: bool = True,
Expand All @@ -310,8 +337,10 @@ def __init__(
init_w[17] = 0
init_w[18] = 0
self.model = FSRS(init_w, float_delta_t)
self.init_w_tensor = torch.tensor(init_w, dtype=torch.float)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
self.clipper = ParameterClipper()
self.gamma = gamma
self.batch_size = batch_size
self.max_seq_len = max_seq_len
self.build_dataset(train_set, test_set)
Expand Down Expand Up @@ -362,6 +391,11 @@ def train(self, verbose: bool = True):
stabilities = outputs[seq_lens - 1, torch.arange(real_batch_size), 0]
retentions = power_forgetting_curve(delta_ts, stabilities)
loss = (self.loss_fn(retentions, labels) * weights).sum()
penalty = torch.sum(
torch.square(self.model.w - self.init_w_tensor)
/ torch.square(DEFAULT_PARAMS_STDDEV_TENSOR)
)
loss += penalty * self.gamma * real_batch_size / epoch_len
loss.backward()
if self.float_delta_t or not self.enable_short_term:
for param in self.model.parameters():
Expand Down Expand Up @@ -412,6 +446,11 @@ def eval(self):
stabilities = outputs[seq_lens - 1, torch.arange(real_batch_size), 0]
retentions = power_forgetting_curve(delta_ts, stabilities)
loss = (self.loss_fn(retentions, labels) * weights).mean()
penalty = torch.sum(
torch.square(self.model.w - self.init_w_tensor)
/ torch.square(DEFAULT_PARAMS_STDDEV_TENSOR)
)
loss += penalty * self.gamma / len(self.train_set.y_train)
losses.append(loss)
self.avg_train_losses.append(losses[0])
self.avg_eval_losses.append(losses[1])
Expand Down

0 comments on commit a2bbc6b

Please sign in to comment.