在训练模型时,保持训练参数的移动平均通常是有益的。使用平均参数进行评估有时会产生比最终训练值显著更好的结果。

timm 支持类似于 tensorflow 的 EMA。

要使用 EMA 训练模型,只需添加 --model-ema 标志和带有值的 --model-ema-decay 标志来定义 EMA 的衰减率。

为了防止 EMA 使用 GPU 资源,请设置 device='cpu'。这将节省一些内存,但会禁用对 EMA 权重的验证。验证必须在单独的进程中手动完成,或在训练停止收敛后进行。

不使用 EMA 进行训练

python train.py ../imagenette2-320 --model resnet34

使用 EMA 进行训练

python train.py ../imagenette2-320 --model resnet34 --model-ema --model-ema-decay 0.99

上面的训练脚本意味着在更新模型权重时,我们在每次迭代中保留 99.99% 的旧模型权重,只更新 0.01% 的新权重。

python"
model_weights = decay * model_weights + (1 - decay) * new_model_weights

timm 中模型 EMA 的内部机制

timm 内部,当我们传递 --model-ema 标志时,timm 会将模型类封装到 ModelEmaV2 类中,其结构如下

class ModelEmaV2(nn.Module):
    def __init__(self, model, decay=0.9999, device=None):
        super(ModelEmaV2, self).__init__()
        # make a copy of the model for accumulating moving average of weights
        self.module = deepcopy(model)
        self.module.eval()
        self.decay = decay
        self.device = device  # perform ema on different device from model if set
        if self.device is not None:
            self.module.to(device=device)

    def _update(self, model, update_fn):
        with torch.no_grad():
            for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
                if self.device is not None:
                    model_v = model_v.to(device=self.device)
                ema_v.copy_(update_fn(ema_v, model_v))

    def update(self, model):
        self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)

    def set(self, model):
        self._update(model, update_fn=lambda e, m: m)

基本上,我们通过传入现有的 model 和衰减率来初始化 ModeEmaV2,在本例中是 decay=0.9999

这看起来有点像 model_ema = ModelEmaV2(model)。这里,model 可以是任何现有模型,只要它是使用 timm.create_model 函数创建的。

接下来,在训练期间,特别是在 train_one_epoch 内部,我们像这样调用 model_emaupdate 方法

if model_ema is not None:
    model_ema.update(model)

所有基于 loss 的参数更新都发生在 model 上。当我们调用 optimizer.step() 时,更新的是 model 的权重,而不是 model_ema 的权重。

因此,当我们调用 model_ema.update 方法时,正如所见,它会调用带有 update_fn = lambda e, m: self.decay * e + (1. - self.decay) * m)_update 方法。

因此,当我们调用 _update 函数时,它会遍历 modelmodel_ema 内部的每个参数,并更新 model_ema 的状态,以保留 99.99% 的现有状态和 0.01% 的新状态。