在本教程中,我们将首先了解如何使用 RandAugment
通过 timm
的训练脚本来训练我们的模型。接着,我们还将了解如何在 timm
中调用 rand_augment_transform
函数,并将 RandAugment
添加到自定义训练循环中。
最后,我们将简要介绍 RandAugment
是什么,并详细查看 timm
中 RandAugment
的实现,以理解其中的差异。
可以在此处查阅 RandAugment
的研究论文。
要使用 randaugment
训练您的模型,只需将带有值的 --aa
参数传递给训练脚本。例如
python train.py ../imagenette2-320 --aa rand-m9-mstd0.5
因此,通过传入值为 rand-m9-mstd0.5
的 --aa
参数,意味着我们将使用 RandAugment
,其中增强操作的幅度为 9
。传入幅度标准差意味着幅度将根据 mstd
的值而变化。
magnitude = random.gauss(magnitude, magnitude_std)
因此,这意味着幅度会围绕 magnitude
值以标准差 mstd
的高斯分布变化。
不想使用 timm
的训练脚本,只想在自己的训练脚本中将 RandAugment
方法用作数据增强?
只需像下面所示创建 rand_augment_transform
,但请确保当输入图像是 PIL.Image
而不是 torch.tensor
时,您的数据集将此转换应用于输入。也就是说,此方法仅适用于 PIL.Image
,不适用于 tensor
。
标准化和转换为 tensor 的操作可以在应用 RandAugment
数据增强后执行。
让我们快速看一下 timm
中 rand_augment_transform
函数的实际应用示例!
rand_augment_transform
函数的 config_str
和 hparams
参数。这将在本教程后面解释。from timm.data.auto_augment import rand_augment_transform
from PIL import Image
from matplotlib import pyplot as plt
tfm = rand_augment_transform(
config_str='rand-m9-mstd0.5',
hparams={'translate_const': 117, 'img_mean': (124, 116, 104)}
)
x = Image.open("../../imagenette2-320/train/n01440764/ILSVRC2012_val_00000293.JPEG")
plt.imshow(x)
让我们可视化原始图像 x
。
plt.imshow(x)
太棒了!这是一张“丁鲷鱼”的图像。(如果您不知道“丁鲷鱼”是什么,那您就不是真正的深度学习从业者)
现在让我们可视化图像的转换后版本。
rand_augment_transform
函数实际上期望输入是 PIL.Image
,而不是 torch.Tensor
。plt.imshow(tfm(x))
正如我们所见,上面的 rand_augment_transform
正在对我们的输入图像 x
执行数据增强。
在本节中,我们将首先探讨 RandAugment
是什么,稍后在 1.2
节中,我们将深入研究 timm
中 RandAugment
的实现。请随意跳过,因为它并没有增加更多信息,只是解释了 timm
如何实现 RandAugment
。
根据论文,RandAugment
可以像这样在 numpy 中实现
transforms = [
’Identity’, ’AutoContrast’, ’Equalize’,
’Rotate’, ’Solarize’, ’Color’, ’Posterize’,
’Contrast’, ’Brightness’, ’Sharpness’,
’ShearX’, ’ShearY’, ’TranslateX’, ’TranslateY’]
def randaugment(N, M):
"""Generate a set of distortions.
Args:
N: Number of augmentation transformations to
apply sequentially.
M: Magnitude for all the transformations.
"""
sampled_ops = np.random.choice(transforms, N)
return [(op, M) for op in sampled_ops]
基本思路是,我们有一个 transforms
列表,从该列表中选择 N
个转换。然后,我们将该操作以幅度 M
应用于输入图像。就是这样。这就是 RandAugment
。让我们看看 timm
如何实现它。
在本节中,我们将深入探讨 rand_augment_transform
函数。让我们看一下源代码
def rand_augment_transform(config_str, hparams):
"""
Create a RandAugment transform
:param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
sections, not order sepecific determine
'm' - integer magnitude of rand augment
'n' - integer num layers (number of transform ops selected per image)
'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
'mstd' - float std deviation of magnitude noise applied
'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
:param hparams: Other hparams (kwargs) for the RandAugmentation scheme
:return: A PyTorch compatible Transform
"""
magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10)
num_layers = 2 # default to 2 ops per image
weight_idx = None # default to no probability weights for op choice
transforms = _RAND_TRANSFORMS
config = config_str.split('-')
assert config[0] == 'rand'
config = config[1:]
for c in config:
cs = re.split(r'(\d.*)', c)
if len(cs) < 2:
continue
key, val = cs[:2]
if key == 'mstd':
# noise param injected via hparams for now
hparams.setdefault('magnitude_std', float(val))
elif key == 'inc':
if bool(val):
transforms = _RAND_INCREASING_TRANSFORMS
elif key == 'm':
magnitude = int(val)
elif key == 'n':
num_layers = int(val)
elif key == 'w':
weight_idx = int(val)
else:
assert False, 'Unknown RandAugment config section'
ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams, transforms=transforms)
choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx)
return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
上述函数的基本思想是——“根据传入的配置字符串 str
,更新 hparams
参数,如果传入了 magnitude
的值,则设置其值,否则保持默认值 _MAX_LEVEL
,即 10.0。”
同时将 transforms
变量设置为 _RAND_TRANSFORMS
。_RAND_TRANSFORMS
是一个转换列表,类似于论文中的,看起来像
_RAND_TRANSFORMS = [
'AutoContrast',
'Equalize',
'Invert',
'Rotate',
'Posterize',
'Solarize',
'SolarizeAdd',
'Color',
'Contrast',
'Brightness',
'Sharpness',
'ShearX',
'ShearY',
'TranslateXRel',
'TranslateYRel',
#'Cutout' # NOTE I've implement this as random erasing separately
]
设置好 hparams
、magnitude
和 transforms
变量后,接下来调用 rand_augment_ops
函数来设置变量 ra_ops
的值。最后,我们根据这些变量返回一个 RandAugment
类实例。
所以接下来让我们看看 rand_augment_ops
函数和 RandAugment
类。
此函数的完整源代码如下所示
def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
hparams = hparams or _HPARAMS_DEFAULT
transforms = transforms or _RAND_TRANSFORMS
return [AugmentOp(
name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms]
prob=0.5
。基本上,它创建了一个 AugmentOp
类的实例。所以,所有核心逻辑都在 AugmentOp
类里面。我们来看一下。
让我们看一下这个类的源代码。
class AugmentOp:
def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
hparams = hparams or _HPARAMS_DEFAULT
self.aug_fn = NAME_TO_OP[name]
self.level_fn = LEVEL_TO_ARG[name]
self.prob = prob
self.magnitude = magnitude
self.hparams = hparams.copy()
self.kwargs = dict(
fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL,
resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION,
)
# If magnitude_std is > 0, we introduce some randomness
# in the usually fixed policy and sample magnitude from a normal distribution
# with mean `magnitude` and std-dev of `magnitude_std`.
# NOTE This is my own hack, being tested, not in papers or reference impls.
self.magnitude_std = self.hparams.get('magnitude_std', 0)
def __call__(self, img):
if self.prob < 1.0 and random.random() > self.prob:
return img
magnitude = self.magnitude
if self.magnitude_std and self.magnitude_std > 0:
magnitude = random.gauss(magnitude, self.magnitude_std)
magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
return self.aug_fn(img, *level_args, **self.kwargs)
上面我们已经知道 self.prob
的值是 0.5。因此,调用这个类时,有 50% 的时间会返回原始图像 img
,另外 50% 的时间会实际执行 self.aug_fn
。
您可能会问这个 self.aug_fn
是什么?还记得 transforms
是一个 _RAND_TRANSFORMS
列表,如下所示吗?
_RAND_TRANSFORMS = [
'AutoContrast',
'Equalize',
'Invert',
'Rotate',
'Posterize',
'Solarize',
'SolarizeAdd',
'Color',
'Contrast',
'Brightness',
'Sharpness',
'ShearX',
'ShearY',
'TranslateXRel',
'TranslateYRel',
#'Cutout' # NOTE I've implement this as random erasing separately
]
而且我们为 rand_augment_ops
返回的每个 transforms
创建了一个 AugmentOp
实例列表,就像这样 [AugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms]
。
实际上,self.aug_fn
首先使用 NAME_TO_OP
字典将名称转换为相应的操作函数。
timm
中非常常见的一种模式。在很多地方,我们将一个字符串 str
作为函数参数传入,该字符串在函数内部被处理并用于执行某些操作。这个 NAME_TO_OP
不过是一个字典,它将每个 _RAND_TRANSFORMS
名称链接到它们在 timm
中的相应函数实现。
NAME_TO_OP = {
'AutoContrast': auto_contrast,
'Equalize': equalize,
'Invert': invert,
'Rotate': rotate,
'Posterize': posterize,
'PosterizeIncreasing': posterize,
'PosterizeOriginal': posterize,
'Solarize': solarize,
'SolarizeIncreasing': solarize,
'SolarizeAdd': solarize_add,
'Color': color,
'ColorIncreasing': color,
'Contrast': contrast,
'ContrastIncreasing': contrast,
'Brightness': brightness,
'BrightnessIncreasing': brightness,
'Sharpness': sharpness,
'SharpnessIncreasing': sharpness,
'ShearX': shear_x,
'ShearY': shear_y,
'TranslateX': translate_x_abs,
'TranslateY': translate_y_abs,
'TranslateXRel': translate_x_rel,
'TranslateYRel': translate_y_rel,
}
总之,这个 AugmentOp
不过是 self.aug_fn
的一个包装器,它接受一个图像 img
,并且只有 50% 的时间会在 img
上执行 self.aug_fn
。否则,它只会返回未更改的 img
。
很好,所以 rand_augment_transform
函数中的 ra_ops
变量不过是 AugmentOp
类实例的列表,这只是意味着我们将给定的数据增强函数以 50% 的概率应用于图像。
最后,正如我们在 rand_augment_transform
的源代码中看到的那样,返回的实际上是 RandAugment
类的一个实例,该实例接受 ra_ops
、choice_weights
和 num_layers
作为参数。所以接下来我们来看一下它。
此类的完整源代码如下所示
class RandAugment:
def __init__(self, ops, num_layers=2, choice_weights=None):
self.ops = ops
self.num_layers = num_layers
self.choice_weights = choice_weights
def __call__(self, img):
# no replacement when using weighted choice
ops = np.random.choice(
self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights)
for op in ops:
img = op(img)
return img
如前所述,传递给 RandAugment 的 ra_ops
不过是一个 AugmentOp
实例列表,这些实例包装了 _RAND_TRANSFORMS
中的各种转换,所以这个 ops
看起来像这样
ops = [<timm.data.auto_augment.AugmentOp object at 0x7f7a03466990>, <timm.data.auto_augment.AugmentOp object at 0x7f7a03466c50>, <timm.data.auto_augment.AugmentOp object at 0x7f7a03466650>, <timm.data.auto_augment.AugmentOp object at 0x7f7a034666d0>, <timm.data.auto_augment.AugmentOp object at 0x7f7a03466e10>, <timm.data.auto_augment.AugmentOp object at 0x7f7a03466490>, <timm.data.auto_augment.AugmentOp object at 0x7f7a03466750>, <timm.data.auto_augment.AugmentOp object at 0x7f7a034667d0>, <timm.data.auto_augment.AugmentOp object at 0x7f7a03466410>, <timm.data.auto_augment.AugmentOp object at 0x7f7a03466710>, <timm.data.auto_augment.AugmentOp object at 0x7f7a03466190>, <timm.data.auto_augment.AugmentOp object at 0x7f7a03466450>, <timm.data.auto_augment.AugmentOp object at 0x7f7a034664d0>, <timm.data.auto_augment.AugmentOp object at 0x7f7a03466150>, <timm.data.auto_augment.AugmentOp object at 0x7f7a034661d0>]
可以看出,ops
不过是 AugmentOp
实例的列表。基本上,每个转换都被这个 AugmentOp
类包装起来,这意味着该 transform
只会被应用 50% 的时间。
接下来,对于每张图像 img
,我们选择 num_layers
个随机增强操作,并将其应用于图像,就像此类中的 __call__
方法所示。
ops = np.random.choice(
self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights)
for op in ops:
img = op(img)
最后,我们返回这张增强后的图像。