Skip to content

Commit

Permalink
Feat/extrapolate values of s0 (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock authored Aug 10, 2023
1 parent 3f944d9 commit e326562
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 61 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "FSRS-Optimizer"
version = "4.8.0"
version = "4.9.0"
readme = "README.md"
dependencies = [
"matplotlib>=3.7.0",
Expand Down
112 changes: 52 additions & 60 deletions src/fsrs_optimizer/fsrs_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,8 @@ def create_time_series(self, timezone: str, revlog_start_date: str, next_day_sta
df['i'] = df.groupby('card_id').cumcount() + 1
df.loc[df['i'] == 1, 'delta_t'] = 0
df = df.groupby('card_id').filter(lambda group: group['review_state'].iloc[0] == Learning)
if df.empty:
raise ValueError('Training data is inadequate.')
df['prev_review_state'] = df.groupby('card_id')['review_state'].shift(1).fillna(Learning).astype(int)
df['helper'] = ((df['review_state'] == Learning) & ((df['prev_review_state'] == Review) | (df['prev_review_state'] == Relearning)) & (df['i'] > 1)).astype(int)
df['helper'] = df.groupby('card_id')['helper'].cumsum()
Expand Down Expand Up @@ -467,6 +469,8 @@ def cal_stability(group: pd.DataFrame) -> pd.DataFrame:
return group

df = df.groupby(by=['r_history'], group_keys=False).progress_apply(cal_stability)
if df.empty:
return "No enough data for stability calculation."
tqdm.write("Stability calculated.")
df.reset_index(drop = True, inplace = True)
df.drop_duplicates(inplace=True)
Expand Down Expand Up @@ -505,7 +509,7 @@ def pretrain(self, verbose=True):
average_recall = self.dataset['y'].mean()
plots = []
s0_size = self.S0_dataset_group.shape[0]
rating_s0 = {
r_s0_default = {
"1": 0.4,
"2": 0.6,
"3": 2.4,
Expand All @@ -522,7 +526,7 @@ def pretrain(self, verbose=True):
count = group['y']['count']
total_count = sum(count)

init_s0 = rating_s0[first_rating]
init_s0 = r_s0_default[first_rating]

def loss(stability):
y_pred = power_forgetting_curve(delta_t, stability)
Expand Down Expand Up @@ -554,70 +558,58 @@ def loss(stability):
plots.append(fig)
tqdm.write(str(rating_stability))

for a, b in ((1, 2), (2, 3), (3, 4)):
if a in rating_stability and b in rating_stability:
if rating_stability[a] > rating_stability[b]:
if rating_count[a] > rating_count[b]:
rating_stability[b] = rating_stability[a]
for small_rating, big_rating in ((1, 2), (2, 3), (3, 4), (1, 3), (2, 4), (1, 4)):
if small_rating in rating_stability and big_rating in rating_stability:
if rating_stability[small_rating] > rating_stability[big_rating]:
if rating_count[small_rating] > rating_count[big_rating]:
rating_stability[big_rating] = rating_stability[small_rating]
else:
rating_stability[a] = rating_stability[b]
rating_stability[small_rating] = rating_stability[big_rating]

w1 = 3/5
w2 = 3/5

if len(rating_stability) == 0:
raise Exception("Not enough data for pretraining!")
elif len(rating_stability) == 1:
init_stability = round(list(rating_stability.values())[0], 2)
self.init_w[0:4] = [init_stability] * 4
elif len(rating_stability) == 4:
for rating, stability in rating_stability.items():
self.init_w[rating-1] = round(stability, 2)
tqdm.write(f"Pretrain finished!")
return plots

def S0_rating_curve(rating, a, b, c):
return np.exp(a + b * rating) + c

params, covs = curve_fit(S0_rating_curve, list(rating_stability.keys()), list(rating_stability.values()), sigma=1/np.sqrt(list(rating_count.values())), method='dogbox', bounds=((-15, 0.03, -5), (15, 7, 30)))
if verbose:
tqdm.write(f'Weighted fit parameters: {params}')
predict_stability = S0_rating_curve(np.array(list(rating_stability.keys())), *params)
tqdm.write(f"Fit stability: {predict_stability}")
tqdm.write(f'RMSE: {mean_squared_error(list(rating_stability.values()), predict_stability, sample_weight=list(rating_count.values()), squared=False):.4f}')
fig = plt.figure()
ax = fig.gca()
ax.plot(list(rating_stability.keys()), list(rating_stability.values()), label='Exact')
ax.plot(np.linspace(1, 4), S0_rating_curve(np.linspace(1, 4), *params), label='Weighted fit')
scatter_size = np.array([x/sum(rating_count.values()) for x in rating_count.values()]) * 1000
ax.scatter(list(rating_stability.keys()), list(rating_stability.values()), scatter_size, label='Exact', alpha=0.5)
ax.legend(loc='upper right', fancybox=True, shadow=False)
ax.grid(True)
ax.set_xlabel('First rating')
ax.set_ylabel('Stability')
ax.set_title('Stability for first rating')
plots.append(fig)

for rating in (1, 2, 3, 4):
again_extrap = max(min(S0_rating_curve(1, *params), 30), 0.1)
# if there isn't enough data to calculate the value for "Again" exactly
rating = list(rating_stability.keys())[0]
factor = rating_stability[rating] / r_s0_default[str(rating)]
init_s0 = list(map(lambda x: x * factor, r_s0_default.values()))
elif len(rating_stability) == 2:
if 1 not in rating_stability and 2 not in rating_stability:
rating_stability[2] = np.power(rating_stability[3], 1/(1-w2)) * np.power(rating_stability[4], 1-1/(1-w2))
rating_stability[1] = np.power(rating_stability[2], 1/w1) * np.power(rating_stability[3], 1-1/w1)
elif 1 not in rating_stability and 3 not in rating_stability:
rating_stability[3] = np.power(rating_stability[2], 1-w2) * np.power(rating_stability[4], w2)
rating_stability[1] = np.power(rating_stability[2], 1/w1) * np.power(rating_stability[3], 1-1/w1)
elif 1 not in rating_stability and 4 not in rating_stability:
rating_stability[4] = np.power(rating_stability[2], 1-1/w2) * np.power(rating_stability[3], 1/w2)
rating_stability[1] = np.power(rating_stability[2], 1/w1) * np.power(rating_stability[3], 1-1/w1)
elif 2 not in rating_stability and 3 not in rating_stability:
rating_stability[2] = np.power(rating_stability[1], w1/(w1+w2-w1*w2)) * np.power(rating_stability[4], 1 - w1/(w1+w2-w1*w2))
rating_stability[3] = np.power(rating_stability[1], 1 - w2/(w1+w2-w1*w2)) * np.power(rating_stability[4], w2/(w1+w2-w1*w2))
elif 2 not in rating_stability and 4 not in rating_stability:
rating_stability[2] = np.power(rating_stability[1], w1) * np.power(rating_stability[3], 1-w1)
rating_stability[4] = np.power(rating_stability[2], 1-1/w2) * np.power(rating_stability[3], 1/w2)
elif 3 not in rating_stability and 4 not in rating_stability:
rating_stability[3] = np.power(rating_stability[1], 1-1/(1-w1)) * np.power(rating_stability[2], 1/(1-w1))
rating_stability[4] = np.power(rating_stability[2], 1-1/w2) * np.power(rating_stability[3], 1/w2)
init_s0 = [item[1] for item in sorted(rating_stability.items(), key=lambda x: x[0])]
elif len(rating_stability) == 3:
if 1 not in rating_stability:
# then check if there exists an exact value for "Hard"
if 2 in rating_stability:
# if it exists, then check whether the extrapolation breaks monotonicity
# Again > Hard is possible, but we should allow it only for exact values, otherwise we should assume monotonicity
if again_extrap > rating_stability[2]:
# if it does, then replace the missing "Again" value with the exact "Hard" value
rating_stability[1] = rating_stability[2]
else:
# if it doesn't break monotonicity, then use the extrapolated value
rating_stability[1] = again_extrap
# if an exact value for "Hard" doesn't exist, then just use the extrapolation, there's nothing else we can do
else:
rating_stability[1] = again_extrap
elif rating not in rating_stability:
rating_stability[rating] = max(min(S0_rating_curve(rating, *params), 30), 0.1)

rating_stability = {k: round(v, 2) for k, v in sorted(rating_stability.items(), key=lambda item: item[0])}
for rating, stability in rating_stability.items():
self.init_w[rating-1] = stability
rating_stability[1] = np.power(rating_stability[2], 1/w1) * np.power(rating_stability[3], 1-1/w1)
elif 2 not in rating_stability:
rating_stability[2] = np.power(rating_stability[1], w1) * np.power(rating_stability[3], 1-w1)
elif 3 not in rating_stability:
rating_stability[3] = np.power(rating_stability[2], 1-w2) * np.power(rating_stability[4], w2)
elif 4 not in rating_stability:
rating_stability[4] = np.power(rating_stability[2], 1-1/w2) * np.power(rating_stability[3], 1/w2)
init_s0 = [item[1] for item in sorted(rating_stability.items(), key=lambda x: x[0])]
elif len(rating_stability) == 4:
init_s0 = [item[1] for item in sorted(rating_stability.items(), key=lambda x: x[0])]

self.init_w[0:4] = init_s0

tqdm.write(f"Pretrain finished!")
return plots

Expand Down

0 comments on commit e326562

Please sign in to comment.