在训练模型时,保持训练参数的移动平均通常是有益的。使用平均参数进行评估有时会产生比最终训练值显著更好的结果。
timm
支持类似于 tensorflow 的 EMA。
要使用 EMA 训练模型,只需添加 --model-ema
标志和带有值的 --model-ema-decay
标志来定义 EMA 的衰减率。
为了防止 EMA 使用 GPU 资源,请设置 device='cpu'。这将节省一些内存,但会禁用对 EMA 权重的验证。验证必须在单独的进程中手动完成,或在训练停止收敛后进行。
python train.py ../imagenette2-320 --model resnet34
python train.py ../imagenette2-320 --model resnet34 --model-ema --model-ema-decay 0.99
上面的训练脚本意味着在更新模型权重时,我们在每次迭代中保留 99.99% 的旧模型权重,只更新 0.01% 的新权重。
python"
model_weights = decay * model_weights + (1 - decay) * new_model_weights
在 timm
内部,当我们传递 --model-ema
标志时,timm
会将模型类封装到 ModelEmaV2
类中,其结构如下
class ModelEmaV2(nn.Module):
def __init__(self, model, decay=0.9999, device=None):
super(ModelEmaV2, self).__init__()
# make a copy of the model for accumulating moving average of weights
self.module = deepcopy(model)
self.module.eval()
self.decay = decay
self.device = device # perform ema on different device from model if set
if self.device is not None:
self.module.to(device=device)
def _update(self, model, update_fn):
with torch.no_grad():
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
if self.device is not None:
model_v = model_v.to(device=self.device)
ema_v.copy_(update_fn(ema_v, model_v))
def update(self, model):
self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
def set(self, model):
self._update(model, update_fn=lambda e, m: m)
基本上,我们通过传入现有的 model
和衰减率来初始化 ModeEmaV2
,在本例中是 decay=0.9999
。
这看起来有点像 model_ema = ModelEmaV2(model)
。这里,model
可以是任何现有模型,只要它是使用 timm.create_model
函数创建的。
接下来,在训练期间,特别是在 train_one_epoch
内部,我们像这样调用 model_ema
的 update
方法
if model_ema is not None:
model_ema.update(model)
所有基于 loss
的参数更新都发生在 model
上。当我们调用 optimizer.step()
时,更新的是 model
的权重,而不是 model_ema
的权重。
因此,当我们调用 model_ema.update
方法时,正如所见,它会调用带有 update_fn = lambda e, m: self.decay * e + (1. - self.decay) * m)
的 _update
方法。
e
指的是 model_ema
,而 m
指的是在训练过程中权重得到更新的 model
。update_fn
指定我们保留 model_ema
的 self.decay
倍和 model
的 1-self.decay
倍。_update
函数时,它会遍历 model
和 model_ema
内部的每个参数,并更新 model_ema
的状态,以保留 99.99% 的现有状态和 0.01% 的新状态。model
和 model_ema
在 state_dict
内部具有相同的键。