我找的一篇论文,他用的是四张A6000进行训练的,而我的是四张2080Ti,batch size只能调成他的一半,测试结果也只有他的一半。但是我用他给的pth文件直接测试,结果与他的一样。
然后我想是不是batch size的问题,于是我就用tiny作为backbone训练,bs与他调成一样,但是测试结果依旧是他的一半。是不是四张卡不同导致的问题,但是这不应该啊?能不能有个朋友帮我解决一下,万分感谢!
测试结果只有原论文的一半这个是为什么?
- 写回答
- 好问题 0 提建议
- 追加酬金
- 关注问题
- 邀请回答
-
5条回答 默认 最新
关注 引自免费微信小程序:皆我百晓生
看起来你在尝试复现一篇深度学习论文的结果,并且已经注意到硬件配置(GPU型号)的变化可能对实验结果产生影响。不过,从你描述的情况来看,还有其他几个可能的原因导致结果差异:
- 数据预处理:确保你的数据预处理步骤与论文中的完全一致。包括图像的缩放、归一化、数据增强等操作。
- 模型结构:虽然你使用了不同的backbone(Tiny),但确认是否在其他层面上保持了一致,比如网络头的设置、损失函数、优化器的选择以及其参数等。
- 训练细节:检查训练循环中的关键参数,例如学习率、衰减策略、批标准化、权重初始化方法、正则化(Dropout、L2惩罚)、早停等。
- 随机性的影响:由于深度学习中的随机性(如权重初始化、数据加载顺序等),每次训练可能会得到略微不同的结果。为了消除这个影响,可以运行多轮训练并取平均结果。
- 模型保存和加载:确认你正确地加载了论文作者提供的模型权重,并且在评估时没有重新开始训练。
- 验证集和测试集划分:确保你的验证集和测试集划分方式与论文相同,以减少偏差。
- 精度计算:核对你们的性能指标(如准确率、AP等)计算方法是否一致。
针对这些点,你可以采取以下步骤来排查问题:
-
数据预处理: 确保你的数据预处理代码与论文中的一致,可以将代码贴出来,我会帮你审查。
-
模型结构: 你可以使用
torch.nn.Module.dump_patches()
方法输出模型的结构信息,然后与论文中的对比。
import torch def print_model(model): print(model) torch.save(model.state_dict(), 'model.pth') with open('model.txt', 'w') as f: model.dump_patches(f) # 将你的模型替换为'model' print_model(model)
-
训练细节: 检查你的
train.py
或类似脚本中的所有相关参数,确保它们与论文中的值匹配。 -
随机性: 在代码中加入固定的随机种子,以确保可重复性。
import numpy as np import random import torch np.random.seed(42) random.seed(42) torch.manual_seed(42) torch.cuda.manual_seed_all(42)
- 模型加载: 确保加载模型的代码正确无误,不包含任何额外的训练或微调。
import torch def load_pretrained_model(model, path): state_dict = torch.load(path, map_location='cpu') model.load_state_dict(state_dict) return model # 将你的模型替换为'model',并将路径替换成论文作者提供的pth文件路径 model = load_pretrained_model(model, 'path_to_paper_model.pth')
-
验证集和测试集划分: 使用相同的划分方法,或者如果论文未提供具体信息,则采用常见的划分比例(如80/10/10分训练集、验证集和测试集)。
-
精度计算: 如果可能,参考论文中提到的精确度计算公式,或者使用开源库(如
pycocotools
)提供的函数进行计算。
通过以上步骤,你应该能逐步找到问题所在。如果你在执行过程中遇到具体代码问题或有更详细的信息,欢迎继续提问。
解决 无用评论 打赏 举报 编辑记录
悬赏问题
- ¥15 react-diff-viewer组件,如何解决数据量过大卡顿问题
- ¥20 遥感植被物候指数空间分布图制作
- ¥15 安装了xlrd库但是import不了…
- ¥20 Github上传代码没有contribution和activity记录
- ¥20 SNETCracker
- ¥15 数学建模大赛交通流量控制
- ¥15 为什么我安装了open3d但是在调用的时候没有报错但是什么都没有发生呢
- ¥50 paddleocr最下面一行似乎无法识别
- ¥15 求某类社交网络数据集
- ¥15 靶向捕获探针方法/参考文献