半路转计算机的女孩 2024-04-22 15:20 采纳率: 0%
浏览 6

测试结果只有原论文的一半这个是为什么?

我找的一篇论文,他用的是四张A6000进行训练的,而我的是四张2080Ti,batch size只能调成他的一半,测试结果也只有他的一半。但是我用他给的pth文件直接测试,结果与他的一样。
然后我想是不是batch size的问题,于是我就用tiny作为backbone训练,bs与他调成一样,但是测试结果依旧是他的一半。是不是四张卡不同导致的问题,但是这不应该啊?能不能有个朋友帮我解决一下,万分感谢!

  • 写回答

5条回答 默认 最新

  • 专家-赤兔[在线] 全栈领域优质创作者 2024-04-22 15:20
    关注

    引自免费微信小程序:皆我百晓生

    看起来你在尝试复现一篇深度学习论文的结果,并且已经注意到硬件配置(GPU型号)的变化可能对实验结果产生影响。不过,从你描述的情况来看,还有其他几个可能的原因导致结果差异:

    1. 数据预处理:确保你的数据预处理步骤与论文中的完全一致。包括图像的缩放、归一化、数据增强等操作。
    2. 模型结构:虽然你使用了不同的backbone(Tiny),但确认是否在其他层面上保持了一致,比如网络头的设置、损失函数、优化器的选择以及其参数等。
    3. 训练细节:检查训练循环中的关键参数,例如学习率、衰减策略、批标准化、权重初始化方法、正则化(Dropout、L2惩罚)、早停等。
    4. 随机性的影响:由于深度学习中的随机性(如权重初始化、数据加载顺序等),每次训练可能会得到略微不同的结果。为了消除这个影响,可以运行多轮训练并取平均结果。
    5. 模型保存和加载:确认你正确地加载了论文作者提供的模型权重,并且在评估时没有重新开始训练。
    6. 验证集和测试集划分:确保你的验证集和测试集划分方式与论文相同,以减少偏差。
    7. 精度计算:核对你们的性能指标(如准确率、AP等)计算方法是否一致。

    针对这些点,你可以采取以下步骤来排查问题:

    1. 数据预处理: 确保你的数据预处理代码与论文中的一致,可以将代码贴出来,我会帮你审查。

    2. 模型结构: 你可以使用torch.nn.Module.dump_patches()方法输出模型的结构信息,然后与论文中的对比。

    import torch
    
    def print_model(model):
        print(model)
        torch.save(model.state_dict(), 'model.pth')
        with open('model.txt', 'w') as f:
            model.dump_patches(f)
    
    # 将你的模型替换为'model'
    print_model(model)
    
    1. 训练细节: 检查你的train.py或类似脚本中的所有相关参数,确保它们与论文中的值匹配。

    2. 随机性: 在代码中加入固定的随机种子,以确保可重复性。

    import numpy as np
    import random
    import torch
    
    np.random.seed(42)
    random.seed(42)
    torch.manual_seed(42)
    torch.cuda.manual_seed_all(42)
    
    1. 模型加载: 确保加载模型的代码正确无误,不包含任何额外的训练或微调。
    import torch
    
    def load_pretrained_model(model, path):
        state_dict = torch.load(path, map_location='cpu')
        model.load_state_dict(state_dict)
        return model
    
    # 将你的模型替换为'model',并将路径替换成论文作者提供的pth文件路径
    model = load_pretrained_model(model, 'path_to_paper_model.pth')
    
    1. 验证集和测试集划分: 使用相同的划分方法,或者如果论文未提供具体信息,则采用常见的划分比例(如80/10/10分训练集、验证集和测试集)。

    2. 精度计算: 如果可能,参考论文中提到的精确度计算公式,或者使用开源库(如pycocotools)提供的函数进行计算。

    通过以上步骤,你应该能逐步找到问题所在。如果你在执行过程中遇到具体代码问题或有更详细的信息,欢迎继续提问。

    评论 编辑记录

报告相同问题?

问题事件

  • 创建了问题 4月22日

悬赏问题

  • ¥15 react-diff-viewer组件,如何解决数据量过大卡顿问题
  • ¥20 遥感植被物候指数空间分布图制作
  • ¥15 安装了xlrd库但是import不了…
  • ¥20 Github上传代码没有contribution和activity记录
  • ¥20 SNETCracker
  • ¥15 数学建模大赛交通流量控制
  • ¥15 为什么我安装了open3d但是在调用的时候没有报错但是什么都没有发生呢
  • ¥50 paddleocr最下面一行似乎无法识别
  • ¥15 求某类社交网络数据集
  • ¥15 靶向捕获探针方法/参考文献