-
Notifications
You must be signed in to change notification settings - Fork 2.7k
[feat] implement GTE Improved Contrastive Loss (MultipleNegativesBidirectionalRankingLoss) #3607
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Hello! Wow, thank you for this! I ran some tests myself on various training scripts, and it does perform rather well. More than expected, even, especially because the idea to use query -> query and document -> document in the random negatives is a bit surprising. Perhaps it helps push elements apart in the embedding space? I reimplemented the core loss somewhat to simplify it, which results in the same loss values (as far as I can tell), but the resulting model does train slightly differently: def compute_loss_from_embeddings(self, embeddings: list[Tensor], labels: Tensor) -> Tensor:
if len(embeddings) < 2:
raise ValueError(f"Expected at least 2 embeddings, got {len(embeddings)}")
queries = embeddings[0]
docs = embeddings[1:]
batch_size = queries.size(0)
offset = 0
if self.gather_across_devices:
# Gather the anchors and candidates across all devices, with gradients. We compute only this device's anchors
# with all candidates from all devices, and only this device's candidates with all anchors from all devices.
# We do this in such a way that the backward pass on the embeddings can flow back to the original devices.
queries = all_gather_with_grad(queries) # (batch_size * world_size, embedding_dim)
docs = [all_gather_with_grad(doc) for doc in docs]
# (1 + num_negatives) tensors of shape (batch_size * world_size, embedding_dim)
if torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
offset = rank * batch_size
docs = torch.cat(docs, dim=0)
# (batch_size * world_size * (1 + num_negatives), embedding_dim)
local_indices = torch.arange(offset, offset + batch_size, device=queries.device)
local_queries = queries[local_indices]
local_docs = docs[local_indices]
sim_qd = self.similarity_fct(local_queries, docs) * self.scale # (batch_size, batch_size * world_size * (1 + num_negatives))
sim_qq = self.similarity_fct(local_queries, queries) * self.scale # (batch_size, batch_size * world_size)
if self.gather_across_devices:
sim_dq = self.similarity_fct(local_docs, queries) * self.scale # (batch_size, batch_size * world_size)
else:
sim_dq = sim_qd.T
sim_dd = self.similarity_fct(local_docs, docs) * self.scale # (batch_size, batch_size * world_size * (1 + num_negatives))
# Remove self-similarity entries q_i -> q_i and d_i -> d_i for local pairs
row_indices = torch.arange(batch_size, device=queries.device)
sim_qq[row_indices, local_indices] = -float("inf")
sim_dd[row_indices, local_indices] = -float("inf")
scores = torch.cat([sim_qd, sim_qq, sim_dq, sim_dd], dim=1) # (batch_size, 2 * batch_size * world_size * (2 + num_negatives))
log_z = torch.logsumexp(scores, dim=1)
positive_scores = sim_qd[row_indices, local_indices]
loss = -(positive_scores - log_z).mean()
return lossSee:
I haven't had time to figure out why yet and I'm wrapping up for the day, but I figured I'd share my version. I also ran a test without the I'll have to follow up more later, but this is looking very cool!
|
|
Hi Tom — thanks so much for taking the time to review the PR and for sharing your refactor and results! Your implementation is indeed simpler and has less computation overall — it looks great. The one issue I found was around hard negatives: in that case the number of docs exceeds the number of queries, so With that fix applied, I re-ran training on the latest code and the results are basically unchanged. My training uses a large batch size (bs=8192), so it’s possible the batch size is a factor in how much the refactor impacts results. Thanks again — I really appreciate the feedback and the careful experiments! |
67e797f to
650a6cf
Compare
650a6cf to
8aff720
Compare
Hello!
Summary
This PR implements the "Improved Contrastive Loss" proposed in the GTE paper as
MultipleNegativesBidirectionalRankingLoss(both standard and GradCache versions), along with accompanying documentation, sampler recommendations, and tests.Motivation
The GTE paper introduces an "Improved Contrastive Loss" that unifies q→d, q→q, d→q, and d→d similarities into a single normalization term, strengthening in-batch signals for contrastive learning. This PR brings that approach to sentence-transformers.
Benchmark Results (NanoBEIR)
I evaluated models fine-tuned with the new
CachedMultipleNegativesBidirectionalRankingLoss(CMNBRL) against the existingCachedMultipleNegativesRankingLoss(CMNRL):Results show comparable performance on most tasks, with notable improvements on ArguAna, SciFact, and MSMARCO:
Training time was approximately 90 minutes on an RTX 5090 for both losses, showing no significant difference in computational cost.
Changes
New Loss Functions
MultipleNegativesBidirectionalRankingLossCachedMultipleNegativesBidirectionalRankingLossHard Negatives Handling
(anchor, positive, negative_1, ..., negative_n), negatives are included as additional documents in q→d and d→d computations (not treated as queries).GradCache Implementation
CachedMultipleNegativesRankingLoss.Related Updates
losses/__init__.pyMatryoshkaLoss,AdaptiveLayerLoss, andMatryoshka2dLossBatchSamplers.NO_DUPLICATESDocumentation
losses.md)Tests
References
Testing
python -m pytest tests/losses/test_multiple_negatives_bidirectional_ranking_loss.py uv run --with ruff ruff check # on modified .py filesNote: This implementation was developed with AI assistance (Codex), but I take full responsibility for the implementation and have verified it through actual training experiments.