import timm
import torch
import torch.nn.functional as F
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from timm.loss import JsdCrossEntropy
from timm.data.mixup import mixup_target
import matplotlib.pyplot as plt
让我们创建一个模型output
和我们的labels
的例子。注意我们有3个输出预测,但只有1个标签。
output = F.one_hot(torch.tensor([0,9,0])).float()
labels=torch.tensor([0])
如果我们我们将标签smoothing
和alpha
设置为0,那么如果我们只看我们的输出和标签的第一个元素,就会得到常规的cross_entropy loss
。
jsd = JsdCrossEntropy(smoothing=0,alpha=0)
jsd(output,labels)
base_loss = F.cross_entropy(output[0,None],labels[0,None])
base_loss
jsd = JsdCrossEntropy(num_splits=1,smoothing=0,alpha=0)
我们也可以改变分割的数量,从而改变每个组的大小。在Augmix
中,这相当于转换混合的数量。
jsd = JsdCrossEntropy(num_splits=2,smoothing=0,alpha=0)
output = F.one_hot(torch.tensor([0,9,1,0])).float()
labels=torch.tensor([0,9])
jsd(output,labels),F.cross_entropy(output[[0,1]],labels)
默认情况下,我们有1个标签对应3个预测,这是一个两部分的损失,同时衡量交叉熵和Jensen-Shannon散度。Jensen-Shannon散度不需要标签,而是衡量3个预测之间的差异显著程度。
jsd = JsdCrossEntropy(smoothing=0)
output = F.one_hot(torch.tensor([0,0,0]),num_classes=10).float()
deltas = torch.cat((torch.zeros([2,10]),torch.tensor([[-1,1,0,0,0,0,0,0,0,0]])))*0.1
deltas[2]
deltas=(torch.arange(-10,11))[...,None,None]*deltas
losses = [jsd((output+delta),labels)-base_loss for delta in deltas]
下面的图表显示了模型输出(预测)在一个组中的变化如何影响Jensen-Shannon散度。
plt.plot([ .1*i-1 for i in range(len(losses))],[loss for loss in losses])
plt.ylabel('JS Divergence')
plt.xlabel('Change in output')
plt.show()