Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DebiasedMultipleNegativesRankingLoss to the losses #3148

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 147 additions & 0 deletions sentence_transformers/losses/DebiasedMultipleNegativesRankingLoss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from __future__ import annotations

from collections.abc import Iterable
from typing import Any

import torch
from torch import Tensor, nn

from sentence_transformers import util
from sentence_transformers.SentenceTransformer import SentenceTransformer


class DebiasedMultipleNegativesRankingLoss(nn.Module):
def __init__(self, model: SentenceTransformer, scale: float = 20.0, similarity_fct=util.cos_sim, tau_plus: float = 0.01) -> None:
"""
This loss is a debiased version of the `MultipleNegativesRankingLoss` loss that addresses the inherent sampling bias in the negative examples.

In standard contrastive loss, negative samples are drawn randomly from the dataset, leading to potential false negatives.

This debiased loss adjusts for this sampling bias by reweighting the contributions of positive and negative terms in the denominator.

For each ``a_i``, it uses all other ``p_j`` as negative samples, i.e., for ``a_i``, we have 1 positive example
(``p_i``) and ``n-1`` negative examples (``p_j``). Unlike the standard implementation, this loss applies a correction
term to account for the sampling bias introduced by in-batch negatives. Specifically, it adjusts the influence of
negatives based on a prior probability ``tau_plus``.

It then minimizes the negative log-likelihood for softmax-normalized scores while reweighting the contributions of
positive and negative terms in the denominator.

This loss function works great to train embeddings for retrieval setups where you have positive pairs
(e.g., (query, relevant_doc)) as it will sample in each batch ``n-1`` negative docs randomly and incorporate a bias
correction for improved robustness.

The performance usually increases with increasing batch sizes.

You can also provide one or multiple hard negatives per anchor-positive pair by structuring the data like this:
``(a_1, p_1, n_1), (a_2, p_2, n_2)``. Then, ``n_1`` is a hard negative for ``(a_1, p_1)``. The loss will use for
the pair ``(a_i, p_i)`` all ``p_j`` for ``j != i``, all ``n_j`` as negatives, and apply the bias correction.

Args:
model: SentenceTransformer model
scale: Output of similarity function is multiplied by scale
value
similarity_fct: similarity function between sentence
embeddings. By default, cos_sim. Can also be set to dot
product (and then set scale to 1)
tau_plus: Prior probability.

References:
- Chuang et al. (2020). Debiased Contrastive Learning. NeurIPS 2020. https://arxiv.org/pdf/2007.00224.pdf

Requirements:
1. The input batch should consist of (anchor, positive) pairs or (anchor, positive, negative) triplets.

Inputs:
+-------------------------------------------------+--------+
| Texts | Labels |
+=================================================+========+
| (anchor, positive) pairs | none |
+-------------------------------------------------+--------+
| (anchor, positive, negative) triplets | none |
+-------------------------------------------------+--------+
| (anchor, positive, negative_1, ..., negative_n) | none |
+-------------------------------------------------+--------+

Recommendations:
- Use ``BatchSamplers.NO_DUPLICATES`` (:class:`docs <sentence_transformers.training_args.BatchSamplers>`) to
ensure that no in-batch negatives are duplicates of the anchor or positive samples.

Relations:
- Extends :class:`MultipleNegativesRankingLoss` by incorporating a bias correction term.

Example:
::

from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses
from datasets import Dataset

model = SentenceTransformer("microsoft/mpnet-base")
train_dataset = Dataset.from_dict({
"anchor": ["It's nice weather outside today.", "He drove to work."],
"positive": ["It's so sunny.", "He took the car to the office."],
})
loss = losses.DebiasedMultipleNegativesRankingLoss(model, tau_plus=0.02)

trainer = SentenceTransformerTrainer(
model=model,
train_dataset=train_dataset,
loss=loss,
)
trainer.train()
"""
super().__init__()
self.model = model
self.scale = scale
self.similarity_fct = similarity_fct
self.tau_plus = tau_plus
self.cross_entropy_loss = nn.CrossEntropyLoss()

def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor:
# Compute the embeddings and distribute them to anchor and candidates (positive and optionally negatives)
embeddings = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features]
anchors = embeddings[0] # (batch_size, embedding_dim)
candidates = torch.cat(embeddings[1:]) # (batch_size * (1 + num_negatives), embedding_dim)

# For every anchor, we compute the similarity to all other candidates (positives and negatives),
# also from other anchors. This gives us a lot of in-batch negatives.
scores: Tensor = self.similarity_fct(anchors, candidates) * self.scale
# (batch_size, batch_size * (1 + num_negatives))

# Compute the mask to remove the similarity of the anchor to the positive candidate.
batch_size = scores.size(0)
mask = torch.ones_like(scores, dtype=torch.bool) # (batch_size, batch_size * (1 + num_negatives))
positive_indices = torch.arange(0, batch_size, device=scores.device)
mask[positive_indices, positive_indices] = False

# Get the similarity of the anchor to the negative candidates.
neg_exp = torch.exp(scores.masked_fill(mask, float("-inf"))).sum(dim=-1) # (batch_size,)
Copy link
Author

@ilanaliouchouche ilanaliouchouche Jan 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, there is an error in my code. The masked_fill method replaces all the True locations in the mask, so ultimately we have pos_exp = neg_exp when computing the positives. This results in $g$ being negative (because tau_plus << 1 - tau_plus). For each $a_i$, $g$ is replaced by $\exp(-\text{scale})$ (hence no learning).

Changes made:

  • Line 113: replaced the ones_like function with zeros_like.
  • Line 115: replaced mask[positive_indices, positive_indices] = False with mask[positive_indices, positive_indices] = True.

Tested, and gradient descent is now active; the loss is no longer stuck at 0. I still need to run some additional tests (the loss behaves differently depending on whether cos_sim or dot_score is used).

# Get the similarity of the anchor to the positive candidate.
pos_exp = torch.exp(torch.gather(scores, -1, positive_indices.unsqueeze(1)).squeeze())
# (batch_size,)

# Compute the g estimator with the exponential of the similarities.
N_neg = scores.size(1) - 1 # Number of negatives
g = torch.clamp((1 / (1 - self.tau_plus)) * ((neg_exp / N_neg) - (self.tau_plus * pos_exp)),
min=torch.exp(-torch.tensor(self.scale)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
min=torch.exp(-torch.tensor(self.scale)))
min=self.scale)

The torch.tensor call results in device mismatches when training on GPUs. Perhaps this can be simplified to just the above.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh right, torch.exp only accepts torch.tensor as input and indeed the device needs to be specified. I think it's better to use np.exp(-scale) s.t we avoid creating a tensor and thus avoid device issues.
I’m keeping the exponential for consistency with the paper.
Capture d’écran 2025-01-06 à 21 04 24

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much better, I totally forgot about keeping the exp - that's quite important, my bad.

# (batch_size,)

# Compute the final debiased loss.
loss = - torch.log(pos_exp / (pos_exp + N_neg * g)).mean()

return loss

def get_config_dict(self) -> dict[str, Any]:
return {"scale": self.scale, "similarity_fct": self.similarity_fct.__name__}

@property
def citation(self) -> str:
return """
@inproceedings{chuang2020debiased,
title={Debiased Contrastive Learning},
author={Ching-Yao Chuang and Joshua Robinson and Lin Yen-Chen and Antonio Torralba and Stefanie Jegelka},
booktitle={Advances in Neural Information Processing Systems},
year={2020},
url={https://arxiv.org/pdf/2007.00224}
}
"""
Loading