`timm` 是由 Ross Wightman 创建的深度学习库,它包含了 SOTA 计算机视觉模型、层、工具、优化器、学习率调度器、数据加载器、数据增强以及训练/验证脚本,并能够重现 ImageNet 训练结果。
pip install timm
或者进行可编辑安装,
git clone https://github.com/rwightman/pytorch-image-models
cd pytorch-image-models && pip install -e .
import timm
import torch
model = timm.create_model('resnet34')
x = torch.randn(1, 3, 224, 224)
model(x).shape
使用 timm
创建模型就是如此简单。create_model
函数是一个工厂方法,可用于创建 timm
库中包含的 300 多个模型。
要创建预训练模型,只需传入 pretrained=True
。
pretrained_resnet_34 = timm.create_model('resnet34', pretrained=True)
要创建具有自定义类别数的模型,只需传入 num_classes=<number_of_classes>
。
import timm
import torch
model = timm.create_model('resnet34', num_classes=10)
x = torch.randn(1, 3, 224, 224)
model(x).shape
timm.list_models()
返回 timm
中可用模型的完整列表。要查看预训练模型的完整列表,请在 list_models
中传入 pretrained=True
。
avail_pretrained_models = timm.list_models(pretrained=True)
len(avail_pretrained_models), avail_pretrained_models[:5]
目前 timm
中总共有 271 个带有预训练权重的模型!
也可以如下使用通配符搜索模型架构
all_densenet_models = timm.list_models('*densenet*')
all_densenet_models
fastai 库支持微调来自 timm 的模型
from fastai.vision.all import *
path = untar_data(URLs.PETS)/'images'
dls = ImageDataLoaders.from_name_func(
path, get_image_files(path), valid_pct=0.2,
label_func=lambda x: x[0].isupper(), item_tfms=Resize(224))
# if a string is passed into the model argument, it will now use timm (if it is installed)
learn = vision_learner(dls, 'vit_tiny_patch16_224', metrics=error_rate)
learn.fine_tune(1)