单标签非对称损失

import timm
import torch
import torch.nn.functional as F
from timm.loss import AsymmetricLossMultiLabel, AsymmetricLossSingleLabel
import matplotlib.pyplot as plt
from PIL import Image
from pathlib import Path

让我们创建一个模型output和我们的labels的示例。

output = F.one_hot(torch.tensor([0,9,0])).float()
labels=torch.tensor([0,0,0])
labels, output
(tensor([0, 0, 0]),
 tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
         [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]))

如果我们将所有参数设置为0,损失就变成了F.cross_entropy损失。

asl = AsymmetricLossSingleLabel(gamma_pos=0,gamma_neg=0,eps=0.0)
asl(output,labels)
tensor(1.7945)
F.cross_entropy(output,labels)
tensor(1.7945)

现在让我们看看非对称部分。ASL在处理正例和负例的方式上是非对称的。正例是图像中存在的标签,而负例是图像中不存在的标签。其想法是,一幅图像包含大量容易识别的负例、少量难以识别的负例以及极少的正例。消除容易识别的负例的影响,应该有助于强调正例的梯度。

Image.open(Path()/'images/cat.jpg')

请注意,这张图片包含一只猫,这将是一个正标签。这张图片不包含狗、大象、熊、长颈鹿、斑马、香蕉以及coco数据集中发现的许多其他标签,这些都将是负例。很容易看出这张图片中没有长颈鹿。

output = (2*F.one_hot(torch.tensor([0,9,0]))-1).float()
labels=torch.tensor([0,9,0])
losses=[AsymmetricLossSingleLabel(gamma_neg=i*0.04+1,eps=0.1,reduction='mean')(output,labels) for i in range(int(80))]
plt.plot([ i*0.04+1 for i,l in enumerate(losses)],[loss for loss in losses])
plt.ylabel('Loss')
plt.xlabel('Change in gamma_neg')
plt.show()

$$L_- = (p)^{\gamma-}\log(1-p) $$

随着gamma_neg的增加,小的负例的贡献迅速减小,因为$\gamma-$是一个指数,$p$应该是一个接近0的小数。

在下方我们将eps设置为0,这使得上面的图完全变平了,我们不再应用标签平滑,因此负例最终不会对损失产生贡献。

losses=[AsymmetricLossSingleLabel(gamma_neg=0+i*0.02,eps=0.0,reduction='mean')(output,labels) for i in range(100)]
plt.plot([ i*0.04 for i in range(len(losses))],[loss for loss in losses])
plt.ylabel('Loss')
plt.xlabel('Change in gamma_neg')
plt.show()

多标签非对称损失 (AsymmetricLossMultiLabel)

AsymmetricLossMultiLabel允许处理多标签问题。

labels=F.one_hot(torch.LongTensor([0,0,0]),num_classes=10)+F.one_hot(torch.LongTensor([1,9,1]),num_classes=10)
labels
tensor([[1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 1],
        [1, 1, 0, 0, 0, 0, 0, 0, 0, 0]])
AsymmetricLossMultiLabel()(output,labels)
tensor(3.1466)

对于AsymmetricLossMultiLabel,还存在另一个参数叫做clip。它将负例的较小输入钳制为0。这被称为非对称概率偏移 (Asymmetric Probability Shifting)。

losses=[AsymmetricLossMultiLabel(clip=i/100)(output,labels) for i in range(100)]
plt.plot([ i/100 for i in range(len(losses))],[loss for loss in losses])
plt.ylabel('Loss')
plt.xlabel('Clip')
plt.show()