名字不能取太长 2024-02-20 16:28 采纳率: 75.6%
浏览 9
已结题

多GPU和单GPU状体字典不一样


    def _load_optimizer_state(self):
        main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
        opt_checkpoint = bf.join(
            bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt"
        )
        if bf.exists(opt_checkpoint):
            logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
            state_dict = dist_util.load_state_dict(
                opt_checkpoint, map_location=dist_util.dev(),strict=False
            )
            self.opt.load_state_dict(state_dict)

以上这段代码是用于从检查点中获取优化器参数,我发现在多卡并行的情况下得到的state_dict的内容和单卡情况下的内容不一样,导致我在多卡情况下恢复训练会出现错误:KeyError: 'param_groups',这是为什么?如何解决?

我尝试了各位的方案,发现还是不行,在单GPU时state_dict内是['state','param_groups'],多GPU就变成了如下:

['t_embedder.mlp.0.weight', 't_embedder.mlp.0.bias', 't_embedder.mlp.2.weight', 't_embedder.mlp.2.bias', 'encoder.pos_embedding', 'encoder.to_patch_embedding.1.weight', 'encoder.to_patch_embedding.1.bias', 'encoder.to_patch_embedding.2.weight', 'encoder.to_patch_embedding.2.bias', 'encoder.to_patch_embedding.3.weight', 'encoder.to_patch_embedding.3.bias', 'encoder.transformer.norm.weight', 'encoder.transformer.norm.bias', 'encoder.transformer.layers.0.0.norm.weight', 'encoder.transformer.layers.0.0.norm.bias', 'encoder.transformer.layers.0.0.to_qkv.weight', 'encoder.transformer.layers.0.0.to_out.0.weight', 'encoder.transformer.layers.0.0.to_out.0.bias', 'encoder.transformer.layers.0.1.net.0.weight', 'encoder.transformer.layers.0.1.net.0.bias', 'encoder.transformer.layers.0.1.net.1.weight', 'encoder.transformer.layers.0.1.net.1.bias', 'encoder.transformer.layers.0.1.net.4.weight', 'encoder.transformer.layers.0.1.net.4.bias', 'encoder.transformer.layers.1.0.norm.weight', 'encoder.transformer.layers.1.0.norm.bias', 'encoder.transformer.layers.1.0.to_qkv.weight', 'encoder.transformer.layers.1.0.to_out.0.weight', 'encoder.transformer.layers.1.0.to_out.0.bias', 'encoder.transformer.layers.1.1.net.0.weight', 'encoder.transformer.layers.1.1.net.0.bias', 'encoder.transformer.layers.1.1.net.1.weight', 'encoder.transformer.layers.1.1.net.1.bias', 'encoder.transformer.layers.1.1.net.4.weight', 'encoder.transformer.layers.1.1.net.4.bias', 'encoder.transformer.layers.2.0.norm.weight', 'encoder.transformer.layers.2.0.norm.bias', 'encoder.transformer.layers.2.0.to_qkv.weight', 'encoder.transformer.layers.2.0.to_out.0.weight', 'encoder.transformer.layers.2.0.to_out.0.bias', 'encoder.transformer.layers.2.1.net.0.weight', 'encoder.transformer.layers.2.1.net.0.bias', 'encoder.transformer.layers.2.1.net.1.weight', 'encoder.transformer.layers.2.1.net.1.bias', 'encoder.transformer.layers.2.1.net.4.weight', 'encoder.transformer.layers.2.1.net.4.bias', 'encoder.transformer.layers.3.0.norm.weight', 'encoder.transformer.layers.3.0.norm.bias', 'encoder.transformer.layers.3.0.to_qkv.weight', 'encoder.transformer.layers.3.0.to_out.0.weight', 'encoder.transformer.layers.3.0.to_out.0.bias', 'encoder.transformer.layers.3.1.net.0.weight', 'encoder.transformer.layers.3.1.net.0.bias', 'encoder.transformer.layers.3.1.net.1.weight', 'encoder.transformer.layers.3.1.net.1.bias', 'encoder.transformer.layers.3.1.net.4.weight', 'encoder.transformer.layers.3.1.net.4.bias', 'encoder.transformer.layers.4.0.norm.weight', 'encoder.transformer.layers.4.0.norm.bias', 'encoder.transformer.layers.4.0.to_qkv.weight', 'encoder.transformer.layers.4.0.to_out.0.weight', 'encoder.transformer.layers.4.0.to_out.0.bias', 'encoder.transformer.layers.4.1.net.0.weight', 'encoder.transformer.layers.4.1.net.0.bias', 'encoder.transformer.layers.4.1.net.1.weight', 'encoder.transformer.layers.4.1.net.1.bias', 'encoder.transformer.layers.4.1.net.4.weight', 'encoder.transformer.layers.4.1.net.4.bias', 'encoder.transformer.layers.5.0.norm.weight', 'encoder.transformer.layers.5.0.norm.bias', 'encoder.transformer.layers.5.0.to_qkv.weight', 'encoder.transformer.layers.5.0.to_out.0.weight', 'encoder.transformer.layers.5.0.to_out.0.bias', 'encoder.transformer.layers.5.1.net.0.weight', 'encoder.transformer.layers.5.1.net.0.bias', 'encoder.transformer.layers.5.1.net.1.weight', 'encoder.transformer.layers.5.1.net.1.bias', 'encoder.transformer.layers.5.1.net.4.weight', 'encoder.transformer.layers.5.1.net.4.bias', 'decoder.0.weight', 'decoder.0.bias', 'decoder.2.weight', 'decoder.2.bias', 'decoder.2.running_mean', 'decoder.2.running_var', 'decoder.2.num_batches_tracked', 'decoder.3.weight', 'decoder.3.bias', 'decoder.5.weight', 'decoder.5.bias', 'decoder.5.running_mean', 'decoder.5.running_var', 'decoder.5.num_batches_tracked', 'decoder.6.weight', 'decoder.6.bias', 'decoder.8.weight', 'decoder.8.bias', 'decoder.8.running_mean', 'decoder.8.running_var', 'decoder.8.num_batches_tracked', 'decoder.9.weight', 'decoder.9.bias', 'decoder.11.weight', 'decoder.11.bias', 'decoder.11.running_mean', 'decoder.11.running_var', 'decoder.11.num_batches_tracked', 'decoder.12.weight', 'decoder.12.bias', 'decoder.14.weight', 'decoder.14.bias', 'decoder.14.running_mean', 'decoder.14.running_var', 'decoder.14.num_batches_tracked', 'decoder.15.weight', 'decoder.15.bias', 'decoder.17.weight', 'decoder.17.bias', 'decoder.17.running_mean', 'decoder.17.running_var', 'decoder.17.num_batches_tracked', 'decoder.18.weight', 'decoder.18.bias', 'decoder.20.weight', 'decoder.20.bias', 'decoder.20.running_mean', 'decoder.20.running_var', 'decoder.20.num_batches_tracked', 'decoder.21.weight', 'decoder.21.bias'])

所以多卡的state_dict因为没有param_groups导致无法从检查点恢复训练,但是我不知道为什么会这样,如何解决?

  • 写回答

17条回答 默认 最新

  • 专家-赤兔[在线] 全栈领域优质创作者 2024-02-20 16:29
    关注

    引自免费微信小程序:皆我百晓生

    在多GPU并行训练中,优化器的state_dict确实可能会与单GPU训练时有所不同。PyTorch中的DataParallel或多进程并行(如DistributedDataParallel, DDP)会处理优化器的状态以适应分布式环境。在DDP中,优化器的state_dict包含了所有参与计算的多个GPU上的参数梯度平均值,以及可能与分布相关的额外信息。

    出现KeyError: 'param_groups'错误通常是因为加载的state_dict不包含优化器的基本结构——param_groups,这是优化器内部用来组织不同参数组和它们的学习率、权重衰减等属性的关键部分。

    解决这个问题,你需要确保:

    1. 在保存优化器状态时,正确地使用了适用于分布式训练的方式保存整个优化器的state_dict。
    2. 加载优化器状态时,考虑到了是否需要通过map_location映射到当前设备上,并且在加载后正确应用到多GPU设置下的优化器实例上。

    针对您的代码片段,可以尝试如下修改:

    import torch
    from torch.nn.parallel import DistributedDataParallel as DDP
    
    # 假设self.opt是已经在分布式环境中初始化好的优化器
    if bf.exists(opt_checkpoint):
        logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
        
        # 确保在加载时映射到正确的设备(主卡)
        map_location = lambda storage, loc: storage.cuda(self.rank)  # 假设self.rank是当前进程的全局rank
        state_dict = torch.load(opt_checkpoint, map_location=map_location)
        
        # 如果使用了DDP,需要先获取优化器原始的state_dict
        if isinstance(self.opt, DDP):
            self.opt.module.optimizer.load_state_dict(state_dict)
        else:
            self.opt.load_state_dict(state_dict)
    

    这里需要注意的是,如果你使用了DDP对模型和优化器进行封装,那么在加载优化器状态时,应加载到.module.optimizer而不是直接加载到.opt。这是因为DDP会对优化器进行一层封装,真实的优化器位于.module.optimizer下。

    另外,还需确认你的检查点文件是在与当前运行环境相匹配的并行模式下生成的,即如果当前是多GPU并行,那么加载的checkpoint也应该是来自同样配置下的训练过程。

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论 编辑记录
查看更多回答(16条)

报告相同问题?

问题事件

  • 系统已结题 3月5日
  • 已采纳回答 2月26日
  • 修改了问题 2月20日
  • 创建了问题 2月20日

悬赏问题

  • ¥30 模拟电路 logisim
  • ¥15 PVE8.2.7无法成功使用a5000的vGPU,什么原因
  • ¥15 is not in the mmseg::model registry。报错,模型注册表找不到自定义模块。
  • ¥15 安装quartus II18.1时弹出此error,怎么解决?
  • ¥15 keil官网下载psn序列号在哪
  • ¥15 想用adb命令做一个通话软件,播放录音
  • ¥30 Pytorch深度学习服务器跑不通问题解决?
  • ¥15 部分客户订单定位有误的问题
  • ¥15 如何在maya程序中利用python编写领子和褶裥的模型的方法
  • ¥15 Bug traq 数据包 大概什么价