[feat] Refactor MultipleNegativesRankingLoss to support improved contrastive loss from GTE paper#3607
Conversation
|
Hello! Wow, thank you for this! I ran some tests myself on various training scripts, and it does perform rather well. More than expected, even, especially because the idea to use query -> query and document -> document in the random negatives is a bit surprising. Perhaps it helps push elements apart in the embedding space? I reimplemented the core loss somewhat to simplify it, which results in the same loss values (as far as I can tell), but the resulting model does train slightly differently: def compute_loss_from_embeddings(self, embeddings: list[Tensor], labels: Tensor) -> Tensor:
if len(embeddings) < 2:
raise ValueError(f"Expected at least 2 embeddings, got {len(embeddings)}")
queries = embeddings[0]
docs = embeddings[1:]
batch_size = queries.size(0)
offset = 0
if self.gather_across_devices:
# Gather the anchors and candidates across all devices, with gradients. We compute only this device's anchors
# with all candidates from all devices, and only this device's candidates with all anchors from all devices.
# We do this in such a way that the backward pass on the embeddings can flow back to the original devices.
queries = all_gather_with_grad(queries) # (batch_size * world_size, embedding_dim)
docs = [all_gather_with_grad(doc) for doc in docs]
# (1 + num_negatives) tensors of shape (batch_size * world_size, embedding_dim)
if torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
offset = rank * batch_size
docs = torch.cat(docs, dim=0)
# (batch_size * world_size * (1 + num_negatives), embedding_dim)
local_indices = torch.arange(offset, offset + batch_size, device=queries.device)
local_queries = queries[local_indices]
local_docs = docs[local_indices]
sim_qd = self.similarity_fct(local_queries, docs) * self.scale # (batch_size, batch_size * world_size * (1 + num_negatives))
sim_qq = self.similarity_fct(local_queries, queries) * self.scale # (batch_size, batch_size * world_size)
if self.gather_across_devices:
sim_dq = self.similarity_fct(local_docs, queries) * self.scale # (batch_size, batch_size * world_size)
else:
sim_dq = sim_qd.T
sim_dd = self.similarity_fct(local_docs, docs) * self.scale # (batch_size, batch_size * world_size * (1 + num_negatives))
# Remove self-similarity entries q_i -> q_i and d_i -> d_i for local pairs
row_indices = torch.arange(batch_size, device=queries.device)
sim_qq[row_indices, local_indices] = -float("inf")
sim_dd[row_indices, local_indices] = -float("inf")
scores = torch.cat([sim_qd, sim_qq, sim_dq, sim_dd], dim=1) # (batch_size, 2 * batch_size * world_size * (2 + num_negatives))
log_z = torch.logsumexp(scores, dim=1)
positive_scores = sim_qd[row_indices, local_indices]
loss = -(positive_scores - log_z).mean()
return lossSee:
I haven't had time to figure out why yet and I'm wrapping up for the day, but I figured I'd share my version. I also ran a test without the I'll have to follow up more later, but this is looking very cool!
|
|
Hi Tom — thanks so much for taking the time to review the PR and for sharing your refactor and results! Your implementation is indeed simpler and has less computation overall — it looks great. The one issue I found was around hard negatives: in that case the number of docs exceeds the number of queries, so With that fix applied, I re-ran training on the latest code and the results are basically unchanged. My training uses a large batch size (bs=8192), so it’s possible the batch size is a factor in how much the refactor impacts results. Thanks again — I really appreciate the feedback and the careful experiments! |
67e797f to
650a6cf
Compare
650a6cf to
8aff720
Compare
|
Heya, I did not improve the results much on our supervised setup (with hard negatives), which is SOTA on BEIR FWIW so the results can somewhat be trusted. |
|
Hello, @NohTow. Thank you for the sharp investigation and thoughtful analysis. I also ran several training experiments with this CachedMultipleNegativesBidirectionalRankingLoss (CMNBRL) implementation. With triplets (roughly negs=1), the performance gains were modest, and when training CMNBRL on datasets with hard negatives (negs=7), I observed cases where performance dropped significantly.
I had been wondering why the results were not improving with triplets and hard negatives, and your point hits the nail on the head. Thank you for calling this out. Based on that, I plan to remove hard negatives from the d-d computation in this implementation and evaluate the impact. |
|
@hotchpotch I think I might prefer an implementation like this: def compute_loss_from_embeddings(self, embeddings: list[Tensor], labels: Tensor) -> Tensor:
if len(embeddings) < 2:
raise ValueError(f"Expected at least 2 embeddings, got {len(embeddings)}")
queries = embeddings[0]
docs = embeddings[1:]
batch_size = queries.size(0)
offset = 0
if self.gather_across_devices:
# Gather the anchors and candidates across all devices, with gradients. We compute only this device's anchors
# with all candidates from all devices, and only this device's candidates with all anchors from all devices.
# We do this in such a way that the backward pass on the embeddings can flow back to the original devices.
queries = all_gather_with_grad(queries)
docs = [all_gather_with_grad(doc) for doc in docs]
if torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
offset = rank * batch_size
+ world_batch_size = queries.size(0)
docs = torch.cat(docs, dim=0)
# (batch_size * world_size * (1 + num_negatives), embedding_dim)
local_indices = torch.arange(offset, offset + batch_size, device=queries.device)
local_queries = queries[local_indices]
local_docs = docs[local_indices]
sim_qd = self.similarity_fct(local_queries, docs) * self.scale # (bs, bs * ws * (1 + nn))
sim_qq = self.similarity_fct(local_queries, queries) * self.scale # (bs, bs * ws)
sim_dq = (self.similarity_fct(queries, local_docs) * self.scale).T # (bs, bs * ws)
sim_dd = (self.similarity_fct(docs, local_docs) * self.scale).T # (bs, bs * ws * (1 + nn))
~ # Remove self-similarity entries q_i -> q_i
row_indices = torch.arange(batch_size, device=queries.device)
sim_qq[row_indices, local_indices] = -torch.inf
+ # Remove d_i_a -> d_i_b for all documents belonging to the same query
+ same_query_doc_mask = torch.eye(world_batch_size, device=queries.device)[local_indices]
+ same_query_doc_mask = same_query_doc_mask.repeat(1, len(embeddings) - 1).bool()
+ sim_dd.masked_fill_(same_query_doc_mask, float('-inf'))
scores = torch.cat([sim_qd, sim_qq, sim_dq, sim_dd], dim=1)
log_z = torch.logsumexp(scores, dim=1)
positive_scores = sim_qd[row_indices, local_indices]
loss = -(positive_scores - log_z).mean()
return lossThis matches the original, intended implementation as far as I can tell, and should prevent d -> d for the same query. This still allows d -> d across queries. I have a local commit with these changes, if you'd like me to push it. I sadly haven't been able to test this variant myself with multiple GPUs, though.
|
|
Hello. I have been experimenting with the implementation that removes hard negatives from the doc-doc term in commit cd7a9c7 .
Overall, my impression is that CMNBRL often works better on pair datasets. With hard negatives, which loss is better seems to depend on the situation. I have also incorporated the improved hard-negative handling proposed by @tomaarsen in commit 8c5680d, and I plan to run training with that next. I also have access to a machine with two GPUs, so I will test the multi-GPU case later as well. |
|
Hello! I've taken your Bidirectional implementation, added a bit of @NohTow 's work re. more generalization, and mixed it together to create more generalized (C)MNRL losses. This has a handful of goals:
I added some docs for the somewhat common variants:
Let me know what you think! If you're on board, we can move away from the new MNBRL and add it straight to MNRL instead. Feel free to make these changes if you wish. I think we can then also remove I'll have to run some tests to make sure that MNRL and MNSRL still perform the same, but I think this is a cool direction. P.s. I'm also open to suggestions re. docstrings, explanations, or parameter names.
|
|
Hello ! Thanks for the refactor. I am fully on board with moving away from MNBRL and folding the GTE-style variant directly into (C)MNRL. I am also fine with removing CachedMultipleNegativesSymmetricRankingLoss from MatryoshkaLoss if it becomes a subclass of CachedMultipleNegativesRankingLoss. The parameterization is clear, and it seems like future InfoNCE variants can be added without duplicating implementations, which is great. I plan to validate that our existing training code still works with the refactored implementation from commit 7b37bbb. Also, multi-GPU training with If there is anything you would like me to help with, investigate, or implement, please do not hesitate to ask. And if there are any areas where you would prefer me to make the implementation changes myself, please let me know. |
|
Perfect! If you'd like, feel free to try and pull the PR over the finish line by removing the MNBRL variant again. Otherwise I'll take care of it in a few days prior to the v5.3 release. Much appreciated on these big PRs, this is super valuable stuff.
|
There was a problem hiding this comment.
Pull request overview
This PR extends the MultipleNegativesRankingLoss and CachedMultipleNegativesRankingLoss classes to support the GTE paper's "Improved Contrastive Loss" by adding two new parameters: directions (to control which similarity interactions are computed) and partition_mode (to control normalization strategy). The implementation also deprecates MultipleNegativesSymmetricRankingLoss and CachedMultipleNegativesSymmetricRankingLoss by making them thin wrappers that call the base classes with appropriate parameters.
Changes:
- Extended
MultipleNegativesRankingLossandCachedMultipleNegativesRankingLosswithdirectionsandpartition_modeparameters supporting query-to-doc, doc-to-query, query-to-query, and doc-to-doc similarity terms - Deprecated
MultipleNegativesSymmetricRankingLossandCachedMultipleNegativesSymmetricRankingLosswith clear migration guidance - Updated
SparseMultipleNegativesRankingLossto match the new API - Added dynamic citation property that returns different papers based on the loss configuration
- Updated documentation table to remove deprecated losses and clarify input formats
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| sentence_transformers/losses/MultipleNegativesRankingLoss.py | Extended core loss with directions and partition_mode parameters, added validation, updated documentation with literature examples, added temperature property and dynamic citations |
| sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py | Extended cached version with same parameters and logic, maintains gradient caching mechanism |
| sentence_transformers/losses/MultipleNegativesSymmetricRankingLoss.py | Deprecated class, now inherits from base class with specific parameter values |
| sentence_transformers/losses/CachedMultipleNegativesSymmetricRankingLoss.py | Deprecated cached variant, now inherits from cached base class |
| sentence_transformers/sparse_encoder/losses/SparseMultipleNegativesRankingLoss.py | Updated to match new API for consistency across dense and sparse implementations |
| docs/sentence_transformer/loss_overview.md | Updated loss table to remove deprecated classes and add anchor-anchor pairs row |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
I think we're looking good! Feel free to rerun if you'd like, but the implementation should be identical (apart from perhaps some differences in randomness). I'd like to merge and release this in the coming days!
|
|
Great job reviewing, refactoring, fixing, and merging my PR. Thanks! |
|
Gladly! Thanks for this nice work 👏
|
Hello!
Summary
This PR implements the "Improved Contrastive Loss" proposed in the GTE paper as
MultipleNegativesBidirectionalRankingLoss(both standard and GradCache versions), along with accompanying documentation, sampler recommendations, and tests.Motivation
The GTE paper introduces an "Improved Contrastive Loss" that unifies q→d, q→q, d→q, and d→d similarities into a single normalization term, strengthening in-batch signals for contrastive learning. This PR brings that approach to sentence-transformers.
Benchmark Results (NanoBEIR)
I evaluated models fine-tuned with the new
CachedMultipleNegativesBidirectionalRankingLoss(CMNBRL) against the existingCachedMultipleNegativesRankingLoss(CMNRL):Results show comparable performance on most tasks, with notable improvements on ArguAna, SciFact, and MSMARCO:
Training time was approximately 90 minutes on an RTX 5090 for both losses, showing no significant difference in computational cost.
Changes
New Loss Functions
MultipleNegativesBidirectionalRankingLossCachedMultipleNegativesBidirectionalRankingLossHard Negatives Handling
(anchor, positive, negative_1, ..., negative_n), negatives are included as additional documents in q→d and d→d computations (not treated as queries).GradCache Implementation
CachedMultipleNegativesRankingLoss.Related Updates
losses/__init__.pyMatryoshkaLoss,AdaptiveLayerLoss, andMatryoshka2dLossBatchSamplers.NO_DUPLICATESDocumentation
losses.md)Tests
References
Testing
python -m pytest tests/losses/test_multiple_negatives_bidirectional_ranking_loss.py uv run --with ruff ruff check # on modified .py filesNote: This implementation was developed with AI assistance (Codex), but I take full responsibility for the implementation and have verified it through actual training experiments.