import timm
import torch
import torch.nn.functional as F
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from timm.data.mixup import mixup_target
与带有标签平滑的 NLL 损失相同。标签平滑在模型正确时增加损失 x
,在模型不正确 x_i
时减少损失。当预期存在错误标签时,可以使用此方法来减轻对模型的惩罚。
x = torch.eye(2)
x_i = 1 - x
y = torch.arange(2)
LabelSmoothingCrossEntropy(0.0)(x,y),LabelSmoothingCrossEntropy(0.0)(x_i,y)
LabelSmoothingCrossEntropy(0.1)(x,y),LabelSmoothingCrossEntropy(0.1)(x_i,y)
与 mixup 一起使用的 log_softmax
系列损失函数。使用 mixup_target 来添加标签平滑并调整目标标签的混合量。
x=torch.tensor([[[0,1.,0,0,1.]],[[1.,1.,1.,1.,1.]]],device='cuda')
y=mixup_target(torch.tensor([1,4],device='cuda'),5, lam=0.7)
x,y
SoftTargetCrossEntropy()(x[0],y),SoftTargetCrossEntropy()(x[1],y)