import timm
import torch
import torch.nn.functional as F
from timm.loss import AsymmetricLossMultiLabel, AsymmetricLossSingleLabel
import matplotlib.pyplot as plt
from PIL import Image
from pathlib import Path
让我们创建一个模型output
和我们的labels
的示例。
output = F.one_hot(torch.tensor([0,9,0])).float()
labels=torch.tensor([0,0,0])
labels, output
如果我们将所有参数设置为0,损失就变成了F.cross_entropy
损失。
asl = AsymmetricLossSingleLabel(gamma_pos=0,gamma_neg=0,eps=0.0)
asl(output,labels)
F.cross_entropy(output,labels)
现在让我们看看非对称部分。ASL在处理正例和负例的方式上是非对称的。正例是图像中存在的标签,而负例是图像中不存在的标签。其想法是,一幅图像包含大量容易识别的负例、少量难以识别的负例以及极少的正例。消除容易识别的负例的影响,应该有助于强调正例的梯度。
Image.open(Path()/'images/cat.jpg')
请注意,这张图片包含一只猫,这将是一个正标签。这张图片不包含狗、大象、熊、长颈鹿、斑马、香蕉以及coco数据集中发现的许多其他标签,这些都将是负例。很容易看出这张图片中没有长颈鹿。
output = (2*F.one_hot(torch.tensor([0,9,0]))-1).float()
labels=torch.tensor([0,9,0])
losses=[AsymmetricLossSingleLabel(gamma_neg=i*0.04+1,eps=0.1,reduction='mean')(output,labels) for i in range(int(80))]
plt.plot([ i*0.04+1 for i,l in enumerate(losses)],[loss for loss in losses])
plt.ylabel('Loss')
plt.xlabel('Change in gamma_neg')
plt.show()
$$L_- = (p)^{\gamma-}\log(1-p) $$
随着gamma_neg的增加,小的负例的贡献迅速减小,因为$\gamma-$是一个指数,$p$应该是一个接近0的小数。
在下方我们将eps
设置为0,这使得上面的图完全变平了,我们不再应用标签平滑,因此负例最终不会对损失产生贡献。
losses=[AsymmetricLossSingleLabel(gamma_neg=0+i*0.02,eps=0.0,reduction='mean')(output,labels) for i in range(100)]
plt.plot([ i*0.04 for i in range(len(losses))],[loss for loss in losses])
plt.ylabel('Loss')
plt.xlabel('Change in gamma_neg')
plt.show()
AsymmetricLossMultiLabel
允许处理多标签问题。
labels=F.one_hot(torch.LongTensor([0,0,0]),num_classes=10)+F.one_hot(torch.LongTensor([1,9,1]),num_classes=10)
labels
AsymmetricLossMultiLabel()(output,labels)
对于AsymmetricLossMultiLabel
,还存在另一个参数叫做clip
。它将负例的较小输入钳制为0。这被称为非对称概率偏移 (Asymmetric Probability Shifting)。
losses=[AsymmetricLossMultiLabel(clip=i/100)(output,labels) for i in range(100)]
plt.plot([ i/100 for i in range(len(losses))],[loss for loss in losses])
plt.ylabel('Loss')
plt.xlabel('Clip')
plt.show()