`timm` 是由 Ross Wightman 创建的深度学习库,它包含了 SOTA 计算机视觉模型、层、工具、优化器、学习率调度器、数据加载器、数据增强以及训练/验证脚本,并能够重现 ImageNet 训练结果。

安装

pip install timm

或者进行可编辑安装,

git clone https://github.com/rwightman/pytorch-image-models
cd pytorch-image-models && pip install -e .

如何使用

创建模型

import timm 
import torch

model = timm.create_model('resnet34')
x     = torch.randn(1, 3, 224, 224)
model(x).shape
torch.Size([1, 1000])

使用 timm 创建模型就是如此简单。create_model 函数是一个工厂方法,可用于创建 timm 库中包含的 300 多个模型。

要创建预训练模型,只需传入 pretrained=True

pretrained_resnet_34 = timm.create_model('resnet34', pretrained=True)
Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth" to /home/tmabraham/.cache/torch/hub/checkpoints/resnet34-43635321.pth

要创建具有自定义类别数的模型,只需传入 num_classes=<number_of_classes>

import timm 
import torch

model = timm.create_model('resnet34', num_classes=10)
x     = torch.randn(1, 3, 224, 224)
model(x).shape
torch.Size([1, 10])

列出带有预训练权重的模型

timm.list_models() 返回 timm 中可用模型的完整列表。要查看预训练模型的完整列表,请在 list_models 中传入 pretrained=True

avail_pretrained_models = timm.list_models(pretrained=True)
len(avail_pretrained_models), avail_pretrained_models[:5]
(592,
 ['adv_inception_v3',
  'bat_resnext26ts',
  'beit_base_patch16_224',
  'beit_base_patch16_224_in22k',
  'beit_base_patch16_384'])

目前 timm 中总共有 271 个带有预训练权重的模型!

使用通配符搜索模型架构

也可以如下使用通配符搜索模型架构

all_densenet_models = timm.list_models('*densenet*')
all_densenet_models
['densenet121',
 'densenet121d',
 'densenet161',
 'densenet169',
 'densenet201',
 'densenet264',
 'densenet264d_iabn',
 'densenetblur121d',
 'tv_densenet121']

在 fastai 中微调 timm 模型

fastai 库支持微调来自 timm 的模型

from fastai.vision.all import *

path = untar_data(URLs.PETS)/'images'
dls = ImageDataLoaders.from_name_func(
    path, get_image_files(path), valid_pct=0.2,
    label_func=lambda x: x[0].isupper(), item_tfms=Resize(224))
    
# if a string is passed into the model argument, it will now use timm (if it is installed)
learn = vision_learner(dls, 'vit_tiny_patch16_224', metrics=error_rate)

learn.fine_tune(1)
epoch 训练损失 验证损失 错误率 时间
0 0.201583 0.024980 0.006766 00:08
epoch 训练损失 验证损失 错误率 时间
0 0.040622 0.024036 0.005413 00:10