摘自论文摘要

In training, Random Erasing randomly selects a rectangle region in an image and erases its pixels with random values. In this process, training images with various levels of occlusion are generated, which reduces the risk of over-fitting and makes the model robust to occlusion. Random Erasing is parameter learning free, easy to implement, and can be integrated with most of the CNN-based recognition models.

Random Erase

从上图可以看出,这种 RandomErase 数据增强随机选择输入图像中的一个区域,擦除该区域中的现有图像,并用随机值填充该区域。

使用 timm 训练脚本通过 RandomErase 训练模型

要使用 timm 的训练脚本并使用 RandomErase 数据增强来训练模型,只需添加带有概率值的 --reprob 标志。

python train.py ../imagenette2-320 --reprob 0.4

运行上述命令将以 0.4 的概率对输入图像应用 RandomErase 数据增强。

在自定义训练脚本中使用 RandomErase 数据增强

1.1 节提供了使用 timm 训练脚本通过 RandomErase 数据增强训练神经网络的示例。但您通常可能只想在自己的自定义训练循环中使用 RandomErase 数据增强。本节解释了如何实现这一点。

timm 中的 RandomErase 数据增强是在 RandomErasing 类中实现的。下面的代码所做的只是首先创建一个输入图像张量并将其可视化。

from PIL import Image
from timm.data.random_erasing import RandomErasing
from torchvision import transforms
from matplotlib import pyplot as plt

img = Image.open("../../imagenette2-320/train/n01440764/ILSVRC2012_val_00000293.JPEG")
x   = transforms.ToTensor()(img)
plt.imshow(x.permute(1, 2, 0))
<matplotlib.image.AxesImage at 0x7f9ec6879b80>

很好,正如我们所见,它是文档中几乎所有地方都展示的“丁鳜”图像。现在让我们应用 RandomErasing 增强并可视化结果。

random_erase = RandomErasing(probability=1, mode='pixel', device='cpu')
plt.imshow(random_erase(x).permute(1, 2, 0))
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<matplotlib.image.AxesImage at 0x7f9ec649e580>

正如我们所见,应用 RandomErasing 数据增强后,图像内部一个随机大小的正方形已被替换为随机值,如论文中所述。因此,在您的自定义训练脚本中使用 RandomErasing 的伪代码如下所示:

from timm.data.random_erasing import RandomErasing

# get input images and convert to `torch.tensor`
X, y = input_training_batch()
X = convert_to_torch_tensor(X)

# perform RandomErase data augmentation
random_erase = RandomErasing(probability=0.5)

# get augmented batch
X_aug = random_erase(X)

# do something here

timmRandomErase 的实现

在本节中,我们将查看 timmRandomErasing 类的源代码。该类的完整源代码如下所示:

class RandomErasing:
    """ Randomly selects a rectangle region in an image and erases its pixels.
        'Random Erasing Data Augmentation' by Zhong et al.
        See https://arxiv.org/pdf/1708.04896.pdf

        This variant of RandomErasing is intended to be applied to either a batch
        or single image tensor after it has been normalized by dataset mean and std.
    Args:
         probability: Probability that the Random Erasing operation will be performed.
         min_area: Minimum percentage of erased area wrt input image area.
         max_area: Maximum percentage of erased area wrt input image area.
         min_aspect: Minimum aspect ratio of erased area.
         mode: pixel color mode, one of 'const', 'rand', or 'pixel'
            'const' - erase block is constant color of 0 for all channels
            'rand'  - erase block is same per-channel random (normal) color
            'pixel' - erase block is per-pixel random (normal) color
        max_count: maximum number of erasing blocks per image, area per box is scaled by count.
            per-image count is randomly chosen between 1 and this value.
    """

    def __init__(
            self,
            probability=0.5, min_area=0.02, max_area=1/3, min_aspect=0.3, max_aspect=None,
            mode='const', min_count=1, max_count=None, num_splits=0, device='cuda'):
        self.probability = probability
        self.min_area = min_area
        self.max_area = max_area
        max_aspect = max_aspect or 1 / min_aspect
        self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
        self.min_count = min_count
        self.max_count = max_count or min_count
        self.num_splits = num_splits
        mode = mode.lower()
        self.rand_color = False
        self.per_pixel = False
        if mode == 'rand':
            self.rand_color = True  # per block random normal
        elif mode == 'pixel':
            self.per_pixel = True  # per pixel random normal
        else:
            assert not mode or mode == 'const'
        self.device = device

    def _erase(self, img, chan, img_h, img_w, dtype):
        if random.random() > self.probability:
            return
        area = img_h * img_w
        count = self.min_count if self.min_count == self.max_count else \
            random.randint(self.min_count, self.max_count)
        for _ in range(count):
            for attempt in range(10):
                target_area = random.uniform(self.min_area, self.max_area) * area / count
                aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
                h = int(round(math.sqrt(target_area * aspect_ratio)))
                w = int(round(math.sqrt(target_area / aspect_ratio)))
                if w < img_w and h < img_h:
                    top = random.randint(0, img_h - h)
                    left = random.randint(0, img_w - w)
                    img[:, top:top + h, left:left + w] = _get_pixels(
                        self.per_pixel, self.rand_color, (chan, h, w),
                        dtype=dtype, device=self.device)
                    break

    def __call__(self, input):
        if len(input.size()) == 3:
            self._erase(input, *input.size(), input.dtype)
        else:
            batch_size, chan, img_h, img_w = input.size()
            # skip first slice of batch if num_splits is set (for clean portion of samples)
            batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0
            for i in range(batch_start, batch_size):
                self._erase(input[i], chan, img_h, img_w, input.dtype)
        return input

所有有趣的部分都在 _erase 方法内部,我们接下来会深入研究。但简单来说,上面的代码所做的是,我们调用这个类时,要么传入一个大小为 3 的张量 CHW,要么传入一个大小为 4 的输入批次 NCHW。如果是输入批次,并且批次不像 Augmix 那样被分割,那么我们将 RandomErase 数据增强应用于整个批次,否则我们保留第一个分割不变,作为干净的分割。数据集的这种分割方式已在这里这里进行了解释。

现在让我们详细看看 _erase 方法,了解其中的所有“魔力”。

def _erase(self, img, chan, img_h, img_w, dtype):
        if random.random() > self.probability:
            return
        area = img_h * img_w
        count = self.min_count if self.min_count == self.max_count else \
            random.randint(self.min_count, self.max_count)
        for _ in range(count):
            for attempt in range(10):
                target_area = random.uniform(self.min_area, self.max_area) * area / count
                aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
                h = int(round(math.sqrt(target_area * aspect_ratio)))
                w = int(round(math.sqrt(target_area / aspect_ratio)))
                if w < img_w and h < img_h:
                    top = random.randint(0, img_h - h)
                    left = random.randint(0, img_w - w)
                    img[:, top:top + h, left:left + w] = _get_pixels(
                        self.per_pixel, self.rand_color, (chan, h, w),
                        dtype=dtype, device=self.device)
                    break

上面的 _erase 方法接受输入 img (torch.tensor)、表示图像通道数的 chan,以及表示图像高度和宽度的 img_himg_w

我们根据 self.min_countself.max_count 选择 count 的值。self.min_count 已设置为随机擦除块的最小数量,self.max_count 指的是随机擦除块的最大数量。大多数情况下,两者都默认为 1,也就是说,我们只向输入 img 添加一个随机擦除块。

接下来,我们随机选择随机擦除块的 target_areaaspect_ratio,并基于这些选择随机擦除块的 h 高度和 w 宽度值。

最后,我们替换图像中位置 img[:, top:top + h, left:left + w] 内的像素,其中 top 表示 y 轴上的随机整数值,left 表示 x 轴上的随机整数值。_get_pixelstimm 中实现的一个函数,它根据 timm 中的 Random Erase 模式返回要填充到随机擦除块内的随机值。

如果 mode=='pixel',则 _get_pixels 返回一个正态分布,否则填充一个常数值 0