在本教程中,我们将深入探讨 create_model
函数的源代码。我们还将了解如何将任何给定模型转换为特征提取器。我们已经在这里看到了一个示例。我们将 ResNet-34
架构转换为了一个特征提取器,用于从第 2、3 和 4 层提取特征。
在本教程中,我们将深入研究 create_model
的源代码,看看 timm
是如何将任何模型转换为特征提取器的。
create_model
函数用于在 timm
中创建数百个模型。它还接受一组 **kwargs
,例如 features_only
和 out_indices
,将这两个 **kwargs
传递给 create_model
函数会创建一个特征提取器。让我们看看它是如何做到的?
create_model
函数本身的源代码大约只有 50 行。所以所有的“魔法”一定发生在其他地方。正如你可能已经知道的,timm.list_models()
中的每个模型名称实际上都是一个函数。
例如
%load_ext autoreload
%autoreload 2
import timm
import random
from timm.models import registry
m = timm.list_models()[-1]
registry.is_model(m)
timm
有一个内部字典 _model_entrypoints
,它包含了所有模型名称及其对应的构造函数。例如,我们可以通过 _model_entrypoints
中的 model_entrypoint
函数获取 xception71
模型的构造函数。
constuctor_fn = registry.model_entrypoint(m)
constuctor_fn
如我们所见,在 timm.models.xception_aligned
模块中有一个名为 xception71
的函数。类似地,每个模型在 timm
中都有一个构造函数。实际上,这个内部 _model_entrypoints
字典看起来像这样
_model_entrypoints
> >
{
'cspresnet50':<function timm.models.cspnet.cspresnet50(pretrained=False, **kwargs)>,'cspresnet50d': <function timm.models.cspnet.cspresnet50d(pretrained=False, **kwargs)>,
'cspresnet50w': <function timm.models.cspnet.cspresnet50w(pretrained=False, **kwargs)>,
'cspresnext50': <function timm.models.cspnet.cspresnext50(pretrained=False, **kwargs)>,
'cspresnext50_iabn': <function timm.models.cspnet.cspresnext50_iabn(pretrained=False, **kwargs)>,
'cspdarknet53': <function timm.models.cspnet.cspdarknet53(pretrained=False, **kwargs)>,
'cspdarknet53_iabn': <function timm.models.cspnet.cspdarknet53_iabn(pretrained=False, **kwargs)>,
'darknet53': <function timm.models.cspnet.darknet53(pretrained=False, **kwargs)>,
'densenet121': <function timm.models.densenet.densenet121(pretrained=False, **kwargs)>,
'densenetblur121d': <function timm.models.densenet.densenetblur121d(pretrained=False, **kwargs)>,
'densenet121d': <function timm.models.densenet.densenet121d(pretrained=False, **kwargs)>,
'densenet169': <function timm.models.densenet.densenet169(pretrained=False, **kwargs)>,
'densenet201': <function timm.models.densenet.densenet201(pretrained=False, **kwargs)>,
'densenet161': <function timm.models.densenet.densenet161(pretrained=False, **kwargs)>,
'densenet264': <function timm.models.densenet.densenet264(pretrained=False, **kwargs)>,
}
因此,timm
中的每个模型都在各自的模块中定义了一个构造函数。例如,所有 ResNet 都定义在 timm.models.resnet
模块中。因此,创建 resnet34
模型有两种方式
import timm
from timm.models.resnet import resnet34
# using `create_model`
m = timm.create_model('resnet34')
# directly calling the constructor fn
m = resnet34()
在 timm
中,你通常不希望直接调用构造函数。所有模型都应该使用 create_model
函数本身来创建。
resnet34
构造函数的源代码如下所示
@register_model
def resnet34(pretrained=False, **kwargs):
"""Constructs a ResNet-34 model.
"""
model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], **kwargs)
return _create_resnet('resnet34', pretrained, **model_args)
timm
中的每个模型都有一个 register_model
装饰器。最初,_model_entrypoints
是一个空字典。正是 register_model
装饰器将给定的模型函数构造器及其名称添加到 _model_entrypoints
中。def register_model(fn):
# lookup containing module
mod = sys.modules[fn.__module__]
module_name_split = fn.__module__.split('.')
module_name = module_name_split[-1] if len(module_name_split) else ''
# add model to __all__ in module
model_name = fn.__name__
if hasattr(mod, '__all__'):
mod.__all__.append(model_name)
else:
mod.__all__ = [model_name]
# add entries to registry dict/sets
_model_entrypoints[model_name] = fn
_model_to_module[model_name] = module_name
_module_to_models[module_name].add(model_name)
has_pretrained = False # check if model has a pretrained url to allow filtering on this
if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:
# this will catch all models that have entrypoint matching cfg key, but miss any aliasing
# entrypoints or non-matching combos
has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url']
if has_pretrained:
_model_has_pretrained.add(model_name)
return fn
如上所示,register_model
函数执行了一些非常基础的步骤。但我想要强调的主要步骤是这个
_model_entrypoints[model_name] = fn
因此,它将给定的 fn
添加到 _model_entrypoints
中,其中键是 fn.__name__
。
resnet34
函数上使用 @register_model
装饰器有什么作用吗?它在 _model_entrypoints
中创建一个条目,看起来像 {’resnet34’: <function timm.models.resnet.resnet34(pretrained=False, **kwargs)>}
。此外,仅通过查看 resnet34
构造函数的源代码,我们可以看到在设置了一些 model_args
后,它会调用 create_resnet
函数。让我们看看它的样子
def _create_resnet(variant, pretrained=False, **kwargs):
return build_model_with_cfg(
ResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, **kwargs)
因此,_create_resnet
函数实际上调用了 build_model_with_cfg
函数,并传入了构造器类 ResNet
、变体名称 resnet34
、一个 default_cfg
和一些 **kwargs
。
timm
中的每个模型都有一个默认配置。这包含模型预训练权重的 URL、要分类的类别数量、输入图像大小、池化大小等等。
resnet34
的默认配置如下所示
{'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth',
'num_classes': 1000,
'input_size': (3, 224, 224),
'pool_size': (7, 7),
'crop_pct': 0.875,
'interpolation': 'bilinear',
'mean': (0.485, 0.456, 0.406),
'std': (0.229, 0.224, 0.225),
'first_conv': 'conv1',
'classifier': 'fc'}
此默认配置连同构造器类和一些模型参数等其他参数一起传递给 build_model_with_cfg
函数。
build_model_with_cfg
函数负责以下任务:
- 实际实例化模型类以在
timm
中创建模型 - 如果
pruned=True
,则对模型进行剪枝 - 如果
pretrained=True
,则加载预训练权重 - 如果
features=True
,则将模型转换为特征提取器
在检查此函数的源代码后
def build_model_with_cfg(
model_cls: Callable,
variant: str,
pretrained: bool,
default_cfg: dict,
model_cfg: dict = None,
feature_cfg: dict = None,
pretrained_strict: bool = True,
pretrained_filter_fn: Callable = None,
pretrained_custom_load: bool = False,
**kwargs):
pruned = kwargs.pop('pruned', False)
features = False
feature_cfg = feature_cfg or {}
if kwargs.pop('features_only', False):
features = True
feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
if 'out_indices' in kwargs:
feature_cfg['out_indices'] = kwargs.pop('out_indices')
model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)
model.default_cfg = deepcopy(default_cfg)
if pruned:
model = adapt_model_from_file(model, variant)
# for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
if pretrained:
if pretrained_custom_load:
load_custom_pretrained(model)
else:
load_pretrained(
model,
num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3),
filter_fn=pretrained_filter_fn, strict=pretrained_strict)
if features:
feature_cls = FeatureListNet
if 'feature_cls' in feature_cfg:
feature_cls = feature_cfg.pop('feature_cls')
if isinstance(feature_cls, str):
feature_cls = feature_cls.lower()
if 'hook' in feature_cls:
feature_cls = FeatureHookNet
else:
assert False, f'Unknown feature class {feature_cls}'
model = feature_cls(model, **feature_cfg)
model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg
return model
就是这样。我们现在已经完全查看了 timm.create_model
函数。主要调用的函数有:
- 模型构造函数,每个模型都不同,用于设置模型特定的参数。
_model_entrypoints
字典包含所有模型名称及其对应的构造函数。 build_with_model_cfg
函数接受一个模型构造器类以及模型构造函数中设置的模型特定参数。load_pretrained
用于加载预训练权重。即使输入通道数不等于 3(如 ImageNet 的情况),此函数也有效。FeatureListNet
类负责将任何模型转换为特征提取器。