本教程介绍了 timm 中可用的各种优化器。我们将了解如何使用 timm 训练脚本使用它们,以及如何将它们作为独立优化器用于自定义 PyTorch 训练脚本。

timm 中可用的各种优化器如下:

  1. SGD
  2. Adam
  3. AdamW
  4. Nadam
  5. Radam
  6. AdamP
  7. SGDP
  8. Adadelta
  9. Adafactor
  10. ADAHESSIAN
  11. RMSprop
  12. NovoGrad

以及来自 apex 的一些其他优化器,例如:

  1. FusedSGD
  2. FusedAdam
  3. FusedLAMB
  4. FusedNovoGrad

它们是仅支持 GPU 的。

timm 还支持 lookahead 优化器。

按照 timm 的惯例,创建优化器的最佳方法是使用 create_optimizer 工厂方法。在本教程中,我们将首先了解如何使用 timm 训练脚本以及如何将这些优化器作为独立优化器用于自定义训练脚本来训练这些模型。

使用 timm 训练脚本

要使用任何优化器进行训练,只需在训练脚本中通过 --opt 参数传递优化器名称即可。

python train.py ../imagenette-320/ --opt adam

由于可以将 Lookahead 技术添加到任何优化器中,因此我们可以在 timm 中使用 Lookahead 训练模型,只需在优化器名称前加上 lookahead_ 前缀即可。例如,对于 adam,训练脚本看起来像

python train.py ../imagenette-320/ --opt lookahead_adam

就这样。通过这种方式,我们可以使用 timm 中所有可用的优化器在 ImageNetImagenette 上训练模型。

作为自定义训练脚本的独立 optimizers

很多时候,我们可能只希望将 timm 中的优化器用于自己的训练脚本。使用 timm 创建优化器的最佳方法是使用 create_optimizer 工厂方法。

create_optimizer 的参数如下所示:

def create_optimizer(args, model, filter_bias_and_bn=True) -> Union[Optimizer, Lookahead]:
    """
    Here, `args` are the arguments parsed by `ArgumentParser` in `timm` training script. 
    If we want to create an optimizer using this function, we should make sure that `args` has the 
    following attributes set: 

    args: Arguments from `ArgumentParser`:
    - `opt`: Optimizer name
    - `weight_decay`: Weight decay if any 
    - `lr`: Learning rate 
    - `momentum`: Decay rate for momentum if passed and not 0 

    model: Model that we want to train 
    """

下面我们来看看如何模拟 args

from types import SimpleNamespace
from timm.optim.optim_factory import create_optimizer
from timm import create_model 

model = create_model('resnet34')

args = SimpleNamespace()
args.weight_decay = 0
args.lr = 1e-4
args.opt = 'adam' #'lookahead_adam' to use `lookahead`
args.momentum = 0.9

optimizer = create_optimizer(args, model)
optimizer
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.0001
    weight_decay: 0
)

使用 timm 优化器训练神经网络 (NN)

在本节中,我们将尝试实验一些可用的优化器,并在我们自己的自定义训练脚本中使用它们。

我们将存储每个优化器的损失,并在最后可视化损失曲线,以比较使用 timm 创建的 resnet-34 模型在 Imagenette 数据集上的性能。

import torch
import torch.optim as optim 
import timm
from timm.data import create_dataset, create_loader
import numpy as np 
from matplotlib import pyplot as plt
import torchvision
import torch.nn as nn 
from tqdm import tqdm
import logging 
from timm.optim import optim_factory
from types import SimpleNamespace
logging.getLogger().setLevel(logging.INFO)
DATA_DIR = '../imagenette2-320/'

数据目录

数据目录的结构如下所示:

imagenette2-320
├── train
│   ├── n01440764
│   ├── n02102040
│   ├── n02979186
│   ├── n03000684
│   ├── n03028079
│   ├── n03394916
│   ├── n03417042
│   ├── n03425413
│   ├── n03445777
│   └── n03888257
└── val
    ├── n01440764
    ├── n02102040
    ├── n02979186
    ├── n03000684
    ├── n03028079
    ├── n03394916
    ├── n03417042
    ├── n03425413
    ├── n03445777
    └── n03888257

训练集和验证集

现在让我们使用 timm 创建训练集和验证集以及数据加载器。有关数据集的更多文档,请参阅此处

train_dataset = create_dataset("train", DATA_DIR, "train")
train_loader  = create_loader(train_dataset, input_size=(3, 320, 320), batch_size=8, use_prefetcher=False, 
                              is_training=True, no_aug=True)
len(train_dataset)
9469
val_dataset = create_dataset("val", DATA_DIR, "val")
val_loader  = create_loader(val_dataset, input_size=(3, 320, 320), batch_size=64, use_prefetcher=False)
len(val_dataset)
3925

这些是 Imagenette 中的类名。我们在下面列出它们以便于可视化:

class_names = ['tench', 'English springer', 'cassette player', 'chain saw', 'church', 
               'French horn', 'garbage truck', 'gas pump', 'golf ball', 'parachute']

现在让我们可视化数据集中包含的一些图像和类别。

def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001) 

    
inputs, classes = next(iter(train_loader))[:8]
out = torchvision.utils.make_grid(inputs, nrow=4)
imshow(out, title=[class_names[x.item()] for x in classes])

直接从 train_loader 中可视化图像是一个很好的做法,可以检查是否存在任何错误。

训练一个周期 (epoch)

在本节中,我们将创建自定义训练循环。

loss_fn = nn.CrossEntropyLoss()
model   = timm.create_model('resnet34', pretrained=False, num_classes=10)
model(inputs).shape
torch.Size([8, 10])

下面的 AverageMeter 类用于平均损失,以便于可视化。如果我们不计算移动平均,则损失曲线会非常崎岖不平,难以可视化。

class AverageMeter:
    """
    Computes and stores the average and current value
    """

    def __init__(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

以下函数定义了我们的自定义训练循环。本质上,我们从 train_loader 中获取 inputstargets。通过将 inputs 传递给模型来获得预测。计算损失函数,使用 PyTorch 进行反向传播以计算梯度。最后,我们使用优化器迈出一步来更新参数并清零梯度。

此外,请注意,我们将每个 mini batch 的损失移动平均值存储在名为 losses 的列表中,通过 losses.append(loss_avg.avg) 实现。最后,我们返回一个字典,其中包含优化器名称和 losses 列表。

def train_one_epoch(args, loader, model, loss_fn = nn.CrossEntropyLoss(), **optim_kwargs):
    model   = timm.create_model('resnet34', pretrained=False, num_classes=10)
    logging.info(f"\ncreated model: {model.__class__.__name__}")
    
    optimizer = optim_factory.create_optimizer(args, model, **optim_kwargs)
    logging.info(f"created optimizer: {optimizer.__class__.__name__}")
    
    losses = []
    loss_avg = AverageMeter()
    model = model.cuda()
    tk0 = tqdm(enumerate(loader), total=len(loader))
    for i, (inputs, targets) in tk0:
        inputs = inputs.cuda()
        targets = targets.cuda()
        preds = model(inputs)
        loss = loss_fn(preds, targets)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        loss_avg.update(loss.item(), loader.batch_size)
        losses.append(loss_avg.avg)
        tk0.set_postfix(loss=loss.item())
    
    return {args.opt: losses}

请注意,此 train_one_epoch 函数接受 args。这些是我们之前看过的模拟 args。此 args 参数被传递给 optim_factory.create_optimizer 以创建优化器。

losses_dict = {}
args = SimpleNamespace()
args.weight_decay = 0 
args.lr = 1e-4
args.momentum = 0.9

现在让我们传入各种优化器。我们创建的训练循环应该负责使用 create_optimizer 函数实例化 Optimizer

我们将学习率设置为 1e-4,权重衰减和动量都设置为 0。

我们还传入 lookahead_adam 以演示如何在 timm 中使用 Lookahead 类进行训练。

for opt in ['SGD', 'Adam', 'AdamW', 'Nadam', 'Radam', 'AdamP', 'Lookahead_Adam']:
    args.opt = opt
    loss_dict = train_one_epoch(args, train_loader, model)
    losses_dict.update(loss_dict)
INFO:root:
created model: ResNet
INFO:root:created optimizer: SGD
100%|██████████| 147/147 [00:30<00:00,  4.82it/s, loss=2.19]
INFO:root:
created model: ResNet
INFO:root:created optimizer: Adam
100%|██████████| 147/147 [00:30<00:00,  4.79it/s, loss=1.68]
INFO:root:
created model: ResNet
INFO:root:created optimizer: AdamW
100%|██████████| 147/147 [00:30<00:00,  4.77it/s, loss=1.77]
INFO:root:
created model: ResNet
INFO:root:created optimizer: Nadam
  0%|          | 0/147 [00:00<?, ?it/s]/home/aman_arora/git/experiments/pytorch-image-models/timm/optim/nadam.py:80: UserWarning: This overload of add_ is deprecated:
	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  /pytorch/torch/csrc/utils/python_arg_parser.cpp:882.)
  exp_avg.mul_(beta1).add_(1. - beta1, grad)
100%|██████████| 147/147 [00:30<00:00,  4.78it/s, loss=1.74]
INFO:root:
created model: ResNet
INFO:root:created optimizer: RAdam
100%|██████████| 147/147 [00:30<00:00,  4.81it/s, loss=2.1] 
INFO:root:
created model: ResNet
INFO:root:created optimizer: AdamP
100%|██████████| 147/147 [00:39<00:00,  3.77it/s, loss=1.72]
INFO:root:
created model: ResNet
INFO:root:created optimizer: Lookahead
100%|██████████| 147/147 [00:30<00:00,  4.77it/s, loss=1.82]

比较不同优化器的性能

最后,让我们可视化结果以比较性能。所有损失以及传入的优化器都存储在 losses_dict 中。

fig, ax = plt.subplots(figsize=(15,8))
for k, v in losses_dict.items():
    ax.plot(range(1, len(v) + 1), v, '.-', label=k)
    
ax.legend()  
ax.grid()

我们可以看到,在我们训练模型的一个周期中,AdamAdamPImagenette 上表现最佳。之后,请随意运行您自己的实验! :)