在本教程中,我们将深入探讨 create_model 函数的源代码。我们还将了解如何将任何给定模型转换为特征提取器。我们已经在这里看到了一个示例。我们将 ResNet-34 架构转换为了一个特征提取器,用于从第 2、3 和 4 层提取特征。

在本教程中,我们将深入研究 create_model 的源代码,看看 timm 是如何将任何模型转换为特征提取器的。

create_model 函数

create_model 函数用于在 timm 中创建数百个模型。它还接受一组 **kwargs,例如 features_onlyout_indices,将这两个 **kwargs 传递给 create_model 函数会创建一个特征提取器。让我们看看它是如何做到的?

create_model 函数本身的源代码大约只有 50 行。所以所有的“魔法”一定发生在其他地方。正如你可能已经知道的,timm.list_models() 中的每个模型名称实际上都是一个函数。

例如

%load_ext autoreload
%autoreload 2
import timm
import random 
from timm.models import registry

m = timm.list_models()[-1]
registry.is_model(m)
True

timm 有一个内部字典 _model_entrypoints,它包含了所有模型名称及其对应的构造函数。例如,我们可以通过 _model_entrypoints 中的 model_entrypoint 函数获取 xception71 模型的构造函数。

constuctor_fn = registry.model_entrypoint(m)
constuctor_fn
<function timm.models.xception_aligned.xception71(pretrained=False, **kwargs)>

如我们所见,在 timm.models.xception_aligned 模块中有一个名为 xception71 的函数。类似地,每个模型在 timm 中都有一个构造函数。实际上,这个内部 _model_entrypoints 字典看起来像这样

_model_entrypoints
> > 
{
'cspresnet50':<function timm.models.cspnet.cspresnet50(pretrained=False, **kwargs)>,'cspresnet50d': <function timm.models.cspnet.cspresnet50d(pretrained=False, **kwargs)>,
'cspresnet50w': <function timm.models.cspnet.cspresnet50w(pretrained=False, **kwargs)>,
'cspresnext50': <function timm.models.cspnet.cspresnext50(pretrained=False, **kwargs)>,
'cspresnext50_iabn': <function timm.models.cspnet.cspresnext50_iabn(pretrained=False, **kwargs)>,
'cspdarknet53': <function timm.models.cspnet.cspdarknet53(pretrained=False, **kwargs)>,
'cspdarknet53_iabn': <function timm.models.cspnet.cspdarknet53_iabn(pretrained=False, **kwargs)>,
'darknet53': <function timm.models.cspnet.darknet53(pretrained=False, **kwargs)>,
'densenet121': <function timm.models.densenet.densenet121(pretrained=False, **kwargs)>,
'densenetblur121d': <function timm.models.densenet.densenetblur121d(pretrained=False, **kwargs)>,
'densenet121d': <function timm.models.densenet.densenet121d(pretrained=False, **kwargs)>,
'densenet169': <function timm.models.densenet.densenet169(pretrained=False, **kwargs)>,
'densenet201': <function timm.models.densenet.densenet201(pretrained=False, **kwargs)>,
'densenet161': <function timm.models.densenet.densenet161(pretrained=False, **kwargs)>,
'densenet264': <function timm.models.densenet.densenet264(pretrained=False, **kwargs)>,

}

因此,timm 中的每个模型都在各自的模块中定义了一个构造函数。例如,所有 ResNet 都定义在 timm.models.resnet 模块中。因此,创建 resnet34 模型有两种方式

import timm
from timm.models.resnet import resnet34

# using `create_model`
m = timm.create_model('resnet34')

# directly calling the constructor fn
m = resnet34()

timm 中,你通常不希望直接调用构造函数。所有模型都应该使用 create_model 函数本身来创建。

注册模型

resnet34 构造函数的源代码如下所示

@register_model
def resnet34(pretrained=False, **kwargs):
    """Constructs a ResNet-34 model.
    """
    model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], **kwargs)
    return _create_resnet('resnet34', pretrained, **model_args)

def register_model(fn):
    # lookup containing module
    mod = sys.modules[fn.__module__]
    module_name_split = fn.__module__.split('.')
    module_name = module_name_split[-1] if len(module_name_split) else ''

    # add model to __all__ in module
    model_name = fn.__name__
    if hasattr(mod, '__all__'):
        mod.__all__.append(model_name)
    else:
        mod.__all__ = [model_name]

    # add entries to registry dict/sets
    _model_entrypoints[model_name] = fn
    _model_to_module[model_name] = module_name
    _module_to_models[module_name].add(model_name)
    has_pretrained = False  # check if model has a pretrained url to allow filtering on this
    if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:
        # this will catch all models that have entrypoint matching cfg key, but miss any aliasing
        # entrypoints or non-matching combos
        has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url']
    if has_pretrained:
        _model_has_pretrained.add(model_name)
    return fn

如上所示,register_model 函数执行了一些非常基础的步骤。但我想要强调的主要步骤是这个

_model_entrypoints[model_name] = fn

因此,它将给定的 fn 添加到 _model_entrypoints 中,其中键是 fn.__name__

此外,仅通过查看 resnet34 构造函数的源代码,我们可以看到在设置了一些 model_args 后,它会调用 create_resnet 函数。让我们看看它的样子

def _create_resnet(variant, pretrained=False, **kwargs):
    return build_model_with_cfg(
        ResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, **kwargs)

因此,_create_resnet 函数实际上调用了 build_model_with_cfg 函数,并传入了构造器类 ResNet、变体名称 resnet34、一个 default_cfg 和一些 **kwargs

默认配置

timm 中的每个模型都有一个默认配置。这包含模型预训练权重的 URL、要分类的类别数量、输入图像大小、池化大小等等。

resnet34 的默认配置如下所示

{'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth',
'num_classes': 1000,
'input_size': (3, 224, 224),
'pool_size': (7, 7),
'crop_pct': 0.875,
'interpolation': 'bilinear',
'mean': (0.485, 0.456, 0.406),
'std': (0.229, 0.224, 0.225),
'first_conv': 'conv1',
'classifier': 'fc'}

此默认配置连同构造器类和一些模型参数等其他参数一起传递给 build_model_with_cfg 函数。

使用配置构建模型

build_model_with_cfg 函数负责以下任务:

  1. 实际实例化模型类以在 timm 中创建模型
  2. 如果 pruned=True,则对模型进行剪枝
  3. 如果 pretrained=True,则加载预训练权重
  4. 如果 features=True,则将模型转换为特征提取器

在检查此函数的源代码后

def build_model_with_cfg(
        model_cls: Callable,
        variant: str,
        pretrained: bool,
        default_cfg: dict,
        model_cfg: dict = None,
        feature_cfg: dict = None,
        pretrained_strict: bool = True,
        pretrained_filter_fn: Callable = None,
        pretrained_custom_load: bool = False,
        **kwargs):
    pruned = kwargs.pop('pruned', False)
    features = False
    feature_cfg = feature_cfg or {}

    if kwargs.pop('features_only', False):
        features = True
        feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
        if 'out_indices' in kwargs:
            feature_cfg['out_indices'] = kwargs.pop('out_indices')

    model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)
    model.default_cfg = deepcopy(default_cfg)

    if pruned:
        model = adapt_model_from_file(model, variant)

    # for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
    num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
    if pretrained:
        if pretrained_custom_load:
            load_custom_pretrained(model)
        else:
            load_pretrained(
                model,
                num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3),
                filter_fn=pretrained_filter_fn, strict=pretrained_strict)

    if features:
        feature_cls = FeatureListNet
        if 'feature_cls' in feature_cfg:
            feature_cls = feature_cfg.pop('feature_cls')
            if isinstance(feature_cls, str):
                feature_cls = feature_cls.lower()
                if 'hook' in feature_cls:
                    feature_cls = FeatureHookNet
                else:
                    assert False, f'Unknown feature class {feature_cls}'
        model = feature_cls(model, **feature_cfg)
        model.default_cfg = default_cfg_for_features(default_cfg)  # add back default_cfg

    return model

可以看出模型是在此点 model = model_cls(**kwargs) 创建的。

此外,作为本教程的一部分,我们不会查看 prunedadapt_model_from_file 函数的内部实现。

我们已经理解并查看了 load_pretrained 函数的内部实现,参见此处

我们还深入探讨了 FeatureListNet 类(此处),它负责将我们的深度学习模型转换为特征提取器。

总结

就是这样。我们现在已经完全查看了 timm.create_model 函数。主要调用的函数有:

  • 模型构造函数,每个模型都不同,用于设置模型特定的参数。_model_entrypoints 字典包含所有模型名称及其对应的构造函数。
  • build_with_model_cfg 函数接受一个模型构造器类以及模型构造函数中设置的模型特定参数。
  • load_pretrained 用于加载预训练权重。即使输入通道数不等于 3(如 ImageNet 的情况),此函数也有效。
  • FeatureListNet 类负责将任何模型转换为特征提取器。