timm 支持多种数据增强方法,其中一种是 MixupCutMix 紧随 Mixup 之后提出,大多数深度学习从业者在训练流程中会使用 Mixup 或 CutMix 来提高性能。

但是,使用 timm 可以选择同时使用两者!在本教程中,我们将专门探讨实现训练期间 MixUpCutMix 数据增强的各种训练参数,并深入研究库的内部以了解 timm 中是如何实现这一点的。

使用 Mixup/Cutmix 数据增强训练神经网络

应用 Mixup/CutMix 数据增强时需要关注的各种训练参数如下:

--mixup MIXUP         mixup alpha, mixup enabled if > 0. (default: 0.)
--cutmix CUTMIX       cutmix alpha, cutmix enabled if > 0. (default: 0.)
--cutmix-minmax CUTMIX_MINMAX [CUTMIX_MINMAX ...]
                    cutmix min/max ratio, overrides alpha and enables
                    cutmix if set (default: None)
--mixup-prob MIXUP_PROB
                    Probability of performing mixup or cutmix when
                    either/both is enabled
--mixup-switch-prob MIXUP_SWITCH_PROB
                    Probability of switching to cutmix when both mixup and
                    cutmix enabled
--mixup-mode MIXUP_MODE
                    How to apply mixup/cutmix params. Per "batch", "pair",
                    or "elem"
--mixup-off-epoch N   Turn off mixup after this epoch, disabled if 0. (default: 0.)

仅使用 Mixup

要仅启用 Mixup 训练网络,只需传入 --mixup 参数并设置其值为 Mixup 的 alpha 值。
数据增强的默认概率是 1.0,如果需要更改,请使用 --mixup-prob 参数设置新值。

python train.py ../imagenette2-320 --mixup 0.5
python train.py ../imagenette2-320 --mixup 0.5 --mixup-prob 0.7

仅使用 CutMix

要仅启用 CutMix 训练网络,只需传入 --cutmix 参数并设置其值为 Cutmix 的 alpha 值。
数据增强的默认概率是 1.0,如果需要更改,请使用 --mixup-prob 参数设置新值。

python train.py ../imagenette2-320 --cutmix 0.2
python train.py ../imagenette2-320 --cutmix 0.2 --mixup-prob 0.7

同时使用 Mixup 和 Cutmix

要同时启用两者来训练神经网络,

python train.py ../imagenette2-320 --cutmix 0.4 --mixup 0.5

在 Mixup 和 Cutmix 之间切换的默认概率是 0.5。
要更改此概率,请使用 --mixup-switch-prob 参数。该值表示切换到 Cutmix 的概率。

python train.py ../imagenette2-320 --cutmix 0.4 --mixup 0.5 --mixup-switch-prob 0.4

可视化 Mixup & Cutmix

在内部,timm 库中有一个名为 Mixup 的类,它能够实现 Mixup 和 Cutmix。

import torch
from timm.data.mixup import Mixup
from timm.data.dataset import ImageDataset
from timm.data.loader import create_loader
def get_dataset_and_loader(mixup_args):
    mixup_fn = Mixup(**mixup_args)
    dataset = ImageDataset('../../imagenette2-320')
    loader = create_loader(dataset, 
                           input_size=(3,224,224), 
                           batch_size=4, 
                           is_training=True, 
                           use_prefetcher=False)
    return mixup_fn, dataset, loader

可视化一些应用了 Mixup 的图像

import torchvision
import numpy as np
from matplotlib import pyplot as plt
def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated
mixup_args = {
    'mixup_alpha': 1.,
    'cutmix_alpha': 0.,
    'cutmix_minmax': None,
    'prob': 1.0,
    'switch_prob': 0.,
    'mode': 'batch',
    'label_smoothing': 0,
    'num_classes': 1000}
mixup_fn, dataset, loader = get_dataset_and_loader(mixup_args)
inputs, classes = next(iter(loader))
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[x.item() for x in classes])
inputs, classes = mixup_fn(inputs, classes)
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[x.item() for x in classes.argmax(1)])

可视化一些应用了 Cutmix 的图像

mixup_args = {
    'mixup_alpha': 0.,
    'cutmix_alpha': 1.0,
    'cutmix_minmax': None,
    'prob': 1.0,
    'switch_prob': 0.,
    'mode': 'batch',
    'label_smoothing': 0,
    'num_classes': 1000}
mixup_fn, dataset, loader = get_dataset_and_loader(mixup_args)
inputs, classes = next(iter(loader))
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[x.item() for x in classes])
inputs, classes = mixup_fn(inputs, classes)
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[x.item() for x in classes.argmax(1)])

Mixup 内部是如何工作的?

def mixup(x, lam):
    """Applies mixup to input batch of images `x`
    
    Args:
    x (torch.Tensor): input batch tensor of shape (bs, 3, H, W)
    lam (float): Amount of MixUp
    """
    x_flipped = x.flip(0).mul_(1-lam)
    x.mul_(lam).add_(x_flipped)
    return x
mixup_fn, dataset, loader = get_dataset_and_loader(mixup_args)
inputs, classes = next(iter(loader))
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[x.item() for x in classes])
imshow(
    torchvision.utils.make_grid(
        mixup(inputs, 0.3)
    ), 
    title=[x.item() for x in classes])

逐元素 Mixup/Cutmix

timm 中也可以进行逐元素(elementwise)的 Mixup/Cutmix。据我所知,这是唯一一个允许进行逐元素 Mixup 和 Cutmix 的库!

到目前为止,所有操作都是按批次(batch-wise)应用的。也就是说,Mixup 是对批次中的所有元素进行的。但是,通过向 Mixup 函数传入参数 mode = 'elem',我们可以将其更改为逐元素模式。

在这种情况下,CutmixMixup 会根据 mixup_args 应用到批次中的每个项目。

如下所示,Cutmix 应用于批次中的第一、第二和第三个项目,而 Mixup 应用于第四个项目。

mixup_args = {
    'mixup_alpha': 0.3,
    'cutmix_alpha': 0.3,
    'cutmix_minmax': None,
    'prob': 1.0,
    'switch_prob': 0.5,
    'mode': 'elem',
    'label_smoothing': 0,
    'num_classes': 1000}
mixup_fn, dataset, loader = get_dataset_and_loader(mixup_args)
inputs, classes = next(iter(loader))
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[x.item() for x in classes])
inputs, classes = mixup_fn(inputs, classes)
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[x.item() for x in classes.argmax(1)])