Skip to content

Commit 3295fee

Browse files
authored
Merge pull request #128 from zerolovesea/inbatchsample
Inbatchsample
2 parents b78447e + 5684437 commit 3295fee

File tree

8 files changed

+440
-173
lines changed

8 files changed

+440
-173
lines changed

tests/test_inbatch_sampling.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import numpy as np
2+
import pandas as pd
3+
import torch
4+
5+
from torch_rechub.basic.features import SequenceFeature, SparseFeature
6+
from torch_rechub.models.matching import DSSM
7+
from torch_rechub.trainers import MatchTrainer
8+
from torch_rechub.utils.data import MatchDataGenerator, df_to_dict
9+
from torch_rechub.utils.match import gather_inbatch_logits, gen_model_input, generate_seq_feature_match, inbatch_negative_sampling
10+
11+
12+
def test_inbatch_negative_sampling_random_and_uniform():
13+
scores = torch.zeros((4, 4))
14+
neg_idx = inbatch_negative_sampling(scores, neg_ratio=2, generator=torch.Generator().manual_seed(0))
15+
logits = gather_inbatch_logits(scores, neg_idx)
16+
assert logits.shape == (4, 3)
17+
assert neg_idx.shape == (4, 2)
18+
for row, sampled in enumerate(neg_idx):
19+
assert row not in sampled.tolist()
20+
21+
# Different seed should give different permutations to ensure randomness
22+
neg_idx_second = inbatch_negative_sampling(scores, neg_ratio=2, generator=torch.Generator().manual_seed(1))
23+
assert not torch.equal(neg_idx, neg_idx_second)
24+
25+
26+
def test_inbatch_negative_sampling_hard_negative():
27+
scores = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 0.0]])
28+
neg_idx = inbatch_negative_sampling(scores, neg_ratio=1, hard_negative=True)
29+
# highest non-diagonal scores for each row
30+
assert torch.equal(neg_idx.squeeze(1), torch.tensor([2, 2, 1]))
31+
32+
33+
def _build_small_match_dataloader():
34+
n_users, n_items, n_samples = 12, 24, 80
35+
data = pd.DataFrame({
36+
"user_id": np.random.randint(0,
37+
n_users,
38+
n_samples),
39+
"item_id": np.random.randint(0,
40+
n_items,
41+
n_samples),
42+
"time": np.arange(n_samples),
43+
})
44+
user_profile = pd.DataFrame({"user_id": np.arange(n_users)})
45+
item_profile = pd.DataFrame({"item_id": np.arange(n_items)})
46+
47+
df_train, _ = generate_seq_feature_match(data, "user_id", "item_id", "time", mode=0, neg_ratio=0)
48+
x_train = gen_model_input(df_train, user_profile, "user_id", item_profile, "item_id", seq_max_len=8)
49+
# labels are unused in in-batch mode; keep zero array for shape alignment
50+
y_train = np.zeros(len(df_train))
51+
52+
user_features = [
53+
SparseFeature("user_id",
54+
n_users,
55+
embed_dim=8),
56+
SequenceFeature("hist_item_id",
57+
n_items,
58+
embed_dim=8,
59+
pooling="mean",
60+
shared_with="item_id"),
61+
]
62+
item_features = [SparseFeature("item_id", n_items, embed_dim=8)]
63+
64+
dg = MatchDataGenerator(x_train, y_train)
65+
train_dl, _, _ = dg.generate_dataloader(x_train, df_to_dict(item_profile), batch_size=8, num_workers=0)
66+
67+
model = DSSM(user_features, item_features, user_params={"dims": [16]}, item_params={"dims": [16]})
68+
return train_dl, model
69+
70+
71+
def test_match_trainer_inbatch_flow_runs_and_updates():
72+
train_dl, model = _build_small_match_dataloader()
73+
74+
trainer = MatchTrainer(model, mode=0, in_batch_neg=True, in_batch_neg_ratio=3, sampler_seed=2, n_epoch=1, device="cpu")
75+
trainer.train_one_epoch(train_dl, log_interval=100)
76+
77+
grads = [p.grad for p in model.parameters() if p.requires_grad]
78+
assert any(g is not None for g in grads)

torch_rechub/basic/loss_func.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ def __init__(self, margin=2, num_items=None):
8181
self.margin = margin
8282
self.n_items = num_items
8383

84-
def forward(self, pos_score, neg_score):
84+
def forward(self, pos_score, neg_score, in_batch_neg=False):
85+
pos_score = pos_score.view(-1)
8586
loss = torch.maximum(torch.max(neg_score, dim=-1).values - pos_score + self.margin, torch.tensor([0]).type_as(pos_score))
8687
if self.n_items is not None:
8788
impostors = neg_score - pos_score.view(-1, 1) + self.margin > 0
@@ -96,9 +97,14 @@ class BPRLoss(torch.nn.Module):
9697
def __init__(self):
9798
super().__init__()
9899

99-
def forward(self, pos_score, neg_score):
100-
loss = torch.mean(-(pos_score - neg_score).sigmoid().log(), dim=-1)
101-
return loss
100+
def forward(self, pos_score, neg_score, in_batch_neg=False):
101+
pos_score = pos_score.view(-1)
102+
if neg_score.dim() == 1:
103+
diff = pos_score - neg_score
104+
else:
105+
diff = pos_score.view(-1, 1) - neg_score
106+
loss = -diff.sigmoid().log()
107+
return loss.mean()
102108

103109

104110
class NCELoss(torch.nn.Module):

torch_rechub/models/matching/narm.py

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@
1717

1818
class NARM(nn.Module):
1919

20-
def __init__(self, item_history_feature, hidden_dim, emb_dropout_p, session_rep_dropout_p):
20+
def __init__(self, item_history_feature, hidden_dim, emb_dropout_p, session_rep_dropout_p, item_feature=None):
2121
super(NARM, self).__init__()
2222

2323
# item embedding layer
2424
self.item_history_feature = item_history_feature
25+
self.item_feature = item_feature # Optional: for in-batch negative sampling
2526
self.item_emb = Embedding(item_history_feature.vocab_size, item_history_feature.embed_dim, padding_idx=0)
27+
self.mode = None # For inference: "user" or "item"
2628

2729
# embedding dropout layer
2830
self.emb_dropout = Dropout(emb_dropout_p)
@@ -42,41 +44,62 @@ def __init__(self, item_history_feature, hidden_dim, emb_dropout_p, session_rep_
4244
# bilinear projection matrix
4345
self.b = Parameter(torch.randn(item_history_feature.embed_dim, hidden_dim * 2))
4446

45-
def forward(self, input_dict):
46-
# Eq. 1-4, index item embeddings and pass through gru
47-
# # Fetch the embeddings for items in the session
47+
def _compute_session_repr(self, input_dict):
48+
"""Compute session representation (user embedding before bilinear transform)."""
4849
input = input_dict[self.item_history_feature.name]
4950
value_mask = (input != 0)
5051
value_counts = value_mask.sum(dim=1, keepdim=False).to("cpu").detach()
5152
embs = rnn_utils.pack_padded_sequence(self.emb_dropout(self.item_emb(input)), value_counts, batch_first=True, enforce_sorted=False)
5253

53-
# # compute hidden states at each time step
5454
h, h_t = self.gru(embs)
5555
h_t = h_t.permute(1, 0, 2)
5656
h, _ = rnn_utils.pad_packed_sequence(h, batch_first=True)
5757

58-
# Eq. 5, set last hidden state of gru as the output of the global
59-
# encoder
6058
c_g = h_t.squeeze(1)
61-
62-
# Eq. 8, compute similarity between final hidden state and previous
63-
# hidden states
6459
q = sigmoid(h_t @ self.a_1.T + h @ self.a_2.T) @ self.v
65-
66-
# Eq. 7, compute attention
6760
alpha = torch.exp(q) * value_mask.unsqueeze(-1)
6861
alpha /= alpha.sum(dim=1, keepdim=True)
69-
70-
# Eq. 6, compute the output of the local encoder
7162
c_l = (alpha * h).sum(1)
7263

73-
# Eq. 9, compute session representation by concatenating user
74-
# sequential behavior (global) and main purpose in the current session
75-
# (local)
7664
c = self.session_rep_dropout(torch.hstack((c_g, c_l)))
65+
return c
66+
67+
def user_tower(self, x):
68+
"""Compute user embedding for in-batch negative sampling."""
69+
if self.mode == "item":
70+
return None
71+
c = self._compute_session_repr(x)
72+
user_emb = c @ self.b.T # [batch_size, embed_dim]
73+
if self.mode == "user":
74+
return user_emb
75+
return user_emb.unsqueeze(1) # [batch_size, 1, embed_dim]
76+
77+
def item_tower(self, x):
78+
"""Compute item embedding for in-batch negative sampling."""
79+
if self.mode == "user":
80+
return None
81+
if self.item_feature is not None:
82+
item_ids = x[self.item_feature.name]
83+
item_emb = self.item_emb(item_ids) # [batch_size, embed_dim]
84+
if self.mode == "item":
85+
return item_emb
86+
return item_emb.unsqueeze(1) # [batch_size, 1, embed_dim]
87+
return None
7788

78-
# Eq. 10, compute bilinear similarity between current session and each
79-
# candidate items
89+
def forward(self, input_dict):
90+
# Support inference mode
91+
if self.mode == "user":
92+
return self.user_tower(input_dict)
93+
if self.mode == "item":
94+
return self.item_tower(input_dict)
95+
96+
# In-batch negative sampling mode
97+
if self.item_feature is not None:
98+
user_emb = self.user_tower(input_dict) # [batch_size, 1, embed_dim]
99+
item_emb = self.item_tower(input_dict) # [batch_size, 1, embed_dim]
100+
return torch.mul(user_emb, item_emb).sum(dim=-1).squeeze()
101+
102+
# Original behavior: compute scores for all items
103+
c = self._compute_session_repr(input_dict)
80104
s = c @ self.b.T @ self.item_emb.weight.T
81-
82105
return s

torch_rechub/models/matching/sasrec.py

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class SASRec(torch.nn.Module):
2121
max_len: The length of the sequence feature.
2222
num_blocks: The number of stacks of attention modules.
2323
num_heads: The number of heads in MultiheadAttention.
24+
item_feature: Optional item feature for in-batch negative sampling mode.
2425
2526
"""
2627

@@ -31,9 +32,15 @@ def __init__(
3132
dropout_rate=0.5,
3233
num_blocks=2,
3334
num_heads=1,
35+
item_feature=None,
3436
):
3537
super(SASRec, self).__init__()
3638

39+
self.features = features
40+
self.item_feature = item_feature # Optional: for in-batch negative sampling
41+
self.mode = None # For inference: "user" or "item"
42+
self.max_len = max_len
43+
3744
self.features = features
3845

3946
self.item_num = self.features[0].vocab_size
@@ -94,17 +101,60 @@ def seq_forward(self, x, embed_x_feature):
94101

95102
return seq_output
96103

104+
def user_tower(self, x):
105+
"""Compute user embedding for in-batch negative sampling.
106+
Takes the last valid position's output as user representation.
107+
"""
108+
if self.mode == "item":
109+
return None
110+
# Get sequence embedding
111+
seq_embed = self.item_emb(x, self.features[:1])[:, 0] # Only use seq feature
112+
seq_output = self.seq_forward(x, seq_embed) # [batch_size, max_len, embed_dim]
113+
114+
# Get the last valid position for each sequence
115+
seq = x['seq']
116+
seq_lens = (seq != 0).sum(dim=1) - 1 # Last valid index
117+
seq_lens = seq_lens.clamp(min=0)
118+
batch_idx = torch.arange(seq_output.size(0), device=seq_output.device)
119+
user_emb = seq_output[batch_idx, seq_lens] # [batch_size, embed_dim]
120+
121+
if self.mode == "user":
122+
return user_emb
123+
return user_emb.unsqueeze(1) # [batch_size, 1, embed_dim]
124+
125+
def item_tower(self, x):
126+
"""Compute item embedding for in-batch negative sampling."""
127+
if self.mode == "user":
128+
return None
129+
if self.item_feature is not None:
130+
item_ids = x[self.item_feature.name]
131+
# Use the embedding layer to get item embeddings
132+
item_emb = self.item_emb.embedding[self.features[0].name](item_ids)
133+
if self.mode == "item":
134+
return item_emb
135+
return item_emb.unsqueeze(1) # [batch_size, 1, embed_dim]
136+
return None
137+
97138
def forward(self, x):
98-
# (batch_size, 3, max_len, embed_dim)
139+
# Support inference mode
140+
if self.mode == "user":
141+
return self.user_tower(x)
142+
if self.mode == "item":
143+
return self.item_tower(x)
144+
145+
# In-batch negative sampling mode
146+
if self.item_feature is not None:
147+
user_emb = self.user_tower(x) # [batch_size, 1, embed_dim]
148+
item_emb = self.item_tower(x) # [batch_size, 1, embed_dim]
149+
return torch.mul(user_emb, item_emb).sum(dim=-1).squeeze()
150+
151+
# Original behavior: pairwise loss with pos/neg sequences
99152
embedding = self.item_emb(x, self.features)
100-
# (batch_size, max_len, embed_dim)
101153
seq_embed, pos_embed, neg_embed = embedding[:, 0], embedding[:, 1], embedding[:, 2]
102-
103-
# (batch_size, max_len, embed_dim)
104154
seq_output = self.seq_forward(x, seq_embed)
105155

106156
pos_logits = (seq_output * pos_embed).sum(dim=-1)
107-
neg_logits = (seq_output * neg_embed).sum(dim=-1) # (batch_size, max_len)
157+
neg_logits = (seq_output * neg_embed).sum(dim=-1)
108158

109159
return pos_logits, neg_logits
110160

torch_rechub/models/matching/stamp.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414

1515
class STAMP(nn.Module):
1616

17-
def __init__(self, item_history_feature, weight_std, emb_std):
17+
def __init__(self, item_history_feature, weight_std, emb_std, item_feature=None):
1818
super(STAMP, self).__init__()
1919

2020
# item embedding layer
2121
self.item_history_feature = item_history_feature
22+
self.item_feature = item_feature # Optional: for in-batch negative sampling
2223
n_items, item_emb_dim, = item_history_feature.vocab_size, item_history_feature.embed_dim
2324
self.item_emb = nn.Embedding(n_items, item_emb_dim, padding_idx=0)
25+
self.mode = None # For inference: "user" or "item"
2426

2527
# weights and biases for attention computation
2628
self.w_0 = nn.Parameter(torch.zeros(item_emb_dim, 1))
@@ -50,32 +52,58 @@ def _init_module_weights(self, module):
5052
elif isinstance(module, nn.Embedding):
5153
module.weight.data.normal_(std=self.emb_std)
5254

53-
def forward(self, input_dict):
54-
# Index the embeddings for the items in the session
55+
def _compute_user_repr(self, input_dict):
56+
"""Compute user representation (h_s * h_t)."""
5557
input = input_dict[self.item_history_feature.name]
5658
value_mask = (input != 0).unsqueeze(-1)
5759
value_counts = value_mask.sum(dim=1, keepdim=True).squeeze(-1)
5860
item_emb_batch = self.item_emb(input) * value_mask
5961

60-
# Index the embeddings of the latest clicked items
6162
x_t = self.item_emb(torch.gather(input, 1, value_counts - 1))
62-
63-
# Eq. 2, user's general interest in the current session
6463
m_s = ((item_emb_batch).sum(1) / value_counts).unsqueeze(1)
6564

66-
# Eq. 7, compute attention coefficient
6765
a = F.normalize(torch.exp(torch.sigmoid(item_emb_batch @ self.w_1_t + x_t @ self.w_2_t + m_s @ self.w_3_t + self.b_a) @ self.w_0) * value_mask, p=1, dim=1)
68-
69-
# Eq. 8, compute user's attention-based interests
7066
m_a = (a * item_emb_batch).sum(1) + m_s.squeeze(1)
7167

72-
# Eq. 3, compute the output state of the general interest
7368
h_s = self.f_s(m_a)
74-
75-
# Eq. 9, compute the output state of the short-term interest
7669
h_t = self.f_t(x_t).squeeze(1)
70+
return h_s * h_t # [batch_size, embed_dim]
71+
72+
def user_tower(self, x):
73+
"""Compute user embedding for in-batch negative sampling."""
74+
if self.mode == "item":
75+
return None
76+
user_emb = self._compute_user_repr(x)
77+
if self.mode == "user":
78+
return user_emb
79+
return user_emb.unsqueeze(1) # [batch_size, 1, embed_dim]
80+
81+
def item_tower(self, x):
82+
"""Compute item embedding for in-batch negative sampling."""
83+
if self.mode == "user":
84+
return None
85+
if self.item_feature is not None:
86+
item_ids = x[self.item_feature.name]
87+
item_emb = self.item_emb(item_ids) # [batch_size, embed_dim]
88+
if self.mode == "item":
89+
return item_emb
90+
return item_emb.unsqueeze(1) # [batch_size, 1, embed_dim]
91+
return None
7792

78-
# Eq. 4, compute candidate scores
79-
z = h_s * h_t @ self.item_emb.weight.T
80-
93+
def forward(self, input_dict):
94+
# Support inference mode
95+
if self.mode == "user":
96+
return self.user_tower(input_dict)
97+
if self.mode == "item":
98+
return self.item_tower(input_dict)
99+
100+
# In-batch negative sampling mode
101+
if self.item_feature is not None:
102+
user_emb = self.user_tower(input_dict) # [batch_size, 1, embed_dim]
103+
item_emb = self.item_tower(input_dict) # [batch_size, 1, embed_dim]
104+
return torch.mul(user_emb, item_emb).sum(dim=-1).squeeze()
105+
106+
# Original behavior: compute scores for all items
107+
user_repr = self._compute_user_repr(input_dict)
108+
z = user_repr @ self.item_emb.weight.T
81109
return z

0 commit comments

Comments
 (0)