Skip to content

Commit 4881e21

Browse files
improve code
1 parent 249cdfc commit 4881e21

File tree

5 files changed

+113
-91
lines changed

5 files changed

+113
-91
lines changed

src/mypackage/baselines/baselines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def _common(stats_per_item: DataFrame, val: DataFrame):
4646
val = val.merge(stats_per_item, "inner", "item")
4747

4848
# Compute aucroc and ndcg for different scenarios
49-
indexes = torch.tensor(val["user"])
49+
indexes = torch.tensor(val["list_id"])
5050
targets = torch.tensor(val["target"])
5151
worst_preds = torch.tensor((val["target"] + 1) % 2, dtype=torch.float32)
5252
random_preds = torch.rand(val.shape[0])

src/mypackage/data_processing/data_processing.py

Lines changed: 53 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -47,32 +47,17 @@ def process_simple(
4747
train = interactions[interactions["timestamp"] < split_date]
4848
val = interactions[interactions["timestamp"] >= split_date]
4949

50-
# Keep train users with condition 0 < mean_target < 1
51-
train_active_users = (
52-
train.groupby("user")
50+
# Keep lists with condition 0 < mean_target < 1
51+
train_valid_lists = (
52+
train.groupby("list_id")
5353
.agg({"target": "mean"})
5454
.rename(columns={"target": "mean"})
5555
.reset_index()
5656
)
57-
train_active_users = train_active_users[
58-
(train_active_users["mean"] > 0) & (train_active_users["mean"] < 1)
57+
train_valid_lists = train_valid_lists[
58+
(train_valid_lists["mean"] > 0) & (train_valid_lists["mean"] < 1)
5959
]
60-
train = train.merge(train_active_users, "inner", "user")
61-
62-
# Limit val users to these who occured in train data and also keep val users with 0 < mean_target < 1
63-
# This way val users is a subset of train users (no cold users in val!)
64-
# and we have train users with number of clicks at least 1 and val users with number of clicks at least 2
65-
# val = val.merge(train_active_users, "inner", "user")
66-
val_active_users = (
67-
val.groupby("user")
68-
.agg({"target": "mean"})
69-
.rename(columns={"target": "mean"})
70-
.reset_index()
71-
)
72-
val_active_users = val_active_users[
73-
(val_active_users["mean"] > 0) & (val_active_users["mean"] < 1)
74-
]
75-
val = val.merge(val_active_users, "inner", "user")
60+
train = train.merge(train_valid_lists, "inner", "list_id")
7661

7762
# Prepare user/item to idx mappers based on train data
7863
unique_train_users = train["user"].unique()
@@ -84,20 +69,32 @@ def process_simple(
8469
{"item": unique_train_items, "item_idx": np.arange(unique_train_items.size)}
8570
)
8671

87-
# Map user/item to idx
72+
# Map user/item to idx - it removes cold users and items from validation
8873
train = train.merge(user_mapper, on="user", how="inner")
8974
train = train.merge(item_mapper, on="item", how="inner")
9075
val = val.merge(user_mapper, on="user", how="inner")
9176
val = val.merge(item_mapper, on="item", how="inner")
9277

78+
# Keep lists with condition 0 < mean_target < 1
79+
val_valid_lists = (
80+
val.groupby("list_id")
81+
.agg({"target": "mean"})
82+
.rename(columns={"target": "mean"})
83+
.reset_index()
84+
)
85+
val_valid_lists = val_valid_lists[
86+
(val_valid_lists["mean"] > 0) & (val_valid_lists["mean"] < 1)
87+
]
88+
val = val.merge(val_valid_lists, "inner", "list_id")
89+
9390
train = train.sort_values("timestamp").reset_index(drop=True)
9491
val = val.sort_values("timestamp").reset_index(drop=True)
9592

9693
# Select valid columns
97-
train = train[["timestamp", "user_idx", "item_idx", "target"]]
98-
train.columns = ["timestamp", "user", "item", "target"]
99-
val = val[["timestamp", "user_idx", "item_idx", "target"]]
100-
val.columns = ["timestamp", "user", "item", "target"]
94+
train = train[["timestamp", "list_id", "user_idx", "item_idx", "target"]]
95+
train.columns = ["timestamp", "list_id", "user", "item", "target"]
96+
val = val[["timestamp", "list_id", "user_idx", "item_idx", "target"]]
97+
val.columns = ["timestamp", "list_id", "user", "item", "target"]
10198

10299
# Mock test_data
103100
test = val.copy() # test set == validation set (should be changed in the future!)
@@ -108,11 +105,13 @@ def process_simple(
108105
stats = {}
109106
stats["train_n_users"] = unique_train_users.size
110107
stats["train_n_items"] = unique_train_items.size
108+
stats["train_n_lists"] = train["list_id"].nunique()
111109
stats["train_n_clicks"] = int(train["target"].sum())
112110
stats["train_n_impressions"] = len(train) - stats["train_n_clicks"]
113111
stats["train_ctr"] = stats["train_n_clicks"] / stats["train_n_impressions"]
114112
stats["val_n_users"] = unique_val_users.size
115113
stats["val_n_items"] = unique_val_items.size
114+
stats["val_n_lists"] = val["list_id"].nunique()
116115
stats["val_n_clicks"] = int(val["target"].sum())
117116
stats["val_n_impressions"] = len(val) - stats["val_n_clicks"]
118117
stats["val_ctr"] = stats["val_n_clicks"] / stats["val_n_impressions"]
@@ -134,17 +133,6 @@ def process_bpr(
134133
train = tmp0.merge(tmp1, "inner", "user", suffixes=("_neg", "_pos"))
135134
val = interactions[interactions["timestamp"] >= split_date]
136135

137-
val_active_users = (
138-
val.groupby("user")
139-
.agg({"target": "mean"})
140-
.rename(columns={"target": "mean"})
141-
.reset_index()
142-
)
143-
val_active_users = val_active_users[
144-
(val_active_users["mean"] > 0) & (val_active_users["mean"] < 1)
145-
]
146-
val = val.merge(val_active_users, "inner", "user")
147-
148136
# Prepare user/item to idx mappers based on train data
149137
unique_train_users = train["user"].unique()
150138
# unique_users = train["user"].unique()
@@ -175,9 +163,21 @@ def process_bpr(
175163

176164
val = val.merge(user_mapper, on="user", how="inner")
177165
val = val.merge(item_mapper, on="item", how="inner")
178-
val = val[["user_idx", "item_idx", "target"]].rename(
179-
columns={"user_idx": "user", "item_idx": "item"}
166+
167+
# Keep lists with condition 0 < mean_target < 1
168+
val_valid_lists = (
169+
val.groupby("list_id")
170+
.agg({"target": "mean"})
171+
.rename(columns={"target": "mean"})
172+
.reset_index()
180173
)
174+
val_valid_lists = val_valid_lists[
175+
(val_valid_lists["mean"] > 0) & (val_valid_lists["mean"] < 1)
176+
]
177+
val = val.merge(val_valid_lists, "inner", "list_id")
178+
179+
val = val[["timestamp", "list_id", "user_idx", "item_idx", "target"]]
180+
val = val.rename(columns={"user_idx": "user", "item_idx": "item"})
181181

182182
# Mock test_data
183183
test = val.copy() # test set == validation set (to change in the future!)
@@ -190,6 +190,7 @@ def process_bpr(
190190
stats["train_n_items"] = unique_train_items.size
191191
stats["val_n_users"] = unique_val_users.size
192192
stats["val_n_items"] = unique_val_items.size
193+
stats["val_n_lists"] = val["list_id"].nunique()
193194
stats["val_n_clicks"] = int(val["target"].sum())
194195
stats["val_n_impressions"] = len(val) - stats["val_n_clicks"]
195196
stats["val_ctr"] = stats["val_n_clicks"] / stats["val_n_impressions"]
@@ -257,6 +258,11 @@ def _common(
257258
# Join positive interactions (clicks) with negative interactions (impressions)
258259
interactions = interactions.merge(impressions_dl, "inner", "recommendation_id")
259260

261+
# Create unique id per (recommandation_id, user_id) pairs
262+
interactions["list_id"] = pd.factorize(
263+
interactions[["recommendation_id", "user_id"]].apply(tuple, axis=1)
264+
)[0]
265+
260266
# Mark positive interactions with 1 and negative with 0
261267
interactions["target"] = np.where(
262268
interactions["series_id"] == interactions["recommended_series_list"],
@@ -266,8 +272,14 @@ def _common(
266272
interactions["target"] = interactions["target"].astype("int32")
267273

268274
interactions = interactions[
269-
["utc_ts_milliseconds", "user_id", "recommended_series_list", "target"]
275+
[
276+
"utc_ts_milliseconds",
277+
"list_id",
278+
"user_id",
279+
"recommended_series_list",
280+
"target",
281+
]
270282
]
271-
interactions.columns = ["timestamp", "user", "item", "target"]
283+
interactions.columns = ["timestamp", "list_id", "user", "item", "target"]
272284

273285
return interactions

src/mypackage/training/datamodules/datamodule.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@ def setup(self, stage: Optional[str] = None):
2727
if stage == "fit":
2828
self.train_dataset = SimpleDataset(self.hparams.train)
2929
self.val_dataset = SimpleDataset(self.hparams.val)
30-
# self.val_dataset = UserGroupedDataset(self.hparams.val)
30+
# self.val_dataset = ListGroupedDataset(self.hparams.val)
3131

3232
if stage == "test":
3333
self.test_dataset = SimpleDataset(self.hparams.test)
34-
# self.test_dataset = UserGroupedDataset(self.hparams.test)
34+
# self.test_dataset = ListGroupedDataset(self.hparams.test)
3535

3636
def train_dataloader(self):
3737
return DataLoader(
@@ -52,7 +52,7 @@ def val_dataloader(self):
5252
)
5353
# return DataLoader(
5454
# dataset=self.val_dataset,
55-
# batch_sampler=UserBatchSampler(self.val_dataset),
55+
# batch_sampler=ListBatchSampler(self.val_dataset),
5656
# collate_fn=collate_fn,
5757
# num_workers=self.hparams.num_workers,
5858
# pin_memory=self.hparams.pin_memory,

src/mypackage/training/datamodules/dataset.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
class SimpleDataset(Dataset):
99
def __init__(self, data: pd.DataFrame):
10+
self.list_ids = torch.tensor(data["list_id"].to_numpy())
1011
self.users = torch.tensor(data["user"].to_numpy())
1112
self.items = torch.tensor(data["item"].to_numpy())
1213
self.targets = torch.tensor(data["target"].to_numpy(), dtype=torch.float32)
@@ -15,47 +16,48 @@ def __len__(self):
1516
return len(self.users)
1617

1718
def __getitem__(self, idx: int):
18-
return self.users[idx], self.items[idx], self.targets[idx]
19+
return self.list_ids[idx], self.users[idx], self.items[idx], self.targets[idx]
1920

2021

21-
class UserGroupedDataset(Dataset):
22+
class ListGroupedDataset(Dataset):
2223
def __init__(self, data: pd.DataFrame):
2324
self.data = data
24-
self.unique_users = list(self.data["user"].unique())
25-
self.user_groups = {
26-
user: data[data["user"] == user].index.tolist()
27-
for user in self.unique_users
28-
} # indices per user
25+
self.unique_list_ids = list(self.data["list_id"].unique())
26+
self.list_id_groups = {
27+
list_id: data[data["list_id"] == list_id].index.tolist()
28+
for list_id in self.unique_list_ids
29+
} # indices per list_id
2930

3031
def __len__(self):
31-
return len(self.unique_users)
32+
return len(self.unique_list_ids)
3233

3334
def __getitem__(self, idx):
34-
user = self.unique_users[idx] # Get user at index
35-
user_indices = self.user_groups[user] # Get all rows for this user
36-
user_data = self.data.iloc[user_indices] # Fetch data for this user
35+
list_id = self.unique_list_ids[idx] # Get list_id at index
36+
list_id_indices = self.list_id_groups[list_id] # Get all rows for this list_id
37+
list_id_data = self.data.iloc[list_id_indices] # Fetch data for this list_id
3738

38-
users = torch.tensor(user_data["user"].to_numpy())
39-
items = torch.tensor(user_data["item"].to_numpy())
40-
targets = torch.tensor(user_data["target"].to_numpy(), dtype=torch.float32)
41-
return users, items, targets # Return user-wise batch
39+
list_ids = torch.tensor(list_id_data["list_id"].to_numpy())
40+
users = torch.tensor(list_id_data["user"].to_numpy())
41+
items = torch.tensor(list_id_data["item"].to_numpy())
42+
targets = torch.tensor(list_id_data["target"].to_numpy(), dtype=torch.float32)
43+
return list_ids, users, items, targets # Return user-wise batch
4244

4345

44-
class UserBatchSampler(Sampler):
46+
class ListBatchSampler(Sampler):
4547
def __init__(self, dataset):
46-
self.unique_users = dataset.unique_users
48+
self.unique_list_ids = dataset.unique_list_ids
4749

4850
def __iter__(self):
49-
for i in range(len(self.unique_users)):
50-
yield [i] # Yield dataset indices for each user
51+
for i in range(len(self.unique_list_ids)):
52+
yield [i] # Yield dataset indices for each list_id
5153

5254
def __len__(self):
53-
return len(self.unique_users)
55+
return len(self.unique_list_ids)
5456

5557

5658
def collate_fn(batch):
57-
users, items, targets = zip(*batch)
58-
return users[0], items[0], targets[0]
59+
list_ids, users, items, targets = zip(*batch)
60+
return list_ids[0], users[0], items[0], targets[0]
5961

6062

6163
class CustomIterableDataset(IterableDataset):

0 commit comments

Comments
 (0)