用pytorch加载别人训练的模型时候,
RuntimeError: Error(s) in loading state_dict for PSMNet:
Missing key(s) in state_dict: "feature_extraction.firstconv.0.0.weight", "feature_extraction.firstconv.0.1.weight",。。。。。
Unexpected key(s) in state_dict: "module.feature_extraction.firstconv.0.0.weight", "module.feature_extraction.firstconv.0.1.weight",
发现每个在模型里的参数都多个了module.
保存的格式是
torch.save({
'epoch': epoch,
'state_dict': model.state_dict(),
'train_loss': total_train_loss / len(TrainImgLoader),
}, savefilename)
读取的代码是
state_dict = torch.load(basic.loadmodel)
model.load_state_dict(state_dict['state_dict'])
请问能不能做到将模型的key值从
module.feature_extraction.firstconv.0.1.weight
修改到
feature_extraction.firstconv.0.1.weight