Skip to content

Conversation

@hotchpotch
Copy link

@hotchpotch hotchpotch commented Jan 5, 2026

Hello!

Summary

This PR implements the "Improved Contrastive Loss" proposed in the GTE paper as MultipleNegativesBidirectionalRankingLoss (both standard and GradCache versions), along with accompanying documentation, sampler recommendations, and tests.

Motivation

The GTE paper introduces an "Improved Contrastive Loss" that unifies q→d, q→q, d→q, and d→d similarities into a single normalization term, strengthening in-batch signals for contrastive learning. This PR brings that approach to sentence-transformers.

Benchmark Results (NanoBEIR)

I evaluated models fine-tuned with the new CachedMultipleNegativesBidirectionalRankingLoss (CMNBRL) against the existing CachedMultipleNegativesRankingLoss (CMNRL):

Results show comparable performance on most tasks, with notable improvements on ArguAna, SciFact, and MSMARCO:

NanoBEIR nDCG@10 CMNBRL CMNRL
Avg 0.5571 0.5484
NanoClimateFEVER 0.3204 0.3249
NanoDBPedia 0.5013 0.5073
NanoFEVER 0.7971 0.8029
NanoFiQA2018 0.4595 0.4646
NanoHotpotQA 0.6496 0.6343
NanoMSMARCO 0.5915 0.5555
NanoNFCorpus 0.2981 0.3071
NanoNQ 0.6279 0.6336
NanoQuoraRetrieval 0.9387 0.9391
NanoSCIDOCS 0.3413 0.3220
NanoArguAna 0.5898 0.5328
NanoSciFact 0.6514 0.6297
NanoTouche2020 0.4762 0.4754

Training time was approximately 90 minutes on an RTX 5090 for both losses, showing no significant difference in computational cost.

Changes

New Loss Functions

  • MultipleNegativesBidirectionalRankingLoss
  • CachedMultipleNegativesBidirectionalRankingLoss

Hard Negatives Handling

  • For inputs (anchor, positive, negative_1, ..., negative_n), negatives are included as additional documents in q→d and d→d computations (not treated as queries).

GradCache Implementation

  • Adopts the same gradient caching mechanism as the existing CachedMultipleNegativesRankingLoss.

Related Updates

  • Added exports to losses/__init__.py
  • Documented compatibility with MatryoshkaLoss, AdaptiveLayerLoss, and Matryoshka2dLoss
  • Added bidirectional losses to recommended losses for BatchSamplers.NO_DUPLICATES

Documentation

  • Added bidirectional loss to API reference (losses.md)
  • Updated loss overview table and descriptions (including hard negatives behavior)

Tests

  • Manual calculation verification tests
  • Gradient consistency tests between cached and non-cached versions

References

Testing

python -m pytest tests/losses/test_multiple_negatives_bidirectional_ranking_loss.py
uv run --with ruff ruff check  # on modified .py files

Note: This implementation was developed with AI assistance (Codex), but I take full responsibility for the implementation and have verified it through actual training experiments.

@hotchpotch hotchpotch changed the title feat: implement GTE Improved Contrastive Loss (MultipleNegativesBidirectionalRankingLoss) [feat] implement GTE Improved Contrastive Loss (MultipleNegativesBidirectionalRankingLoss) Jan 5, 2026
@hotchpotch hotchpotch marked this pull request as ready for review January 5, 2026 23:44
@tomaarsen
Copy link
Member

Hello!

Wow, thank you for this! I ran some tests myself on various training scripts, and it does perform rather well. More than expected, even, especially because the idea to use query -> query and document -> document in the random negatives is a bit surprising. Perhaps it helps push elements apart in the embedding space?

I reimplemented the core loss somewhat to simplify it, which results in the same loss values (as far as I can tell), but the resulting model does train slightly differently:

    def compute_loss_from_embeddings(self, embeddings: list[Tensor], labels: Tensor) -> Tensor:
        if len(embeddings) < 2:
            raise ValueError(f"Expected at least 2 embeddings, got {len(embeddings)}")

        queries = embeddings[0]
        docs = embeddings[1:]
        batch_size = queries.size(0)
        offset = 0

        if self.gather_across_devices:
            # Gather the anchors and candidates across all devices, with gradients. We compute only this device's anchors
            # with all candidates from all devices, and only this device's candidates with all anchors from all devices.
            # We do this in such a way that the backward pass on the embeddings can flow back to the original devices.
            queries = all_gather_with_grad(queries)  # (batch_size * world_size, embedding_dim)
            docs = [all_gather_with_grad(doc) for doc in docs]
            # (1 + num_negatives) tensors of shape (batch_size * world_size, embedding_dim)

            if torch.distributed.is_initialized():
                rank = torch.distributed.get_rank()
                offset = rank * batch_size

        docs = torch.cat(docs, dim=0)
        # (batch_size * world_size * (1 + num_negatives), embedding_dim)

        local_indices = torch.arange(offset, offset + batch_size, device=queries.device)
        local_queries = queries[local_indices]
        local_docs = docs[local_indices]

        sim_qd = self.similarity_fct(local_queries, docs) * self.scale  # (batch_size, batch_size * world_size * (1 + num_negatives))
        sim_qq = self.similarity_fct(local_queries, queries) * self.scale  # (batch_size, batch_size * world_size)
        if self.gather_across_devices:
            sim_dq = self.similarity_fct(local_docs, queries) * self.scale  # (batch_size, batch_size * world_size)
        else:
            sim_dq = sim_qd.T
        sim_dd = self.similarity_fct(local_docs, docs) * self.scale  # (batch_size, batch_size * world_size * (1 + num_negatives))

        # Remove self-similarity entries q_i -> q_i and d_i -> d_i for local pairs
        row_indices = torch.arange(batch_size, device=queries.device)
        sim_qq[row_indices, local_indices] = -float("inf")
        sim_dd[row_indices, local_indices] = -float("inf")

        scores = torch.cat([sim_qd, sim_qq, sim_dq, sim_dd], dim=1)  # (batch_size, 2 * batch_size * world_size * (2 + num_negatives))
        log_z = torch.logsumexp(scores, dim=1)

        positive_scores = sim_qd[row_indices, local_indices]
        loss = -(positive_scores - log_z).mean()
        return loss

See:

I haven't had time to figure out why yet and I'm wrapping up for the day, but I figured I'd share my version. I also ran a test without the sim_qq_rows and sim_dd_cols from your implementation, and the model was quite a bit worse (mostly ahead of pure InfoNCE/MNRL during training, but behind at the end (0.8513 NDCG@10), and quite a bit behind the MNBRL implementation of yours).

I'll have to follow up more later, but this is looking very cool!

  • Tom Aarsen

@hotchpotch
Copy link
Author

Hi Tom — thanks so much for taking the time to review the PR and for sharing your refactor and results!

Your implementation is indeed simpler and has less computation overall — it looks great. The one issue I found was around hard negatives: in that case the number of docs exceeds the number of queries, so sim_qd becomes shape (B, B*(1+neg)). Using sim_dq = sim_qd.T then produces (B*(1+neg), B), which no longer aligns with the other score blocks (and breaks concatenation / changes semantics). I fixed this by computing sim_dq as similarity(queries, local_docs).T (and similarly sim_dd as similarity(docs, local_docs).T), which keeps the shapes consistent and matches the cached implementation.

With that fix applied, I re-ran training on the latest code and the results are basically unchanged. My training uses a large batch size (bs=8192), so it’s possible the batch size is a factor in how much the refactor impacts results.

Thanks again — I really appreciate the feedback and the careful experiments!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants