song小屁虫 2021-07-30 16:55 采纳率: 50%
浏览 78
已结题

深度学习测试时,加载模型出问题。

代码:

 concat_mask = True if 'MST_shanghaitech' in args.PATH else False
    model = MST(config, concat_mask)
    model.load()
    model.inference(args.image_path, args.mask_path, config.valid_th, config.mask_th,
                    not_obj_remove=args.not_obj_remove)

报错结果为:

Traceback (most recent call last):
  File "test_single.py", line 52, in <module>
    model.load()
  File "E:\code\MST_inpainting-main\src\MST_model.py", line 102, in load
    self.inpaint_decoder.generator.load_state_dict(
  File "D:\Anaconda3\envs\torch18\lib\site-packages\torch\nn\modules\module.py", line 1223, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for InpaintGateGenerator:
        size mismatch for encoder.1.gate_conv.weight: copying a param with shape torch.Size([128, 6, 7, 7]) from checkpoint, the shape in current model is torch.Size([128, 7, 7, 7]).

这个该怎么去改它的参数呢?

  • 写回答

2条回答 默认 最新

  • 爱晚乏客游 2021-07-30 17:15
    关注

    ckp和模型的维度数目不匹配,具体的你可以看看这个看下能不能改
    https://blog.csdn.net/qq_45128278/article/details/116588153

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(1条)

报告相同问题?

问题事件

  • 系统已结题 9月27日
  • 已采纳回答 9月19日
  • 创建了问题 7月30日

悬赏问题

  • ¥15 #MATLAB仿真#车辆换道路径规划
  • ¥15 java 操作 elasticsearch 8.1 实现 索引的重建
  • ¥15 数据可视化Python
  • ¥15 要给毕业设计添加扫码登录的功能!!有偿
  • ¥15 kafka 分区副本增加会导致消息丢失或者不可用吗?
  • ¥15 微信公众号自制会员卡没有收款渠道啊
  • ¥15 stable diffusion
  • ¥100 Jenkins自动化部署—悬赏100元
  • ¥15 关于#python#的问题:求帮写python代码
  • ¥20 MATLAB画图图形出现上下震荡的线条