import timm
import torch
import torch.nn.functional as F
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from timm.loss import JsdCrossEntropy
from timm.data.mixup import mixup_target
import matplotlib.pyplot as plt

让我们创建一个模型output和我们的labels的例子。注意我们有3个输出预测,但只有1个标签。

output = F.one_hot(torch.tensor([0,9,0])).float()
labels=torch.tensor([0])

如果我们我们将标签smoothingalpha设置为0,那么如果我们只看我们的输出和标签的第一个元素,就会得到常规的cross_entropy loss

jsd = JsdCrossEntropy(smoothing=0,alpha=0)
jsd(output,labels)
tensor(1.4612)
base_loss = F.cross_entropy(output[0,None],labels[0,None])
base_loss
tensor(1.4612)
jsd = JsdCrossEntropy(num_splits=1,smoothing=0,alpha=0)

我们也可以改变分割的数量,从而改变每个组的大小。在Augmix中,这相当于转换混合的数量。

jsd = JsdCrossEntropy(num_splits=2,smoothing=0,alpha=0)
output = F.one_hot(torch.tensor([0,9,1,0])).float()
labels=torch.tensor([0,9])
jsd(output,labels),F.cross_entropy(output[[0,1]],labels)
(tensor(1.4612), tensor(1.4612))

默认情况下,我们有1个标签对应3个预测,这是一个两部分的损失,同时衡量交叉熵和Jensen-Shannon散度。Jensen-Shannon散度不需要标签,而是衡量3个预测之间的差异显著程度。

jsd = JsdCrossEntropy(smoothing=0)
output = F.one_hot(torch.tensor([0,0,0]),num_classes=10).float()
deltas = torch.cat((torch.zeros([2,10]),torch.tensor([[-1,1,0,0,0,0,0,0,0,0]])))*0.1
deltas[2]
tensor([-0.1000,  0.1000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000])
deltas=(torch.arange(-10,11))[...,None,None]*deltas
losses = [jsd((output+delta),labels)-base_loss for delta in deltas]

下面的图表显示了模型输出(预测)在一个组中的变化如何影响Jensen-Shannon散度。

plt.plot([ .1*i-1 for i in range(len(losses))],[loss for loss in losses])
plt.ylabel('JS Divergence')
plt.xlabel('Change in output')
plt.show()