yangtuomao 2024-10-21 22:30 采纳率: 20%
浏览 9
问题最晚将于10月29日00:00点结题

苹果MGIE项目部署缺少emb权重

mgie项目(https://github.com/apple/ml-mgie

KeyError                                  Traceback (most recent call last)
Cell In[3], line 31
     28 image_token_len = (vision_config.image_size//vision_config.patch_size)**2
     30 _ = model.eval()
---> 31 EMB = ckpt['emb'].cuda()
     32 with T.inference_mode(): NULL = model.edit_head(T.zeros(1, 8, 4096).half().to('cuda'), EMB)
     33 print('NULL:', NULL.shape)

KeyError:  'emb'

这个错误,
目测原因是在运行到demo.py中,

EMB = ckpt['emb'].cuda()
with T.inference_mode(): NULL = model.edit_head(T.zeros(1, 8, 4096).half().to('cuda'), EMB)
print('NULL:', NULL.shape)
The error traceback is as follows:

这里无法找到模型中的'emb'参数。

我是在A100上完整的跑完了训练,数据集使用pix2pix给的,但是只用了一部分数据集。
环境配置和他的略有不同,但是至少可以完整跑完前面代码。

希望有成功本地部署并且运行这个项目的朋友可以解决一下,有偿。
我接受远程指导,再次强调,有偿。

  • 写回答

19条回答 默认 最新

  • 阿齐Archie 嵌入式领域优质创作者 2024-10-21 22:31
    关注

    以下参考免费微信小程序:皆我百晓生 作答,并由阿齐Archie整理回复,若还有疑问可评论或留言。

    根据你提供的错误信息,KeyError: 'emb' 表明在加载模型检查点时,检查点文件中没有 emb 这个键。这可能是由于以下几个原因之一:

    1. 检查点文件不完整:你在训练过程中可能没有正确保存所有需要的参数。
    2. 检查点文件格式不匹配:你使用的检查点文件与代码期望的格式不一致。
    3. 代码逻辑问题:代码中加载检查点的方式可能有问题。

    解决方案

    1. 检查检查点文件

    首先,确保你的检查点文件是完整的,并且包含所有需要的参数。你可以通过以下方式检查检查点文件的内容:

    import torch
    
    # 加载检查点文件
    ckpt = torch.load('path_to_your_checkpoint_file.pth')
    
    # 查看检查点文件中的键
    print(ckpt.keys())
    

    如果 emb 不在键列表中,那么你需要检查训练过程中是否正确保存了 emb 参数。

    2. 修改训练代码

    确保在训练过程中正确保存了 emb 参数。通常,保存模型和参数的代码类似于以下形式:

    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'emb': model.emb  # 确保这里包含了 emb 参数
    }, 'path_to_your_checkpoint_file.pth')
    

    3. 修改加载代码

    确保在加载检查点时正确加载了 emb 参数。你可以修改加载代码如下:

    import torch as T
    
    # 加载检查点文件
    ckpt = T.load('path_to_your_checkpoint_file.pth')
    
    # 检查 emb 是否在检查点文件中
    if 'emb' in ckpt:
        EMB = ckpt['emb'].cuda()
    else:
        raise KeyError("emb not found in checkpoint file")
    
    # 继续执行后续代码
    with T.inference_mode():
        NULL = model.edit_head(T.zeros(1, 8, 4096).half().to('cuda'), EMB)
    print('NULL:', NULL.shape)
    

    4. 检查模型定义

    确保模型定义中确实有 emb 参数。你可以查看模型的定义,确保 emb 参数被正确初始化并包含在模型的状态字典中。

    class YourModelClass(nn.Module):
        def __init__(self):
            super(YourModelClass, self).__init__()
            # 其他参数
            self.emb = nn.Parameter(torch.randn(...))  # 确保这里初始化了 emb 参数
    
        def forward(self, x):
            # 前向传播逻辑
            pass
    

    远程指导

    如果你需要更详细的远程指导,可以通过以下方式联系我:

    1. 电子邮件:发送邮件至 [your_email@example.com],详细描述你的问题和当前进展。
    2. 即时通讯工具:添加我的 [QQ/微信/Telegram] 账号 [your_contact_info],我们可以实时沟通。

    希望这些步骤能帮助你解决问题。如果有任何进一步的问题,请随时联系我。

    评论

报告相同问题?

问题事件

  • 创建了问题 10月21日

悬赏问题

  • ¥15 电脑软件运行次数如何统计
  • ¥15 同一个浏览器打开两个窗口怎么区分会话
  • ¥100 如何编写自己的emmc镜像
  • ¥15 starccm线性内聚力模型
  • ¥15 点云四边形凸包确定顶点
  • ¥15 关于redhat虚拟机系统新建卷的问题
  • ¥50 WRFDA读取风云四号A 星的GIIRS数据
  • ¥15 C# 爬虫融通金网址实时银价
  • ¥15 热敏电阻NTC,温控不同颜色的LED的亮与灭,PCB
  • ¥20 ESP32使用MicroPyhon开发,怎么获取485温湿度的值,温湿度计使用的鞋子是Modbus RTU