正如你可能从标题中猜到的那样,在本教程中,我们将深入探讨 timm 中的 create_model 函数,并查看可以传递给此函数的所有 **kwargs

create_model 函数有什么作用?

timm 中,create_model 函数负责创建超过 300 种深度学习模型的架构!要创建模型,只需将 model_name 传递给 create_model

import timm 
# creates resnet-34 architecture
model = timm.create_model('resnet34')
# creates efficientnet-b0 architecture
model = timm.create_model('efficientnet_b0')
# creates densenet architecture
model = timm.create_model('densenet121')

依此类推……可以使用 timm.list_models() 函数找到可用模型的完整列表。

创建预训练模型

要创建预训练模型,只需将 pretrained=True 关键字参数以及模型名称传递给 timm.create_model 函数。

import timm 
# creates pretrained resnet-34 architecture
model = timm.create_model('resnet34', pretrained=True)
# creates pretrained efficientnet-b0 architecture
model = timm.create_model('efficientnet_b0', pretrained=True)
# creates pretrained densenet architecture
model = timm.create_model('densenet121', pretrained=True)

要获取 timm 中可用预训练模型的完整列表,请将 pretrained=True 传递给 timm.list_models() 函数。

all_pretrained_models_available = timm.list_models(pretrained=True)

将任何模型转换为特征提取器

所有模型都支持为 create_model 调用传递 features_only=True 参数,以返回一个网络,该网络从每个步幅的最深层提取特征图。还可以使用 out_indices=[...] 参数指定要从中提取特征的层的索引。

import timm 
import torch 

# input batch with batch size of 1 and 3-channel image of size 224x224
x = torch.randn(1,3,224,224)
model = timm.create_model('resnet34')
model(x).shape
torch.Size([1, 1000])
feature_extractor = timm.create_model('resnet34', features_only=True, out_indices=[2,3,4])
out = feature_extractor(x)

如果我告诉你 out 是一个 Tensor 列表,你能猜出它的长度吗?

Resnet34

我们知道 resnet-34 架构如上所示。如果开头的 7x7 卷积层被视为第 0 层,你能猜出从第 1 层、第 2 层、第 3 层和第 4 层(每层用不同颜色表示)输出的特征的形状吗?

import torch.nn as nn
import torch 

# input batch
x = torch.randn(1, 3, 224, 224)

pool  = nn.MaxPool2d(3, 2, 1, 1)
conv1 = nn.Conv2d(3, 64, 7, stride=2, padding=3)
conv2 = nn.Conv2d(64, 64, 3, 1, 1)
conv3 = nn.Conv2d(64, 128, 3, 2, 1)

# feature map from Layer-0
conv1(x).shape
# feature map from Layer-1
conv2(pool(conv1(x))).shape
# and so on..

正如你现在可能猜到的那样,从第 2 层、第 3 层和第 4 层输出的特征图形状应分别为 [1, 128, 28, 28][[1, 256, 14, 14], [1, 512, 7, 7]]

让我们看看结果是否符合我们的预期。

[x.shape for x in out]
[torch.Size([1, 128, 28, 28]),
 torch.Size([1, 256, 14, 14]),
 torch.Size([1, 512, 7, 7])]

特征图的输出形状符合我们的预期。通过这种方式,我们可以将 timm 中的任何模型转换为特征提取器。