qq_41544501
lyf_
采纳率50%
2019-04-12 13:39 浏览 1.1k

pytorch加载model发现key的值有差异,能不能修改

5

用pytorch加载别人训练的模型时候,
RuntimeError: Error(s) in loading state_dict for PSMNet:
Missing key(s) in state_dict: "feature_extraction.firstconv.0.0.weight", "feature_extraction.firstconv.0.1.weight",。。。。。

Unexpected key(s) in state_dict: "module.feature_extraction.firstconv.0.0.weight", "module.feature_extraction.firstconv.0.1.weight",
图片说明

发现每个在模型里的参数都多个了module.
保存的格式是

 torch.save({
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'train_loss': total_train_loss / len(TrainImgLoader),
        }, savefilename)

读取的代码是

 state_dict = torch.load(basic.loadmodel)
 model.load_state_dict(state_dict['state_dict'])

请问能不能做到将模型的key值从
module.feature_extraction.firstconv.0.1.weight

修改到
feature_extraction.firstconv.0.1.weight

  • 点赞
  • 写回答
  • 关注问题
  • 收藏
  • 复制链接分享
  • 邀请回答

1条回答 默认 最新

  • 已采纳
    qq_41544501 lyf_ 2019-04-12 14:56

    解决了 , torchload返回的是字典,搞一个新字典,然后把key用‘.’分段后重新组装一个字典就可以读了

    点赞 评论 复制链接分享

相关推荐