在本教程中,我们将探讨如何利用 AutoAugment 作为数据增强技术来训练神经网络。
我们将了解
- 如何使用
timm
训练脚本应用AutoAugment
。 - 如何将
AutoAugment
作为独立的数据增强技术用于自定义训练循环。 - 深入探讨
AutoAugment
的源代码。
要使用 timm
训练模型并应用自动增强数据策略,只需添加 --aa
标志并将其值设置为 'original'
或 'v1'
,如下所示:
python train.py ../imagenette2-320 --aa original
上面的脚本使用 AutoAugment
作为增强技术来训练神经网络,其策略与论文中提到的相同。
在本节中,我们将看到如何在自己的自定义训练循环中将 AutoAugment
用作独立的数据增强技术。
我们可以简单地使用 timm
中的 auto_augment_transform
函数创建一个名为 tfm
的转换函数。我们将 config_str
和一些 hparams
传递给该函数来创建我们的转换函数。
AutoAugment
,下面创建的转换函数 tfm
需要输入是 PIL.Image
的实例,而不是 torch.tensor
。在 torch.tensor
上调用此函数将导致错误。下面,我们创建了我们的转换函数 tfm
,并创建了一个输入图像 X
,它是一条“丁鲷鱼”的图像,这在本文档的其他地方也使用过。
from timm.data.auto_augment import auto_augment_transform
from PIL import Image
from matplotlib import pyplot as plt
tfm = auto_augment_transform(config_str = 'original', hparams = {'translate_const': 100, 'img_mean': (124, 116, 104)})
X = Image.open('../../imagenette2-320/train/n01440764/ILSVRC2012_val_00000293.JPEG')
plt.imshow(X);
可视化 X
后,我们来应用转换函数,该函数将自动增强策略应用于 X
,然后我们来可视化下面的结果
plt.imshow(tfm(X));
正如我们所见,函数 tfm
将自动增强技术应用于输入图像 X
。
因此,只要确保此函数转换的输入图像是 PIL.Image
类型,我们就可以在自定义训练循环中使用 timm
应用 AutoAugment
。
现在让我们深入探讨以理解 timm
对 AutoAugment
策略的实现。
我们上面使用的 auto_augment_transform
的完整源代码如下所示
def auto_augment_transform(config_str, hparams):
"""
Create a AutoAugment transform
:param config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by
dashes ('-'). The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr').
The remaining sections, not order sepecific determine
'mstd' - float std deviation of magnitude noise applied
Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5
:param hparams: Other hparams (kwargs) for the AutoAugmentation scheme
:return: A PyTorch compatible Transform
"""
config = config_str.split('-')
policy_name = config[0]
config = config[1:]
for c in config:
cs = re.split(r'(\d.*)', c)
if len(cs) < 2:
continue
key, val = cs[:2]
if key == 'mstd':
# noise param injected via hparams for now
hparams.setdefault('magnitude_std', float(val))
else:
assert False, 'Unknown AutoAugment config section'
aa_policy = auto_augment_policy(policy_name, hparams=hparams)
return AutoAugment(aa_policy)
这与 RandAugment 的实现非常相似。基本上,我们传入一个配置字符串,根据该配置字符串,此函数设置一些 hparams
,然后将这些 hparams
传递给 auto_augment_policy
以创建策略。最后,我们将此 aa_policy
包装在 AutoAugment
类周围,该类将被返回以应用于输入数据。
我们来看一下下面的 auto_augment_policy
和 AutoAugment
源代码。
auto_augment_policy
函数的源代码看起来像这样
def auto_augment_policy(name='v0', hparams=None):
hparams = hparams or _HPARAMS_DEFAULT
if name == 'original':
return auto_augment_policy_original(hparams)
基本上,这个函数接受一个策略名称,然后返回相应的增强策略。
我们来看一下下面的 auto_augment_policy_original
函数。
这个函数的源代码看起来像这样
def auto_augment_policy_original(hparams):
# ImageNet policy from https://arxiv.org/abs/1805.09501
policy = [
[('PosterizeOriginal', 0.4, 8), ('Rotate', 0.6, 9)],
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
[('PosterizeOriginal', 0.6, 7), ('PosterizeOriginal', 0.6, 6)],
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
[('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
[('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
[('PosterizeOriginal', 0.8, 5), ('Equalize', 1.0, 2)],
[('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
[('Equalize', 0.6, 8), ('PosterizeOriginal', 0.4, 6)],
[('Rotate', 0.8, 8), ('Color', 0.4, 0)],
[('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
[('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
[('Rotate', 0.8, 8), ('Color', 1.0, 2)],
[('Color', 0.8, 8), ('Solarize', 0.8, 7)],
[('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
[('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
[('Color', 0.4, 0), ('Equalize', 0.6, 3)],
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
]
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
return pc
auto_augment_policy_original
的构建方式实际上非常简单。正如我们从论文中了解到的,ImageNet 策略包含 25 个子策略。我们迭代上面的 policy
以获得一个子策略 sp
。最后,我们还迭代子策略 sp
中称为 a
的每个操作,并将其包装在 AugmentOp
类周围。
我们已经在这里介绍了 AugmentOp
类。
因此,本质上,子策略中的每个操作都会根据上面策略中提到的概率和幅度值转换为 AugmentOp
类的实例。这成为了在 auto_augment_transform
中用于创建 aa_policy
并被返回的策略。
作为 auto_augment_transform
的最后一步,我们将 aa_policy
包装在 AutoAugment
类中,这就是应用于输入数据的内容。因此,让我们来看一下下面的 AutoAugment
class AutoAugment:
def __init__(self, policy):
self.policy = policy
def __call__(self, img):
sub_policy = random.choice(self.policy)
for op in sub_policy:
img = op(img)
return img
实际上,这是 AutoAugment
最简单的实现之一。与论文类似,我们选择一个随机子策略,它包含两个操作(每个操作包括一个增强函数、应用增强函数的幅度(magnitude)和概率(probability)),最后将这些操作应用于 img
以返回增强后的图像。