分批归一化(Split Batch Normalization)首次在《Split Batch Normalization: Improving Semi-Supervised Learning under Domain Shift》中提出。
论文摘要中提到
Recent work has shown that using unlabeled data in semisupervised learning is not always beneficial and can even hurt generalization, especially when there is a class mismatch between the unlabeled and labeled examples. We investigate this phenomenon for image classification on the CIFAR-10 and the ImageNet datasets, and with many other forms of domain shifts applied (e.g. salt-and-pepper noise). Our main contribution is Split Batch Normalization (Split-BN), a technique to improve SSL when the additional unlabeled data comes from a shifted distribution. We achieve it by using separate batch normalization statistics for unlabeled examples. Due to its simplicity, we recommend it as a standard practice. Finally, we analyse how domain shift affects the SSL training process. In particular, we find that during training the statistics of hidden activations in late layers become markedly different between the unlabeled and the labeled examples.
简单来说,他们提出为无监督和有监督数据集分别计算批归一化统计量。也就是说,对于整个批次,不是使用一个 BN 层,而是使用单独的 BN 层。
你可能会说,这说起来容易,但如何在代码中实现呢?
好吧,在 timm
训练中,你只需这样做
python train.py ../imagenette2-320 --aug-splits 3 --split-bn --aa rand-m9-mstd0.5-inc1 --resplit
就这样。但是这个命令是什么意思呢?
运行上述命令会-
- 创建 3 组训练批次
- 第一组被称为原始批次(具有最小或无数据增强)
- 第二组是在第一组基础上应用了随机数据增强。
- 第三组也是在第一组基础上应用了随机数据增强。注意:随机数据增强是随机性的。因此,第二组和第三组批次彼此不同。2. 将模型内部的每个批归一化层转换为分批归一化层(Split Batch Normalization Layer)。
- 不将随机擦除(Random Erase)应用于第一组批次,也称为第一个增强分割。
SplitBatchNorm2d
本身只有几行代码
class SplitBatchNorm2d(torch.nn.BatchNorm2d):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
track_running_stats=True, num_splits=2):
super().__init__(num_features, eps, momentum, affine, track_running_stats)
assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)'
self.num_splits = num_splits
self.aux_bn = nn.ModuleList([
nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)])
def forward(self, input: torch.Tensor):
if self.training: # aux BN only relevant while training
split_size = input.shape[0] // self.num_splits
assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits"
split_input = input.split(split_size)
x = [super().forward(split_input[0])]
for i, a in enumerate(self.aux_bn):
x.append(a(split_input[i + 1]))
return torch.cat(x, dim=0)
else:
return super().forward(input)
基本上,在《Adversarial Examples Improve Image Recognition》这篇论文中,作者将这种分批归一化称为辅助批归一化。因此,如代码所示,self.aux_bn
是一个长度为 num_splits-1
的列表。
基本上,由于我们继承了 torch.nn.BatchNorm2d
,因此这个 SplitBatchNorm2d
本身就是一个批归一化的实例,所以第一个批归一化层就是 nn.BatchNorm2d
本身,可用于归一化第一个增强分割或干净的批次。
然后,我们创建 num_splits-1
个辅助批归一化层,用于归一化输入批次中的剩余分割。
通过这种方式,我们根据分割的数量分别归一化输入批次 X
。这在以下几行代码中实现:
split_input = input.split(split_size)
x = [super().forward(split_input[0])]
for i, a in enumerate(self.aux_bn):
x.append(a(split_input[i + 1]))
return torch.cat(x, dim=0)
这就是 timm
如何在 PyTorch 中实现 SplitBatchNorm2d
的 :)。