在本教程中,我们将介绍 timm 的训练脚本。timm 提供了多种功能,其中一些列在下方:

  1. 自动数据增强 (Auto Augmentation) 论文
  2. Augmix
  3. 在多 GPU 上进行分布式训练
  4. 混合精度训练
  5. 用于 AdvProp 的辅助批归一化 (Auxiliary Batch Norm) 论文
  6. 同步批归一化 (Synchronized Batch Norm)
  7. Mixup 和 Cutmix,能够在这两者之间切换,并且能够在某个 epoch 关闭数据增强

timm 还支持多种优化器和调度器。在本教程中,我们将只关注上述 7 个功能,并介绍如何利用 timm 在您自己的自定义数据集上使用这些功能进行实验。

作为本教程的一部分,我们将首先对训练脚本进行总体介绍,并从高层次了解该脚本内部发生的各个关键步骤。然后,我们将深入了解上述 7 个功能的细节,以进一步理解 train.py

训练 args

timm 中的训练脚本可以接受约 100 个参数。您可以通过运行 python train.py --help 了解更多信息。这些参数用于定义数据集/模型参数、优化器参数、学习率调度器参数、数据增强和正则化、批归一化参数、模型指数移动平均参数以及一些杂项参数,如 --seed--tta 等。

作为本教程的一部分,我们将从高层次介绍训练脚本如何使用这些参数。这对于您使用 timm 在 ImageNet 或任何其他自定义数据集上运行自己的实验可能有所帮助。

必需 args

timm 训练脚本唯一必需的参数是训练数据(如 ImageNet)的路径,其结构如下:

imagenette2-320
├── train
│   ├── n01440764
│   ├── n02102040
│   ├── n02979186
│   ├── n03000684
│   ├── n03028079
│   ├── n03394916
│   ├── n03417042
│   ├── n03425413
│   ├── n03445777
│   └── n03888257
└── val
    ├── n01440764
    ├── n02102040
    ├── n02979186
    ├── n03000684
    ├── n03028079
    ├── n03394916
    ├── n03417042
    ├── n03425413
    ├── n03445777
    └── n03888257

因此,要开始在这个 imagenette2-320 数据集上训练,我们可以简单地执行类似 python train.py 的命令。

默认 args

训练脚本中的各种默认 args 已为您设置好,传递给训练脚本的参数大致如下所示:

Namespace(aa=None, amp=False, apex_amp=False, aug_splits=0, batch_size=32, bn_eps=None, bn_momentum=None, bn_tf=False, channels_last=False, clip_grad=None, color_jitter=0.4, cooldown_epochs=10, crop_pct=None, cutmix=0.0, cutmix_minmax=None, data_dir='../imagenette2-320', dataset='', decay_epochs=30, decay_rate=0.1, dist_bn='', drop=0.0, drop_block=None, drop_connect=None, drop_path=None, epochs=200, eval_metric='top1', gp=None, hflip=0.5, img_size=None, initial_checkpoint='', input_size=None, interpolation='', jsd=False, local_rank=0, log_interval=50, lr=0.01, lr_cycle_limit=1, lr_cycle_mul=1.0, lr_noise=None, lr_noise_pct=0.67, lr_noise_std=1.0, mean=None, min_lr=1e-05, mixup=0.0, mixup_mode='batch', mixup_off_epoch=0, mixup_prob=1.0, mixup_switch_prob=0.5, model='resnet101', model_ema=False, model_ema_decay=0.9998, model_ema_force_cpu=False, momentum=0.9, native_amp=False, no_aug=False, no_prefetcher=False, no_resume_opt=False, num_classes=None, opt='sgd', opt_betas=None, opt_eps=None, output='', patience_epochs=10, pin_mem=False, pretrained=False, ratio=[0.75, 1.3333333333333333], recount=1, recovery_interval=0, remode='const', reprob=0.0, resplit=False, resume='', save_images=False, scale=[0.08, 1.0], sched='step', seed=42, smoothing=0.1, split_bn=False, start_epoch=None, std=None, sync_bn=False, torchscript=False, train_interpolation='random', train_split='train', tta=0, use_multi_epochs_loader=False, val_split='validation', validation_batch_size_multiplier=1, vflip=0.0, warmup_epochs=3, warmup_lr=0.0001, weight_decay=0.0001, workers=4)

请注意,args 是一个 Namespace,这意味着如果需要,我们可以通过类似 args.new_variable="some_value" 的方式在运行时设置更多参数。

要获取这些各种参数的一行简介,我们可以简单地执行类似 python train.py --help 的命令。

训练脚本的 20 个步骤

在本节中,我们将从高层次了解训练脚本内部发生的各个步骤。这些步骤已按正确顺序概述如下:

  1. 如果 args.distributedTrue,则设置分布式训练参数
  2. 设置手动种子以获得可重现的结果。
  3. 创建模型:使用 timm.create_model 函数创建要训练的模型。
  4. 根据模型的默认配置设置数据配置。一般来说,模型的默认配置看起来像这样:
    {'url': '', 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bicubic', 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 'first_conv': 'conv1', 'classifier': 'fc'}
    
  5. 设置数据增强批次分割,如果数据增强批次分割的数量大于 1,则将所有模型的 BatchNorm 层转换为 Split Batch Normalization 层。
  6. 如果我们在多 GPU 上进行训练,则设置 apex syncBN 或 PyTorch 原生 SyncBatchNorm 以设置同步批归一化。这意味着我们不对每个单独的 GPU 上的数据进行归一化,而是在多个 GPU 上对整个批次进行统一归一化。

  7. 如果请求,使用 torch.jit 使模型可导出。

  8. 根据传递给训练脚本的参数初始化优化器
  9. 设置混合精度——使用 apex.amp 或使用原生 torch amp - torch.cuda.amp.autocast
  10. 如果从模型检查点恢复,则加载模型权重。
  11. 设置模型权重的指数移动平均。这类似于随机权重平均 (Stochastic Weight Averaging)
  12. 根据步骤 1 的参数设置分布式训练
  13. 设置学习率调度器
  14. 创建训练和验证数据集
  15. 设置Mixup/Cutmix 数据增强。
  16. 如果步骤 5 的数据增强批次分割数量大于 1,则将训练数据集转换为AugmixDataset
  17. 创建训练数据加载器和验证数据加载器18. 设置损失函数
  18. 设置模型检查点和评估指标
  19. 训练并验证模型,并将评估指标存储到输出文件中。

一些关键的 timm 特性

自动数据增强 (Auto-Augment)

要在训练期间启用自动数据增强:

python train.py ./imagenette2-320 --aa 'v0'

Augmix

此处提供了关于 augmix 的简要介绍。要在训练期间启用 augmix,只需这样做:

python train.py ./imagenette2-320 --aug-splits 3 --jsd

timm 还支持与 RandAugmentAutoAugment 结合使用的 augmix,如下所示:

python train.py ./imagenette2-320 --aug-splits 3 --jsd --aa rand-m9-mstd0.5-inc1

在多 GPU 上进行分布式训练

要在多 GPU 上训练模型,只需将 python train.py 替换为 ./distributed_train.sh ,如下所示:

./distributed_train.sh 4 ./imagenette2-320 --aug-splits 3 --jsd

这将使用 AugMix 数据增强在 4 个 GPU 上训练模型。

混合精度训练

要启用混合精度训练,只需添加 --amp 标志。timm 将自动使用 apex 或 PyTorch 原生混合精度训练实现混合精度训练。

python train.py ../imagenette2-320 --aug-splits 3 --jsd --amp

辅助批归一化 / SplitBatchNorm

来自论文:

Batch normalization serves as an essential component for many state-of-the-art computer vision models. Specifically, BN normalizes input features by the mean and variance computed within each mini-batch. **One intrinsic assumption of utilizing BN is that the input features should come from a single or similar distributions.** This normalization behavior could be problematic if the mini-batch contains data from different distributions, there- fore resulting in inaccurate statistics estimation.

To disentangle this mixture distribution into two simpler ones respectively for the clean and adversarial images, we hereby propose an auxiliary BN to guarantee its normalization statistics are exclusively preformed on the adversarial examples.

要启用 split batch norm,

python train.py ./imagenette2-320 --aug-splits 3 --aa rand-m9-mstd0.5-inc1 --split-bn

使用上述命令,timm 现在为每个数据增强分割拥有独立的批归一化层。

同步批归一化 (Synchronized Batch Norm)

同步批归一化仅在多 GPU 训练时使用。来自 papers with code

Synchronized Batch Normalization (SyncBN) is a type of batch normalization used for multi-GPU training. Standard batch normalization only normalizes the data within each device (GPU). SyncBN normalizes the input within the whole mini-batch.

要启用,只需添加 --sync-bn 标志,如下所示:

./distributed_train.sh 4 ../imagenette2-320 --aug-splits 3 --jsd --sync-bn

Mixup 和 Cutmix

要启用 mixup 或 cutmix,只需添加 --mixup--cutmix 标志并指定 alpha 值。
应用数据增强的默认概率是 1.0。如果您需要更改,请使用 --mixup-prob 参数并指定新值。

例如,要启用 mixup,

train.py ../imagenette2-320 --mixup 0.5
train.py ../imagenette2-320 --mixup 0.5 --mixup-prob 0.7

或启用 Cutmix,

train.py ../imagenette2-320 --cutmix 0.5
train.py ../imagenette2-320 --cutmix 0.5 --mixup-prob 0.7

也可以同时启用两者,

python train.py ../imagenette2-320 --mixup 0.5 --cutmix 0.5 --mixup-switch-prob 0.3

上述命令将使用 Mixup 或 Cutmix 作为数据增强技术,并以 50% 的概率应用于批次。它还将以 30% 的概率在这两者之间切换(Mixup - 70%,切换到 Cutmix 30%)。

还有一个参数可以在某个 epoch 关闭 Mixup/Cutmix 数据增强

python train.py ../imagenette2-320 --mixup 0.5 --cutmix 0.5 --mixup-switch-prob 0.3 --mixup-off-epoch 10

上述命令只在前 10 个 epoch 应用 Mixup/Cutmix 数据增强。