岭南香雪 2024-09-09 10:54 采纳率: 50%
浏览 12

torch_ema库的EMA(ExponentialMovingAverage,,指数滑动平均)使用问题

我从https://github.com/fadel/pytorch_ema下载并使用torch_ema这个库,在调用ExponentialMovingAverage(EMA,指数滑动平均)时遇到了问题。以下是我的详细代码:

from torch_ema import ExponentialMovingAverage
model = ...
for name, para in model.named_parameters():
    if "blocks" in name or "head" in name:
        para.requires_grad_(True)
    else:
        para.requires_grad_(False)
# pg are the parameters to be trained in the model(I freeze some of the parameters of the model during training)
pg = [p for p in model.parameters() if p.requires_grad] 
optimizer = optim.SGD(pg, lr=0.01, momentum=0.9, weight_decay=5e-5)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=0)
ema_model = ExponentialMovingAverage(parameters=pg, decay=0.9999)

best_MAE=10
save_path=...
for epoch in range(args.epochs):# train
    print('epoch:',epoch,'Current learning rate:', optimizer.param_groups[0]['lr'])
    train_loss, train_MAE, tb_writer = train_one_epoch(model=model,
                                            optimizer=optimizer,
                                            data_loader=train_loader,
                                            device=device,
                                            epoch=epoch,
                                            scheduler=scheduler,
                                            csv_filename=args.csv_filename,
                                            tb_writer=tb_writer)
    
    scheduler.step()
    ema_model.update()
    # validate
    with ema_model.average_parameters():
        val_loss, val_MAE = evaluate(model=model,data_loader=val_loader,device=device,epoch=epoch)
    if val_MAE < best_MAE:
        best_MAE=val_MAE
        torch.save(ema_model.state_dict(), save_path) 

其中,evaluate函数如下所示:

@torch.no_grad()
def evaluate(model, data_loader, device, epoch):
    softceloss_function = SoftCrossEntropy()
    model.eval()
    data_loader = tqdm(data_loader)
    for step, data in enumerate(data_loader):
        images, names, labels = data
        pred = model(images.to(device))
        softlabel = softlabel_function(labels) # a function to convert labels to softlabel
        loss = softceloss_function(pred, softlabel.to(device))

    val_loss,val_MAE = ... # calculate loss and MAE
    return val_loss, val_MAE

在训练过程中的每个epoch结束时,我都会在验证集上执行evaluate函数。如果 val_MAE < best_MAE,我会保存模型的checkpoint。经过 20 个 epoch 的训练后,我会选择在验证集上性能最好的模型,去在测试集上进行测试。

结果如下:

  1. 如果我不使用 pytorch_ema 软件包,即在不使用 ema_model.average_parameters() 这句代码时执行evaluate函数,epoch 1、2、3 的 val_MAE 分别为 6.185、5.779 和 5.529。

  2. 如果我使用上面提供的代码,即 ema_model = ExponentialMovingAverage(parameters=pg,decay=0.9999),其中decay设置为 0.9999,则epoch 1、2、3 的值 val_MAE 分别为 6.269、5.878 和 5.548。这证明了 ema 模式的有效性。

  3. 当我尝试将decay设置为其他值,如 0.999(或任何其他值),即ema_model = ExponentialMovingAverage(parameters=pg, decay=0.999),并重新开始训练时,我发现每轮的验证结果完全一样,epoch 1、2、3 的 val_MAE 还是分别为 6.269、5.878 和 5.548。

总的来说,我有三个问题:
1.我想知道为什么设置不同的 decay 进行训练时,验证结果完全相同。
2.由于 ema_model.state_dict() 与 model.state_dict() 不同,如何保存 ema_model 的 checkpoint 并将其应用于测试集?
3.是否有其他更方便的、兼容torch的,使用EMA功能的方法,如有麻烦提供链接或例程,非常感谢!

  • 写回答

1条回答 默认 最新

  • 码踏云端 Java领域新星创作者 2024-09-15 21:49
    关注

    问题分析与解决方案

    你提出的问题主要集中在以下几个方面:

    1. 不同的decay值导致相同的验证结果
    2. 如何保存和加载ema_model的checkpoint
    3. 是否有其他更方便的、兼容PyTorch的EMA方法

    下面我将逐一解决这些问题,并提供相应的解决方案。

    1. 不同的decay值导致相同的验证结果

    问题分析

    ExponentialMovingAverage中,decay参数控制了指数滑动平均的权重。理论上,不同的decay值应该会导致不同的模型参数滑动平均结果,从而影响验证结果。然而,你发现无论decay值如何变化,验证结果都相同。

    可能原因

    1. 初始化问题:可能是因为ema_model的初始化方式导致不同的decay值没有产生预期的效果。
    2. 更新机制问题:可能是因为ema_model.update()的实现方式没有正确处理不同的decay值。

    解决方案

    1. 检查ExponentialMovingAverage的实现:确保ExponentialMovingAverage的实现正确处理了decay参数。
    2. 手动调试:在训练过程中打印ema_model的参数,观察其变化情况。
    ema_model = ExponentialMovingAverage(parameters=pg, decay=0.9999)
    
    for epoch in range(args.epochs):
        # 训练代码...
        ema_model.update()
        print(f"Epoch {epoch}: EMA parameters updated.")
        # 打印EMA模型的参数
        for name, param in ema_model.named_parameters():
            print(f"{name}: {param.data}")
    
    1. 使用其他EMA实现:如果torch_ema库的实现有问题,可以尝试其他EMA实现,如torchcontrib中的EMA实现。

    2. 如何保存和加载ema_model的checkpoint

    问题分析

    由于ema_modelstate_dict()modelstate_dict()不同,直接保存ema_model.state_dict()会导致加载时出现问题。

    解决方案

    1. **保存ema_modelstate_dict()**:
    torch.save(ema_model.state_dict(), save_path)
    
    1. **加载ema_modelstate_dict()**:
    # 加载EMA模型的state_dict
    ema_model.load_state_dict(torch.load(save_path))
    
    1. 在测试时使用EMA模型
    # 在测试时使用EMA模型
    with ema_model.average_parameters():
        test_loss, test_MAE = evaluate(model=model, data_loader=test_loader, device=device, epoch=epoch)
    

    3. 是否有其他更方便的、兼容PyTorch的EMA方法

    解决方案

    1. 使用torchcontribtorchcontrib库中提供了EMA的实现,使用起来更加方便。
    from torchcontrib.optim import SWA
    
    # 初始化EMA模型
    ema_model = SWA(model, swa_start=10, swa_freq=5, swa_lr=0.05)
    
    # 训练代码...
    ema_model.update_swa()
    
    # 保存EMA模型的state_dict
    torch.save(ema_model.state_dict(), save_path)
    
    # 加载EMA模型的state_dict
    ema_model.load_state_dict(torch.load(save_path))
    
    # 在测试时使用EMA模型
    with ema_model.average_parameters():
        test_loss, test_MAE = evaluate(model=model, data_loader=test_loader, device=device, epoch=epoch)
    
    1. 自定义EMA实现:如果你对EMA的实现有特殊需求,可以自定义EMA实现。
    class ExponentialMovingAverage:
        def __init__(self, model, decay):
            self.model = model
            self.decay = decay
            self.shadow = {}
            self.backup = {}
            self.register()
    
        def register(self):
            for name, param in self.model.named_parameters():
                if param.requires_grad:
                    self.shadow[name] = param.data.clone()
    
        def update(self):
            for name, param in self.model.named_parameters():
                if param.requires_grad:
                    assert name in self.shadow
                    new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
                    self.shadow[name] = new_average.clone()
    
        def apply_shadow(self):
            for name, param in self.model.named_parameters():
                if param.requires_grad:
                    assert name in self.shadow
                    self.backup[name] = param.data
                    param.data = self.shadow[name]
    
        def restore(self):
            for name, param in self.model.named_parameters():
                if param.requires_grad:
                    assert name in self.backup
                    param.data = self.backup[name]
            self.backup = {}
    
    # 使用自定义EMA实现
    ema_model = ExponentialMovingAverage(model, decay=0.9999)
    
    for epoch in range(args.epochs):
        # 训练代码...
        ema_model.update()
    
        # 在验证时使用EMA模型
        with ema_model.apply_shadow():
            val_loss, val_MAE = evaluate(model=model, data_loader=val_loader, device=device, epoch=epoch)
    
        if val_MAE < best_MAE:
            best_MAE = val_MAE
            torch.save(ema_model.shadow, save_path)  # 保存EMA模型的参数
    
    # 加载EMA模型的参数
    ema_model.shadow = torch.load(save_path)
    
    # 在测试时使用EMA模型
    with ema_model.apply_shadow():
        test_loss, test_MAE = evaluate(model=model, data_loader=test_loader, device=device, epoch=epoch)
    

    总结

    通过以上解决方案,你可以解决不同decay值导致相同验证结果的问题,并且能够正确保存和加载ema_model的checkpoint。此外,你还可以尝试其他更方便的EMA实现方法,如torchcontrib库中的EMA实现。

    评论

报告相同问题?

问题事件

  • 修改了问题 9月10日
  • 创建了问题 9月9日

悬赏问题

  • ¥15 如何在vue.config.js中读取到public文件夹下window.APP_CONFIG.API_BASE_URL的值
  • ¥50 浦育平台scratch图形化编程
  • ¥20 求这个的原理图 只要原理图
  • ¥15 vue2项目中,如何配置环境,可以在打完包之后修改请求的服务器地址
  • ¥20 微信的店铺小程序如何修改背景图
  • ¥15 UE5.1局部变量对蓝图不可见
  • ¥15 一共有五道问题关于整数幂的运算还有房间号码 还有网络密码的解答?(语言-python)
  • ¥20 sentry如何捕获上传Android ndk 崩溃
  • ¥15 在做logistic回归模型限制性立方条图时候,不能出完整图的困难
  • ¥15 G0系列单片机HAL库中景园gc9307液晶驱动芯片无法使用硬件SPI+DMA驱动,如何解决?