应用 Mixup
/CutMix
数据增强时需要关注的各种训练参数如下:
--mixup MIXUP mixup alpha, mixup enabled if > 0. (default: 0.)
--cutmix CUTMIX cutmix alpha, cutmix enabled if > 0. (default: 0.)
--cutmix-minmax CUTMIX_MINMAX [CUTMIX_MINMAX ...]
cutmix min/max ratio, overrides alpha and enables
cutmix if set (default: None)
--mixup-prob MIXUP_PROB
Probability of performing mixup or cutmix when
either/both is enabled
--mixup-switch-prob MIXUP_SWITCH_PROB
Probability of switching to cutmix when both mixup and
cutmix enabled
--mixup-mode MIXUP_MODE
How to apply mixup/cutmix params. Per "batch", "pair",
or "elem"
--mixup-off-epoch N Turn off mixup after this epoch, disabled if 0. (default: 0.)
要仅启用 Mixup 训练网络,只需传入 --mixup
参数并设置其值为 Mixup 的 alpha 值。
数据增强的默认概率是 1.0,如果需要更改,请使用 --mixup-prob
参数设置新值。
python train.py ../imagenette2-320 --mixup 0.5
python train.py ../imagenette2-320 --mixup 0.5 --mixup-prob 0.7
要仅启用 CutMix 训练网络,只需传入 --cutmix
参数并设置其值为 Cutmix 的 alpha 值。
数据增强的默认概率是 1.0,如果需要更改,请使用 --mixup-prob
参数设置新值。
python train.py ../imagenette2-320 --cutmix 0.2
python train.py ../imagenette2-320 --cutmix 0.2 --mixup-prob 0.7
要同时启用两者来训练神经网络,
python train.py ../imagenette2-320 --cutmix 0.4 --mixup 0.5
在 Mixup 和 Cutmix 之间切换的默认概率是 0.5。
要更改此概率,请使用 --mixup-switch-prob
参数。该值表示切换到 Cutmix 的概率。
python train.py ../imagenette2-320 --cutmix 0.4 --mixup 0.5 --mixup-switch-prob 0.4
在内部,timm
库中有一个名为 Mixup
的类,它能够实现 Mixup 和 Cutmix。
import torch
from timm.data.mixup import Mixup
from timm.data.dataset import ImageDataset
from timm.data.loader import create_loader
def get_dataset_and_loader(mixup_args):
mixup_fn = Mixup(**mixup_args)
dataset = ImageDataset('../../imagenette2-320')
loader = create_loader(dataset,
input_size=(3,224,224),
batch_size=4,
is_training=True,
use_prefetcher=False)
return mixup_fn, dataset, loader
import torchvision
import numpy as np
from matplotlib import pyplot as plt
def imshow(inp, title=None):
"""Imshow for Tensor."""
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
plt.imshow(inp)
if title is not None:
plt.title(title)
plt.pause(0.001) # pause a bit so that plots are updated
mixup_args = {
'mixup_alpha': 1.,
'cutmix_alpha': 0.,
'cutmix_minmax': None,
'prob': 1.0,
'switch_prob': 0.,
'mode': 'batch',
'label_smoothing': 0,
'num_classes': 1000}
mixup_fn, dataset, loader = get_dataset_and_loader(mixup_args)
inputs, classes = next(iter(loader))
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[x.item() for x in classes])
inputs, classes = mixup_fn(inputs, classes)
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[x.item() for x in classes.argmax(1)])
mixup_args = {
'mixup_alpha': 0.,
'cutmix_alpha': 1.0,
'cutmix_minmax': None,
'prob': 1.0,
'switch_prob': 0.,
'mode': 'batch',
'label_smoothing': 0,
'num_classes': 1000}
mixup_fn, dataset, loader = get_dataset_and_loader(mixup_args)
inputs, classes = next(iter(loader))
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[x.item() for x in classes])
inputs, classes = mixup_fn(inputs, classes)
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[x.item() for x in classes.argmax(1)])
def mixup(x, lam):
"""Applies mixup to input batch of images `x`
Args:
x (torch.Tensor): input batch tensor of shape (bs, 3, H, W)
lam (float): Amount of MixUp
"""
x_flipped = x.flip(0).mul_(1-lam)
x.mul_(lam).add_(x_flipped)
return x
mixup_fn, dataset, loader = get_dataset_and_loader(mixup_args)
inputs, classes = next(iter(loader))
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[x.item() for x in classes])
imshow(
torchvision.utils.make_grid(
mixup(inputs, 0.3)
),
title=[x.item() for x in classes])
在 timm
中也可以进行逐元素(elementwise)的 Mixup/Cutmix。据我所知,这是唯一一个允许进行逐元素 Mixup 和 Cutmix 的库!
到目前为止,所有操作都是按批次(batch-wise)应用的。也就是说,Mixup 是对批次中的所有元素进行的。但是,通过向 Mixup
函数传入参数 mode = 'elem'
,我们可以将其更改为逐元素模式。
在这种情况下,Cutmix
或 Mixup
会根据 mixup_args
应用到批次中的每个项目。
如下所示,Cutmix 应用于批次中的第一、第二和第三个项目,而 Mixup 应用于第四个项目。
mixup_args = {
'mixup_alpha': 0.3,
'cutmix_alpha': 0.3,
'cutmix_minmax': None,
'prob': 1.0,
'switch_prob': 0.5,
'mode': 'elem',
'label_smoothing': 0,
'num_classes': 1000}
mixup_fn, dataset, loader = get_dataset_and_loader(mixup_args)
inputs, classes = next(iter(loader))
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[x.item() for x in classes])
inputs, classes = mixup_fn(inputs, classes)
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[x.item() for x in classes.argmax(1)])