Skip to content

[feat] Refactor MultipleNegativesRankingLoss to support improved contrastive loss from GTE paper#3607

Merged
tomaarsen merged 28 commits intohuggingface:mainfrom
hotchpotch:gte_info_nce
Feb 5, 2026
Merged

[feat] Refactor MultipleNegativesRankingLoss to support improved contrastive loss from GTE paper#3607
tomaarsen merged 28 commits intohuggingface:mainfrom
hotchpotch:gte_info_nce

Conversation

@hotchpotch
Copy link
Contributor

@hotchpotch hotchpotch commented Jan 5, 2026

Hello!

Summary

This PR implements the "Improved Contrastive Loss" proposed in the GTE paper as MultipleNegativesBidirectionalRankingLoss (both standard and GradCache versions), along with accompanying documentation, sampler recommendations, and tests.

Motivation

The GTE paper introduces an "Improved Contrastive Loss" that unifies q→d, q→q, d→q, and d→d similarities into a single normalization term, strengthening in-batch signals for contrastive learning. This PR brings that approach to sentence-transformers.

Benchmark Results (NanoBEIR)

I evaluated models fine-tuned with the new CachedMultipleNegativesBidirectionalRankingLoss (CMNBRL) against the existing CachedMultipleNegativesRankingLoss (CMNRL):

Results show comparable performance on most tasks, with notable improvements on ArguAna, SciFact, and MSMARCO:

NanoBEIR nDCG@10 CMNBRL CMNRL
Avg 0.5571 0.5484
NanoClimateFEVER 0.3204 0.3249
NanoDBPedia 0.5013 0.5073
NanoFEVER 0.7971 0.8029
NanoFiQA2018 0.4595 0.4646
NanoHotpotQA 0.6496 0.6343
NanoMSMARCO 0.5915 0.5555
NanoNFCorpus 0.2981 0.3071
NanoNQ 0.6279 0.6336
NanoQuoraRetrieval 0.9387 0.9391
NanoSCIDOCS 0.3413 0.3220
NanoArguAna 0.5898 0.5328
NanoSciFact 0.6514 0.6297
NanoTouche2020 0.4762 0.4754

Training time was approximately 90 minutes on an RTX 5090 for both losses, showing no significant difference in computational cost.

Changes

New Loss Functions

  • MultipleNegativesBidirectionalRankingLoss
  • CachedMultipleNegativesBidirectionalRankingLoss

Hard Negatives Handling

  • For inputs (anchor, positive, negative_1, ..., negative_n), negatives are included as additional documents in q→d and d→d computations (not treated as queries).

GradCache Implementation

  • Adopts the same gradient caching mechanism as the existing CachedMultipleNegativesRankingLoss.

Related Updates

  • Added exports to losses/__init__.py
  • Documented compatibility with MatryoshkaLoss, AdaptiveLayerLoss, and Matryoshka2dLoss
  • Added bidirectional losses to recommended losses for BatchSamplers.NO_DUPLICATES

Documentation

  • Added bidirectional loss to API reference (losses.md)
  • Updated loss overview table and descriptions (including hard negatives behavior)

Tests

  • Manual calculation verification tests
  • Gradient consistency tests between cached and non-cached versions

References

Testing

python -m pytest tests/losses/test_multiple_negatives_bidirectional_ranking_loss.py
uv run --with ruff ruff check  # on modified .py files

Note: This implementation was developed with AI assistance (Codex), but I take full responsibility for the implementation and have verified it through actual training experiments.

@hotchpotch hotchpotch changed the title feat: implement GTE Improved Contrastive Loss (MultipleNegativesBidirectionalRankingLoss) [feat] implement GTE Improved Contrastive Loss (MultipleNegativesBidirectionalRankingLoss) Jan 5, 2026
@hotchpotch hotchpotch marked this pull request as ready for review January 5, 2026 23:44
@tomaarsen
Copy link
Member

Hello!

Wow, thank you for this! I ran some tests myself on various training scripts, and it does perform rather well. More than expected, even, especially because the idea to use query -> query and document -> document in the random negatives is a bit surprising. Perhaps it helps push elements apart in the embedding space?

I reimplemented the core loss somewhat to simplify it, which results in the same loss values (as far as I can tell), but the resulting model does train slightly differently:

    def compute_loss_from_embeddings(self, embeddings: list[Tensor], labels: Tensor) -> Tensor:
        if len(embeddings) < 2:
            raise ValueError(f"Expected at least 2 embeddings, got {len(embeddings)}")

        queries = embeddings[0]
        docs = embeddings[1:]
        batch_size = queries.size(0)
        offset = 0

        if self.gather_across_devices:
            # Gather the anchors and candidates across all devices, with gradients. We compute only this device's anchors
            # with all candidates from all devices, and only this device's candidates with all anchors from all devices.
            # We do this in such a way that the backward pass on the embeddings can flow back to the original devices.
            queries = all_gather_with_grad(queries)  # (batch_size * world_size, embedding_dim)
            docs = [all_gather_with_grad(doc) for doc in docs]
            # (1 + num_negatives) tensors of shape (batch_size * world_size, embedding_dim)

            if torch.distributed.is_initialized():
                rank = torch.distributed.get_rank()
                offset = rank * batch_size

        docs = torch.cat(docs, dim=0)
        # (batch_size * world_size * (1 + num_negatives), embedding_dim)

        local_indices = torch.arange(offset, offset + batch_size, device=queries.device)
        local_queries = queries[local_indices]
        local_docs = docs[local_indices]

        sim_qd = self.similarity_fct(local_queries, docs) * self.scale  # (batch_size, batch_size * world_size * (1 + num_negatives))
        sim_qq = self.similarity_fct(local_queries, queries) * self.scale  # (batch_size, batch_size * world_size)
        if self.gather_across_devices:
            sim_dq = self.similarity_fct(local_docs, queries) * self.scale  # (batch_size, batch_size * world_size)
        else:
            sim_dq = sim_qd.T
        sim_dd = self.similarity_fct(local_docs, docs) * self.scale  # (batch_size, batch_size * world_size * (1 + num_negatives))

        # Remove self-similarity entries q_i -> q_i and d_i -> d_i for local pairs
        row_indices = torch.arange(batch_size, device=queries.device)
        sim_qq[row_indices, local_indices] = -float("inf")
        sim_dd[row_indices, local_indices] = -float("inf")

        scores = torch.cat([sim_qd, sim_qq, sim_dq, sim_dd], dim=1)  # (batch_size, 2 * batch_size * world_size * (2 + num_negatives))
        log_z = torch.logsumexp(scores, dim=1)

        positive_scores = sim_qd[row_indices, local_indices]
        loss = -(positive_scores - log_z).mean()
        return loss

See:

I haven't had time to figure out why yet and I'm wrapping up for the day, but I figured I'd share my version. I also ran a test without the sim_qq_rows and sim_dd_cols from your implementation, and the model was quite a bit worse (mostly ahead of pure InfoNCE/MNRL during training, but behind at the end (0.8513 NDCG@10), and quite a bit behind the MNBRL implementation of yours).

I'll have to follow up more later, but this is looking very cool!

  • Tom Aarsen

@hotchpotch
Copy link
Contributor Author

Hi Tom — thanks so much for taking the time to review the PR and for sharing your refactor and results!

Your implementation is indeed simpler and has less computation overall — it looks great. The one issue I found was around hard negatives: in that case the number of docs exceeds the number of queries, so sim_qd becomes shape (B, B*(1+neg)). Using sim_dq = sim_qd.T then produces (B*(1+neg), B), which no longer aligns with the other score blocks (and breaks concatenation / changes semantics). I fixed this by computing sim_dq as similarity(queries, local_docs).T (and similarly sim_dd as similarity(docs, local_docs).T), which keeps the shapes consistent and matches the cached implementation.

With that fix applied, I re-ran training on the latest code and the results are basically unchanged. My training uses a large batch size (bs=8192), so it’s possible the batch size is a factor in how much the refactor impacts results.

Thanks again — I really appreciate the feedback and the careful experiments!

@NohTow
Copy link
Contributor

NohTow commented Jan 21, 2026

Heya,
As @tomaarsen said we went into a rabbit hole and ran a bunch of experiments, so I figured out I should share some results here.
We essentially tested a bunch of setup, dividing the partition functions, including qi to the denominator 0, one or two times, etc.
I think the main conclusions are:
This full loss seems better than traditional contrastive when you do not have negatives documents. When you do, it seems to be pretty much the same as original contrastive.
Also, an important part is that, in the d-d computation part, the hard negatives must be excluded from the computation.
Using one or two qi in the denominator seems to be roughtly the same, as long as you include at least one (despite the paper from LeCun ahah).
So for now I think it's mostly useful for unsupervised pre-training, but maybe we need even more tests.

I did not improve the results much on our supervised setup (with hard negatives), which is SOTA on BEIR FWIW so the results can somewhat be trusted.
Tom had a bit more success on his own data, which do not include hard negs (and is not as strong). When he tried adding the negatives, things started to looks roughly the same.

@hotchpotch
Copy link
Contributor Author

Hello, @NohTow. Thank you for the sharp investigation and thoughtful analysis.

I also ran several training experiments with this CachedMultipleNegativesBidirectionalRankingLoss (CMNBRL) implementation. With triplets (roughly negs=1), the performance gains were modest, and when training CMNBRL on datasets with hard negatives (negs=7), I observed cases where performance dropped significantly.

Also, an important part is that, in the d-d computation part, the hard negatives must be excluded from the computation.

I had been wondering why the results were not improving with triplets and hard negatives, and your point hits the nail on the head. Thank you for calling this out.

Based on that, I plan to remove hard negatives from the d-d computation in this implementation and evaluate the impact.

@tomaarsen
Copy link
Member

tomaarsen commented Jan 26, 2026

@hotchpotch I think I might prefer an implementation like this:

    def compute_loss_from_embeddings(self, embeddings: list[Tensor], labels: Tensor) -> Tensor:
        if len(embeddings) < 2:
            raise ValueError(f"Expected at least 2 embeddings, got {len(embeddings)}")

        queries = embeddings[0]
        docs = embeddings[1:]
        batch_size = queries.size(0)
        offset = 0

        if self.gather_across_devices:
            # Gather the anchors and candidates across all devices, with gradients. We compute only this device's anchors
            # with all candidates from all devices, and only this device's candidates with all anchors from all devices.
            # We do this in such a way that the backward pass on the embeddings can flow back to the original devices.
            queries = all_gather_with_grad(queries)
            docs = [all_gather_with_grad(doc) for doc in docs]
            if torch.distributed.is_initialized():
                rank = torch.distributed.get_rank()
                offset = rank * batch_size

+       world_batch_size = queries.size(0)
        docs = torch.cat(docs, dim=0)
        # (batch_size * world_size * (1 + num_negatives), embedding_dim)
        local_indices = torch.arange(offset, offset + batch_size, device=queries.device)
        local_queries = queries[local_indices]
        local_docs = docs[local_indices]

        sim_qd = self.similarity_fct(local_queries, docs) * self.scale  # (bs, bs * ws * (1 + nn))
        sim_qq = self.similarity_fct(local_queries, queries) * self.scale  # (bs, bs * ws)
        sim_dq = (self.similarity_fct(queries, local_docs) * self.scale).T  # (bs, bs * ws)
        sim_dd = (self.similarity_fct(docs, local_docs) * self.scale).T  # (bs, bs * ws * (1 + nn))

~       # Remove self-similarity entries q_i -> q_i
        row_indices = torch.arange(batch_size, device=queries.device)
        sim_qq[row_indices, local_indices] = -torch.inf

+       # Remove d_i_a -> d_i_b for all documents belonging to the same query
+       same_query_doc_mask = torch.eye(world_batch_size, device=queries.device)[local_indices]
+       same_query_doc_mask = same_query_doc_mask.repeat(1, len(embeddings) - 1).bool()
+       sim_dd.masked_fill_(same_query_doc_mask, float('-inf'))

        scores = torch.cat([sim_qd, sim_qq, sim_dq, sim_dd], dim=1)
        log_z = torch.logsumexp(scores, dim=1)

        positive_scores = sim_qd[row_indices, local_indices]
        loss = -(positive_scores - log_z).mean()
        return loss

This matches the original, intended implementation as far as I can tell, and should prevent d -> d for the same query. This still allows d -> d across queries. I have a local commit with these changes, if you'd like me to push it.
In tests by @NohTow and myself, this should roughly match standard InfoNCE when there's multiple hard negatives.

I sadly haven't been able to test this variant myself with multiple GPUs, though.

  • Tom Aarsen

@hotchpotch
Copy link
Contributor Author

Hello.

I have been experimenting with the implementation that removes hard negatives from the doc-doc term in commit cd7a9c7 .

  • Training with pair data:
    • CMNBRL produced better results than CMNRL.
  • Training with triplet (negs=1) and hard negatives (negs=7):
    • Which loss performs better varies by dataset.
    • In commits before cd7a9c7, CMNBRL tended to be worse than CMNRL on hard negatives, so this is an improvement.

Overall, my impression is that CMNBRL often works better on pair datasets. With hard negatives, which loss is better seems to depend on the situation.

I have also incorporated the improved hard-negative handling proposed by @tomaarsen in commit 8c5680d, and I plan to run training with that next.

I also have access to a machine with two GPUs, so I will test the multi-GPU case later as well.

@tomaarsen
Copy link
Member

Hello!

I've taken your Bidirectional implementation, added a bit of @NohTow 's work re. more generalization, and mixed it together to create more generalized (C)MNRL losses. This has a handful of goals:

  1. Make it clearer what the differences between the various InfoNCE variants are
  2. By implementing the features directly in MNRL, the non-default variants should become more accessible
  3. Avoiding loss growth from variants. There's a ton of different possible options, and a separate loss for each will be infeasible, especially considering a lot of the names become quite vague

I added some docs for the somewhat common variants:

  1. Standard InfoNCE (default)
  2. Symmetric (e.g. like Jina)
  3. GTE-style

Let me know what you think! If you're on board, we can move away from the new MNBRL and add it straight to MNRL instead. Feel free to make these changes if you wish. I think we can then also remove CachedMultipleNegativesSymmetricRankingLoss from MatryoshkaLoss as it should be a subclass of CachedMultipleNegativesRankingLoss.

I'll have to run some tests to make sure that MNRL and MNSRL still perform the same, but I think this is a cool direction.

P.s. I'm also open to suggestions re. docstrings, explanations, or parameter names.

  • Tom Aarsen

@hotchpotch
Copy link
Contributor Author

Hello !

Thanks for the refactor. I am fully on board with moving away from MNBRL and folding the GTE-style variant directly into (C)MNRL.

I am also fine with removing CachedMultipleNegativesSymmetricRankingLoss from MatryoshkaLoss if it becomes a subclass of CachedMultipleNegativesRankingLoss.

The parameterization is clear, and it seems like future InfoNCE variants can be added without duplicating implementations, which is great.

I plan to validate that our existing training code still works with the refactored implementation from commit 7b37bbb.

Also, multi-GPU training with torchrun --nproc_per_node=2 completed successfully on commit 8c5680d.

If there is anything you would like me to help with, investigate, or implement, please do not hesitate to ask. And if there are any areas where you would prefer me to make the implementation changes myself, please let me know.

@tomaarsen
Copy link
Member

Perfect! If you'd like, feel free to try and pull the PR over the finish line by removing the MNBRL variant again. Otherwise I'll take care of it in a few days prior to the v5.3 release.

Much appreciated on these big PRs, this is super valuable stuff.

  • Tom Aarsen

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends the MultipleNegativesRankingLoss and CachedMultipleNegativesRankingLoss classes to support the GTE paper's "Improved Contrastive Loss" by adding two new parameters: directions (to control which similarity interactions are computed) and partition_mode (to control normalization strategy). The implementation also deprecates MultipleNegativesSymmetricRankingLoss and CachedMultipleNegativesSymmetricRankingLoss by making them thin wrappers that call the base classes with appropriate parameters.

Changes:

  • Extended MultipleNegativesRankingLoss and CachedMultipleNegativesRankingLoss with directions and partition_mode parameters supporting query-to-doc, doc-to-query, query-to-query, and doc-to-doc similarity terms
  • Deprecated MultipleNegativesSymmetricRankingLoss and CachedMultipleNegativesSymmetricRankingLoss with clear migration guidance
  • Updated SparseMultipleNegativesRankingLoss to match the new API
  • Added dynamic citation property that returns different papers based on the loss configuration
  • Updated documentation table to remove deprecated losses and clarify input formats

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
sentence_transformers/losses/MultipleNegativesRankingLoss.py Extended core loss with directions and partition_mode parameters, added validation, updated documentation with literature examples, added temperature property and dynamic citations
sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py Extended cached version with same parameters and logic, maintains gradient caching mechanism
sentence_transformers/losses/MultipleNegativesSymmetricRankingLoss.py Deprecated class, now inherits from base class with specific parameter values
sentence_transformers/losses/CachedMultipleNegativesSymmetricRankingLoss.py Deprecated cached variant, now inherits from cached base class
sentence_transformers/sparse_encoder/losses/SparseMultipleNegativesRankingLoss.py Updated to match new API for consistency across dense and sparse implementations
docs/sentence_transformer/loss_overview.md Updated loss table to remove deprecated classes and add anchor-anchor pairs row

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@tomaarsen
Copy link
Member

I think we're looking good! Feel free to rerun if you'd like, but the implementation should be identical (apart from perhaps some differences in randomness). I'd like to merge and release this in the coming days!

  • Tom Aarsen

@tomaarsen tomaarsen changed the title [feat] implement GTE Improved Contrastive Loss (MultipleNegativesBidirectionalRankingLoss) [feat] Refactor MultipleNegativesRankingLoss to support improved contrastive loss from GTE paper Feb 5, 2026
@tomaarsen tomaarsen merged commit bf47a97 into huggingface:main Feb 5, 2026
17 checks passed
@hotchpotch
Copy link
Contributor Author

Great job reviewing, refactoring, fixing, and merging my PR. Thanks!

@tomaarsen
Copy link
Member

Gladly! Thanks for this nice work 👏

  • Tom Aarsen

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants