Skip to content

Commit

Permalink
Fix/export remove outliers (#74)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock authored Jan 5, 2024
1 parent eb6a812 commit 9519811
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 46 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "FSRS-Optimizer"
version = "4.20.6"
version = "4.20.7"
readme = "README.md"
dependencies = [
"matplotlib>=3.7.0",
Expand Down
87 changes: 42 additions & 45 deletions src/fsrs_optimizer/fsrs_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,48 @@ def batch_predict(self, dataset):
return stabilities.tolist(), difficulties.tolist()


"""Used to store all the results from FSRS related functions"""
def remove_outliers(group: pd.DataFrame) -> pd.DataFrame:
# threshold = np.mean(group['delta_t']) * 1.5
# threshold = group['delta_t'].quantile(0.95)
# Q1 = group['delta_t'].quantile(0.25)
# Q3 = group['delta_t'].quantile(0.75)
# IQR = Q3 - Q1
# threshold = Q3 + 1.5 * IQR
# group = group[group['delta_t'] <= threshold]
grouped_group = (
group.groupby(by=["r_history", "delta_t"], group_keys=False)
.agg({"y": ["mean", "count"]})
.reset_index()
)
sort_index = grouped_group.sort_values(
by=[("y", "count"), "delta_t"], ascending=[True, False]
).index

total = sum(grouped_group[("y", "count")])
if total <= 20:
return pd.DataFrame()
has_been_removed = 0
for i in sort_index:
count = grouped_group.loc[i, ("y", "count")]
if has_been_removed + count >= max(total * 0.05, 20):
break
has_been_removed += count
group = group[
group["delta_t"].isin(
grouped_group[(grouped_group[("y", "count")] >= max(count, 6))]["delta_t"]
)
& (group["delta_t"] <= (100 if group.name[0] != "4" else 365))
]
return group


def remove_non_continuous_rows(group):
discontinuity = group["i"].diff().fillna(1).ne(1)
if not discontinuity.any():
return group
else:
first_non_continuous_index = discontinuity.idxmax()
return group.loc[: first_non_continuous_index - 1]


class Optimizer:
Expand Down Expand Up @@ -629,57 +670,13 @@ def cum_concat(x):
].copy()
df["y"] = df["review_rating"].map(lambda x: {1: 0, 2: 1, 3: 1, 4: 1}[x])

def remove_outliers(group: pd.DataFrame) -> pd.DataFrame:
# threshold = np.mean(group['delta_t']) * 1.5
# threshold = group['delta_t'].quantile(0.95)
# Q1 = group['delta_t'].quantile(0.25)
# Q3 = group['delta_t'].quantile(0.75)
# IQR = Q3 - Q1
# threshold = Q3 + 1.5 * IQR
# group = group[group['delta_t'] <= threshold]
grouped_group = (
group.groupby(by=["r_history", "delta_t"], group_keys=False)
.agg({"y": ["mean", "count"]})
.reset_index()
)
sort_index = grouped_group.sort_values(
by=[("y", "count"), "delta_t"], ascending=[True, False]
).index

total = sum(grouped_group[("y", "count")])
if total <= 20:
return pd.DataFrame()
has_been_removed = 0
for i in sort_index:
count = grouped_group.loc[i, ("y", "count")]
if has_been_removed + count >= max(total * 0.05, 20):
break
has_been_removed += count
group = group[
group["delta_t"].isin(
grouped_group[(grouped_group[("y", "count")] >= max(count, 6))][
"delta_t"
]
)
& (group["delta_t"] <= (100 if group.name[0] != "4" else 365))
]
return group

df[df["i"] == 2] = (
df[df["i"] == 2]
.groupby(by=["r_history", "t_history"], as_index=False, group_keys=False)
.apply(remove_outliers)
)
df.dropna(inplace=True)

def remove_non_continuous_rows(group):
discontinuity = group["i"].diff().fillna(1).ne(1)
if not discontinuity.any():
return group
else:
first_non_continuous_index = discontinuity.idxmax()
return group.loc[: first_non_continuous_index - 1]

df = df.groupby("card_id", as_index=False, group_keys=False).progress_apply(
remove_non_continuous_rows
)
Expand Down

0 comments on commit 9519811

Please sign in to comment.