摘自论文摘要
In training, Random Erasing randomly selects a rectangle region in an image and erases its pixels with random values. In this process, training images with various levels of occlusion are generated, which reduces the risk of over-fitting and makes the model robust to occlusion. Random Erasing is parameter learning free, easy to implement, and can be integrated with most of the CNN-based recognition models.

从上图可以看出,这种 RandomErase
数据增强随机选择输入图像中的一个区域,擦除该区域中的现有图像,并用随机值填充该区域。
要使用 timm
的训练脚本并使用 RandomErase
数据增强来训练模型,只需添加带有概率值的 --reprob
标志。
python train.py ../imagenette2-320 --reprob 0.4
运行上述命令将以 0.4
的概率对输入图像应用 RandomErase
数据增强。
第 1.1
节提供了使用 timm
训练脚本通过 RandomErase
数据增强训练神经网络的示例。但您通常可能只想在自己的自定义训练循环中使用 RandomErase
数据增强。本节解释了如何实现这一点。
timm
中的 RandomErase
数据增强是在 RandomErasing
类中实现的。下面的代码所做的只是首先创建一个输入图像张量并将其可视化。
RandAugment
不同,后者类期望 PIL.Image
作为输入。from PIL import Image
from timm.data.random_erasing import RandomErasing
from torchvision import transforms
from matplotlib import pyplot as plt
img = Image.open("../../imagenette2-320/train/n01440764/ILSVRC2012_val_00000293.JPEG")
x = transforms.ToTensor()(img)
plt.imshow(x.permute(1, 2, 0))
很好,正如我们所见,它是文档中几乎所有地方都展示的“丁鳜”图像。现在让我们应用 RandomErasing
增强并可视化结果。
random_erase = RandomErasing(probability=1, mode='pixel', device='cpu')
plt.imshow(random_erase(x).permute(1, 2, 0))
正如我们所见,应用 RandomErasing
数据增强后,图像内部一个随机大小的正方形已被替换为随机值,如论文中所述。因此,在您的自定义训练脚本中使用 RandomErasing
的伪代码如下所示:
from timm.data.random_erasing import RandomErasing
# get input images and convert to `torch.tensor`
X, y = input_training_batch()
X = convert_to_torch_tensor(X)
# perform RandomErase data augmentation
random_erase = RandomErasing(probability=0.5)
# get augmented batch
X_aug = random_erase(X)
# do something here
在本节中,我们将查看 timm
中 RandomErasing
类的源代码。该类的完整源代码如下所示:
class RandomErasing:
""" Randomly selects a rectangle region in an image and erases its pixels.
'Random Erasing Data Augmentation' by Zhong et al.
See https://arxiv.org/pdf/1708.04896.pdf
This variant of RandomErasing is intended to be applied to either a batch
or single image tensor after it has been normalized by dataset mean and std.
Args:
probability: Probability that the Random Erasing operation will be performed.
min_area: Minimum percentage of erased area wrt input image area.
max_area: Maximum percentage of erased area wrt input image area.
min_aspect: Minimum aspect ratio of erased area.
mode: pixel color mode, one of 'const', 'rand', or 'pixel'
'const' - erase block is constant color of 0 for all channels
'rand' - erase block is same per-channel random (normal) color
'pixel' - erase block is per-pixel random (normal) color
max_count: maximum number of erasing blocks per image, area per box is scaled by count.
per-image count is randomly chosen between 1 and this value.
"""
def __init__(
self,
probability=0.5, min_area=0.02, max_area=1/3, min_aspect=0.3, max_aspect=None,
mode='const', min_count=1, max_count=None, num_splits=0, device='cuda'):
self.probability = probability
self.min_area = min_area
self.max_area = max_area
max_aspect = max_aspect or 1 / min_aspect
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
self.min_count = min_count
self.max_count = max_count or min_count
self.num_splits = num_splits
mode = mode.lower()
self.rand_color = False
self.per_pixel = False
if mode == 'rand':
self.rand_color = True # per block random normal
elif mode == 'pixel':
self.per_pixel = True # per pixel random normal
else:
assert not mode or mode == 'const'
self.device = device
def _erase(self, img, chan, img_h, img_w, dtype):
if random.random() > self.probability:
return
area = img_h * img_w
count = self.min_count if self.min_count == self.max_count else \
random.randint(self.min_count, self.max_count)
for _ in range(count):
for attempt in range(10):
target_area = random.uniform(self.min_area, self.max_area) * area / count
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio)))
if w < img_w and h < img_h:
top = random.randint(0, img_h - h)
left = random.randint(0, img_w - w)
img[:, top:top + h, left:left + w] = _get_pixels(
self.per_pixel, self.rand_color, (chan, h, w),
dtype=dtype, device=self.device)
break
def __call__(self, input):
if len(input.size()) == 3:
self._erase(input, *input.size(), input.dtype)
else:
batch_size, chan, img_h, img_w = input.size()
# skip first slice of batch if num_splits is set (for clean portion of samples)
batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0
for i in range(batch_start, batch_size):
self._erase(input[i], chan, img_h, img_w, input.dtype)
return input
现在让我们详细看看 _erase
方法,了解其中的所有“魔力”。
def _erase(self, img, chan, img_h, img_w, dtype):
if random.random() > self.probability:
return
area = img_h * img_w
count = self.min_count if self.min_count == self.max_count else \
random.randint(self.min_count, self.max_count)
for _ in range(count):
for attempt in range(10):
target_area = random.uniform(self.min_area, self.max_area) * area / count
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio)))
if w < img_w and h < img_h:
top = random.randint(0, img_h - h)
left = random.randint(0, img_w - w)
img[:, top:top + h, left:left + w] = _get_pixels(
self.per_pixel, self.rand_color, (chan, h, w),
dtype=dtype, device=self.device)
break
上面的 _erase
方法接受输入 img
(torch.tensor)、表示图像通道数的 chan
,以及表示图像高度和宽度的 img_h
和 img_w
。
我们根据 self.min_count
和 self.max_count
选择 count
的值。self.min_count
已设置为随机擦除块的最小数量,self.max_count
指的是随机擦除块的最大数量。大多数情况下,两者都默认为 1,也就是说,我们只向输入 img
添加一个随机擦除块。
接下来,我们随机选择随机擦除块的 target_area
和 aspect_ratio
,并基于这些选择随机擦除块的 h
高度和 w
宽度值。
最后,我们替换图像中位置 img[:, top:top + h, left:left + w]
内的像素,其中 top
表示 y 轴上的随机整数值,left
表示 x 轴上的随机整数值。_get_pixels
是 timm
中实现的一个函数,它根据 timm
中的 Random Erase
模式返回要填充到随机擦除块内的随机值。
如果 mode=='pixel'
,则 _get_pixels
返回一个正态分布,否则填充一个常数值 0
。