Skip to content

Commit 4032352

Browse files
authored
Feat/linear damping && w[7] >= 0.001 (#143)
* Feat/linear damping * update default parameters * update formula of simulator * add more tests * add test for loss and grad * update ParameterClipper * bump version
1 parent 6f1d4a9 commit 4032352

File tree

5 files changed

+181
-53
lines changed

5 files changed

+181
-53
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "FSRS-Optimizer"
7-
version = "5.2.1"
7+
version = "5.2.2"
88
readme = "README.md"
99
dependencies = [
1010
"matplotlib>=3.7.0",

src/fsrs_optimizer/fsrs_optimizer.py

Lines changed: 36 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66
import os
77
import math
8-
from typing import List, Optional
8+
from typing import List, Optional, Tuple
99
from datetime import timedelta, datetime
1010
from collections import defaultdict
1111
import statsmodels.api as sm # type: ignore
@@ -42,25 +42,25 @@
4242
Relearning = 3
4343

4444
DEFAULT_PARAMETER = [
45-
0.4072,
46-
1.1829,
47-
3.1262,
48-
15.4722,
49-
7.2102,
50-
0.5316,
51-
1.0651,
52-
0.0234,
53-
1.616,
54-
0.1544,
55-
1.0824,
56-
1.9813,
57-
0.0953,
58-
0.2975,
59-
2.2042,
60-
0.2407,
61-
2.9466,
62-
0.5034,
63-
0.6567,
45+
0.40255,
46+
1.18385,
47+
3.173,
48+
15.69105,
49+
7.1949,
50+
0.5345,
51+
1.4604,
52+
0.0046,
53+
1.54575,
54+
0.1192,
55+
1.01925,
56+
1.9395,
57+
0.11,
58+
0.29605,
59+
2.2698,
60+
0.2315,
61+
2.9898,
62+
0.51655,
63+
0.6621,
6464
]
6565

6666
S_MIN = 0.01
@@ -105,8 +105,12 @@ def init_d(self, rating: Tensor) -> Tensor:
105105
new_d = self.w[4] - torch.exp(self.w[5] * (rating - 1)) + 1
106106
return new_d
107107

108+
def linear_damping(self, delta_d: Tensor, old_d: Tensor) -> Tensor:
109+
return delta_d * (10 - old_d) / 9
110+
108111
def next_d(self, state: Tensor, rating: Tensor) -> Tensor:
109-
new_d = state[:, 1] - self.w[6] * (rating - 3)
112+
delta_d = -self.w[6] * (rating - 3)
113+
new_d = state[:, 1] + self.linear_damping(delta_d, state[:, 1])
110114
new_d = self.mean_reversion(self.init_d(4), new_d)
111115
return new_d
112116

@@ -151,7 +155,9 @@ def step(self, X: Tensor, state: Tensor) -> Tensor:
151155
new_s = new_s.clamp(S_MIN, 36500)
152156
return torch.stack([new_s, new_d], dim=1)
153157

154-
def forward(self, inputs: Tensor, state: Optional[Tensor] = None) -> Tensor:
158+
def forward(
159+
self, inputs: Tensor, state: Optional[Tensor] = None
160+
) -> Tuple[Tensor, Tensor]:
155161
"""
156162
:param inputs: shape[seq_len, batch_size, 2]
157163
"""
@@ -179,16 +185,16 @@ def __call__(self, module):
179185
w[2] = w[2].clamp(S_MIN, 100)
180186
w[3] = w[3].clamp(S_MIN, 100)
181187
w[4] = w[4].clamp(1, 10)
182-
w[5] = w[5].clamp(0.01, 4)
183-
w[6] = w[6].clamp(0.01, 4)
184-
w[7] = w[7].clamp(0, 0.75)
188+
w[5] = w[5].clamp(0.001, 4)
189+
w[6] = w[6].clamp(0.001, 4)
190+
w[7] = w[7].clamp(0.001, 0.75)
185191
w[8] = w[8].clamp(0, 4.5)
186192
w[9] = w[9].clamp(0, 0.8)
187-
w[10] = w[10].clamp(0.01, 3.5)
188-
w[11] = w[11].clamp(0.1, 5)
189-
w[12] = w[12].clamp(0.01, 0.25)
190-
w[13] = w[13].clamp(0.01, 0.9)
191-
w[14] = w[14].clamp(0.01, 4)
193+
w[10] = w[10].clamp(0.001, 3.5)
194+
w[11] = w[11].clamp(0.001, 5)
195+
w[12] = w[12].clamp(0.001, 0.25)
196+
w[13] = w[13].clamp(0.001, 0.9)
197+
w[14] = w[14].clamp(0, 4)
192198
w[15] = w[15].clamp(0, 1)
193199
w[16] = w[16].clamp(1, 6)
194200
w[17] = w[17].clamp(0, 2)
@@ -2075,22 +2081,3 @@ def wrap_short_term_ratings(r_history, t_history):
20752081
else:
20762082
result.pop()
20772083
return "".join(result)
2078-
2079-
2080-
if __name__ == "__main__":
2081-
model = FSRS(DEFAULT_PARAMETER)
2082-
stability = torch.tensor([5.0] * 4)
2083-
difficulty = torch.tensor([1.0, 2.0, 3.0, 4.0])
2084-
retention = torch.tensor([0.9, 0.8, 0.7, 0.6])
2085-
rating = torch.tensor([1, 2, 3, 4])
2086-
state = torch.stack([stability, difficulty]).unsqueeze(0)
2087-
s_recall = model.stability_after_success(state, retention, rating)
2088-
print(s_recall)
2089-
s_forget = model.stability_after_failure(state, retention)
2090-
print(s_forget)
2091-
2092-
retentions = torch.tensor([0.1, 0.2, 0.3, 0.4])
2093-
labels = torch.tensor([0.0, 1.0, 0.0, 1.0])
2094-
loss_fn = nn.BCELoss()
2095-
loss = loss_fn(retentions, labels)
2096-
print(loss)

src/fsrs_optimizer/fsrs_simulator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,12 @@ def init_d_with_short_term(rating):
127127
new_d = init_d(rating) - w[6] * rating_offset
128128
return np.clip(new_d, 1, 10)
129129

130+
def linear_damping(delta_d, old_d):
131+
return delta_d * (10 - old_d) / 9
132+
130133
def next_d(d, rating):
131-
new_d = d - w[6] * (rating - 3)
134+
delta_d = -w[6] * (rating - 3)
135+
new_d = d + linear_damping(delta_d, d)
132136
new_d = mean_reversion(init_d(4), new_d)
133137
return np.clip(new_d, 1, 10)
134138

tests/model_test.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
from src.fsrs_optimizer import *
2+
3+
4+
class Test_Model:
5+
def test_next_stability(self):
6+
model = FSRS(DEFAULT_PARAMETER)
7+
stability = torch.tensor([5.0] * 4)
8+
difficulty = torch.tensor([1.0, 2.0, 3.0, 4.0])
9+
retention = torch.tensor([0.9, 0.8, 0.7, 0.6])
10+
rating = torch.tensor([1, 2, 3, 4])
11+
state = torch.stack([stability, difficulty]).unsqueeze(0)
12+
s_recall = model.stability_after_success(state, retention, rating)
13+
assert torch.allclose(
14+
s_recall, torch.tensor([25.7761, 14.1219, 60.4044, 208.9760]), atol=1e-4
15+
)
16+
s_forget = model.stability_after_failure(state, retention)
17+
assert torch.allclose(
18+
s_forget, torch.tensor([1.7029, 1.9799, 2.3760, 2.8885]), atol=1e-4
19+
)
20+
s_short_term = model.stability_short_term(state, rating)
21+
assert torch.allclose(
22+
s_short_term, torch.tensor([2.5051, 4.1992, 7.0389, 11.7988]), atol=1e-4
23+
)
24+
25+
def test_next_difficulty(self):
26+
model = FSRS(DEFAULT_PARAMETER)
27+
stability = torch.tensor([5.0] * 4)
28+
difficulty = torch.tensor([5.0] * 4)
29+
rating = torch.tensor([1, 2, 3, 4])
30+
state = torch.stack([stability, difficulty]).unsqueeze(0)
31+
d_recall = model.next_d(state, rating)
32+
assert torch.allclose(
33+
d_recall,
34+
torch.tensor([6.6070, 5.7994, 4.9918, 4.1842]),
35+
atol=1e-4,
36+
)
37+
38+
def test_power_forgetting_curve(self):
39+
delta_t = torch.tensor([0, 1, 2, 3, 4, 5])
40+
stability = torch.tensor([1, 2, 3, 4, 4, 2])
41+
retention = power_forgetting_curve(delta_t, stability)
42+
assert torch.allclose(
43+
retention,
44+
torch.tensor([1.0, 0.946059, 0.9299294, 0.9221679, 0.90000004, 0.79394597]),
45+
atol=1e-4,
46+
)
47+
48+
def test_forward(self):
49+
model = FSRS(DEFAULT_PARAMETER)
50+
delta_ts = torch.tensor(
51+
[
52+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
53+
[1.0, 1.0, 1.0, 1.0, 2.0, 2.0],
54+
]
55+
)
56+
ratings = torch.tensor(
57+
[
58+
[1.0, 2.0, 3.0, 4.0, 1.0, 2.0],
59+
[1.0, 2.0, 3.0, 4.0, 1.0, 2.0],
60+
]
61+
)
62+
inputs = torch.stack([delta_ts, ratings], dim=2)
63+
_, state = model.forward(inputs)
64+
stability = state[:, 0]
65+
difficulty = state[:, 1]
66+
assert torch.allclose(
67+
stability,
68+
torch.tensor([0.2619, 1.7073, 5.8691, 25.0123, 0.3403, 2.1482]),
69+
atol=1e-4,
70+
)
71+
assert torch.allclose(
72+
difficulty,
73+
torch.tensor([8.0827, 7.0405, 5.2729, 2.1301, 8.0827, 7.0405]),
74+
atol=1e-4,
75+
)
76+
77+
def test_loss_and_grad(self):
78+
model = FSRS(DEFAULT_PARAMETER)
79+
loss_fn = nn.BCELoss(reduction="none")
80+
t_histories = torch.tensor(
81+
[
82+
[0.0, 0.0, 0.0, 0.0],
83+
[0.0, 0.0, 0.0, 0.0],
84+
[0.0, 0.0, 0.0, 1.0],
85+
[0.0, 1.0, 1.0, 3.0],
86+
[1.0, 3.0, 3.0, 5.0],
87+
[3.0, 6.0, 6.0, 12.0],
88+
]
89+
)
90+
r_histories = torch.tensor(
91+
[
92+
[1.0, 2.0, 3.0, 4.0],
93+
[3.0, 4.0, 2.0, 4.0],
94+
[1.0, 4.0, 4.0, 3.0],
95+
[4.0, 3.0, 3.0, 3.0],
96+
[3.0, 1.0, 3.0, 3.0],
97+
[2.0, 3.0, 3.0, 4.0],
98+
]
99+
)
100+
delta_ts = torch.tensor([4.0, 11.0, 12.0, 23.0])
101+
labels = torch.tensor([1, 1, 1, 0], dtype=torch.float32, requires_grad=False)
102+
inputs = torch.stack([t_histories, r_histories], dim=2)
103+
seq_lens = inputs.shape[0]
104+
real_batch_size = inputs.shape[1]
105+
outputs, _ = model.forward(inputs)
106+
stabilities = outputs[seq_lens - 1, torch.arange(real_batch_size), 0]
107+
retentions = power_forgetting_curve(delta_ts, stabilities)
108+
loss = loss_fn(retentions, labels).sum()
109+
assert round(loss.item(), 4) == 4.4467
110+
loss.backward()
111+
assert torch.allclose(
112+
model.w.grad,
113+
torch.tensor(
114+
[
115+
-0.0583,
116+
-0.0068,
117+
-0.0026,
118+
0.0105,
119+
-0.0513,
120+
1.3643,
121+
0.0837,
122+
-0.9502,
123+
0.5345,
124+
-2.8929,
125+
0.5142,
126+
-0.0131,
127+
0.0419,
128+
-0.1183,
129+
-0.0009,
130+
-0.1445,
131+
0.2024,
132+
0.2141,
133+
0.0323,
134+
]
135+
),
136+
atol=1e-4,
137+
)

tests/simulator_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def test_simulate(self):
1010
memorized_cnt_per_day,
1111
cost_per_day,
1212
) = simulate(w=DEFAULT_PARAMETER, request_retention=0.9)
13-
assert memorized_cnt_per_day[-1] == 5918.574208243532
13+
assert memorized_cnt_per_day[-1] == 5875.025236206539
1414

1515
def test_optimal_retention(self):
1616
default_params = {
@@ -24,4 +24,4 @@ def test_optimal_retention(self):
2424
"loss_aversion": 2.5,
2525
}
2626
r = optimal_retention(**default_params)
27-
assert r == 0.8346739534878145
27+
assert r == 0.8263932

0 commit comments

Comments
 (0)