Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-label contrastive loss implementation #13

Open
Beichenqwq opened this issue Jan 29, 2024 · 1 comment
Open

Multi-label contrastive loss implementation #13

Beichenqwq opened this issue Jan 29, 2024 · 1 comment

Comments

@Beichenqwq
Copy link

Hello, thank you for your great work. Could you please tell me where the Multi-label contrastive loss in your paper is implemented? I really can't find it, thank you very much.

@HAL-42
Copy link
Owner

HAL-42 commented Jan 29, 2024

It's at src/libs/loss/multi_cls/cl_loss.py:

import torch
from torch import nn


class MultiLabelCLLoss(nn.Module):

    def __init__(self,
                 gamma: float | None = None,
                 reduce: str = 'pos_mean'):
        """将多标签分类的正类视作正样本,负类视作负样本,计算对比损失。

        Args:
            gamma: 相似度的放缩因子。
            reduce: 可以为:pos_mean,对每个正负样本对求平均;sample_mean,先对每个样本的正负样本对求平均,
                再对所有样本求平均。
        """
        super().__init__()
        self.gamma = gamma
        self.reduce = reduce

    def forward(self, S: torch.Tensor, cls_lb: torch.Tensor) -> torch.Tensor:
        """将多标签分类的正类视作正样本,负类视作负样本,计算对比损失。

        Args:
            S: (N, G)的相似度图。
            cls_lb: (N, G)的类别标签。

        Returns:
            多标签对比损失。
        """
        # * 提前计算要用到的索引、数量。
        # ** 计算有效anchor的数量。有限anchor指的是至少有一个正样本和负样本的anchor。
        valid_anchor_mask = torch.any(cls_lb == 0, dim=1) & torch.any(cls_lb == 1, dim=1)
        valid_anchor_num = valid_anchor_mask.sum(dtype=torch.int32)
        if valid_anchor_num == 0:
            return S.mean() * 0
        # ** 计算正负样本掩码。
        neg_mask = (cls_lb == 0) * valid_anchor_mask[:, None]
        pos_mask = (cls_lb == 1) * valid_anchor_mask[:, None]
        # ** 计算正样本的锚序号。
        pi_a = torch.nonzero(pos_mask, as_tuple=True)[0]
        # ** 对每个正样本,计算与之共享anchor的正样本数量。
        pi_I = pos_mask.sum(dim=1, dtype=torch.int32)[pi_a]

        # * 计算CL损失。
        # ** 计算放缩后的相似度。
        S = self.gamma * S if self.gamma is not None else S
        # ** 计算所有相似度的指数。
        e_S = torch.exp(S)
        # * 计算Σje^Snj。
        Sigma_e_Sn_j = (e_S * neg_mask).sum(dim=1)
        # * 得到Spi, e^Spi。
        Spi = S[pos_mask]
        e_Spi = e_S[pos_mask]
        # * 得到每个正样本对应的Σje^Snj。
        pi_Sigma_e_Sn_j = Sigma_e_Sn_j[pi_a]

        # * 对每个正样本,计算对比损失。
        pi_cl_loss = -(Spi - torch.log(e_Spi + pi_Sigma_e_Sn_j))

        # * 按照reduction计算平均损失。
        match self.reduce:
            case 'pos_mean':
                return pi_cl_loss.mean()
            case 'sample_mean':
                pi_cl_loss = pi_cl_loss / pi_I
                return pi_cl_loss.sum() / valid_anchor_num
            case _:
                raise ValueError(f"不支持的{self.reduce=}。")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants