在本教程中,我们将首先了解如何使用 RandAugment 通过 timm 的训练脚本来训练我们的模型。接着,我们还将了解如何在 timm 中调用 rand_augment_transform 函数,并将 RandAugment 添加到自定义训练循环中。

最后,我们将简要介绍 RandAugment 是什么,并详细查看 timmRandAugment 的实现,以理解其中的差异。

可以在此处查阅 RandAugment 的研究论文。

使用 timm 的训练脚本训练带有 RandAugment 的模型

要使用 randaugment 训练您的模型,只需将带有值的 --aa 参数传递给训练脚本。例如

python train.py ../imagenette2-320 --aa rand-m9-mstd0.5

因此,通过传入值为 rand-m9-mstd0.5--aa 参数,意味着我们将使用 RandAugment,其中增强操作的幅度为 9。传入幅度标准差意味着幅度将根据 mstd 的值而变化。

magnitude = random.gauss(magnitude, magnitude_std)

因此,这意味着幅度会围绕 magnitude 值以标准差 mstd 的高斯分布变化。

在自定义训练脚本中使用 RandAugment

不想使用 timm 的训练脚本,只想在自己的训练脚本中将 RandAugment 方法用作数据增强?

只需像下面所示创建 rand_augment_transform,但请确保当输入图像是 PIL.Image 而不是 torch.tensor 时,您的数据集将此转换应用于输入。也就是说,此方法仅适用于 PIL.Image,不适用于 tensor

标准化和转换为 tensor 的操作可以在应用 RandAugment 数据增强后执行。

让我们快速看一下 timmrand_augment_transform 函数的实际应用示例!

from timm.data.auto_augment import rand_augment_transform
from PIL import Image
from matplotlib import pyplot as plt

tfm = rand_augment_transform(
    config_str='rand-m9-mstd0.5', 
    hparams={'translate_const': 117, 'img_mean': (124, 116, 104)}
)

x   = Image.open("../../imagenette2-320/train/n01440764/ILSVRC2012_val_00000293.JPEG")
plt.imshow(x)

让我们可视化原始图像 x

plt.imshow(x)
<matplotlib.image.AxesImage at 0x7f8f2d7a2520>

太棒了!这是一张“丁鲷鱼”的图像。(如果您不知道“丁鲷鱼”是什么,那您就不是真正的深度学习从业者)

现在让我们可视化图像的转换后版本。

plt.imshow(tfm(x))
<matplotlib.image.AxesImage at 0x7f8f2809f430>

正如我们所见,上面的 rand_augment_transform 正在对我们的输入图像 x 执行数据增强。

什么是 RandAugment

在本节中,我们将首先探讨 RandAugment 是什么,稍后在 1.2 节中,我们将深入研究 timmRandAugment 的实现。请随意跳过,因为它并没有增加更多信息,只是解释了 timm 如何实现 RandAugment

根据论文,RandAugment 可以像这样在 numpy 中实现

transforms = [
    Identity, AutoContrast, Equalize,
    Rotate, Solarize, Color, Posterize,
    Contrast, Brightness, Sharpness,
    ShearX, ShearY, TranslateX, TranslateY]

def randaugment(N, M):
"""Generate a set of distortions.
Args:
N: Number of augmentation transformations to
apply sequentially.
M: Magnitude for all the transformations.
"""
    sampled_ops = np.random.choice(transforms, N)
    return [(op, M) for op in sampled_ops]

基本思路是,我们有一个 transforms 列表,从该列表中选择 N 个转换。然后,我们将该操作以幅度 M 应用于输入图像。就是这样。这就是 RandAugment。让我们看看 timm 如何实现它。

timmRandAugment 的实现

rand_augment_transform

在本节中,我们将深入探讨 rand_augment_transform 函数。让我们看一下源代码

def rand_augment_transform(config_str, hparams):
    """
    Create a RandAugment transform

    :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
    dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
    sections, not order sepecific determine
        'm' - integer magnitude of rand augment
        'n' - integer num layers (number of transform ops selected per image)
        'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
        'mstd' -  float std deviation of magnitude noise applied
        'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
    Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
    'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2

    :param hparams: Other hparams (kwargs) for the RandAugmentation scheme

    :return: A PyTorch compatible Transform
    """
    magnitude = _MAX_LEVEL  # default to _MAX_LEVEL for magnitude (currently 10)
    num_layers = 2  # default to 2 ops per image
    weight_idx = None  # default to no probability weights for op choice
    transforms = _RAND_TRANSFORMS
    config = config_str.split('-')
    assert config[0] == 'rand'
    config = config[1:]
    for c in config:
        cs = re.split(r'(\d.*)', c)
        if len(cs) < 2:
            continue
        key, val = cs[:2]
        if key == 'mstd':
            # noise param injected via hparams for now
            hparams.setdefault('magnitude_std', float(val))
        elif key == 'inc':
            if bool(val):
                transforms = _RAND_INCREASING_TRANSFORMS
        elif key == 'm':
            magnitude = int(val)
        elif key == 'n':
            num_layers = int(val)
        elif key == 'w':
            weight_idx = int(val)
        else:
            assert False, 'Unknown RandAugment config section'
    ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams, transforms=transforms)
    choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx)
    return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)

上述函数的基本思想是——“根据传入的配置字符串 str,更新 hparams 参数,如果传入了 magnitude 的值,则设置其值,否则保持默认值 _MAX_LEVEL,即 10.0。”

同时将 transforms 变量设置为 _RAND_TRANSFORMS_RAND_TRANSFORMS 是一个转换列表,类似于论文中的,看起来像

_RAND_TRANSFORMS = [
    'AutoContrast',
    'Equalize',
    'Invert',
    'Rotate',
    'Posterize',
    'Solarize',
    'SolarizeAdd',
    'Color',
    'Contrast',
    'Brightness',
    'Sharpness',
    'ShearX',
    'ShearY',
    'TranslateXRel',
    'TranslateYRel',
    #'Cutout'  # NOTE I've implement this as random erasing separately
]

设置好 hparamsmagnitudetransforms 变量后,接下来调用 rand_augment_ops 函数来设置变量 ra_ops 的值。最后,我们根据这些变量返回一个 RandAugment 类实例。

所以接下来让我们看看 rand_augment_ops 函数和 RandAugment 类。

rand_augment_ops

此函数的完整源代码如下所示

def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
    hparams = hparams or _HPARAMS_DEFAULT
    transforms = transforms or _RAND_TRANSFORMS
    return [AugmentOp(
        name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms]

基本上,它创建了一个 AugmentOp 类的实例。所以,所有核心逻辑都在 AugmentOp 类里面。我们来看一下。

AugmentOp

让我们看一下这个类的源代码。

class AugmentOp:

    def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
        hparams = hparams or _HPARAMS_DEFAULT
        self.aug_fn = NAME_TO_OP[name]
        self.level_fn = LEVEL_TO_ARG[name]
        self.prob = prob
        self.magnitude = magnitude
        self.hparams = hparams.copy()
        self.kwargs = dict(
            fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL,
            resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION,
        )

        # If magnitude_std is > 0, we introduce some randomness
        # in the usually fixed policy and sample magnitude from a normal distribution
        # with mean `magnitude` and std-dev of `magnitude_std`.
        # NOTE This is my own hack, being tested, not in papers or reference impls.
        self.magnitude_std = self.hparams.get('magnitude_std', 0)

    def __call__(self, img):
        if self.prob < 1.0 and random.random() > self.prob:
            return img
        magnitude = self.magnitude
        if self.magnitude_std and self.magnitude_std > 0:
            magnitude = random.gauss(magnitude, self.magnitude_std)
        magnitude = min(_MAX_LEVEL, max(0, magnitude))  # clip to valid range
        level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
        return self.aug_fn(img, *level_args, **self.kwargs)

上面我们已经知道 self.prob 的值是 0.5。因此,调用这个类时,有 50% 的时间会返回原始图像 img,另外 50% 的时间会实际执行 self.aug_fn

您可能会问这个 self.aug_fn 是什么?还记得 transforms 是一个 _RAND_TRANSFORMS 列表,如下所示吗?

_RAND_TRANSFORMS = [
    'AutoContrast',
    'Equalize',
    'Invert',
    'Rotate',
    'Posterize',
    'Solarize',
    'SolarizeAdd',
    'Color',
    'Contrast',
    'Brightness',
    'Sharpness',
    'ShearX',
    'ShearY',
    'TranslateXRel',
    'TranslateYRel',
    #'Cutout'  # NOTE I've implement this as random erasing separately
]

而且我们为 rand_augment_ops 返回的每个 transforms 创建了一个 AugmentOp 实例列表,就像这样 [AugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms]

实际上,self.aug_fn 首先使用 NAME_TO_OP 字典将名称转换为相应的操作函数。

这个 NAME_TO_OP 不过是一个字典,它将每个 _RAND_TRANSFORMS 名称链接到它们在 timm 中的相应函数实现。

NAME_TO_OP = {
    'AutoContrast': auto_contrast,
    'Equalize': equalize,
    'Invert': invert,
    'Rotate': rotate,
    'Posterize': posterize,
    'PosterizeIncreasing': posterize,
    'PosterizeOriginal': posterize,
    'Solarize': solarize,
    'SolarizeIncreasing': solarize,
    'SolarizeAdd': solarize_add,
    'Color': color,
    'ColorIncreasing': color,
    'Contrast': contrast,
    'ContrastIncreasing': contrast,
    'Brightness': brightness,
    'BrightnessIncreasing': brightness,
    'Sharpness': sharpness,
    'SharpnessIncreasing': sharpness,
    'ShearX': shear_x,
    'ShearY': shear_y,
    'TranslateX': translate_x_abs,
    'TranslateY': translate_y_abs,
    'TranslateXRel': translate_x_rel,
    'TranslateYRel': translate_y_rel,
}

总之,这个 AugmentOp 不过是 self.aug_fn 的一个包装器,它接受一个图像 img,并且只有 50% 的时间会在 img 上执行 self.aug_fn。否则,它只会返回未更改的 img

很好,所以 rand_augment_transform 函数中的 ra_ops 变量不过是 AugmentOp 类实例的列表,这只是意味着我们将给定的数据增强函数以 50% 的概率应用于图像。

最后,正如我们在 rand_augment_transform 的源代码中看到的那样,返回的实际上是 RandAugment 类的一个实例,该实例接受 ra_opschoice_weightsnum_layers 作为参数。所以接下来我们来看一下它。

RandAugment

此类的完整源代码如下所示

class RandAugment:
    def __init__(self, ops, num_layers=2, choice_weights=None):
        self.ops = ops
        self.num_layers = num_layers
        self.choice_weights = choice_weights

    def __call__(self, img):
        # no replacement when using weighted choice
        ops = np.random.choice(
            self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights)
        for op in ops:
            img = op(img)
        return img

如前所述,传递给 RandAugment 的 ra_ops 不过是一个 AugmentOp 实例列表,这些实例包装了 _RAND_TRANSFORMS 中的各种转换,所以这个 ops 看起来像这样

ops = [<timm.data.auto_augment.AugmentOp object at 0x7f7a03466990>, <timm.data.auto_augment.AugmentOp object at 0x7f7a03466c50>, <timm.data.auto_augment.AugmentOp object at 0x7f7a03466650>, <timm.data.auto_augment.AugmentOp object at 0x7f7a034666d0>, <timm.data.auto_augment.AugmentOp object at 0x7f7a03466e10>, <timm.data.auto_augment.AugmentOp object at 0x7f7a03466490>, <timm.data.auto_augment.AugmentOp object at 0x7f7a03466750>, <timm.data.auto_augment.AugmentOp object at 0x7f7a034667d0>, <timm.data.auto_augment.AugmentOp object at 0x7f7a03466410>, <timm.data.auto_augment.AugmentOp object at 0x7f7a03466710>, <timm.data.auto_augment.AugmentOp object at 0x7f7a03466190>, <timm.data.auto_augment.AugmentOp object at 0x7f7a03466450>, <timm.data.auto_augment.AugmentOp object at 0x7f7a034664d0>, <timm.data.auto_augment.AugmentOp object at 0x7f7a03466150>, <timm.data.auto_augment.AugmentOp object at 0x7f7a034661d0>]

可以看出,ops 不过是 AugmentOp 实例的列表。基本上,每个转换都被这个 AugmentOp 类包装起来,这意味着该 transform 只会被应用 50% 的时间。

接下来,对于每张图像 img,我们选择 num_layers 个随机增强操作,并将其应用于图像,就像此类中的 __call__ 方法所示。

ops = np.random.choice(
            self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights)
for op in ops:
    img = op(img)

最后,我们返回这张增强后的图像。